From fa8f05c4171c71c65afda3a10223004f88ac3082 Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Thu, 29 Jul 2021 15:25:54 +0200 Subject: [PATCH 1/9] added callbacks --- train.py | 19 ++++- utils/callbacks.py | 183 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+), 3 deletions(-) create mode 100644 utils/callbacks.py diff --git a/train.py b/train.py index 1d3404ffc41..e3fcd0b3429 100644 --- a/train.py +++ b/train.py @@ -43,6 +43,8 @@ from utils.metrics import fitness from utils.loggers import Loggers +from utils import callbacks + LOGGER = logging.getLogger(__name__) LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) @@ -52,6 +54,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary opt, device, + callbacks ): save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ @@ -330,6 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots) + callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots) # end batch ------------------------------------------------------------------------------------------------ @@ -340,6 +344,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if RANK in [-1, 0]: # mAP loggers.on_train_epoch_end(epoch) + callbacks.on_train_epoch_end(epoch) + ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) final_epoch = epoch + 1 == epochs if not noval or final_epoch: # Calculate mAP @@ -361,6 +367,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if fi > best_fitness: best_fitness = fi loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi) + callbacks.on_val_end(mloss, results, lr, epoch, best_fitness, fi) # Save model if (not nosave) or (final_epoch and not evolve): # if save @@ -378,6 +385,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary torch.save(ckpt, best) del ckpt loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi) + callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi) # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- @@ -401,6 +409,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if f.exists(): strip_optimizer(f) # strip optimizers loggers.on_train_end(last, best, plots) + callbacks.on_train_end(last, best, plots) torch.cuda.empty_cache() return results @@ -447,7 +456,11 @@ def parse_opt(known=False): return opt -def main(opt): +def main(opt, callback_handler = None): + + # Define new hook handler if one is not passed in + if not callback_handler: callback_handler = callbacks.Callbacks() + set_logging(RANK) if RANK in [-1, 0]: print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items())) @@ -483,7 +496,7 @@ def main(opt): # Train if not opt.evolve: - train(opt.hyp, opt, device) + train(opt.hyp, opt, device, callback_handler) if WORLD_SIZE > 1 and RANK == 0: _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')] @@ -563,7 +576,7 @@ def main(opt): hyp[k] = round(hyp[k], 5) # significant digits # Train mutation - results = train(hyp.copy(), opt, device) + results = train(hyp.copy(), opt, device, callback_handler) # Write mutation results print_mutation(hyp.copy(), results, yaml_file, opt.bucket) diff --git a/utils/callbacks.py b/utils/callbacks.py new file mode 100644 index 00000000000..d02f114021b --- /dev/null +++ b/utils/callbacks.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python + +class Callbacks: + """" + Handles all registered callbacks for YOLOv5 Hooks + """ + + _callbacks = { + 'on_pretrain_routine_start' :[], + 'on_pretrain_routine_end' :[], + + 'on_train_start' :[], + 'on_train_end': [], + 'on_train_epoch_start': [], + 'on_train_epoch_end': [], + 'on_train_batch_start': [], + 'on_train_batch_end': [], + + 'on_val_start' :[], + 'on_val_end': [], + 'on_val_epoch_start': [], + 'on_val_epoch_end': [], + 'on_val_batch_start': [], + 'on_val_batch_end': [], + + + 'on_model_save': [], + 'optimizer_step': [], + 'on_before_zero_grad': [], + 'teardown': [], + } + + def __init__(self): + return + + def regsiterAction(self, hook, name, callback): + """ + Register a new action to a callback hook + + Args: + action The callback hook name to register the action to + name The name of the action + callback The callback to fire + + Returns: + (Bool) The success state + """ + if hook in self._callbacks: + self._callbacks[hook].append({'name': name, 'callback': callback}) + return True + else: + return False + + def getRegisteredActions(self, hook=None): + """" + Returns all the registered actions by callback hook + + Args: + hook The name of the hook to check, defaults to all + """ + if hook: + return self._callbacks[hook] + else: + return self._callbacks + + def fireCallbacks(self, register, *args): + """ + Loop throughs the registered actions and fires all callbacks + """ + for logger in register: + logger['callback'](*args) + + + def on_pretrain_routine_start(self, *args): + """ + Fires all registered callbacks at the start of each pretraining routine + """ + self.fireCallbacks(self._callbacks['on_pretrain_routine_start'], *args) + + def on_pretrain_routine_end(self, *args): + """ + Fires all registered callbacks at the end of each pretraining routine + """ + self.fireCallbacks(self._callbacks['on_pretrain_routine_end'], *args) + + def on_train_start(self, *args): + """ + Fires all registered callbacks at the start of each training + """ + self.fireCallbacks(self._callbacks['on_train_start'], *args) + + def on_train_end(self, *args): + """ + Fires all registered callbacks at the end of training + """ + self.fireCallbacks(self._callbacks['on_train_end'], *args) + + def on_train_epoch_start(self, *args): + """ + Fires all registered callbacks at the start of each training epoch + """ + self.fireCallbacks(self._callbacks['on_train_epoch_start'], *args) + + def on_train_epoch_end(self, *args): + """ + Fires all registered callbacks at the end of each training epoch + """ + self.fireCallbacks(self._callbacks['on_train_epoch_end'], *args) + + + def on_train_batch_start(self, *args): + """ + Fires all registered callbacks at the start of each training batch + """ + self.fireCallbacks(self._callbacks['on_train_batch_start'], *args) + + def on_train_batch_end(self, *args): + """ + Fires all registered callbacks at the end of each training batch + """ + self.fireCallbacks(self._callbacks['on_train_batch_end'], *args) + + def on_val_start(self, *args): + """ + Fires all registered callbacks at the start of the validation + """ + self.fireCallbacks(self._callbacks['on_val_start'], *args) + + def on_val_end(self, *args): + """ + Fires all registered callbacks at the end of the validation + """ + self.fireCallbacks(self._callbacks['on_val_end'], *args) + + def on_val_epoch_start(self, *args): + """ + Fires all registered callbacks at the start of each validation epoch + """ + self.fireCallbacks(self._callbacks['on_val_epoch_start'], *args) + + def on_val_epoch_end(self, *args): + """ + Fires all registered callbacks at the end of each validation epoch + """ + self.fireCallbacks(self._callbacks['on_val_epoch_end'], *args) + + def on_val_batch_start(self, *args): + """ + Fires all registered callbacks at the start of each validation batch + """ + self.fireCallbacks(self._callbacks['on_val_batch_start'], *args) + + def on_val_batch_end(self, *args): + """ + Fires all registered callbacks at the end of each validation batch + """ + self.fireCallbacks(self._callbacks['on_val_batch_end'], *args) + + def on_model_save(self, *args): + """ + Fires all registered callbacks after each model save + """ + self.fireCallbacks(self._callbacks['on_model_save'], *args) + + def optimizer_step(self, *args): + """ + Fires all registered callbacks on each optimizer step + """ + self.fireCallbacks(self._callbacks['optimizer_step'], *args) + + def on_before_zero_grad(self, *args): + """ + Fires all registered callbacks before zero grad + """ + self.fireCallbacks(self._callbacks['on_before_zero_grad'], *args) + + def teardown(self, *args): + """ + Fires all registered callbacks before teardown + """ + self.fireCallbacks(self._callbacks['teardown'], *args) + + From 17de22e324a6ac73a1261556847930df36c3e813 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 30 Jul 2021 14:21:58 +0200 Subject: [PATCH 2/9] Update callbacks.py --- utils/callbacks.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/utils/callbacks.py b/utils/callbacks.py index d02f114021b..fc482bc61af 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -6,24 +6,23 @@ class Callbacks: """ _callbacks = { - 'on_pretrain_routine_start' :[], - 'on_pretrain_routine_end' :[], + 'on_pretrain_routine_start': [], + 'on_pretrain_routine_end': [], - 'on_train_start' :[], + 'on_train_start': [], 'on_train_end': [], 'on_train_epoch_start': [], 'on_train_epoch_end': [], 'on_train_batch_start': [], 'on_train_batch_end': [], - - 'on_val_start' :[], + + 'on_val_start': [], 'on_val_end': [], 'on_val_epoch_start': [], 'on_val_epoch_end': [], 'on_val_batch_start': [], 'on_val_batch_end': [], - 'on_model_save': [], 'optimizer_step': [], 'on_before_zero_grad': [], @@ -70,7 +69,6 @@ def fireCallbacks(self, register, *args): for logger in register: logger['callback'](*args) - def on_pretrain_routine_start(self, *args): """ Fires all registered callbacks at the start of each pretraining routine @@ -107,7 +105,6 @@ def on_train_epoch_end(self, *args): """ self.fireCallbacks(self._callbacks['on_train_epoch_end'], *args) - def on_train_batch_start(self, *args): """ Fires all registered callbacks at the start of each training batch @@ -179,5 +176,3 @@ def teardown(self, *args): Fires all registered callbacks before teardown """ self.fireCallbacks(self._callbacks['teardown'], *args) - - From 64b0f2ecb94de63bcd51439991f9505422c52bf7 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 30 Jul 2021 14:22:52 +0200 Subject: [PATCH 3/9] Update train.py --- train.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/train.py b/train.py index e3fcd0b3429..eb84494b2c8 100644 --- a/train.py +++ b/train.py @@ -6,19 +6,18 @@ import argparse import logging +import math +import numpy as np import os import random import sys import time -from copy import deepcopy -from pathlib import Path - -import math -import numpy as np import torch import torch.distributed as dist import torch.nn as nn import yaml +from copy import deepcopy +from pathlib import Path from torch.cuda import amp from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Adam, SGD, lr_scheduler @@ -42,8 +41,7 @@ from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.metrics import fitness from utils.loggers import Loggers - -from utils import callbacks +from utils.callbacks import Callbacks LOGGER = logging.getLogger(__name__) LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html @@ -456,11 +454,8 @@ def parse_opt(known=False): return opt -def main(opt, callback_handler = None): - - # Define new hook handler if one is not passed in - if not callback_handler: callback_handler = callbacks.Callbacks() - +def main(opt, callback_handler=Callbacks()): + # Checks set_logging(RANK) if RANK in [-1, 0]: print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items())) From 3f693b58e8ef3ec14b3a9c9a1465b1f4fccbcaa6 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 30 Jul 2021 14:23:45 +0200 Subject: [PATCH 4/9] Update val.py --- val.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/val.py b/val.py index 86439b1380d..e231f60f6be 100644 --- a/val.py +++ b/val.py @@ -6,13 +6,12 @@ import argparse import json +import numpy as np import os import sys +import torch from pathlib import Path from threading import Thread - -import numpy as np -import torch from tqdm import tqdm FILE = Path(__file__).absolute() @@ -26,6 +25,7 @@ from utils.plots import plot_images, output_to_target, plot_study_txt from utils.torch_utils import select_device, time_sync from utils.loggers import Loggers +from utils.callbacks import Callbacks def save_one_txt(predn, save_conf, shape, file): @@ -98,6 +98,7 @@ def run(data, save_dir=Path(''), plots=True, loggers=Loggers(), + callbacks=Callbacks(), compute_loss=None, ): # Initialize/load model and set device @@ -214,6 +215,7 @@ def run(data, if save_json: save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary loggers.on_val_batch_end(pred, predn, path, names, img[si]) + callbacks.on_val_batch_end(pred, predn, path, names, img[si]) # Plot images if plots and batch_i < 3: @@ -251,6 +253,7 @@ def run(data, if plots: confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) loggers.on_val_end() + callbacks.on_val_end() # Save JSON if save_json and len(jdict): From 59731bc9f8ddd361fceda871f6879a5026dae858 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 31 Jul 2021 20:45:01 +0200 Subject: [PATCH 5/9] Fix CamlCase add staticmethod --- utils/callbacks.py | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/utils/callbacks.py b/utils/callbacks.py index fc482bc61af..78a4f02be31 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -32,7 +32,7 @@ class Callbacks: def __init__(self): return - def regsiterAction(self, hook, name, callback): + def register_action(self, hook, name, callback): """ Register a new action to a callback hook @@ -50,7 +50,7 @@ def regsiterAction(self, hook, name, callback): else: return False - def getRegisteredActions(self, hook=None): + def get_registered_actions(self, hook=None): """" Returns all the registered actions by callback hook @@ -62,9 +62,10 @@ def getRegisteredActions(self, hook=None): else: return self._callbacks - def fireCallbacks(self, register, *args): + @staticmethod + def fire_callbacks(register, *args): """ - Loop throughs the registered actions and fires all callbacks + Loop through the registered actions and fire all callbacks """ for logger in register: logger['callback'](*args) @@ -73,106 +74,106 @@ def on_pretrain_routine_start(self, *args): """ Fires all registered callbacks at the start of each pretraining routine """ - self.fireCallbacks(self._callbacks['on_pretrain_routine_start'], *args) + self.fire_callbacks(self._callbacks['on_pretrain_routine_start'], *args) def on_pretrain_routine_end(self, *args): """ Fires all registered callbacks at the end of each pretraining routine """ - self.fireCallbacks(self._callbacks['on_pretrain_routine_end'], *args) + self.fire_callbacks(self._callbacks['on_pretrain_routine_end'], *args) def on_train_start(self, *args): """ Fires all registered callbacks at the start of each training """ - self.fireCallbacks(self._callbacks['on_train_start'], *args) + self.fire_callbacks(self._callbacks['on_train_start'], *args) def on_train_end(self, *args): """ Fires all registered callbacks at the end of training """ - self.fireCallbacks(self._callbacks['on_train_end'], *args) + self.fire_callbacks(self._callbacks['on_train_end'], *args) def on_train_epoch_start(self, *args): """ Fires all registered callbacks at the start of each training epoch """ - self.fireCallbacks(self._callbacks['on_train_epoch_start'], *args) + self.fire_callbacks(self._callbacks['on_train_epoch_start'], *args) def on_train_epoch_end(self, *args): """ Fires all registered callbacks at the end of each training epoch """ - self.fireCallbacks(self._callbacks['on_train_epoch_end'], *args) + self.fire_callbacks(self._callbacks['on_train_epoch_end'], *args) def on_train_batch_start(self, *args): """ Fires all registered callbacks at the start of each training batch """ - self.fireCallbacks(self._callbacks['on_train_batch_start'], *args) + self.fire_callbacks(self._callbacks['on_train_batch_start'], *args) def on_train_batch_end(self, *args): """ Fires all registered callbacks at the end of each training batch """ - self.fireCallbacks(self._callbacks['on_train_batch_end'], *args) + self.fire_callbacks(self._callbacks['on_train_batch_end'], *args) def on_val_start(self, *args): """ Fires all registered callbacks at the start of the validation """ - self.fireCallbacks(self._callbacks['on_val_start'], *args) + self.fire_callbacks(self._callbacks['on_val_start'], *args) def on_val_end(self, *args): """ Fires all registered callbacks at the end of the validation """ - self.fireCallbacks(self._callbacks['on_val_end'], *args) + self.fire_callbacks(self._callbacks['on_val_end'], *args) def on_val_epoch_start(self, *args): """ Fires all registered callbacks at the start of each validation epoch """ - self.fireCallbacks(self._callbacks['on_val_epoch_start'], *args) + self.fire_callbacks(self._callbacks['on_val_epoch_start'], *args) def on_val_epoch_end(self, *args): """ Fires all registered callbacks at the end of each validation epoch """ - self.fireCallbacks(self._callbacks['on_val_epoch_end'], *args) + self.fire_callbacks(self._callbacks['on_val_epoch_end'], *args) def on_val_batch_start(self, *args): """ Fires all registered callbacks at the start of each validation batch """ - self.fireCallbacks(self._callbacks['on_val_batch_start'], *args) + self.fire_callbacks(self._callbacks['on_val_batch_start'], *args) def on_val_batch_end(self, *args): """ Fires all registered callbacks at the end of each validation batch """ - self.fireCallbacks(self._callbacks['on_val_batch_end'], *args) + self.fire_callbacks(self._callbacks['on_val_batch_end'], *args) def on_model_save(self, *args): """ Fires all registered callbacks after each model save """ - self.fireCallbacks(self._callbacks['on_model_save'], *args) + self.fire_callbacks(self._callbacks['on_model_save'], *args) def optimizer_step(self, *args): """ Fires all registered callbacks on each optimizer step """ - self.fireCallbacks(self._callbacks['optimizer_step'], *args) + self.fire_callbacks(self._callbacks['optimizer_step'], *args) def on_before_zero_grad(self, *args): """ Fires all registered callbacks before zero grad """ - self.fireCallbacks(self._callbacks['on_before_zero_grad'], *args) + self.fire_callbacks(self._callbacks['on_before_zero_grad'], *args) def teardown(self, *args): """ Fires all registered callbacks before teardown """ - self.fireCallbacks(self._callbacks['teardown'], *args) + self.fire_callbacks(self._callbacks['teardown'], *args) From 60349f2455a471738b9b0e06aa9d8e04870f1c9c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 31 Jul 2021 22:38:59 +0200 Subject: [PATCH 6/9] Refactor logger into callbacks --- train.py | 30 +++++---- utils/callbacks.py | 129 +++++++++++++++++--------------------- utils/general.py | 5 ++ utils/loggers/__init__.py | 30 ++++----- utils/plots.py | 6 +- val.py | 11 ++-- 6 files changed, 96 insertions(+), 115 deletions(-) diff --git a/train.py b/train.py index eb84494b2c8..f1ce9df7839 100644 --- a/train.py +++ b/train.py @@ -6,18 +6,19 @@ import argparse import logging -import math -import numpy as np import os import random import sys import time +from copy import deepcopy +from pathlib import Path + +import math +import numpy as np import torch import torch.distributed as dist import torch.nn as nn import yaml -from copy import deepcopy -from pathlib import Path from torch.cuda import amp from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Adam, SGD, lr_scheduler @@ -33,7 +34,7 @@ from utils.datasets import create_dataloader from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ - check_requirements, print_mutation, set_logging, one_cycle, colorstr + check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods from utils.downloads import attempt_download from utils.loss import ComputeLoss from utils.plots import plot_labels, plot_evolution @@ -78,12 +79,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Loggers if RANK in [-1, 0]: - loggers = Loggers(save_dir, weights, opt, hyp, LOGGER).start() # loggers dict + loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance if loggers.wandb: data_dict = loggers.wandb.data_dict if resume: weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp + # Register actions + for k in methods(loggers): + callbacks.register_action(k, callback=getattr(loggers, k)) + # Config plots = not evolve # create plots cuda = device.type != 'cpu' @@ -216,7 +221,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency # model._initialize_biases(cf.to(device)) if plots: - plot_labels(labels, names, save_dir, loggers) + plot_labels(labels, names, save_dir) # Anchors if not opt.noautoanchor: @@ -330,9 +335,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) - loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots) callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots) - # end batch ------------------------------------------------------------------------------------------------ # Scheduler @@ -341,9 +344,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if RANK in [-1, 0]: # mAP - loggers.on_train_epoch_end(epoch) callbacks.on_train_epoch_end(epoch) - ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) final_epoch = epoch + 1 == epochs if not noval or final_epoch: # Calculate mAP @@ -357,15 +358,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary save_json=is_coco and final_epoch, verbose=nc < 50 and final_epoch, plots=plots and final_epoch, - loggers=loggers, + callbacks=callbacks, compute_loss=compute_loss) # Update best mAP fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] if fi > best_fitness: best_fitness = fi - loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi) - callbacks.on_val_end(mloss, results, lr, epoch, best_fitness, fi) + callbacks.on_fit_epoch_end(mloss, results, lr, epoch, best_fitness, fi) # Save model if (not nosave) or (final_epoch and not evolve): # if save @@ -382,7 +382,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if best_fitness == fi: torch.save(ckpt, best) del ckpt - loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi) callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi) # end epoch ---------------------------------------------------------------------------------------------------- @@ -406,7 +405,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary for f in last, best: if f.exists(): strip_optimizer(f) # strip optimizers - loggers.on_train_end(last, best, plots) callbacks.on_train_end(last, best, plots) torch.cuda.empty_cache() diff --git a/utils/callbacks.py b/utils/callbacks.py index 78a4f02be31..8eb3ae26bbf 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -10,45 +10,40 @@ class Callbacks: 'on_pretrain_routine_end': [], 'on_train_start': [], - 'on_train_end': [], 'on_train_epoch_start': [], - 'on_train_epoch_end': [], 'on_train_batch_start': [], + 'optimizer_step': [], + 'on_before_zero_grad': [], 'on_train_batch_end': [], + 'on_train_epoch_end': [], 'on_val_start': [], - 'on_val_end': [], - 'on_val_epoch_start': [], - 'on_val_epoch_end': [], 'on_val_batch_start': [], 'on_val_batch_end': [], + 'on_val_end': [], + 'on_fit_epoch_end': [], # fit = train + val 'on_model_save': [], - 'optimizer_step': [], - 'on_before_zero_grad': [], + 'on_train_end': [], + 'teardown': [], } def __init__(self): return - def register_action(self, hook, name, callback): + def register_action(self, hook, name='', callback=None): """ Register a new action to a callback hook Args: - action The callback hook name to register the action to + hook The callback hook name to register the action to name The name of the action callback The callback to fire - - Returns: - (Bool) The success state """ - if hook in self._callbacks: - self._callbacks[hook].append({'name': name, 'callback': callback}) - return True - else: - return False + assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" + assert callable(callback), f"callback '{callback}' is not callable" + self._callbacks[hook].append({'name': name, 'callback': callback}) def get_registered_actions(self, hook=None): """" @@ -63,117 +58,111 @@ def get_registered_actions(self, hook=None): return self._callbacks @staticmethod - def fire_callbacks(register, *args): + def run_callbacks(register, *args, **kwargs): """ Loop through the registered actions and fire all callbacks """ for logger in register: - logger['callback'](*args) + logger['callback'](*args, **kwargs) - def on_pretrain_routine_start(self, *args): + def on_pretrain_routine_start(self, *args, **kwargs): """ Fires all registered callbacks at the start of each pretraining routine """ - self.fire_callbacks(self._callbacks['on_pretrain_routine_start'], *args) + self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs) - def on_pretrain_routine_end(self, *args): + def on_pretrain_routine_end(self, *args, **kwargs): """ Fires all registered callbacks at the end of each pretraining routine """ - self.fire_callbacks(self._callbacks['on_pretrain_routine_end'], *args) + self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs) - def on_train_start(self, *args): + def on_train_start(self, *args, **kwargs): """ Fires all registered callbacks at the start of each training """ - self.fire_callbacks(self._callbacks['on_train_start'], *args) + self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs) - def on_train_end(self, *args): + def on_train_epoch_start(self, *args, **kwargs): """ - Fires all registered callbacks at the end of training + Fires all registered callbacks at the start of each training epoch """ - self.fire_callbacks(self._callbacks['on_train_end'], *args) + self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs) - def on_train_epoch_start(self, *args): + def on_train_batch_start(self, *args, **kwargs): """ - Fires all registered callbacks at the start of each training epoch + Fires all registered callbacks at the start of each training batch """ - self.fire_callbacks(self._callbacks['on_train_epoch_start'], *args) + self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs) - def on_train_epoch_end(self, *args): + def optimizer_step(self, *args, **kwargs): """ - Fires all registered callbacks at the end of each training epoch + Fires all registered callbacks on each optimizer step """ - self.fire_callbacks(self._callbacks['on_train_epoch_end'], *args) + self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs) - def on_train_batch_start(self, *args): + def on_before_zero_grad(self, *args, **kwargs): """ - Fires all registered callbacks at the start of each training batch + Fires all registered callbacks before zero grad """ - self.fire_callbacks(self._callbacks['on_train_batch_start'], *args) + self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs) - def on_train_batch_end(self, *args): + def on_train_batch_end(self, *args, **kwargs): """ Fires all registered callbacks at the end of each training batch """ - self.fire_callbacks(self._callbacks['on_train_batch_end'], *args) + self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs) - def on_val_start(self, *args): + def on_train_epoch_end(self, *args, **kwargs): """ - Fires all registered callbacks at the start of the validation + Fires all registered callbacks at the end of each training epoch """ - self.fire_callbacks(self._callbacks['on_val_start'], *args) + self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs) - def on_val_end(self, *args): + def on_val_start(self, *args, **kwargs): """ - Fires all registered callbacks at the end of the validation + Fires all registered callbacks at the start of the validation """ - self.fire_callbacks(self._callbacks['on_val_end'], *args) + self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs) - def on_val_epoch_start(self, *args): + def on_val_batch_start(self, *args, **kwargs): """ - Fires all registered callbacks at the start of each validation epoch + Fires all registered callbacks at the start of each validation batch """ - self.fire_callbacks(self._callbacks['on_val_epoch_start'], *args) + self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs) - def on_val_epoch_end(self, *args): + def on_val_batch_end(self, *args, **kwargs): """ - Fires all registered callbacks at the end of each validation epoch + Fires all registered callbacks at the end of each validation batch """ - self.fire_callbacks(self._callbacks['on_val_epoch_end'], *args) + self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs) - def on_val_batch_start(self, *args): + def on_val_end(self, *args, **kwargs): """ - Fires all registered callbacks at the start of each validation batch + Fires all registered callbacks at the end of the validation """ - self.fire_callbacks(self._callbacks['on_val_batch_start'], *args) + self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs) - def on_val_batch_end(self, *args): + def on_fit_epoch_end(self, *args, **kwargs): """ - Fires all registered callbacks at the end of each validation batch + Fires all registered callbacks at the end of each fit (train+val) epoch """ - self.fire_callbacks(self._callbacks['on_val_batch_end'], *args) + self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs) - def on_model_save(self, *args): + def on_model_save(self, *args, **kwargs): """ Fires all registered callbacks after each model save """ - self.fire_callbacks(self._callbacks['on_model_save'], *args) + self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs) - def optimizer_step(self, *args): - """ - Fires all registered callbacks on each optimizer step + def on_train_end(self, *args, **kwargs): """ - self.fire_callbacks(self._callbacks['optimizer_step'], *args) - - def on_before_zero_grad(self, *args): - """ - Fires all registered callbacks before zero grad + Fires all registered callbacks at the end of training """ - self.fire_callbacks(self._callbacks['on_before_zero_grad'], *args) + self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs) - def teardown(self, *args): + def teardown(self, *args, **kwargs): """ Fires all registered callbacks before teardown """ - self.fire_callbacks(self._callbacks['teardown'], *args) + self.run_callbacks(self._callbacks['teardown'], *args, **kwargs) diff --git a/utils/general.py b/utils/general.py index a414b391d24..ed028d2b376 100755 --- a/utils/general.py +++ b/utils/general.py @@ -67,6 +67,11 @@ def handler(*args, **kwargs): return handler +def methods(instance): + # Get class/instance methods + return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] + + def set_logging(rank=-1, verbose=True): logging.basicConfig( format="%(message)s", diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index 06d562d60f9..a6ea4d987ad 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -29,10 +29,12 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, self.hyp = hyp self.logger = logger # for printing results to console self.include = include + self.keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss + 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics + 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss + 'x/lr0', 'x/lr1', 'x/lr2'] # params for k in LOGGERS: setattr(self, k, None) # init empty logger dictionary - - def start(self): self.csv = True # always log to csv # Message @@ -57,7 +59,11 @@ def start(self): else: self.wandb = None - return self + def on_pretrain_routine_end(self): + # Callback runs on pre-train routine end + paths = self.save_dir.glob('*labels*.jpg') # training labels + if self.wandb: + self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]}) def on_train_batch_end(self, ni, model, imgs, targets, paths, plots): # Callback runs on train batch end @@ -89,19 +95,14 @@ def on_val_end(self): files = sorted(self.save_dir.glob('val*.jpg')) self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]}) - def on_train_val_end(self, mloss, results, lr, epoch, best_fitness, fi): - # Callback runs on val end during training + def on_fit_epoch_end(self, mloss, results, lr, epoch, best_fitness, fi): + # Callback runs at the end of each fit (train+val) epoch vals = list(mloss) + list(results) + lr - keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss - 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics - 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss - 'x/lr0', 'x/lr1', 'x/lr2'] # params - x = {k: v for k, v in zip(keys, vals)} # dict - + x = {k: v for k, v in zip(self.keys, vals)} # dict if self.csv: file = self.save_dir / 'results.csv' n = len(x) + 1 # number of cols - s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # add header + s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header with open(file, 'a') as f: f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n') @@ -131,8 +132,3 @@ def on_train_end(self, last, best, plots): name='run_' + self.wandb.wandb_run.id + '_model', aliases=['latest', 'best', 'stripped']) self.wandb.finish_run() - - def log_images(self, paths): - # Log images - if self.wandb: - self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]}) diff --git a/utils/plots.py b/utils/plots.py index e13e316314d..252e128168e 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -281,7 +281,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx plt.savefig(str(Path(path).name) + '.png', dpi=300) -def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): +def plot_labels(labels, names=(), save_dir=Path('')): # plot dataset labels print('Plotting labels... ') c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes @@ -324,10 +324,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): matplotlib.use('Agg') plt.close() - # loggers - if loggers: - loggers.log_images(save_dir.glob('*labels*.jpg')) - def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() # Plot hyperparameter evolution results in evolve.txt diff --git a/val.py b/val.py index e231f60f6be..4ab16bd6321 100644 --- a/val.py +++ b/val.py @@ -6,12 +6,13 @@ import argparse import json -import numpy as np import os import sys -import torch from pathlib import Path from threading import Thread + +import numpy as np +import torch from tqdm import tqdm FILE = Path(__file__).absolute() @@ -24,7 +25,6 @@ from utils.metrics import ap_per_class, ConfusionMatrix from utils.plots import plot_images, output_to_target, plot_study_txt from utils.torch_utils import select_device, time_sync -from utils.loggers import Loggers from utils.callbacks import Callbacks @@ -97,7 +97,6 @@ def run(data, dataloader=None, save_dir=Path(''), plots=True, - loggers=Loggers(), callbacks=Callbacks(), compute_loss=None, ): @@ -214,7 +213,6 @@ def run(data, save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt')) if save_json: save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary - loggers.on_val_batch_end(pred, predn, path, names, img[si]) callbacks.on_val_batch_end(pred, predn, path, names, img[si]) # Plot images @@ -252,7 +250,6 @@ def run(data, # Plots if plots: confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) - loggers.on_val_end() callbacks.on_val_end() # Save JSON @@ -297,7 +294,7 @@ def parse_opt(): parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path') parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') parser.add_argument('--batch-size', type=int, default=32, help='batch size') - parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)') + parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=256, help='inference size (pixels)') parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.6, help='NMS IoU threshold') parser.add_argument('--task', default='val', help='train, val, test, speed or study') From 668294f98f9ab00d0d2d5181b57b03e0a8c462ee Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 31 Jul 2021 22:45:39 +0200 Subject: [PATCH 7/9] Cleanup --- train.py | 10 ++++++---- val.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index f1ce9df7839..f6736e33814 100644 --- a/train.py +++ b/train.py @@ -53,7 +53,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary opt, device, - callbacks + callbacks=Callbacks() ): save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ @@ -228,6 +228,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) model.half().float() # pre-reduce anchor precision + callbacks.on_pretrain_routine_end() + # DDP mode if cuda and RANK != -1: model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) @@ -452,7 +454,7 @@ def parse_opt(known=False): return opt -def main(opt, callback_handler=Callbacks()): +def main(opt): # Checks set_logging(RANK) if RANK in [-1, 0]: @@ -489,7 +491,7 @@ def main(opt, callback_handler=Callbacks()): # Train if not opt.evolve: - train(opt.hyp, opt, device, callback_handler) + train(opt.hyp, opt, device) if WORLD_SIZE > 1 and RANK == 0: _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')] @@ -569,7 +571,7 @@ def main(opt, callback_handler=Callbacks()): hyp[k] = round(hyp[k], 5) # significant digits # Train mutation - results = train(hyp.copy(), opt, device, callback_handler) + results = train(hyp.copy(), opt, device) # Write mutation results print_mutation(hyp.copy(), results, yaml_file, opt.bucket) diff --git a/val.py b/val.py index 4ab16bd6321..fadfced1d71 100644 --- a/val.py +++ b/val.py @@ -294,7 +294,7 @@ def parse_opt(): parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path') parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') parser.add_argument('--batch-size', type=int, default=32, help='batch size') - parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=256, help='inference size (pixels)') + parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.6, help='NMS IoU threshold') parser.add_argument('--task', default='val', help='train, val, test, speed or study') From a5f078684b7b3c7e1be957078eabe1cb04a42d47 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 31 Jul 2021 23:07:25 +0200 Subject: [PATCH 8/9] New callback on_val_image_end() --- train.py | 3 ++- utils/callbacks.py | 8 ++++++++ utils/loggers/__init__.py | 4 ++-- val.py | 4 ++-- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index f6736e33814..cf57a355191 100644 --- a/train.py +++ b/train.py @@ -346,7 +346,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if RANK in [-1, 0]: # mAP - callbacks.on_train_epoch_end(epoch) + callbacks.on_train_epoch_end(epoch=epoch) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) final_epoch = epoch + 1 == epochs if not noval or final_epoch: # Calculate mAP @@ -408,6 +408,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if f.exists(): strip_optimizer(f) # strip optimizers callbacks.on_train_end(last, best, plots) + LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") torch.cuda.empty_cache() return results diff --git a/utils/callbacks.py b/utils/callbacks.py index 8eb3ae26bbf..f23d57a6c04 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -19,6 +19,7 @@ class Callbacks: 'on_val_start': [], 'on_val_batch_start': [], + 'on_val_image_end': [], 'on_val_batch_end': [], 'on_val_end': [], @@ -63,6 +64,7 @@ def run_callbacks(register, *args, **kwargs): Loop through the registered actions and fire all callbacks """ for logger in register: + # print(f"Running callbacks.{logger['callback'].__name__}()") logger['callback'](*args, **kwargs) def on_pretrain_routine_start(self, *args, **kwargs): @@ -131,6 +133,12 @@ def on_val_batch_start(self, *args, **kwargs): """ self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs) + def on_val_image_end(self, *args, **kwargs): + """ + Fires all registered callbacks at the end of each val image + """ + self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs) + def on_val_batch_end(self, *args, **kwargs): """ Fires all registered callbacks at the end of each validation batch diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index a6ea4d987ad..4b17ee44c18 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -84,8 +84,8 @@ def on_train_epoch_end(self, epoch): if self.wandb: self.wandb.current_epoch = epoch + 1 - def on_val_batch_end(self, pred, predn, path, names, im): - # Callback runs on train batch end + def on_val_image_end(self, pred, predn, path, names, im): + # Callback runs on val image end if self.wandb: self.wandb.val_one_image(pred, predn, path, names, im) diff --git a/val.py b/val.py index fadfced1d71..58e8170da86 100644 --- a/val.py +++ b/val.py @@ -213,7 +213,7 @@ def run(data, save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt')) if save_json: save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary - callbacks.on_val_batch_end(pred, predn, path, names, img[si]) + callbacks.on_val_image_end(pred, predn, path, names, img[si]) # Plot images if plots and batch_i < 3: @@ -282,7 +282,7 @@ def run(data, model.float() # for training if not training: s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' - print(f"Results saved to {save_dir}{s}") + print(f"Results saved to {colorstr('bold', save_dir)}{s}") maps = np.zeros(nc) + map for i, c in enumerate(ap_class): maps[c] = ap[i] From e6960412fbf86292a5b0e1404991660d338c0e85 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 1 Aug 2021 00:11:27 +0200 Subject: [PATCH 9/9] Add curves and results images to TensorBoard --- train.py | 2 +- utils/loggers/__init__.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index cf57a355191..d4a5495d3b3 100644 --- a/train.py +++ b/train.py @@ -407,7 +407,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary for f in last, best: if f.exists(): strip_optimizer(f) # strip optimizers - callbacks.on_train_end(last, best, plots) + callbacks.on_train_end(last, best, plots, epoch) LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") torch.cuda.empty_cache() diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index 4b17ee44c18..5d4377d5415 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -108,7 +108,7 @@ def on_fit_epoch_end(self, mloss, results, lr, epoch, best_fitness, fi): if self.tb: for k, v in x.items(): - self.tb.add_scalar(k, v, epoch) # TensorBoard + self.tb.add_scalar(k, v, epoch) if self.wandb: self.wandb.log(x) @@ -120,12 +120,19 @@ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi): if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1: self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi) - def on_train_end(self, last, best, plots): + def on_train_end(self, last, best, plots, epoch): # Callback runs on training end if plots: plot_results(dir=self.save_dir) # save results.png files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter + + if self.tb: + from PIL import Image + import numpy as np + for f in files: + self.tb.add_image(f.stem, np.asarray(Image.open(f)), epoch, dataformats='HWC') + if self.wandb: wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]}) wandb.log_artifact(str(best if best.exists() else last), type='model',