Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference_only option to trainers #850

Merged
merged 9 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/fairchem/core/common/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
import torch.distributed
from torch.utils.data import BatchSampler, Dataset, DistributedSampler
from typing_extensions import override
from typing_extensions import deprecated, override

from fairchem.core.common import distutils, gp_utils
from fairchem.core.datasets import data_list_collater
Expand All @@ -29,6 +29,9 @@
from torch_geometric.data import Batch, Data


@deprecated(
"OCPColatter is deprecated. Please use data_list_collater optionally with functools.partial to set defaults"
)
class OCPCollater:
def __init__(self, otf_graph: bool = False) -> None:
self.otf_graph = otf_graph
Expand Down
6 changes: 1 addition & 5 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,6 @@ def __init__(
config["model_attributes"]["name"] = config.pop("model")
config["model"] = config["model_attributes"]

# for checkpoints with relaxation datasets defined, remove to avoid
# unnecesarily trying to load that dataset
if "relax_dataset" in config.get("task", {}):
del config["task"]["relax_dataset"]

# Calculate the edge indices on the fly
config["model"]["otf_graph"] = True

Expand All @@ -189,6 +184,7 @@ def __init__(
is_debug=config.get("is_debug", True),
cpu=cpu,
amp=config.get("amp", False),
inference_only=True,
)

if checkpoint_path is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
pass



from .edge_rot_mat import init_edge_rot_mat
from .gaussian_rbf import GaussianRadialBasisLayer
from .input_block import EdgeDegreeEmbedding
Expand Down Expand Up @@ -155,7 +154,7 @@ def __init__(
load_energy_lin_ref: bool | None = False,
):
logging.warning(
"equiformer_v2 (EquiformerV2) class is deprecaed in favor of equiformer_v2_backbone_and_heads (EquiformerV2BackboneAndHeads)"
"equiformer_v2 (EquiformerV2) class is deprecated in favor of equiformer_v2_backbone_and_heads (EquiformerV2BackboneAndHeads)"
)
if mmax_list is None:
mmax_list = [2]
Expand Down
56 changes: 32 additions & 24 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import random
import sys
from abc import ABC, abstractmethod
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np
import numpy.typing as npt
Expand All @@ -29,7 +30,7 @@

from fairchem.core import __version__
from fairchem.core.common import distutils, gp_utils
from fairchem.core.common.data_parallel import BalancedBatchSampler, OCPCollater
from fairchem.core.common.data_parallel import BalancedBatchSampler
from fairchem.core.common.registry import registry
from fairchem.core.common.slurm import (
add_timestamp_id_to_submission_pickle,
Expand All @@ -44,6 +45,7 @@
save_checkpoint,
update_config,
)
from fairchem.core.datasets import data_list_collater
from fairchem.core.datasets.base_dataset import create_dataset
from fairchem.core.modules.evaluator import Evaluator
from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage
Expand All @@ -65,13 +67,13 @@
class BaseTrainer(ABC):
def __init__(
self,
task,
model,
outputs,
dataset,
optimizer,
loss_functions,
evaluation_metrics,
task: dict[str, str | Any],
model: dict[str, Any],
outputs: dict[str, str | int],
dataset: dict[str, str | float],
optimizer: dict[str, str | float],
loss_functions: dict[str, str | float],
evaluation_metrics: dict[str, str],
identifier: str,
# TODO: dealing with local rank is dangerous
# T201111838 remove this and use CUDA_VISIBILE_DEVICES instead so trainers don't need to know about which devie to use
Expand All @@ -87,6 +89,7 @@ def __init__(
name: str = "ocp",
slurm=None,
gp_gpus: int | None = None,
inference_only: bool = False,
) -> None:
if slurm is None:
slurm = {}
Expand Down Expand Up @@ -205,7 +208,7 @@ def __init__(
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.load()
self.load(inference_only)

@abstractmethod
def train(self, disable_eval_tqdm: bool = False) -> None:
Expand All @@ -224,16 +227,19 @@ def _get_timestamp(device: torch.device, suffix: str | None) -> str:
timestamp_str += "-" + suffix
return timestamp_str

def load(self) -> None:
def load(self, inference_only: bool) -> None:
self.load_seed_from_config()
self.load_logger()
self.load_datasets()
self.load_references_and_normalizers()
self.load_task()
self.load_model()
self.load_loss()
self.load_optimizer()
self.load_extras()

if inference_only is False:
self.load_datasets()
self.load_references_and_normalizers()
self.load_loss()
self.load_optimizer()
self.load_extras()

if self.config["optim"].get("load_datasets_and_model_then_exit", False):
sys.exit(0)

Expand Down Expand Up @@ -298,14 +304,16 @@ def get_sampler(
def get_dataloader(self, dataset, sampler) -> DataLoader:
return DataLoader(
dataset,
collate_fn=self.ocp_collater,
collate_fn=self.collater,
num_workers=self.config["optim"]["num_workers"],
pin_memory=True,
batch_sampler=sampler,
)

def load_datasets(self) -> None:
self.ocp_collater = OCPCollater(self.config["model"].get("otf_graph", False))
self.collater = partial(
data_list_collater, otf_graph=self.config["model"].get("otf_graph", False)
)
self.train_loader = None
self.val_loader = None
self.test_loader = None
Expand Down Expand Up @@ -498,15 +506,15 @@ def load_task(self):
][target_name].get("level", "system")
if "train_on_free_atoms" not in self.output_targets[subtarget]:
self.output_targets[subtarget]["train_on_free_atoms"] = (
self.config["outputs"][target_name].get(
"train_on_free_atoms", True
)
self.config[
"outputs"
][target_name].get("train_on_free_atoms", True)
)
if "eval_on_free_atoms" not in self.output_targets[subtarget]:
self.output_targets[subtarget]["eval_on_free_atoms"] = (
self.config["outputs"][target_name].get(
"eval_on_free_atoms", True
)
self.config[
"outputs"
][target_name].get("eval_on_free_atoms", True)
)

# TODO: Assert that all targets, loss fn, metrics defined are consistent
Expand Down
51 changes: 26 additions & 25 deletions src/fairchem/core/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import os
from collections import defaultdict
from itertools import chain
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np
import torch
Expand Down Expand Up @@ -71,28 +71,29 @@ class OCPTrainer(BaseTrainer):

def __init__(
self,
task,
model,
outputs,
dataset,
optimizer,
loss_functions,
evaluation_metrics,
identifier,
task: dict[str, str | Any],
model: dict[str, Any],
outputs: dict[str, str | int],
dataset: dict[str, str | float],
optimizer: dict[str, str | float],
loss_functions: dict[str, str | float],
evaluation_metrics: dict[str, str],
identifier: str,
# TODO: dealing with local rank is dangerous
# T201111838 remove this and use CUDA_VISIBILE_DEVICES instead so trainers don't need to know about which devie to use
local_rank,
timestamp_id=None,
run_dir=None,
is_debug=False,
print_every=100,
seed=None,
logger="wandb",
amp=False,
cpu=False,
local_rank: int,
timestamp_id: str | None = None,
run_dir: str | None = None,
is_debug: bool = False,
print_every: int = 100,
seed: int | None = None,
logger: str = "wandb",
amp: bool = False,
cpu: bool = False,
name: str = "ocp",
slurm=None,
name="ocp",
gp_gpus=None,
gp_gpus: int | None = None,
inference_only: bool = False,
):
if slurm is None:
slurm = {}
Expand All @@ -117,6 +118,7 @@ def __init__(
slurm=slurm,
name=name,
gp_gpus=gp_gpus,
inference_only=inference_only,
)

def train(self, disable_eval_tqdm: bool = False) -> None:
Expand Down Expand Up @@ -253,8 +255,9 @@ def _forward(self, batch):
elif isinstance(out[target_key], dict):
# if output is a nested dictionary (in the case of hydra models), we attempt to retrieve it using the property name
# ie: "output_head_name.property"
assert "property" in self.output_targets[target_key], \
f"we need to know which property to match the target to, please specify the property field in the task config, current config: {self.output_targets[target_key]}"
assert (
"property" in self.output_targets[target_key]
), f"we need to know which property to match the target to, please specify the property field in the task config, current config: {self.output_targets[target_key]}"
property = self.output_targets[target_key]["property"]
pred = out[target_key][property]

Expand Down Expand Up @@ -661,9 +664,7 @@ def run_relaxations(self, split="val"):
)
gather_results["chunk_idx"] = np.cumsum(
[gather_results["chunk_idx"][i] for i in idx]
)[
:-1
] # np.split does not need last idx, assumes n-1:end
)[:-1] # np.split does not need last idx, assumes n-1:end

full_path = os.path.join(
self.config["cmd"]["results_dir"], "relaxed_positions.npz"
Expand Down
Loading