Skip to content

Commit

Permalink
changed default trainer to energy (#298)
Browse files Browse the repository at this point in the history
Co-authored-by: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com>
  • Loading branch information
AdeeshKolluru and mshuaibii authored Oct 18, 2021
1 parent 1044e31 commit 458da25
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __call__(self, config):
try:
setup_imports()
self.trainer = registry.get_trainer_class(
config.get("trainer", "simple")
config.get("trainer", "energy")
)(
task=config["task"],
model=config["model"],
Expand Down
2 changes: 1 addition & 1 deletion ocpmodels/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
self.config["checkpoint"] = checkpoint

self.trainer = registry.get_trainer_class(
config.get("trainer", "simple")
config.get("trainer", "energy")
)(
task=config["task"],
model=config["model"],
Expand Down
6 changes: 3 additions & 3 deletions ocpmodels/models/gemnet/fit_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import torch
from tqdm import trange

from ocpmodels.models.gemnet.layers.scaling import AutomaticFit
from ocpmodels.models.gemnet.utils import write_json
from ocpmodels.common.flags import flags
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import build_config, setup_imports, setup_logging
from ocpmodels.models.gemnet.layers.scaling import AutomaticFit
from ocpmodels.models.gemnet.utils import write_json

if __name__ == "__main__":
setup_logging()
Expand Down Expand Up @@ -74,7 +74,7 @@ def initialize_scale_file(scale_file):

AutomaticFit.set2fitmode()

trainer = registry.get_trainer_class(config.get("trainer", "simple"))(
trainer = registry.get_trainer_class(config.get("trainer", "energy"))(
task=config["task"],
model=config["model"],
dataset=config["dataset"],
Expand Down
2 changes: 1 addition & 1 deletion scripts/hpo/run_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def ocp_trainable(config, checkpoint_dir=None):
setup_imports()
# trainer defaults are changed to run HPO
trainer = registry.get_trainer_class(config.get("trainer", "simple"))(
trainer = registry.get_trainer_class(config.get("trainer", "energy"))(
task=config["task"],
model=config["model"],
dataset=config["dataset"],
Expand Down
2 changes: 1 addition & 1 deletion scripts/hpo/run_tune_pbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def ocp_trainable(config, checkpoint_dir=None):
# update config for PBT learning rate
config["optim"].update(lr_initial=config["lr"])
# trainer defaults are changed to run HPO
trainer = registry.get_trainer_class(config.get("trainer", "simple"))(
trainer = registry.get_trainer_class(config.get("trainer", "energy"))(
task=config["task"],
model=config["model"],
dataset=config["dataset"],
Expand Down

0 comments on commit 458da25

Please sign in to comment.