diff --git a/mlip_arena/tasks/md.py b/mlip_arena/tasks/md.py index df6b122..7eb1633 100644 --- a/mlip_arena/tasks/md.py +++ b/mlip_arena/tasks/md.py @@ -196,10 +196,10 @@ def run( device: str | None = None, ensemble: Literal["nve", "nvt", "npt"] = "nvt", dynamics: str | MolecularDynamics = "langevin", - time_step: float | None = None, - total_time: float = 1000, - temperature: float | Sequence | np.ndarray | None = 300.0, - pressure: float | Sequence | np.ndarray | None = None, + time_step: float | None = None, # fs + total_time: float = 1000, # fs + temperature: float | Sequence | np.ndarray | None = 300.0, # K + pressure: float | Sequence | np.ndarray | None = None, # eV/A^3 ase_md_kwargs: dict | None = None, md_velocity_seed: int | None = None, zero_linear_momentum: bool = True, @@ -363,6 +363,7 @@ def _callback(dyn: MolecularDynamics = md_runner) -> None: md_runner.run(steps=n_steps) end_time = datetime.now() + if traj_file is not None: traj.close() return { diff --git a/tests/test_eos.py b/tests/test_eos.py index ba50415..91d6857 100644 --- a/tests/test_eos.py +++ b/tests/test_eos.py @@ -9,11 +9,6 @@ atoms = bulk("Cu", "fcc", a=3.6) -# @pytest.fixture(autouse=True, scope="session") -# def prefect_test_fixture(): -# with prefect_test_harness(): -# yield - @pytest.mark.skipif(sys.version_info[:2] != (3,11), reason="avoid prefect race condition on concurrent tasks") @pytest.mark.parametrize("model", [MLIPEnum["MACE-MP(M)"]]) def test_eos(model: MLIPEnum): diff --git a/tests/test_md.py b/tests/test_md.py new file mode 100644 index 0000000..f5751b8 --- /dev/null +++ b/tests/test_md.py @@ -0,0 +1,25 @@ + +import sys + +import pytest +from ase.build import bulk + +from mlip_arena.models import MLIPEnum +from mlip_arena.tasks.md import run as MD + +atoms = bulk("Cu", "fcc", a=3.6) + +@pytest.mark.skipif(sys.version_info[:2] != (3,11), reason="avoid prefect race condition on concurrent tasks") +@pytest.mark.parametrize("model", [MLIPEnum["MACE-MP(M)"]]) +def test_nve(model: MLIPEnum): + + result = MD.fn( + atoms, + calculator_name=model.name, + calculator_kwargs={}, + ensemble="nve", + dynamics="velocityverlet", + total_time=3, + ) + + assert isinstance(result["atoms"].get_potential_energy(), float)