Skip to content

Commit

Permalink
Config fix (#737)
Browse files Browse the repository at this point in the history
* config fix

* config compatibility in ase calc

* check if eval_metrics is present

* handle existing empty loss_functions dicts
  • Loading branch information
lbluque authored Jun 21, 2024
1 parent fa23491 commit 0866211
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
8 changes: 1 addition & 7 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,7 @@ def __init__(
config["model"]["otf_graph"] = True

### backwards compatability with OCP v<2.0
### TODO: better format check for older configs
### Taken from base_trainer
if not config.get("loss_functions"):
logging.warning(
"Detected old config, converting to new format. Consider updating to avoid potential incompatibilities."
)
config = update_config(config)
config = update_config(config)

# Save config so obj can be transported over network (pkl)
self.config = copy.deepcopy(config)
Expand Down
20 changes: 19 additions & 1 deletion src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,11 +1176,29 @@ def irreps_sum(ang_mom: int) -> int:

def update_config(base_config):
"""
Configs created prior to OCP 2.0 are organized a little different than they
Configs created prior to FAIRChem/OCP 2.0 are organized a little different than they
are now. Update old configs to fit the new expected structure.
"""
### TODO: better format check for older configs
# some configs have a loss_functions key with an empty dictionary, those need to be updated as well
if len(base_config.get("loss_functions", {})) > 0:
return base_config

logging.warning(
"Detected old config, converting to new format. Consider updating to avoid potential incompatibilities."
)

# do we need a copy?
config = copy.deepcopy(base_config)

# initial fairchem/ocp 2.0 configs renamed loss_functions -> loss_fns and evaluation_metrics -> eval_metrics
# so some checkpoints may have configs in new format with the exception of renamed loss_funs and eval_metrics
if "loss_fns" in config:
config["loss_functions"] = config.pop("loss_fns")
if "eval_metrics" in config:
config["evaluation_metrics"] = config.pop("eval_metrics")
return config

# If config["dataset"]["format"] is missing, get it from the task (legacy location).
# If it is not there either, default to LMDB.
config["dataset"]["format"] = config["dataset"].get(
Expand Down
7 changes: 1 addition & 6 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,7 @@ def __init__(
os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True)

### backwards compatability with OCP v<2.0
### TODO: better format check for older configs
if not self.config.get("loss_functions"):
logging.warning(
"Detected old config, converting to new format. Consider updating to avoid potential incompatibilities."
)
self.config = update_config(self.config)
self.config = update_config(self.config)

if distutils.is_master():
logging.info(yaml.dump(self.config, default_flow_style=False))
Expand Down

0 comments on commit 0866211

Please sign in to comment.