Skip to content

Commit

Permalink
add optional field to calculator to output only requested
Browse files Browse the repository at this point in the history
fix lint

undo change to packages

undo change to packages
  • Loading branch information
misko committed Nov 21, 2024
1 parent aa298ac commit 5e3ee75
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 1 deletion.
15 changes: 15 additions & 0 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
max_neighbors: int = 50,
cpu: bool = True,
seed: int | None = None,
only_output: list[str] | None = None,
) -> None:
"""
OCP-ASE Calculator
Expand Down Expand Up @@ -170,6 +171,20 @@ def __init__(
self.config["checkpoint"] = checkpoint_path
del config["dataset"]["src"]

# some models that are published have configs that include tasks
# which are not output by the model
if only_output is not None:
assert isinstance(
only_output, list
), "only output must be a list of targets to output"
for key in only_output:
assert (
key in config["outputs"]
), f"{key} listed in only_outputs is not present in current model outputs {config['outputs'].keys()}"
remove_outputs = set(config["outputs"].keys()) - set(only_output)
for key in remove_outputs:
config["outputs"].pop(key)

self.trainer = registry.get_trainer_class(config["trainer"])(
task=config.get("task", {}),
model=config["model"],
Expand Down
3 changes: 2 additions & 1 deletion src/fairchem/core/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ def _forward(self, batch):
)
else:
raise AttributeError(
f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}"
f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}\n"
+ "If this is being called from OCPCalculator consider using only_output=[..]"
)

### not all models are consistent with the output shape
Expand Down
3 changes: 3 additions & 0 deletions tests/core/common/__snapshots__/test_ase_calculator.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# serializer version: 1
# name: test_energy_with_is2re_model
1.09
# ---
# name: test_relaxation_final_energy
0.92
# ---
22 changes: 22 additions & 0 deletions tests/core/common/test_ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def atoms() -> Atoms:
"PaiNN-S2EF-OC20-All",
"GemNet-OC-Large-S2EF-OC20-All+MD",
"SCN-S2EF-OC20-All+MD",
"PaiNN-IS2RE-OC20-All",
# Equiformer v2 # already tested in test_relaxation_final_energy
# "EquiformerV2-153M-S2EF-OC20-All+MD"
# eSCNm # already tested in test_random_seed_final_energy
Expand All @@ -54,6 +55,27 @@ def test_calculator_setup(checkpoint_path):
_ = OCPCalculator(checkpoint_path=checkpoint_path, cpu=True)


def test_energy_with_is2re_model(atoms, tmp_path, snapshot):
random.seed(1)
torch.manual_seed(1)

with pytest.raises(AttributeError): # noqa
calc = OCPCalculator(
checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path),
cpu=True,
)
atoms.set_calculator(calc)
atoms.get_potential_energy()

calc = OCPCalculator(
checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path),
cpu=True,
only_output=["energy"],
)
atoms.set_calculator(calc)
assert snapshot == round(atoms.get_potential_energy(), 2)


# test relaxation with EqV2
def test_relaxation_final_energy(atoms, tmp_path, snapshot) -> None:
random.seed(1)
Expand Down

0 comments on commit 5e3ee75

Please sign in to comment.