Skip to content

Commit

Permalink
test md
Browse files Browse the repository at this point in the history
  • Loading branch information
chiang-yuan committed Oct 21, 2024
1 parent dd24ea1 commit 4817e63
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
9 changes: 5 additions & 4 deletions mlip_arena/tasks/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ def run(
device: str | None = None,
ensemble: Literal["nve", "nvt", "npt"] = "nvt",
dynamics: str | MolecularDynamics = "langevin",
time_step: float | None = None,
total_time: float = 1000,
temperature: float | Sequence | np.ndarray | None = 300.0,
pressure: float | Sequence | np.ndarray | None = None,
time_step: float | None = None, # fs
total_time: float = 1000, # fs
temperature: float | Sequence | np.ndarray | None = 300.0, # K
pressure: float | Sequence | np.ndarray | None = None, # eV/A^3
ase_md_kwargs: dict | None = None,
md_velocity_seed: int | None = None,
zero_linear_momentum: bool = True,
Expand Down Expand Up @@ -363,6 +363,7 @@ def _callback(dyn: MolecularDynamics = md_runner) -> None:
md_runner.run(steps=n_steps)
end_time = datetime.now()

if traj_file is not None:
traj.close()

return {
Expand Down
5 changes: 0 additions & 5 deletions tests/test_eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@

atoms = bulk("Cu", "fcc", a=3.6)

# @pytest.fixture(autouse=True, scope="session")
# def prefect_test_fixture():
# with prefect_test_harness():
# yield

@pytest.mark.skipif(sys.version_info[:2] != (3,11), reason="avoid prefect race condition on concurrent tasks")
@pytest.mark.parametrize("model", [MLIPEnum["MACE-MP(M)"]])
def test_eos(model: MLIPEnum):
Expand Down
25 changes: 25 additions & 0 deletions tests/test_md.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

import sys

import pytest
from ase.build import bulk

from mlip_arena.models import MLIPEnum
from mlip_arena.tasks.md import run as MD

atoms = bulk("Cu", "fcc", a=3.6)

@pytest.mark.skipif(sys.version_info[:2] != (3,11), reason="avoid prefect race condition on concurrent tasks")
@pytest.mark.parametrize("model", [MLIPEnum["MACE-MP(M)"]])
def test_nve(model: MLIPEnum):

result = MD.fn(
atoms,
calculator_name=model.name,
calculator_kwargs={},
ensemble="nve",
dynamics="velocityverlet",
total_time=3,
)

assert isinstance(result["atoms"].get_potential_energy(), float)

0 comments on commit 4817e63

Please sign in to comment.