Skip to content

Commit

Permalink
Merge pull request #43 from Benzoin96485/active_learning
Browse files Browse the repository at this point in the history
Optimizer, scheduler and model resume
  • Loading branch information
Benzoin96485 authored Oct 10, 2024
2 parents fba2df1 + bb024d3 commit ee1b063
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 53 deletions.
27 changes: 22 additions & 5 deletions enerzyme/models/ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,13 @@ def __init__(self,
) -> None:
super().__init__(datahub, trainer, model_str, loss, architecture, build_params, layers, pretrain_path)
from .modelhub import get_pretrain_path
self.pretrain_path = get_pretrain_path(pretrain_path, "best", None)
if self.trainer.resume:
if pretrain_path is None:
self.pretrain_path = get_pretrain_path(self.dump_dir, "last", None)
else:
self.pretrain_path = get_pretrain_path(pretrain_path, "last", None)
else:
self.pretrain_path = get_pretrain_path(pretrain_path, "best", None)
self.model = self._init_model(self.build_params)

def _train(
Expand Down Expand Up @@ -343,7 +349,11 @@ def active_learn(self) -> None:
from ..tasks.active_learning import max_Fa_norm_std_picking
partitions = self._init_partition()
training_set = partitions["training"]
validation_set = partitions.get("validation", None)
withheld_set = partitions["withheld"]
len_training = len(training_set)
len_validation = len(validation_set) if validation_set is not None else 0
ratio_training = len_training / (len_training + len_validation)
params = self.trainer.active_learning_params
lb = params["error_lower_bound"]
ub = params["error_upper_bound"]
Expand All @@ -357,7 +367,7 @@ def active_learn(self) -> None:
for i in range(max_iter):
if i > 0:
self._init_pretrain_path(self.dump_dir)
self._train(training_set)
self._train(training_set, validation_set)
unmasked_relative_indices = withheld_mask.nonzero()[0]
unmasked_size = len(unmasked_relative_indices)
if unmasked_size == 0:
Expand All @@ -377,12 +387,19 @@ def active_learn(self) -> None:
break
masked_relative_indices = masked_relative_indices[:sample_size]
expand_absolute_indices = withheld_set.raw_indices[masked_relative_indices]
training_set.expand_with_indices(expand_absolute_indices)
len_expanded = len(expand_absolute_indices)
len_expanded_training = int(len_expanded * ratio_training + 0.5)
training_set.expand_with_indices(expand_absolute_indices[:len_expanded_training])
if validation_set is not None:
validation_set.expand_with_indices(expand_absolute_indices[len_expanded_training:])
withheld_mask[masked_relative_indices] = False
np.savez(os.path.join(self.dump_dir, "active_learning_split.npz"), {
new_indices = {
"training": training_set.raw_indices,
"withheld": withheld_set.raw_indices
})
}
if validation_set is not None:
new_indices["validation"] = validation_set.raw_indices
np.savez(os.path.join(self.dump_dir, "active_learning_split.npz"), new_indices)
logger.info(f"Active learning iteration {i + 1} / {max_iter} finished!")
else:
raise NotImplementedError
3 changes: 1 addition & 2 deletions enerzyme/simulate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import torch
from .models import FF_REGISTER, get_model_str, build_model, get_pretrain_path
from .tasks import Simulation, _load_state_dict
from .tasks import Simulation
from .utils import YamlHandler, logger
from .data import Transform

Expand Down
17 changes: 8 additions & 9 deletions enerzyme/tasks/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
from typing import Dict
from typing import Dict, Callable, Tuple
import numpy as np
import torch
from sklearn.metrics import mean_absolute_error, mean_squared_error
from ..data import is_atomic, get_tensor_rank
from ..utils.base_logger import logger
Expand Down Expand Up @@ -66,20 +64,21 @@ def cal_metric(self, label, predict):
raw_metric_score["_judge_score"] = self.cal_judge_score(raw_metric_score)
return raw_metric_score

def _early_stop_choice(self, wait, min_score, metric_score, max_score, save_handle, patience, epoch):
def _early_stop_choice(self, wait: int, best_score: float, metric_score: Dict, save_handle: Callable, patience: int, epoch: int) -> Tuple[bool, float, int]:
judge_score = metric_score.get("_judge_score", self.cal_judge_score(metric_score))
is_early_stop, min_score, wait = self._judge_early_stop_decrease(wait, judge_score, min_score, save_handle, patience, epoch)
return is_early_stop, min_score, wait, max_score
return self._judge_early_stop_decrease(wait, judge_score, best_score, save_handle, patience, epoch)

def _judge_early_stop_decrease(self, wait, score, min_score, save_handle, patience, epoch):
def _judge_early_stop_decrease(self, wait: int, score: float, min_score: float, save_handle: Callable, patience: int, epoch: int) -> Tuple[bool, float, int]:
is_early_stop = False
saved = False
if score <= min_score:
min_score = score
wait = 0
save_handle()
save_handle(best_score=score, best_epoch=epoch, epoch=epoch)
saved = True
elif score >= min_score:
wait += 1
if wait == patience:
logger.warning(f'Early stopping at epoch: {epoch+1}')
is_early_stop = True
return is_early_stop, min_score, wait
return is_early_stop, min_score, wait, saved
2 changes: 1 addition & 1 deletion enerzyme/tasks/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, config, model, model_path, out_dir, transform):
self.device = torch.device("cuda:0" if torch.cuda.is_available() and self.cuda else "cpu")
# single ff simulation
self.model = model.to(self.device).type(self.dtype)
_load_state_dict(model, self.device, model_path)
_load_state_dict(model, self.device, model_path, inference=True)
self.model.eval()
self.out_dir = out_dir
# self.simulation_config = {k: (v.to_dict() if hasattr(v, "to_dict") else v) for k, v in config.Simulation.items() if not hasattr(self, k)}
Expand Down
144 changes: 108 additions & 36 deletions enerzyme/tasks/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ast import dump
from functools import partial
from typing import Iterable, Optional, Callable, Tuple, Dict, Any
from collections import defaultdict
Expand All @@ -7,7 +6,8 @@
import torch
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.nn import Module
from torch.nn.utils import clip_grad_norm_
from torch_ema import ExponentialMovingAverage
Expand Down Expand Up @@ -111,9 +111,10 @@ def _decorate_batch_output(output, features, targets):
return y_pred, (y_truth if y_truth else None)


def _load_state_dict(model: Module, device: torch.device, pretrain_path: Optional[str], ema: Optional[ExponentialMovingAverage]=None, inference: bool=False) -> None:
def _load_state_dict(model: Module, device: torch.device, pretrain_path: Optional[str], ema: Optional[ExponentialMovingAverage]=None, inference: bool=False, optimizer: Optional[Optimizer]=None, scheduler: Optional[LRScheduler]=None) -> Dict:
other_info = dict()
if pretrain_path is None:
return
return other_info
loaded_info = torch.load(pretrain_path, map_location=device)
if ema is not None and "ema_state_dict" in loaded_info:
model.load_state_dict(loaded_info["model_state_dict"])
Expand All @@ -128,6 +129,29 @@ def _load_state_dict(model: Module, device: torch.device, pretrain_path: Optiona
else:
model.load_state_dict(loaded_info["model_state_dict"])
logger.info(f"loading model state dict from {pretrain_path}...")
if not inference:
if optimizer is not None and "optimizer_state_dict" in loaded_info:
optimizer.load_state_dict(loaded_info["optimizer_state_dict"])
logger.info(f"loading optimizer state dict from {pretrain_path}...")
if scheduler is not None and "scheduler_state_dict" in loaded_info:
scheduler.load_state_dict(loaded_info["scheduler_state_dict"])
logger.info(f"loading scheduler state dict from {pretrain_path}...")
if "epoch" in loaded_info:
other_info["epoch"] = loaded_info["epoch"]
if "best_epoch" in loaded_info:
other_info["best_epoch"] = loaded_info["best_epoch"]
if "best_score" in loaded_info:
other_info["best_score"] = loaded_info["best_score"]
# if "torch_rng_state" in loaded_info:
# torch.random.set_rng_state(loaded_info["torch_rng_state"])
# logger.info(f"loading torch random generator state from {pretrain_path}...")
# if "torch_cuda_rng_state_all" in loaded_info:
# torch.cuda.random.set_rng_state_all(loaded_info["torch_cuda_rng_state_all"])
# logger.info(f"loading torch cuda random generator state from {pretrain_path}...")
# if "np_rng_state" in loaded_info:
# np.random.set_state(loaded_info["np_rng_state"])
# logger.info(f"loading numpy random generator state from {pretrain_path}...")
return other_info


class Trainer:
Expand Down Expand Up @@ -169,6 +193,7 @@ def __init__(self, out_dir: str=None, metric_config: Metrics=dict(), **params) -
self.active_learning = True
else:
self.active_learning = False
self.resume = params.get("resume", False)

def decorate_batch_input(self, batch):
return _decorate_batch_input(batch, self.dtype, self.device)
Expand All @@ -186,22 +211,41 @@ def _set_seed(self, seed):
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

def load_state_dict(self, model: Module, pretrain_path: Optional[str], ema: Optional[ExponentialMovingAverage]=None, inference: bool=False) -> None:
return _load_state_dict(model, self.device, pretrain_path, ema, inference)
def load_state_dict(
self,
model: Module,
optimizer: Optional[Optimizer]=None,
scheduler: Optional[LRScheduler]=None,
pretrain_path: Optional[str]=None,
ema: Optional[ExponentialMovingAverage]=None,
inference: bool=False
) -> None:
return _load_state_dict(model, self.device, pretrain_path, ema, inference, optimizer, scheduler)

def save_state_dict(self, model: Module, dump_dir, ema: Optional[ExponentialMovingAverage]=None, suffix="last", model_rank=None):
def save_state_dict(self, model: Module, optimizer: Optimizer, scheduler: LRScheduler, dump_dir: str, ema: Optional[ExponentialMovingAverage]=None, suffix="last", model_rank=None, epoch: Optional[int]=None, best_score: Optional[float]=None, best_epoch: Optional[int]=None):
if model_rank is None:
model_rank = ''
os.makedirs(dump_dir, exist_ok=True)
if ema is None:
info = {'model_state_dict': model.state_dict()}
else:
info = {'ema_state_dict': ema.state_dict(), 'model_state_dict': model.state_dict()}
info = {"model_state_dict": model.state_dict()}
if ema is not None:
info["ema_state_dict"] = ema.state_dict()
info["optimizer_state_dict"] = optimizer.state_dict()
info["scheduler_state_dict"] = scheduler.state_dict()
if epoch is not None:
info["epoch"] = epoch
if best_score is not None:
info["best_score"] = best_score
if best_epoch is not None:
info["best_epoch"] = best_epoch
# info["torch_rng_state"] = torch.random.get_rng_state()
# if self.cuda:
# info["torch_cuda_rng_state_all"] = torch.cuda.random.get_rng_state_all()
# info["np_rng_state"] = np.random.get_state()
torch.save(info, os.path.join(dump_dir, f'model{model_rank}_{suffix}.pth'))

def fit_predict(self,
model: Module, pretrain_path: Optional[str],
train_dataset: Dataset, valid_dataset: Dataset,
train_dataset: Dataset, valid_dataset: Optional[Dataset],
loss_terms: Iterable[Callable], dump_dir: str, transform: Transform,
test_dataset: Optional[Dataset]=None, model_rank=None) -> Tuple[Optional[defaultdict[Any]], Dict]:
self._set_seed(self.seed + (model_rank if model_rank is not None else 0))
Expand All @@ -213,20 +257,40 @@ def fit_predict(self,
collate_fn=self.decorate_batch_input,
drop_last=True
)
min_val_loss = float("inf")
max_score = float("-inf")
wait = 0

num_training_steps = len(train_dataloader) * self.max_epochs
num_warmup_steps = int(num_training_steps * self.warmup_ratio)
optimizer = Adam(model.parameters(), lr=self.learning_rate, eps=1e-6, weight_decay=self.weight_decay, amsgrad=self.amsgrad)
if self.use_ema:
ema = ExponentialMovingAverage(model.parameters(), self.ema_decay, self.ema_use_num_updates)
else:
ema = None
self.load_state_dict(model, pretrain_path, ema)
scheduler = get_scheduler(self.schedule, optimizer, num_warmup_steps, num_training_steps)
other_info = self.load_state_dict(model, optimizer, scheduler, pretrain_path, ema)

if self.resume and "best_epoch" in other_info and "epoch" in other_info:
wait = other_info["epoch"] - other_info["best_epoch"]
best_score = other_info.get("best_score", float("inf"))
start_epoch = other_info["epoch"] + 1
else:
wait = 0
start_epoch = other_info.get("epoch", -1) + 1
if valid_dataset is not None:
best_score = other_info.get("best_score", float("inf"))
else:
best_score = None

if self.resume:
max_epochs = self.max_epochs
else:
max_epochs = start_epoch + self.max_epochs

if valid_dataset is not None:
best_epoch = other_info.get("best_epoch", start_epoch)
else:
best_epoch = None

for epoch in range(self.max_epochs):
for epoch in range(start_epoch, max_epochs):
model = model.train()
start_time = time.time()
batch_bar = tqdm(
Expand All @@ -246,7 +310,7 @@ def fit_predict(self,
trn_loss.append(float(loss.data))

batch_bar.set_postfix(
Epoch="Epoch {}/{}".format(epoch+1, self.max_epochs),
Epoch="Epoch {}/{}".format(epoch+1, max_epochs),
loss="{:.04f}".format(float(sum(trn_loss) / (i + 1))),
lr="{:.04f}".format(float(optimizer.param_groups[0]['lr']))
)
Expand All @@ -268,12 +332,13 @@ def fit_predict(self,

batch_bar.close()
total_trn_loss = np.mean(trn_loss)
message = f'Epoch [{epoch+1}/{self.max_epochs}] train_loss: {total_trn_loss:.4f}, lr: {optimizer.param_groups[0]["lr"]:.6f}'
message = f'Epoch [{epoch+1}/{max_epochs}] train_loss: {total_trn_loss:.4f}, lr: {optimizer.param_groups[0]["lr"]:.6f}'

if self.use_ema:
cm = ema.average_parameters()
else:
cm = contextlib.nullcontext()

y_preds = None
if valid_dataset is not None:
with cm:
Expand All @@ -289,47 +354,54 @@ def fit_predict(self,
total_val_loss = np.mean(val_loss)
_score = metric_score["_judge_score"]
_metric = str(self.metrics)
save_handle = partial(self.save_state_dict, model=model, dump_dir=dump_dir, ema=None, suffix="best", model_rank=model_rank)
is_early_stop, min_val_loss, wait, max_score = self._early_stop_choice(wait, min_val_loss, metric_score, max_score, save_handle, self.patience, epoch)
save_handle = partial(self.save_state_dict, model=model, optimizer=optimizer, scheduler=scheduler, dump_dir=dump_dir, ema=ema, suffix="best", model_rank=model_rank)
is_early_stop, best_score, wait, saved = self._early_stop_choice(wait, best_score, metric_score, save_handle, self.patience, epoch)
if saved:
best_epoch = epoch
message += f', val_loss: {total_val_loss:.4f}, ' + \
", ".join([f'val_{k}: {v:.4f}' for k, v in metric_score.items() if k != "_judge_score"]) + \
f', val_judge_score ({_metric}): {_score:.4f}' + \
(f', Patience [{wait}/{self.patience}], min_val_judge_score: {min_val_loss:.4f}' if wait else '')
(f', Patience [{wait}/{self.patience}], min_val_judge_score: {best_score:.4f}' if wait else '')
else:
is_early_stop = False

end_time = time.time()
message += f', {(end_time - start_time):.1f}s'
logger.info(message)
self.save_state_dict(model, dump_dir, ema, "last", model_rank)
self.save_state_dict(model, optimizer, scheduler, dump_dir, ema, "last", model_rank, epoch=epoch, best_score=best_score, best_epoch=best_epoch)
if is_early_stop:
break

if test_dataset is not None:
y_preds, _, metric_score = self.predict(
model=model,
dataset=test_dataset,
loss_terms=loss_terms,
dump_dir=dump_dir,
transform=transform,
epoch=epoch,
load_model=True,
model_rank=model_rank
)
if self.use_ema:
cm = ema.average_parameters()
else:
cm = contextlib.nullcontext()
with cm:
y_preds, _, metric_score = self.predict(
model=model,
dataset=test_dataset,
loss_terms=loss_terms,
dump_dir=dump_dir,
transform=transform,
epoch=epoch,
load_model=True,
model_rank=model_rank
)
else:
metric_score = None
return y_preds, metric_score

def _early_stop_choice(self, wait, min_loss, metric_score, max_score, save_handle, patience, epoch):
return self.metrics._early_stop_choice(wait, min_loss, metric_score, max_score, save_handle, patience, epoch)
def _early_stop_choice(self, wait, min_loss, metric_score, save_handle, patience, epoch):
return self.metrics._early_stop_choice(wait, min_loss, metric_score, save_handle, patience, epoch)

def predict(self, model, dataset, loss_terms, dump_dir, transform, epoch=1, load_model=False, model_rank=None):
self._set_seed(self.seed)
model = model.to(self.device).type(self.dtype)
if load_model == True:
from ..models import get_pretrain_path
pretrain_path = get_pretrain_path(dump_dir, "best", model_rank)
self.load_state_dict(model, pretrain_path, inference=True)
self.load_state_dict(model, pretrain_path=pretrain_path, inference=True)

dataloader = DataLoader(
dataset=dataset,
Expand Down

0 comments on commit ee1b063

Please sign in to comment.