Skip to content

Commit

Permalink
add ReadData
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Dec 11, 2023
1 parent 83d0ed2 commit bbeed50
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ipsuite/data_loading/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""ipsuite data loading module."""

from ipsuite.data_loading.add_data_ase import AddData
from ipsuite.data_loading.add_data_ase import AddData, ReadData
from ipsuite.data_loading.add_data_h5md import AddDataH5MD

__all__ = ["AddData", "AddDataH5MD"]
__all__ = ["AddData", "AddDataH5MD", "ReadData"]
27 changes: 27 additions & 0 deletions ipsuite/data_loading/add_data_ase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""ipsuite data loading with ASE."""

import functools
import logging
import pathlib
import typing
Expand Down Expand Up @@ -41,6 +42,32 @@ def load_data(
return atoms


class ReadData(base.IPSNode):
"""Read data without converting it to H5MD.
This Node can be used instead of `AddData` to avoid
initial conversion to H5MD. Later Nodes might still
convert the data to H5MD.
Attributes
----------
file: str|Path
path to the file that should be read.
lines_to_read: int, optional
maximal number of lines/configurations to read, None for read all
"""

file: typing.Union[str, pathlib.Path] = zntrack.deps_path()
lines_to_read: int = zntrack.params(None)

def run(self):
pass

@functools.cached_property
def atoms(self) -> typing.List[ase.Atoms]:
return load_data(self.file, self.lines_to_read)


class AddData(base.IPSNode):
"""Add data using ASE.
Expand Down
1 change: 1 addition & 0 deletions ipsuite/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class _Nodes:
# Data
AddData = "ipsuite.data_loading.AddData"
AddDataH5MD = "ipsuite.data_loading.AddDataH5MD"
ReadData = "ipsuite.data_loading.ReadData"

# Bootstrap
RattleAtoms = "ipsuite.bootstrap.RattleAtoms"
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/test_AddData.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@ def test_AddData(proj_path, traj_file, atoms_list, eager):
subprocess.check_call(["dvc", "add", traj_file.name])
with ipsuite.Project() as project:
data = ipsuite.AddData(file=traj_file.name)
data2 = ipsuite.data_loading.ReadData(file=traj_file.name)

project.run(eager=eager)
if not eager:
data.load()
data2.load()

assert isinstance(data.atoms, list)
assert isinstance(data.atoms[0], ase.Atoms)

assert isinstance(data2.atoms, list)
assert isinstance(data2.atoms[0], ase.Atoms)

assert data.atoms == data2.atoms

for loaded, given in zip(data.atoms[:], atoms_list):
# Check that the atoms match
assert loaded.get_potential_energy() == given.get_potential_energy()
Expand Down

0 comments on commit bbeed50

Please sign in to comment.