Skip to content

Commit

Permalink
Merge pull request #151 from Visual-Behavior/loading_best
Browse files Browse the repository at this point in the history
142: Load best model instead of last one
  • Loading branch information
thibo73800 authored Feb 21, 2022
2 parents ae2232b + 4312418 commit 3e967fd
Showing 1 changed file with 90 additions and 9 deletions.
99 changes: 90 additions & 9 deletions alonet/common/pl_helpers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from argparse import ArgumentParser, ArgumentTypeError, Namespace, _ArgumentGroup
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl

from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import datetime
import os
import re
import torch

from argparse import ArgumentParser, _ArgumentGroup, Namespace

parser = ArgumentParser()


Expand All @@ -21,6 +20,17 @@ def vb_folder():
return alofolder


def checkpoint_type(arg_value):
if arg_value is None:
return "last"
elif arg_value.lower() == "last" or arg_value.lower() == "best":
return arg_value.lower()
elif os.path.splitext(arg_value)[-1].lower() == ".ckpt":
return arg_value
else:
raise ArgumentTypeError(f"{arg_value} is not a valid checkpoint. Use 'last','best' or '.ckpt' file")


def add_argparse_args(parent_parser, add_pl_args=True, mode="training"):
"""add cli argparse arguments to parent_parser
Expand Down Expand Up @@ -57,17 +67,33 @@ def add_argparse_args(parent_parser, add_pl_args=True, mode="training"):

if mode == "training":
parser.add_argument("--resume", action="store_true", help="Resume all the training, not only the weights.")
parser.add_argument("--save", action="store_true", help="Save every epoch and keep the top_k=3")
parser.add_argument("--save", action="store_true", help="Save every epoch and keep the top_k")
parser.add_argument(
"--save_top_k",
type=int,
default=3,
help="Stores up to top_k models, by default %(default)s. Use save_top_k=-1 to store all checkpoints",
)
if mode == "eval": # Only make sense for eval
parser.add_argument(
"--checkpoint",
type=checkpoint_type,
default="last",
help="Load the weights from 'best', 'last' or '.ckpt' file into run_id, by default '%(default)s'",
)
parser.add_argument(
"--log",
type=str,
default=None,
nargs="?",
const="wandb",
help="Log results, can specify logger (default:None, if set but value not provided:wandb",
help="Log results, can specify logger, by default %(default)s. If set but value not provided:wandb",
)
parser.add_argument("--cpu", action="store_true", help="Use the CPU instead of scaling on the vaiable GPUs")
parser.add_argument("--run_id", type=str, help="Load the weights from this saved experiment")
parser.add_argument(
"--monitor", type=str, default="val_loss", help="Metric to save/load weights, by default '%(default)s'"
)
parser.add_argument(
"--no_run_id", action="store_true", help="Skip loading form run_id when an experiment is restored."
)
Expand All @@ -83,6 +109,52 @@ def add_argparse_args(parent_parser, add_pl_args=True, mode="training"):
return parent_parser


def checkpoint_handler(checkpoint_path, rfolder, monitor="val_loss"):
if checkpoint_path == "last":
return "last.ckpt"
elif os.path.splitext(checkpoint_path)[-1].lower():
return checkpoint_path
elif checkpoint_path.lower() == "best":
best_path, best_monitor, best_epoch = None, None, None
# First way: load the best model from file name
for fpath in os.listdir(rfolder):
try:
ck_props = dict(map(lambda y: y.split("="), re.findall("[-]*(.*?)(?:-|.ckpt)", fpath)))
except: # No information in filename, skipping
continue
if monitor not in ck_props:
continue
cmonitor, cepoch = float(ck_props[monitor]), int(ck_props["epoch"])
if best_monitor is None or cmonitor < best_monitor:
best_path, best_monitor, best_epoch = fpath, cmonitor, cepoch
elif cmonitor == best_monitor and cepoch < best_epoch:
best_path, best_monitor, best_epoch = fpath, cmonitor, cepoch

if best_path is not None:
print(f"Found best model at {best_path}.")
return best_path

# Second way: load the best model using monitor saved in checkpoints
for fpath in os.listdir(rfolder):
stact_dict = torch.load(os.path.join(rfolder, fpath), map_location="cpu")
ck_props = [v for k, v in stact_dict["callbacks"].items() if ModelCheckpoint.__name__ == k.__name__]
if len(ck_props) == 0 or monitor != ck_props[0]["monitor"]:
continue
cmonitor = ck_props[0]["current_score"].item()
if best_monitor is None or cmonitor < best_monitor:
best_path, best_monitor = fpath, cmonitor

if best_path is not None:
print(f"Found best model at {best_path}.")
return best_path
else:
raise RuntimeError(
f"Not '{monitor}' found on checkpoints. Use '--checkpoint last' instead or another monitor"
)
else:
raise ValueError(f"Unknown checkpoint: {checkpoint_path}")


def load_training(
lit_model_class,
args: Namespace = None,
Expand All @@ -99,8 +171,11 @@ def load_training(

strict = True if "nostrict" not in args else not args.nostrict
if run_id is not None and project_run_id is not None:
run_id_project_dir = os.path.join(vb_folder(), f"project_{project_run_id}")
ckpt_path = os.path.join(run_id_project_dir, run_id, "last.ckpt")
run_id_project_dir = os.path.join(vb_folder(), f"project_{project_run_id}", run_id)
ckpt_path = kwargs.get("checkpoint") or getattr(args, "checkpoint", "last.ckpt") # Highest priority on kwargs
monitor = kwargs.get("monitor") or getattr(args, "monitor", "val_loss")
ckpt_path = checkpoint_handler(ckpt_path, run_id_project_dir, monitor)
ckpt_path = os.path.join(run_id_project_dir, ckpt_path)
if not os.path.exists(ckpt_path):
raise Exception(f"Impossible to load the ckpt at the following destination:{ckpt_path}")
print(f"Loading ckpt from {run_id} at {ckpt_path}")
Expand Down Expand Up @@ -180,8 +255,14 @@ def run_pl_training(
logger = None

if args.save:
monitor = getattr(args, "monitor", "val_loss")
checkpoint_callback = ModelCheckpoint(
dirpath=expe_dir, verbose=True, save_last=True, save_top_k=3, monitor="val_loss"
dirpath=expe_dir,
verbose=True,
save_last=True,
save_top_k=getattr(args, "save_top_k", 3),
monitor=monitor,
filename="{epoch}-{step}-{" + monitor + ":.4f}",
)
callbacks.append(checkpoint_callback)

Expand Down

0 comments on commit 3e967fd

Please sign in to comment.