Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
LoryWangxx committed Mar 5, 2024
1 parent 82b3530 commit 473c4e5
Showing 1 changed file with 15 additions and 244 deletions.
259 changes: 15 additions & 244 deletions finetuna/finetuner_utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class Trainer(OCPTrainer):
def __init__(self, config_yml=None, checkpoint_path=None, cutoff=6, max_neighbors=50):
def __init__(
self, config_yml=None, checkpoint_path=None, cutoff=6, max_neighbors=50
):
setup_imports()
setup_logging()

Expand All @@ -29,9 +31,7 @@ def __init__(self, config_yml=None, checkpoint_path=None, cutoff=6, max_neighbor
checkpoint = None
if config_yml is not None:
if isinstance(config_yml, str):
config, duplicates_warning, duplicates_error = load_config(
config_yml
)
config, duplicates_warning, duplicates_error = load_config(config_yml)
if len(duplicates_warning) > 0:
logging.warning(
f"Overwritten config parameters from included configs "
Expand All @@ -52,9 +52,7 @@ def __init__(self, config_yml=None, checkpoint_path=None, cutoff=6, max_neighbor
# config["dataset"] = config["dataset"].get("train", None)
else:
# Loads the config from the checkpoint directly (always on CPU).
checkpoint = torch.load(
checkpoint_path, map_location=torch.device("cpu")
)
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
config = checkpoint["config"]

# if trainer is not None:
Expand Down Expand Up @@ -169,7 +167,7 @@ def save(
checkpoint["config"]["normalizer"] = self.normalizer
torch.save(checkpoint, checkpoint_path)
return checkpoint_path

def _compute_loss(self, out, batch):
batch_size = batch.natoms.numel()
fixed = batch.fixed
Expand Down Expand Up @@ -219,23 +217,16 @@ def _compute_loss(self, out, batch):

loss = sum(loss)
return loss

def train(self, disable_eval_tqdm: bool = False) -> None:
# ensure_fitted(self._unwrapped_model, warn=True)

eval_every = self.config["optim"].get(
"eval_every", len(self.train_loader)
)
checkpoint_every = self.config["optim"].get(
"checkpoint_every", eval_every
)
eval_every = self.config["optim"].get("eval_every", len(self.train_loader))
checkpoint_every = self.config["optim"].get("checkpoint_every", eval_every)
primary_metric = self.evaluation_metrics.get(
"primary_metric", self.evaluator.task_primary_metric[self.name]
)
if (
not hasattr(self, "primary_metric")
or self.primary_metric != primary_metric
):
if not hasattr(self, "primary_metric") or self.primary_metric != primary_metric:
self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0
else:
primary_metric = self.primary_metric
Expand All @@ -245,9 +236,7 @@ def train(self, disable_eval_tqdm: bool = False) -> None:
# to prevent inconsistencies due to different batch size in checkpoint.
start_epoch = self.step // len(self.train_loader)

for epoch_int in range(
start_epoch, self.config["optim"]["max_epochs"]
):
for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]):
self.train_sampler.set_epoch(epoch_int)
skip_steps = self.step % len(self.train_loader)
train_loader_iter = iter(self.train_loader)
Expand All @@ -272,9 +261,7 @@ def train(self, disable_eval_tqdm: bool = False) -> None:
self.evaluator,
self.metrics,
)
self.metrics = self.evaluator.update(
"loss", loss.item(), self.metrics
)
self.metrics = self.evaluator.update("loss", loss.item(), self.metrics)

loss = self.scaler.scale(loss) if self.scaler else loss
self._backward(loss)
Expand All @@ -292,9 +279,7 @@ def train(self, disable_eval_tqdm: bool = False) -> None:
self.step % self.config["cmd"]["print_every"] == 0
and distutils.is_master()
):
log_str = [
"{}: {:.2e}".format(k, v) for k, v in log_dict.items()
]
log_str = ["{}: {:.2e}".format(k, v) for k, v in log_dict.items()]
logging.info(", ".join(log_str))
self.metrics = {}

Expand All @@ -305,13 +290,8 @@ def train(self, disable_eval_tqdm: bool = False) -> None:
split="train",
)

if (
checkpoint_every != -1
and self.step % checkpoint_every == 0
):
self.save(
checkpoint_file="checkpoint.pt", training_state=True
)
if checkpoint_every != -1 and self.step % checkpoint_every == 0:
self.save(checkpoint_file="checkpoint.pt", training_state=True)

# Evaluate on val set every `eval_every` iterations.
if self.step % eval_every == 0:
Expand Down Expand Up @@ -359,212 +339,3 @@ def train(self, disable_eval_tqdm: bool = False) -> None:
self.val_dataset.close_db()
if self.config.get("test_dataset", False):
self.test_dataset.close_db()
# def train(self, disable_eval_tqdm=False):
# eval_every = self.config["optim"].get("eval_every", None)
# if eval_every is None:
# eval_every = len(self.train_loader)
# checkpoint_every = self.config["optim"].get("checkpoint_every", eval_every)
# primary_metric = self.config["task"].get(
# "primary_metric", self.evaluator.task_primary_metric[self.name]
# )
# self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0
# self.metrics = {}

# # Calculate start_epoch from step instead of loading the epoch number
# # to prevent inconsistencies due to different batch size in checkpoint.
# start_epoch = self.step // len(self.train_loader)

# for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]):
# self.train_sampler.set_epoch(epoch_int)
# skip_steps = self.step % len(self.train_loader)
# train_loader_iter = iter(self.train_loader)

# for i in range(skip_steps, len(self.train_loader)):
# self.epoch = epoch_int + (i + 1) / len(self.train_loader)
# self.step = epoch_int * len(self.train_loader) + i + 1
# self.model.train()

# # Get a batch.
# batch = next(train_loader_iter)

# if self.config["optim"]["optimizer"] == "LBFGS":

# def closure():
# self.optimizer.zero_grad()
# with torch.cuda.amp.autocast(enabled=self.scaler is not None):
# out = self._forward(batch)
# loss = self._compute_loss(out, batch)
# loss.backward()
# return loss

# self.optimizer.step(closure)

# self.optimizer.zero_grad()
# with torch.cuda.amp.autocast(enabled=self.scaler is not None):
# out = self._forward(batch)
# loss = self._compute_loss(out, batch)

# else:
# # Forward, loss, backward.
# with torch.cuda.amp.autocast(enabled=self.scaler is not None):
# out = self._forward(batch)
# loss = self._compute_loss(out, batch)
# loss = self.scaler.scale(loss) if self.scaler else loss
# self._backward(loss)

# scale = self.scaler.get_scale() if self.scaler else 1.0

# # Compute metrics.
# self.metrics = self._compute_metrics(
# out,
# batch,
# self.evaluator,
# self.metrics,
# )
# self.metrics = self.evaluator.update(
# "loss", loss.item() / scale, self.metrics
# )

# # Log metrics.
# log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
# log_dict.update(
# {
# "lr": self.scheduler.get_lr(),
# "epoch": self.epoch,
# "step": self.step,
# }
# )
# if (
# self.step % self.config["cmd"]["print_every"] == 0
# and distutils.is_master()
# and not self.is_hpo
# ):
# log_str = ["{}: {:.2e}".format(k, v) for k, v in log_dict.items()]
# logging.info(", ".join(log_str))
# self.metrics = {}

# if self.logger is not None:
# self.logger.log(
# log_dict,
# step=self.step,
# split="train",
# )

# if checkpoint_every != -1 and self.step % checkpoint_every == 0:
# self.save(checkpoint_file="checkpoint.pt", training_state=True)

# # Evaluate on val set every `eval_every` iterations.
# if self.step % eval_every == 0:
# if self.test_loader is not None:
# test_metrics = self.validate(
# split="test",
# disable_tqdm=disable_eval_tqdm,
# )
# if self.val_loader is not None:
# val_metrics = self.validate(
# split="val",
# disable_tqdm=disable_eval_tqdm,
# )
# self.update_best(
# primary_metric,
# val_metrics,
# disable_eval_tqdm=disable_eval_tqdm,
# )
# if self.is_hpo:
# self.hpo_update(
# self.epoch,
# self.step,
# self.metrics,
# val_metrics,
# )

# if self.config["task"].get("eval_relaxations", False):
# if "relax_dataset" not in self.config["task"]:
# logging.warning(
# "Cannot evaluate relaxations, relax_dataset not specified"
# )
# else:
# self.run_relaxations()

# if self.config["optim"].get("print_loss_and_lr", False):
# if self.step % eval_every == 0 or not self.config["optim"].get(
# "print_only_on_eval", True
# ):
# if self.val_loader is not None:
# print(
# "epoch: "
# + "{:.1f}".format(self.epoch)
# + ", \tstep: "
# + str(self.step)
# + ", \tloss: "
# + str(loss.detach().item())
# + ", \tlr: "
# + str(self.scheduler.get_lr())
# + ", \tval: "
# + str(val_metrics["loss"]["metric"])
# )
# else:
# print(
# "epoch: "
# + "{:.1f}".format(self.epoch)
# + ", \tstep: "
# + str(self.step)
# + ", \tloss: "
# + str(loss.detach().item())
# + ", \tlr: "
# + str(self.scheduler.get_lr())
# )

# if self.scheduler.scheduler_type == "ReduceLROnPlateau":
# if (
# self.step % eval_every == 0
# and self.config["optim"].get("scheduler_loss", None) == "train"
# ):
# self.scheduler.step(
# metrics=loss.detach().item(),
# )
# elif self.step % eval_every == 0 and self.val_loader is not None:
# self.scheduler.step(
# metrics=val_metrics[primary_metric]["metric"],
# )
# else:
# self.scheduler.step()

# break_below_lr = (
# self.config["optim"].get("break_below_lr", None) is not None
# ) and (self.scheduler.get_lr() < self.config["optim"]["break_below_lr"])
# if break_below_lr:
# break
# if break_below_lr:
# break

# torch.cuda.empty_cache()

# if checkpoint_every == -1:
# self.save(checkpoint_file="checkpoint.pt", training_state=True)

# self.train_dataset.close_db()
# if "val_dataset" in self.config:
# self.val_dataset.close_db()
# if "test_dataset" in self.config:
# self.test_dataset.close_db()

# def load_loss(self):
# self.loss_fn = {}
# self.loss_fn["energy"] = self.config["optim"].get("loss_energy", "mae")
# self.loss_fn["force"] = self.config["optim"].get("loss_force", "mae")
# for loss, loss_name in self.loss_fn.items():
# if loss_name in ["l1", "mae"]:
# self.loss_fn[loss] = nn.L1Loss()
# elif loss_name == "mse":
# self.loss_fn[loss] = nn.MSELoss()
# elif loss_name == "l2mae":
# self.loss_fn[loss] = L2MAELoss()
# elif loss_name == "rell2mae":
# self.loss_fn[loss] = RelativeL2MAELoss()
# elif loss_name == "atomwisel2":
# self.loss_fn[loss] = AtomwiseL2LossNoBatch()
# else:
# raise NotImplementedError(f"Unknown loss function name: {loss_name}")
# if distutils.initialized():
# self.loss_fn[loss] = DDPLoss(self.loss_fn[loss])

0 comments on commit 473c4e5

Please sign in to comment.