From 226f03e2cb8ed68f47d80883051585e48bd03c88 Mon Sep 17 00:00:00 2001 From: Florian Vahl <7vahl@informatik.uni-hamburg.de> Date: Thu, 23 Mar 2023 16:39:30 +0100 Subject: [PATCH 1/2] Tune hyperparameters for coco --- config/yolov3.cfg | 6 +++--- pytorchyolo/train.py | 2 ++ pytorchyolo/utils/loss.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/config/yolov3.cfg b/config/yolov3.cfg index 233923af5f..3cb1e780b2 100644 --- a/config/yolov3.cfg +++ b/config/yolov3.cfg @@ -3,8 +3,8 @@ #batch=1 #subdivisions=1 # Training -batch=16 -subdivisions=1 +batch=64 +subdivisions=4 width=416 height=416 channels=3 @@ -15,7 +15,7 @@ saturation = 1.5 exposure = 1.5 hue=.1 -learning_rate=0.0001 +learning_rate=0.001 burn_in=1000 max_batches = 500200 policy=steps diff --git a/pytorchyolo/train.py b/pytorchyolo/train.py index 2a1ae02549..2f7baf63d0 100755 --- a/pytorchyolo/train.py +++ b/pytorchyolo/train.py @@ -164,6 +164,8 @@ def run(): loss, loss_components = compute_loss(outputs, targets, model) + loss *= imgs.shape[0] + loss.backward() ############### diff --git a/pytorchyolo/utils/loss.py b/pytorchyolo/utils/loss.py index 943f9ed9d4..ba9b2c54f4 100644 --- a/pytorchyolo/utils/loss.py +++ b/pytorchyolo/utils/loss.py @@ -168,7 +168,7 @@ def build_targets(p, targets, model): gi, gj = gij.T # grid xy indices # Convert anchor indexes to int - a = t[:, 6].long() + a = t[:, 6].long().view(-1) # Add target tensors for this yolo layer to the output lists # Add to index list and limit index range to prevent out of bounds indices.append((b, a, gj.clamp_(0, gain[3].long() - 1), gi.clamp_(0, gain[2].long() - 1))) From 9e3dceefff77d0939d5bb9b8d611701b33cbe77c Mon Sep 17 00:00:00 2001 From: Florian Vahl <7vahl@informatik.uni-hamburg.de> Date: Wed, 5 Apr 2023 09:37:37 +0200 Subject: [PATCH 2/2] Tune hyperparameters and fix lr sheduler --- config/yolov3.cfg | 3 +- pytorchyolo/train.py | 77 +++++++++++++++++++++----------------- pytorchyolo/utils/utils.py | 16 ++++++-- 3 files changed, 58 insertions(+), 38 deletions(-) diff --git a/config/yolov3.cfg b/config/yolov3.cfg index 3cb1e780b2..5f207c4aaf 100644 --- a/config/yolov3.cfg +++ b/config/yolov3.cfg @@ -15,7 +15,8 @@ saturation = 1.5 exposure = 1.5 hue=.1 -learning_rate=0.001 +optimizer=sgd +learning_rate=0.01 burn_in=1000 max_batches = 500200 policy=steps diff --git a/pytorchyolo/train.py b/pytorchyolo/train.py index 2f7baf63d0..f101885b5a 100755 --- a/pytorchyolo/train.py +++ b/pytorchyolo/train.py @@ -4,15 +4,16 @@ import os import argparse -import tqdm +import numpy as np +import tqdm import torch from torch.utils.data import DataLoader import torch.optim as optim from pytorchyolo.models import load_model from pytorchyolo.utils.logger import Logger -from pytorchyolo.utils.utils import to_cpu, load_classes, print_environment_info, provide_determinism, worker_seed_set +from pytorchyolo.utils.utils import to_cpu, load_classes, print_environment_info, provide_determinism, worker_seed_set, cumprod from pytorchyolo.utils.datasets import ListDataset from pytorchyolo.utils.augmentations import AUGMENTATION_TRANSFORMS #from pytorchyolo.utils.transforms import DEFAULT_TRANSFORMS @@ -141,7 +142,8 @@ def run(): params, lr=model.hyperparams['learning_rate'], weight_decay=model.hyperparams['decay'], - momentum=model.hyperparams['momentum']) + momentum=model.hyperparams['momentum'], + nesterov=True) else: print("Unknown optimizer. Please choose between (adam, sgd).") @@ -155,7 +157,8 @@ def run(): model.train() # Set model to training mode for batch_i, (_, imgs, targets) in enumerate(tqdm.tqdm(dataloader, desc=f"Training Epoch {epoch}")): - batches_done = len(dataloader) * epoch + batch_i + batches_done = len(dataloader) * (epoch - 1) + batch_i + 1 + optimizer_steps_done = batches_done // model.hyperparams['subdivisions'] imgs = imgs.to(device, non_blocking=True) targets = targets.to(device) @@ -176,16 +179,22 @@ def run(): # Adapt learning rate # Get learning rate defined in cfg lr = model.hyperparams['learning_rate'] - if batches_done < model.hyperparams['burn_in']: - # Burn in - lr *= (batches_done / model.hyperparams['burn_in']) - else: - # Set and parse the learning rate to the steps defined in the cfg - for threshold, value in model.hyperparams['lr_steps']: - if batches_done > threshold: - lr *= value + + # Get learing rate schedule hyperparameter + # Split it into lists defining the timestep and corresponding factors + lr_thresholds, lr_factors = zip(*model.hyperparams['lr_steps']) + # As the factors are cumulative calculate the cumulative product to get the factors relative to the base learning rate + lr_factors = cumprod(lr_factors) + + # Add burn in to lr schedule + lr_thresholds = [-1, model.hyperparams['burn_in']] + list(lr_thresholds) + lr_factors = [0, 1] + list(lr_factors) + + + # Multiply learning rate by factor based on linear interpolation of the learning rate schedule + lr *= np.interp(optimizer_steps_done, lr_thresholds, lr_factors) # Log the learning rate - logger.scalar_summary("train/learning_rate", lr, batches_done) + logger.scalar_summary("train/learning_rate", lr, optimizer_steps_done) # Set learning rate for g in optimizer.param_groups: g['lr'] = lr @@ -195,27 +204,27 @@ def run(): # Reset gradients optimizer.zero_grad() - # ############ - # Log progress - # ############ - if args.verbose: - print(AsciiTable( - [ - ["Type", "Value"], - ["IoU loss", float(loss_components[0])], - ["Object loss", float(loss_components[1])], - ["Class loss", float(loss_components[2])], - ["Loss", float(loss_components[3])], - ["Batch loss", to_cpu(loss).item()], - ]).table) - - # Tensorboard logging - tensorboard_log = [ - ("train/iou_loss", float(loss_components[0])), - ("train/obj_loss", float(loss_components[1])), - ("train/class_loss", float(loss_components[2])), - ("train/loss", to_cpu(loss).item())] - logger.list_of_scalars_summary(tensorboard_log, batches_done) + # ############ + # Log progress + # ############ + if args.verbose: + print(AsciiTable( + [ + ["Type", "Value"], + ["IoU loss", float(loss_components[0])], + ["Object loss", float(loss_components[1])], + ["Class loss", float(loss_components[2])], + ["Loss", float(loss_components[3])], + ["Batch loss", to_cpu(loss).item()], + ]).table) + + # Tensorboard logging + tensorboard_log = [ + ("train/iou_loss", float(loss_components[0])), + ("train/obj_loss", float(loss_components[1])), + ("train/class_loss", float(loss_components[2])), + ("train/loss", to_cpu(loss).item())] + logger.list_of_scalars_summary(tensorboard_log, optimizer_steps_done) model.seen += imgs.size(0) diff --git a/pytorchyolo/utils/utils.py b/pytorchyolo/utils/utils.py index fdde29aeac..7933e46456 100644 --- a/pytorchyolo/utils/utils.py +++ b/pytorchyolo/utils/utils.py @@ -1,14 +1,16 @@ from __future__ import division -import time +import random import platform +import subprocess +import time +from functools import reduce + import tqdm import torch import torch.nn as nn import torchvision import numpy as np -import subprocess -import random import imgaug as ia @@ -396,3 +398,11 @@ def print_environment_info(): print(f"Current Commit Hash: {subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], stderr=subprocess.DEVNULL).decode('ascii').strip()}") except (subprocess.CalledProcessError, FileNotFoundError): print("No git or repo found") + + +def cumprod(lst): + """ + Returns a list where each element is the cumulative product of the input list up to that index. + Similar to NumPy cumsum.""" + return reduce(lambda acc, x: acc + [acc[-1] * x], lst[1:], [lst[0]]) +