From 27c9807410dba272da3db7e826551f945e2642c5 Mon Sep 17 00:00:00 2001 From: Johansmm Date: Fri, 18 Feb 2022 14:13:59 +0100 Subject: [PATCH 1/2] feat: load best weights, save_top_k formatted --- alonet/common/pl_helpers.py | 98 +++++++++++++++++++++++++++++++++---- 1 file changed, 89 insertions(+), 9 deletions(-) diff --git a/alonet/common/pl_helpers.py b/alonet/common/pl_helpers.py index fe2e85ea..fd80b69f 100644 --- a/alonet/common/pl_helpers.py +++ b/alonet/common/pl_helpers.py @@ -1,13 +1,12 @@ +from argparse import ArgumentParser, ArgumentTypeError, Namespace, _ArgumentGroup from pytorch_lightning.callbacks import ModelCheckpoint import pytorch_lightning as pl - -from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger import datetime import os +import re import torch -from argparse import ArgumentParser, _ArgumentGroup, Namespace - parser = ArgumentParser() @@ -21,6 +20,17 @@ def vb_folder(): return alofolder +def checkpoint_type(arg_value): + if arg_value is None: + return "last" + elif arg_value.lower() == "last" or arg_value.lower() == "best": + return arg_value.lower() + elif os.path.splitext(arg_value)[-1].lower() == ".ckpt": + return arg_value + else: + raise ArgumentTypeError(f"{arg_value} is not a valid checkpoint. Use 'last','best' or '.ckpt' file") + + def add_argparse_args(parent_parser, add_pl_args=True, mode="training"): """add cli argparse arguments to parent_parser @@ -57,17 +67,33 @@ def add_argparse_args(parent_parser, add_pl_args=True, mode="training"): if mode == "training": parser.add_argument("--resume", action="store_true", help="Resume all the training, not only the weights.") - parser.add_argument("--save", action="store_true", help="Save every epoch and keep the top_k=3") + parser.add_argument("--save", action="store_true", help="Save every epoch and keep the top_k") + parser.add_argument( + "--save_top_k", + type=int, + default=3, + help="Stores up to top_k models, by default %(default)s. Use save_top_k=-1 to store all checkpoints", + ) + if mode == "eval": # Only make sense for eval + parser.add_argument( + "--checkpoint", + type=checkpoint_type, + default="last", + help="Load the weights from 'best', 'last' or '.ckpt' file into run_id, by default '%(default)s'", + ) parser.add_argument( "--log", type=str, default=None, nargs="?", const="wandb", - help="Log results, can specify logger (default:None, if set but value not provided:wandb", + help="Log results, can specify logger, by default %(default)s. If set but value not provided:wandb", ) parser.add_argument("--cpu", action="store_true", help="Use the CPU instead of scaling on the vaiable GPUs") parser.add_argument("--run_id", type=str, help="Load the weights from this saved experiment") + parser.add_argument( + "--monitor", type=str, default="val_loss", help="Metric to save/load weights, by default '%(default)s'" + ) parser.add_argument( "--no_run_id", action="store_true", help="Skip loading form run_id when an experiment is restored." ) @@ -83,6 +109,52 @@ def add_argparse_args(parent_parser, add_pl_args=True, mode="training"): return parent_parser +def checkpoint_handler(checkpoint_path, rfolder, monitor="val_loss"): + if checkpoint_path == "last": + return "last.ckpt" + elif os.path.splitext(checkpoint_path)[-1].lower(): + return checkpoint_path + elif checkpoint_path.lower() == "best": + best_path, best_monitor, best_epoch = None, None, None + # First way: load the best model from file name + for fpath in os.listdir(rfolder): + try: + ck_props = dict(map(lambda y: y.split("="), re.findall("[-]*(.*?)(?:-|.ckpt)", fpath))) + except: # No information in filename, skipping + continue + if monitor not in ck_props: + continue + cmonitor, cepoch = float(ck_props[monitor]), int(ck_props["epoch"]) + if best_monitor is None or cmonitor < best_monitor: + best_path, best_monitor, best_epoch = fpath, cmonitor, cepoch + elif cmonitor == best_monitor and cepoch < best_epoch: + best_path, best_monitor, best_epoch = fpath, cmonitor, cepoch + + if best_path is not None: + print(f"Found best model at {best_path}.") + return best_path + + # Second way: load the best model using monitor saved in checkpoints + for fpath in os.listdir(rfolder): + stact_dict = torch.load(os.path.join(rfolder, fpath), map_location="cpu") + ck_props = [v for k, v in stact_dict["callbacks"].items() if ModelCheckpoint.__name__ == k.__name__] + if len(ck_props) == 0 or monitor != ck_props[0]["monitor"]: + continue + cmonitor = ck_props[0]["current_score"].item() + if best_monitor is None or cmonitor < best_monitor: + best_path, best_monitor = fpath, cmonitor + + if best_path is not None: + print(f"Found best model at {best_path}.") + return best_path + else: + raise RuntimeError( + f"Not '{monitor}' found on checkpoints. Use '--checkpoint last' instead or another monitor" + ) + else: + raise ValueError(f"Unknown checkpoint: {checkpoint_path}") + + def load_training( lit_model_class, args: Namespace = None, @@ -99,8 +171,10 @@ def load_training( strict = True if "nostrict" not in args else not args.nostrict if run_id is not None and project_run_id is not None: - run_id_project_dir = os.path.join(vb_folder(), f"project_{project_run_id}") - ckpt_path = os.path.join(run_id_project_dir, run_id, "last.ckpt") + run_id_project_dir = os.path.join(vb_folder(), f"project_{project_run_id}", run_id) + ckpt_path = kwargs.get("checkpoint") or getattr(args, "checkpoint", "last.ckpt") # Highest priority on kwargs + ckpt_path = checkpoint_handler(ckpt_path, run_id_project_dir) + ckpt_path = os.path.join(run_id_project_dir, ckpt_path) if not os.path.exists(ckpt_path): raise Exception(f"Impossible to load the ckpt at the following destination:{ckpt_path}") print(f"Loading ckpt from {run_id} at {ckpt_path}") @@ -180,8 +254,14 @@ def run_pl_training( logger = None if args.save: + monitor = getattr(args, "monitor", "val_loss") checkpoint_callback = ModelCheckpoint( - dirpath=expe_dir, verbose=True, save_last=True, save_top_k=3, monitor="val_loss" + dirpath=expe_dir, + verbose=True, + save_last=True, + save_top_k=getattr(args, "save_top_k", 3), + monitor=monitor, + filename="{epoch}-{step}-{" + monitor + ":.4f}", ) callbacks.append(checkpoint_callback) From 43124184a192c3180e6d2d0fd003dca52e4a76e0 Mon Sep 17 00:00:00 2001 From: Johansmm Date: Fri, 18 Feb 2022 14:43:24 +0100 Subject: [PATCH 2/2] fix: use monitor on loading method --- alonet/common/pl_helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/alonet/common/pl_helpers.py b/alonet/common/pl_helpers.py index fd80b69f..3c54380b 100644 --- a/alonet/common/pl_helpers.py +++ b/alonet/common/pl_helpers.py @@ -173,7 +173,8 @@ def load_training( if run_id is not None and project_run_id is not None: run_id_project_dir = os.path.join(vb_folder(), f"project_{project_run_id}", run_id) ckpt_path = kwargs.get("checkpoint") or getattr(args, "checkpoint", "last.ckpt") # Highest priority on kwargs - ckpt_path = checkpoint_handler(ckpt_path, run_id_project_dir) + monitor = kwargs.get("monitor") or getattr(args, "monitor", "val_loss") + ckpt_path = checkpoint_handler(ckpt_path, run_id_project_dir, monitor) ckpt_path = os.path.join(run_id_project_dir, ckpt_path) if not os.path.exists(ckpt_path): raise Exception(f"Impossible to load the ckpt at the following destination:{ckpt_path}")