Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom mace model by specifying "model" in calculator kwargs" #1017

Merged
merged 16 commits into from
Oct 31, 2024
19 changes: 16 additions & 3 deletions src/atomate2/forcefields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING

from monty.json import MontyDecoder
Expand Down Expand Up @@ -59,9 +60,21 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N
calculator = PESCalculator(potential, **kwargs)

elif calculator_name == MLFF.MACE:
from mace.calculators import mace_mp

calculator = mace_mp(**kwargs)
from mace.calculators import MACECalculator, mace_mp

model = kwargs.get("model")
if isinstance(model, str | Path) and Path(model).exists():
model_path = model
device = kwargs.get("device") or "cpu"
if "device" in kwargs:
del kwargs["device"]
calculator = MACECalculator(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we need to change this to:

device = kwargs.get("device") or "cpu"
calculator =  MACECalculator(
    model_paths=model_path,
    device = device,
    **kwargs
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. One needs to remove the "device" key from the kwargs. Otherwise, it will complain that it is set twice. Will do that.

model_paths=model_path,
device=device,
**kwargs,
)
else:
calculator = mace_mp(**kwargs)

elif calculator_name == MLFF.GAP:
from quippy.potential import Potential
Expand Down
4 changes: 2 additions & 2 deletions tests/forcefields/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_mace_relax_maker(
# NOTE the test model is not trained on Si, so the energy is not accurate
job = ForceFieldRelaxMaker(
force_field_name="MACE",
calculator_kwargs={"model": model},
calculator_kwargs={"model": model, "default_dtype": "float32"},
steps=25,
optimizer_kwargs={"optimizer": "BFGSLineSearch"},
relax_cell=relax_cell,
Expand Down Expand Up @@ -308,7 +308,7 @@ def test_mace_relax_maker(

if fix_symmetry: # if symmetry is fixed, the symmetry should be the same or higher
assert is_subgroup(symmetry_ops_init, symmetry_ops_final)
else: # if symmetry is not fixed, it can both increase or decrease
else: # if symmetry is not fixed, it can both increase or decrease or stay the same
assert not is_subgroup(symmetry_ops_init, symmetry_ops_final)

if relax_cell:
Expand Down
Loading