Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Source = "https://github.com/aiidateam/aiida-pythonjob"
[project.entry-points."aiida.data"]
"pythonjob.jsonable_data" = "aiida_pythonjob.data.jsonable_data:JsonableData"
"pythonjob.ase.atoms.Atoms" = "aiida_pythonjob.data.atoms:AtomsData"
"pythonjob.ase.atoms.Trajectory" = "aiida_pythonjob.data.atoms:Trajectory"
"pythonjob.builtins.NoneType" = "aiida_pythonjob.data.common_data:NoneData"
"pythonjob.datetime.datetime" = "aiida_pythonjob.data.common_data:DateTimeData"

Expand Down
59 changes: 58 additions & 1 deletion src/aiida_pythonjob/data/atoms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Iterable

import numpy as np
from aiida.orm import Data
from ase import Atoms
from ase.db.row import atoms2dict

__all__ = ("AtomsData",)
__all__ = ("AtomsData", "Trajectory")


class AtomsData(Data):
Expand Down Expand Up @@ -54,3 +56,58 @@ def value(self):
data = self.base.attributes.get_many(keys)
data = dict(zip(keys, data))
return Atoms(**data)


class Trajectory(Data):
"""Data to represent a list of ASE Atoms."""

_cached_traj = None

def __init__(self, value=None, **kwargs):
"""Initialise a `Trajectory` node instance.

:param value: List of ASE Atoms to initialise the `Trajectory` node from
"""
if value and not isinstance(value, Iterable):
raise ValueError("Trajectory must be iterable")

traj = value or [Atoms()]
super().__init__(**kwargs)
self.set_traj(traj)

def set_traj(self, traj):
"""Convert list of ASE Atoms to lists of dictionaries and keys."""
dicts = []
keys_list = []

for struct in traj:
data, keys = AtomsData.atoms2dict(struct)
dicts.append(data)
keys_list.append(keys)

# Store list of atom-dicts and associated keys
self.base.attributes.set("traj", dicts)
self.base.attributes.set("keys_list", keys_list)

self._cached_traj = None

def get_traj(self):
"""Reconstruct the list of ASE Atoms."""
if self._cached_traj is not None:
return self._cached_traj
serialised = self.base.attributes.get("traj")
keys_list = self.base.attributes.get("keys_list")

traj = []
for data, keys in zip(serialised, keys_list):
# Pick only the real ASE constructor keys
ase_kwargs = {k: data[k] for k in keys}
traj.append(Atoms(**ase_kwargs))

self._cached_traj = traj
return traj

@property
def value(self):
"""Get list of atoms."""
return self.get_traj()
5 changes: 5 additions & 0 deletions src/aiida_pythonjob/data/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"aiida.orm.nodes.data.dict.Dict": "aiida_pythonjob.data.deserializer.dict_data_to_dict",
"aiida.orm.nodes.data.array.array.ArrayData": "aiida_pythonjob.data.deserializer.array_data_to_array",
"aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms",
"aiida.orm.nodes.data.array.trajectory.TrajectoryData": "aiida_pythonjob.data.deserializer.trajectory_data_to_atoms",
}


Expand Down Expand Up @@ -39,6 +40,10 @@ def structure_data_to_atoms(structure):
return structure.get_ase()


def trajectory_data_to_atoms(trajectory):
return [trajectory.get_step_structure(i).get_ase() for i in trajectory.get_stepids()]


def structure_data_to_pymatgen(structure):
return structure.get_pymatgen()

Expand Down
45 changes: 45 additions & 0 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,48 @@ def test_serialize_json():

serialized_data = general_serializer(data, serializers=all_serializers)
assert isinstance(serialized_data, JsonableData)


def test_serialize_ase_atoms():
from ase import Atoms

from aiida_pythonjob.data.atoms import AtomsData
from aiida_pythonjob.data.serializer import general_serializer

data = Atoms("C")
serialized_data = general_serializer(data, serializers=all_serializers)
assert isinstance(serialized_data, AtomsData)


def test_serialize_ase_traj():
from ase import Atoms

from aiida_pythonjob.data.atoms import Trajectory
from aiida_pythonjob.data.serializer import general_serializer

data = Trajectory([Atoms("C"), Atoms("C")])
serialized_data = general_serializer(data, serializers=all_serializers)
assert isinstance(serialized_data, Trajectory)
Comment on lines +80 to +82
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused on this part. The input data is Trajectory, and then it is serialized to a Trajectory again?



def test_deserialize_atoms():
from ase import Atoms

from aiida_pythonjob.data.atoms import AtomsData
from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data

data = AtomsData(Atoms("C"))
deserialized_data = deserialize_to_raw_python_data(data)
assert isinstance(deserialized_data, Atoms)


def test_deserialize_trajectory():
from ase import Atoms

from aiida_pythonjob.data.atoms import Trajectory
from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data

data = Trajectory([Atoms("C"), Atoms("C")])
deserialized_data = deserialize_to_raw_python_data(data)
assert isinstance(deserialized_data, list)
assert all(isinstance(struct, Atoms) for struct in deserialized_data)