From 50e41d911ead8fe272892cb834aae9b444a56e72 Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru Date: Thu, 14 Oct 2021 15:34:50 +0000 Subject: [PATCH] changed default trainer to energy --- main.py | 2 +- ocpmodels/common/relaxation/ase_utils.py | 2 +- ocpmodels/models/gemnet/fit_scaling.py | 6 +++--- scripts/hpo/run_tune.py | 2 +- scripts/hpo/run_tune_pbt.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index d49263d65..55ef3ed88 100644 --- a/main.py +++ b/main.py @@ -40,7 +40,7 @@ def __call__(self, config): try: setup_imports() self.trainer = registry.get_trainer_class( - config.get("trainer", "simple") + config.get("trainer", "energy") )( task=config["task"], model=config["model"], diff --git a/ocpmodels/common/relaxation/ase_utils.py b/ocpmodels/common/relaxation/ase_utils.py index 7c64cac06..fd097384d 100644 --- a/ocpmodels/common/relaxation/ase_utils.py +++ b/ocpmodels/common/relaxation/ase_utils.py @@ -95,7 +95,7 @@ def __init__( self.config["checkpoint"] = checkpoint self.trainer = registry.get_trainer_class( - config.get("trainer", "simple") + config.get("trainer", "energy") )( task=config["task"], model=config["model"], diff --git a/ocpmodels/models/gemnet/fit_scaling.py b/ocpmodels/models/gemnet/fit_scaling.py index 8ddafcc1f..97badfd71 100644 --- a/ocpmodels/models/gemnet/fit_scaling.py +++ b/ocpmodels/models/gemnet/fit_scaling.py @@ -20,11 +20,11 @@ import torch from tqdm import trange -from ocpmodels.models.gemnet.layers.scaling import AutomaticFit -from ocpmodels.models.gemnet.utils import write_json from ocpmodels.common.flags import flags from ocpmodels.common.registry import registry from ocpmodels.common.utils import build_config, setup_imports, setup_logging +from ocpmodels.models.gemnet.layers.scaling import AutomaticFit +from ocpmodels.models.gemnet.utils import write_json if __name__ == "__main__": setup_logging() @@ -74,7 +74,7 @@ def initialize_scale_file(scale_file): AutomaticFit.set2fitmode() - trainer = registry.get_trainer_class(config.get("trainer", "simple"))( + trainer = registry.get_trainer_class(config.get("trainer", "energy"))( task=config["task"], model=config["model"], dataset=config["dataset"], diff --git a/scripts/hpo/run_tune.py b/scripts/hpo/run_tune.py index 656b02f52..2d3854c7b 100644 --- a/scripts/hpo/run_tune.py +++ b/scripts/hpo/run_tune.py @@ -14,7 +14,7 @@ def ocp_trainable(config, checkpoint_dir=None): setup_imports() # trainer defaults are changed to run HPO - trainer = registry.get_trainer_class(config.get("trainer", "simple"))( + trainer = registry.get_trainer_class(config.get("trainer", "energy"))( task=config["task"], model=config["model"], dataset=config["dataset"], diff --git a/scripts/hpo/run_tune_pbt.py b/scripts/hpo/run_tune_pbt.py index 8ea2dacb4..c70d72cfc 100644 --- a/scripts/hpo/run_tune_pbt.py +++ b/scripts/hpo/run_tune_pbt.py @@ -17,7 +17,7 @@ def ocp_trainable(config, checkpoint_dir=None): # update config for PBT learning rate config["optim"].update(lr_initial=config["lr"]) # trainer defaults are changed to run HPO - trainer = registry.get_trainer_class(config.get("trainer", "simple"))( + trainer = registry.get_trainer_class(config.get("trainer", "energy"))( task=config["task"], model=config["model"], dataset=config["dataset"],