Skip to content

Commit

Permalink
Save model info (#336)
Browse files Browse the repository at this point in the history
* Save model info

* Test singlepoint model info saved

* Fix saving model

* Change info label for consistency

* Describe output info

* Add reference to aiida mlip

* Update docs/source/user_guide/python.rst

Co-authored-by: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com>

---------

Co-authored-by: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com>
  • Loading branch information
ElliottKasoar and oerc0122 authored Nov 14, 2024
1 parent ca1d799 commit 563742a
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 21 deletions.
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,58 @@ Jupyter Notebook tutorials illustrating the use of currently available calculati
- [Phonons](https://colab.research.google.com/github/stfc/janus-tutorials/blob/main/phonons.ipynb)


## Calculation outputs

By default, calculations performed will modify the underlying [ase.Atoms](https://wiki.fysik.dtu.dk/ase/ase/atoms.html) object
to store information in the `Atoms.info` and `Atoms.arrays` dictionaries about the MLIP used.

Additional dictionary keys include `arch`, corresponding to the MLIP architecture used,
and `model_path`, corresponding to the model path, name or label.

Results from the MLIP calculator, which are typically stored in `Atoms.calc.results`, will also, by default,
be copied to these dictionaries, prefixed by the MLIP `arch`.

For example:

```python
from janus_core.calculations.single_point import SinglePoint

single_point = SinglePoint(
struct_path="tests/data/NaCl.cif",
arch="mace_mp",
model_path="tests/models/mace_mp_small.model",
)

single_point.run()
print(single_point.struct.info)
```

will return

```python
{
'spacegroup': Spacegroup(1, setting=1),
'unit_cell': 'conventional',
'occupancy': {'0': {'Na': 1.0}, '1': {'Cl': 1.0}, '2': {'Na': 1.0}, '3': {'Cl': 1.0}, '4': {'Na': 1.0}, '5': {'Cl': 1.0}, '6': {'Na': 1.0}, '7': {'Cl': 1.0}},
'model_path': 'tests/models/mace_mp_small.model',
'arch': 'mace_mp',
'mace_mp_energy': -27.035127799332745,
'mace_mp_stress': array([-4.78327600e-03, -4.78327600e-03, -4.78327600e-03, 1.08000967e-19, -2.74004242e-19, -2.04504710e-19]),
'system_name': 'NaCl',
}
```

> [!NOTE]
> If running calculations with multiple MLIPs, `arch` and `mlip_model` will be overwritten with the most recent MLIP information.
> Results labelled by the architecture (e.g. `mace_mp_energy`) will be saved between MLIPs,
> unless the same `arch` is chosen, in which case these values will also be overwritten.
This is also the case the calculations performed using the CLI, with the same information written to extxyz output files.

> [!TIP]
> For complete provenance tracking, calculations and training can be run using the [aiida-mlip](https://github.com/stfc/aiida-mlip/) AiiDA plugin.

## Development

We recommend installing poetry for dependency management when developing for `janus-core`:
Expand Down
27 changes: 27 additions & 0 deletions docs/source/user_guide/command_line.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,33 @@ This will run a singlepoint energy calculation on ``KCl.cif`` using the `MACE-MP
Example configurations for all commands can be found in `janus-tutorials <https://github.com/stfc/janus-tutorials/tree/main/configs>`_


Output files
------------

By default, calculations performed will modify the underlying `ase.Atoms <https://wiki.fysik.dtu.dk/ase/ase/atoms.html>`_ object
to store information in the ``Atoms.info`` and ``Atoms.arrays`` dictionaries about the MLIP used.

Additional dictionary keys include ``arch``, corresponding to the MLIP architecture used,
and ``model_path``, corresponding to the model path, name or label.

Results from the MLIP calculator, which are typically stored in ``Atoms.calc.results``, will also, by default,
be copied to these dictionaries, prefixed by the MLIP ``arch``.

This information is then saved when extxyz files are written. For example:

.. code-block:: bash
janus singlepoint --struct tests/data/NaCl.cif --arch mace_mp --model-path /path/to/mace/model
Generates an output file, ``NaCl-results.extxyz``, with ``arch``, ``model_path``, ``mace_mp_energy``, ``mace_mp_forces``, and ``mace_mp_stress``.

.. note::
If running calculations with multiple MLIPs, ``arch`` and ``mlip_model`` will be overwritten with the most recent MLIP information.
Results labelled by the architecture (e.g. ``mace_mp_energy``) will be saved between MLIPs,
unless the same ``arch`` is chosen, in which case these values will also be overwritten.


Single point calculations
-------------------------

Expand Down
48 changes: 48 additions & 0 deletions docs/source/user_guide/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,51 @@ Jupyter Notebook tutorials illustrating the use of currently available calculati
- `Molecular Dynamics <https://colab.research.google.com/github/stfc/janus-tutorials/blob/main/md.ipynb>`_
- `Equation of State <https://colab.research.google.com/github/stfc/janus-tutorials/blob/main/eos.ipynb>`_
- `Phonons <https://colab.research.google.com/github/stfc/janus-tutorials/blob/main/phonons.ipynb>`_


Calculation outputs
===================

By default, calculations performed will modify the underlying :class:`ase.Atoms` object
to store information in the ``Atoms.info`` and ``Atoms.arrays`` dictionaries about the MLIP used.

Additional dictionary keys include ``arch``, corresponding to the MLIP architecture used,
and ``model_path``, corresponding to the model path, name or label.

Results from the MLIP calculator, which are typically stored in ``Atoms.calc.results``, will also,
by default, be copied to these dictionaries, prefixed by the MLIP ``arch``.

For example:

.. code-block:: python
from janus_core.calculations.single_point import SinglePoint
single_point = SinglePoint(
struct_path="tests/data/NaCl.cif",
arch="mace_mp",
model_path="tests/models/mace_mp_small.model",
)
single_point.run()
print(single_point.struct.info)
will return

.. code-block:: python
{
'spacegroup': Spacegroup(1, setting=1),
'unit_cell': 'conventional',
'occupancy': {'0': {'Na': 1.0}, '1': {'Cl': 1.0}, '2': {'Na': 1.0}, '3': {'Cl': 1.0}, '4': {'Na': 1.0}, '5': {'Cl': 1.0}, '6': {'Na': 1.0}, '7': {'Cl': 1.0}},
'model_path': 'tests/models/mace_mp_small.model',
'arch': 'mace_mp',
'mace_mp_energy': -27.035127799332745,
'mace_mp_stress': array([-4.78327600e-03, -4.78327600e-03, -4.78327600e-03, 1.08000967e-19, -2.74004242e-19, -2.04504710e-19]),
'system_name': 'NaCl',
}
.. note::
If running calculations with multiple MLIPs, ``arch`` and ``mlip_model`` will be overwritten with the most recent MLIP information.
Results labelled by the architecture (e.g. ``mace_mp_energy``) will be saved between MLIPs,
unless the same ``arch`` is chosen, in which case these values will also be overwritten.
36 changes: 19 additions & 17 deletions janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,20 @@ def choose_calculator(
from mace.calculators import mace_mp

# Default to "small" model and float64 precision
model = model_path if model_path else "small"
model_path = model_path if model_path else "small"
kwargs.setdefault("default_dtype", "float64")

calculator = mace_mp(model=model, device=device, **kwargs)
calculator = mace_mp(model=model_path, device=device, **kwargs)

elif arch == "mace_off":
from mace import __version__
from mace.calculators import mace_off

# Default to "small" model and float64 precision
model = model_path if model_path else "small"
model_path = model_path if model_path else "small"
kwargs.setdefault("default_dtype", "float64")

calculator = mace_off(model=model, device=device, **kwargs)
calculator = mace_off(model=model_path, device=device, **kwargs)

elif arch == "m3gnet":
from matgl import __version__, load_model
Expand All @@ -155,14 +155,16 @@ def choose_calculator(
# Otherwise, load the model if given a path, else use a default model
if isinstance(model_path, Potential):
potential = model_path
model_path = "loaded_Potential"
elif isinstance(model_path, Path):
if model_path.is_file():
model_path = model_path.parent
potential = load_model(model_path)
elif isinstance(model_path, str):
potential = load_model(model_path)
else:
potential = load_model("M3GNet-MP-2021.2.8-DIRECT-PES")
model_path = "M3GNet-MP-2021.2.8-DIRECT-PES"
potential = load_model(model_path)

calculator = M3GNetCalculator(potential=potential, **kwargs)

Expand All @@ -179,11 +181,13 @@ def choose_calculator(
# Otherwise, load the model if given a path, else use a default model
if isinstance(model_path, CHGNet):
model = model_path
model_path = "loaded_CHGNet"
elif isinstance(model_path, Path):
model = CHGNet.from_file(model_path)
elif isinstance(model_path, str):
model = CHGNet.load(model_name=model_path, use_device=device)
else:
model_path = "0.3.0"
model = None

calculator = CHGNetCalculator(model=model, use_device=device, **kwargs)
Expand All @@ -198,31 +202,28 @@ def choose_calculator(

# Set default path to directory containing config and model location
if isinstance(model_path, Path):
path = model_path
if path.is_file():
path = path.parent
if model_path.is_file():
model_path = model_path.parent
# If a string, assume referring to model_name e.g. "v5.27.2024"
elif isinstance(model_path, str):
path = get_figshare_model_ff(model_name=model_path)
model_path = get_figshare_model_ff(model_name=model_path)
else:
path = default_path()
model_path = default_path()

calculator = AlignnAtomwiseCalculator(path=path, device=device, **kwargs)
calculator = AlignnAtomwiseCalculator(path=model_path, device=device, **kwargs)

elif arch == "sevennet":
from sevenn import __version__
from sevenn.sevennet_calculator import SevenNetCalculator

if isinstance(model_path, Path):
model = str(model_path)
elif isinstance(model_path, str):
model = model_path
else:
model = "SevenNet-0_11July2024"
model_path = str(model_path)
elif not isinstance(model_path, str):
model_path = "SevenNet-0_11July2024"

kwargs.setdefault("file_type", "checkpoint")
kwargs.setdefault("sevennet_config", None)
calculator = SevenNetCalculator(model=model, device=device, **kwargs)
calculator = SevenNetCalculator(model=model_path, device=device, **kwargs)

else:
raise ValueError(
Expand All @@ -232,6 +233,7 @@ def choose_calculator(

calculator.parameters["version"] = __version__
calculator.parameters["arch"] = arch
calculator.parameters["model_path"] = str(model_path)

return calculator

Expand Down
5 changes: 5 additions & 0 deletions janus_core/helpers/struct_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def results_to_info(
if not properties:
properties = get_args(Properties)

if struct.calc and "model_path" in struct.calc.parameters:
struct.info["model_path"] = struct.calc.parameters["model_path"]

# Only add to info if MLIP calculator with "arch" parameter set
if struct.calc and "arch" in struct.calc.parameters:
arch = struct.calc.parameters["arch"]
Expand Down Expand Up @@ -268,6 +271,8 @@ def output_structs(
for image in images:
if image.calc and "arch" in image.calc.parameters:
image.info["arch"] = image.calc.parameters["arch"]
if image.calc and "model_path" in image.calc.parameters:
image.info["model_path"] = image.calc.parameters["model_path"]

# Add label for system
for image in images:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_mlips(arch, device, kwargs):
"""Test mace calculators can be configured."""
calculator = choose_calculator(arch=arch, device=device, **kwargs)
assert calculator.parameters["version"] is not None
assert calculator.parameters["model_path"] is not None


def test_invalid_arch():
Expand Down Expand Up @@ -129,6 +130,7 @@ def test_extra_mlips(arch, device, kwargs):
**kwargs,
)
assert calculator.parameters["version"] is not None
assert calculator.parameters["model_path"] is not None
except BadZipFile:
pytest.skip()

Expand Down
14 changes: 10 additions & 4 deletions tests/test_singlepoint_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,23 @@ def test_singlepoint():
assert summary_path.exists

finally:
# Check atoms can read read, then delete file
# Ensure files deleted if command fails
log_path.unlink(missing_ok=True)
summary_path.unlink(missing_ok=True)

# Check atoms file can be read, then delete
atoms = read_atoms(results_path)
assert "mace_mp_energy" in atoms.info

assert "arch" in atoms.info
assert "model_path" in atoms.info
assert atoms.info["arch"] == "mace_mp"
assert atoms.info["model_path"] == "small"

assert "mace_mp_forces" in atoms.arrays
assert "system_name" in atoms.info
assert atoms.info["system_name"] == "NaCl"

# Ensure files deleted if command fails
log_path.unlink(missing_ok=True)
summary_path.unlink(missing_ok=True)
clear_log_handlers()


Expand Down
9 changes: 9 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ def test_output_structs(
else:
results_keys = {"energy", "forces", "stress"}

if arch == "mace_mp":
model = "small"
if arch == "m3gnet":
model = "M3GNet-MP-2021.2.8-DIRECT-PES"
if arch == "chgnet":
model = "0.3.0"

label_keys = {f"{arch}_{key}" for key in results_keys}

write_kwargs = {}
Expand Down Expand Up @@ -117,6 +124,7 @@ def test_output_structs(
if "set_info" not in write_kwargs or write_kwargs["set_info"]:
assert label_keys <= struct.info.keys() | struct.arrays.keys()
assert struct.info["arch"] == arch
assert struct.info["model_path"] == model

# Check file written correctly if write_results
if write_results:
Expand All @@ -128,6 +136,7 @@ def test_output_structs(
if "set_info" not in write_kwargs or write_kwargs["set_info"]:
assert label_keys <= atoms.info.keys() | atoms.arrays.keys()
assert atoms.info["arch"] == arch
assert atoms.info["model_path"] == model

# Check calculator results depend on invalidate_calc
if invalidate_calc:
Expand Down

0 comments on commit 563742a

Please sign in to comment.