From 5e3ee755694be69187efb12ea846694911971783 Mon Sep 17 00:00:00 2001 From: Misko Date: Tue, 19 Nov 2024 20:00:54 +0000 Subject: [PATCH] add optional field to calculator to output only requested fix lint undo change to packages undo change to packages --- .../core/common/relaxation/ase_utils.py | 15 +++++++++++++ src/fairchem/core/trainers/ocp_trainer.py | 3 ++- .../__snapshots__/test_ase_calculator.ambr | 3 +++ tests/core/common/test_ase_calculator.py | 22 +++++++++++++++++++ 4 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/fairchem/core/common/relaxation/ase_utils.py b/src/fairchem/core/common/relaxation/ase_utils.py index 2dacce2cb7..237f35b592 100644 --- a/src/fairchem/core/common/relaxation/ase_utils.py +++ b/src/fairchem/core/common/relaxation/ase_utils.py @@ -80,6 +80,7 @@ def __init__( max_neighbors: int = 50, cpu: bool = True, seed: int | None = None, + only_output: list[str] | None = None, ) -> None: """ OCP-ASE Calculator @@ -170,6 +171,20 @@ def __init__( self.config["checkpoint"] = checkpoint_path del config["dataset"]["src"] + # some models that are published have configs that include tasks + # which are not output by the model + if only_output is not None: + assert isinstance( + only_output, list + ), "only output must be a list of targets to output" + for key in only_output: + assert ( + key in config["outputs"] + ), f"{key} listed in only_outputs is not present in current model outputs {config['outputs'].keys()}" + remove_outputs = set(config["outputs"].keys()) - set(only_output) + for key in remove_outputs: + config["outputs"].pop(key) + self.trainer = registry.get_trainer_class(config["trainer"])( task=config.get("task", {}), model=config["model"], diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 9a13faed69..152cddd93b 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -291,7 +291,8 @@ def _forward(self, batch): ) else: raise AttributeError( - f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}" + f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}\n" + + "If this is being called from OCPCalculator consider using only_output=[..]" ) ### not all models are consistent with the output shape diff --git a/tests/core/common/__snapshots__/test_ase_calculator.ambr b/tests/core/common/__snapshots__/test_ase_calculator.ambr index 71435c0ebd..49d8fb3e07 100644 --- a/tests/core/common/__snapshots__/test_ase_calculator.ambr +++ b/tests/core/common/__snapshots__/test_ase_calculator.ambr @@ -1,4 +1,7 @@ # serializer version: 1 +# name: test_energy_with_is2re_model + 1.09 +# --- # name: test_relaxation_final_energy 0.92 # --- diff --git a/tests/core/common/test_ase_calculator.py b/tests/core/common/test_ase_calculator.py index 3d62c35e1a..d59a506084 100644 --- a/tests/core/common/test_ase_calculator.py +++ b/tests/core/common/test_ase_calculator.py @@ -38,6 +38,7 @@ def atoms() -> Atoms: "PaiNN-S2EF-OC20-All", "GemNet-OC-Large-S2EF-OC20-All+MD", "SCN-S2EF-OC20-All+MD", + "PaiNN-IS2RE-OC20-All", # Equiformer v2 # already tested in test_relaxation_final_energy # "EquiformerV2-153M-S2EF-OC20-All+MD" # eSCNm # already tested in test_random_seed_final_energy @@ -54,6 +55,27 @@ def test_calculator_setup(checkpoint_path): _ = OCPCalculator(checkpoint_path=checkpoint_path, cpu=True) +def test_energy_with_is2re_model(atoms, tmp_path, snapshot): + random.seed(1) + torch.manual_seed(1) + + with pytest.raises(AttributeError): # noqa + calc = OCPCalculator( + checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path), + cpu=True, + ) + atoms.set_calculator(calc) + atoms.get_potential_energy() + + calc = OCPCalculator( + checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path), + cpu=True, + only_output=["energy"], + ) + atoms.set_calculator(calc) + assert snapshot == round(atoms.get_potential_energy(), 2) + + # test relaxation with EqV2 def test_relaxation_final_energy(atoms, tmp_path, snapshot) -> None: random.seed(1)