Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

142: Load best model instead of last one #151

Merged
merged 2 commits into from
Feb 21, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 90 additions & 9 deletions alonet/common/pl_helpers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from argparse import ArgumentParser, ArgumentTypeError, Namespace, _ArgumentGroup
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl

from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import datetime
import os
import re
import torch

from argparse import ArgumentParser, _ArgumentGroup, Namespace

parser = ArgumentParser()


Expand All @@ -21,6 +20,17 @@ def vb_folder():
return alofolder


def checkpoint_type(arg_value):
if arg_value is None:
return "last"
elif arg_value.lower() == "last" or arg_value.lower() == "best":
return arg_value.lower()
elif os.path.splitext(arg_value)[-1].lower() == ".ckpt":
return arg_value
else:
raise ArgumentTypeError(f"{arg_value} is not a valid checkpoint. Use 'last','best' or '.ckpt' file")


def add_argparse_args(parent_parser, add_pl_args=True, mode="training"):
"""add cli argparse arguments to parent_parser

Expand Down Expand Up @@ -57,17 +67,33 @@ def add_argparse_args(parent_parser, add_pl_args=True, mode="training"):

if mode == "training":
parser.add_argument("--resume", action="store_true", help="Resume all the training, not only the weights.")
parser.add_argument("--save", action="store_true", help="Save every epoch and keep the top_k=3")
parser.add_argument("--save", action="store_true", help="Save every epoch and keep the top_k")
parser.add_argument(
"--save_top_k",
type=int,
default=3,
help="Stores up to top_k models, by default %(default)s. Use save_top_k=-1 to store all checkpoints",
)
if mode == "eval": # Only make sense for eval
parser.add_argument(
"--checkpoint",
type=checkpoint_type,
default="last",
help="Load the weights from 'best', 'last' or '.ckpt' file into run_id, by default '%(default)s'",
)
parser.add_argument(
"--log",
type=str,
default=None,
nargs="?",
const="wandb",
help="Log results, can specify logger (default:None, if set but value not provided:wandb",
help="Log results, can specify logger, by default %(default)s. If set but value not provided:wandb",
)
parser.add_argument("--cpu", action="store_true", help="Use the CPU instead of scaling on the vaiable GPUs")
parser.add_argument("--run_id", type=str, help="Load the weights from this saved experiment")
parser.add_argument(
"--monitor", type=str, default="val_loss", help="Metric to save/load weights, by default '%(default)s'"
)
parser.add_argument(
"--no_run_id", action="store_true", help="Skip loading form run_id when an experiment is restored."
)
Expand All @@ -83,6 +109,52 @@ def add_argparse_args(parent_parser, add_pl_args=True, mode="training"):
return parent_parser


def checkpoint_handler(checkpoint_path, rfolder, monitor="val_loss"):
if checkpoint_path == "last":
return "last.ckpt"
elif os.path.splitext(checkpoint_path)[-1].lower():
return checkpoint_path
elif checkpoint_path.lower() == "best":
best_path, best_monitor, best_epoch = None, None, None
# First way: load the best model from file name
for fpath in os.listdir(rfolder):
try:
ck_props = dict(map(lambda y: y.split("="), re.findall("[-]*(.*?)(?:-|.ckpt)", fpath)))
except: # No information in filename, skipping
continue
if monitor not in ck_props:
continue
cmonitor, cepoch = float(ck_props[monitor]), int(ck_props["epoch"])
if best_monitor is None or cmonitor < best_monitor:
best_path, best_monitor, best_epoch = fpath, cmonitor, cepoch
elif cmonitor == best_monitor and cepoch < best_epoch:
best_path, best_monitor, best_epoch = fpath, cmonitor, cepoch

if best_path is not None:
print(f"Found best model at {best_path}.")
return best_path

# Second way: load the best model using monitor saved in checkpoints
for fpath in os.listdir(rfolder):
stact_dict = torch.load(os.path.join(rfolder, fpath), map_location="cpu")
ck_props = [v for k, v in stact_dict["callbacks"].items() if ModelCheckpoint.__name__ == k.__name__]
if len(ck_props) == 0 or monitor != ck_props[0]["monitor"]:
continue
cmonitor = ck_props[0]["current_score"].item()
if best_monitor is None or cmonitor < best_monitor:
best_path, best_monitor = fpath, cmonitor

if best_path is not None:
print(f"Found best model at {best_path}.")
return best_path
else:
raise RuntimeError(
f"Not '{monitor}' found on checkpoints. Use '--checkpoint last' instead or another monitor"
)
else:
raise ValueError(f"Unknown checkpoint: {checkpoint_path}")


def load_training(
lit_model_class,
args: Namespace = None,
Expand All @@ -99,8 +171,11 @@ def load_training(

strict = True if "nostrict" not in args else not args.nostrict
if run_id is not None and project_run_id is not None:
run_id_project_dir = os.path.join(vb_folder(), f"project_{project_run_id}")
ckpt_path = os.path.join(run_id_project_dir, run_id, "last.ckpt")
run_id_project_dir = os.path.join(vb_folder(), f"project_{project_run_id}", run_id)
ckpt_path = kwargs.get("checkpoint") or getattr(args, "checkpoint", "last.ckpt") # Highest priority on kwargs
monitor = kwargs.get("monitor") or getattr(args, "monitor", "val_loss")
ckpt_path = checkpoint_handler(ckpt_path, run_id_project_dir, monitor)
ckpt_path = os.path.join(run_id_project_dir, ckpt_path)
if not os.path.exists(ckpt_path):
raise Exception(f"Impossible to load the ckpt at the following destination:{ckpt_path}")
print(f"Loading ckpt from {run_id} at {ckpt_path}")
Expand Down Expand Up @@ -180,8 +255,14 @@ def run_pl_training(
logger = None

if args.save:
monitor = getattr(args, "monitor", "val_loss")
checkpoint_callback = ModelCheckpoint(
dirpath=expe_dir, verbose=True, save_last=True, save_top_k=3, monitor="val_loss"
dirpath=expe_dir,
verbose=True,
save_last=True,
save_top_k=getattr(args, "save_top_k", 3),
monitor=monitor,
filename="{epoch}-{step}-{" + monitor + ":.4f}",
)
callbacks.append(checkpoint_callback)

Expand Down