diff --git a/pyproject.toml b/pyproject.toml index 2137596..c2d21ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ Source = "https://github.com/aiidateam/aiida-pythonjob" [project.entry-points."aiida.data"] "pythonjob.jsonable_data" = "aiida_pythonjob.data.jsonable_data:JsonableData" "pythonjob.ase.atoms.Atoms" = "aiida_pythonjob.data.atoms:AtomsData" +"pythonjob.ase.atoms.Trajectory" = "aiida_pythonjob.data.atoms:Trajectory" "pythonjob.builtins.NoneType" = "aiida_pythonjob.data.common_data:NoneData" "pythonjob.datetime.datetime" = "aiida_pythonjob.data.common_data:DateTimeData" diff --git a/src/aiida_pythonjob/data/atoms.py b/src/aiida_pythonjob/data/atoms.py index 946cde2..86acf39 100644 --- a/src/aiida_pythonjob/data/atoms.py +++ b/src/aiida_pythonjob/data/atoms.py @@ -1,9 +1,11 @@ +from typing import Iterable + import numpy as np from aiida.orm import Data from ase import Atoms from ase.db.row import atoms2dict -__all__ = ("AtomsData",) +__all__ = ("AtomsData", "Trajectory") class AtomsData(Data): @@ -54,3 +56,58 @@ def value(self): data = self.base.attributes.get_many(keys) data = dict(zip(keys, data)) return Atoms(**data) + + +class Trajectory(Data): + """Data to represent a list of ASE Atoms.""" + + _cached_traj = None + + def __init__(self, value=None, **kwargs): + """Initialise a `Trajectory` node instance. + + :param value: List of ASE Atoms to initialise the `Trajectory` node from + """ + if value and not isinstance(value, Iterable): + raise ValueError("Trajectory must be iterable") + + traj = value or [Atoms()] + super().__init__(**kwargs) + self.set_traj(traj) + + def set_traj(self, traj): + """Convert list of ASE Atoms to lists of dictionaries and keys.""" + dicts = [] + keys_list = [] + + for struct in traj: + data, keys = AtomsData.atoms2dict(struct) + dicts.append(data) + keys_list.append(keys) + + # Store list of atom-dicts and associated keys + self.base.attributes.set("traj", dicts) + self.base.attributes.set("keys_list", keys_list) + + self._cached_traj = None + + def get_traj(self): + """Reconstruct the list of ASE Atoms.""" + if self._cached_traj is not None: + return self._cached_traj + serialised = self.base.attributes.get("traj") + keys_list = self.base.attributes.get("keys_list") + + traj = [] + for data, keys in zip(serialised, keys_list): + # Pick only the real ASE constructor keys + ase_kwargs = {k: data[k] for k in keys} + traj.append(Atoms(**ase_kwargs)) + + self._cached_traj = traj + return traj + + @property + def value(self): + """Get list of atoms.""" + return self.get_traj() diff --git a/src/aiida_pythonjob/data/deserializer.py b/src/aiida_pythonjob/data/deserializer.py index 15a4013..fe3c28e 100644 --- a/src/aiida_pythonjob/data/deserializer.py +++ b/src/aiida_pythonjob/data/deserializer.py @@ -12,6 +12,7 @@ "aiida.orm.nodes.data.dict.Dict": "aiida_pythonjob.data.deserializer.dict_data_to_dict", "aiida.orm.nodes.data.array.array.ArrayData": "aiida_pythonjob.data.deserializer.array_data_to_array", "aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms", + "aiida.orm.nodes.data.array.trajectory.TrajectoryData": "aiida_pythonjob.data.deserializer.trajectory_data_to_atoms", } @@ -39,6 +40,10 @@ def structure_data_to_atoms(structure): return structure.get_ase() +def trajectory_data_to_atoms(trajectory): + return [trajectory.get_step_structure(i).get_ase() for i in trajectory.get_stepids()] + + def structure_data_to_pymatgen(structure): return structure.get_pymatgen() diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 1bf36f1..fb76f93 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -58,3 +58,48 @@ def test_serialize_json(): serialized_data = general_serializer(data, serializers=all_serializers) assert isinstance(serialized_data, JsonableData) + + +def test_serialize_ase_atoms(): + from ase import Atoms + + from aiida_pythonjob.data.atoms import AtomsData + from aiida_pythonjob.data.serializer import general_serializer + + data = Atoms("C") + serialized_data = general_serializer(data, serializers=all_serializers) + assert isinstance(serialized_data, AtomsData) + + +def test_serialize_ase_traj(): + from ase import Atoms + + from aiida_pythonjob.data.atoms import Trajectory + from aiida_pythonjob.data.serializer import general_serializer + + data = Trajectory([Atoms("C"), Atoms("C")]) + serialized_data = general_serializer(data, serializers=all_serializers) + assert isinstance(serialized_data, Trajectory) + + +def test_deserialize_atoms(): + from ase import Atoms + + from aiida_pythonjob.data.atoms import AtomsData + from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data + + data = AtomsData(Atoms("C")) + deserialized_data = deserialize_to_raw_python_data(data) + assert isinstance(deserialized_data, Atoms) + + +def test_deserialize_trajectory(): + from ase import Atoms + + from aiida_pythonjob.data.atoms import Trajectory + from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data + + data = Trajectory([Atoms("C"), Atoms("C")]) + deserialized_data = deserialize_to_raw_python_data(data) + assert isinstance(deserialized_data, list) + assert all(isinstance(struct, Atoms) for struct in deserialized_data)