Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tune for COCO #826

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions config/yolov3.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#batch=1
#subdivisions=1
# Training
batch=16
subdivisions=1
batch=64
subdivisions=4
width=416
height=416
channels=3
Expand All @@ -15,7 +15,8 @@ saturation = 1.5
exposure = 1.5
hue=.1

learning_rate=0.0001
optimizer=sgd
learning_rate=0.01
burn_in=1000
max_batches = 500200
policy=steps
Expand Down
79 changes: 45 additions & 34 deletions pytorchyolo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@

import os
import argparse
import tqdm

import numpy as np
import tqdm
import torch
from torch.utils.data import DataLoader
import torch.optim as optim

from pytorchyolo.models import load_model
from pytorchyolo.utils.logger import Logger
from pytorchyolo.utils.utils import to_cpu, load_classes, print_environment_info, provide_determinism, worker_seed_set
from pytorchyolo.utils.utils import to_cpu, load_classes, print_environment_info, provide_determinism, worker_seed_set, cumprod
from pytorchyolo.utils.datasets import ListDataset
from pytorchyolo.utils.augmentations import AUGMENTATION_TRANSFORMS
#from pytorchyolo.utils.transforms import DEFAULT_TRANSFORMS
Expand Down Expand Up @@ -141,7 +142,8 @@ def run():
params,
lr=model.hyperparams['learning_rate'],
weight_decay=model.hyperparams['decay'],
momentum=model.hyperparams['momentum'])
momentum=model.hyperparams['momentum'],
nesterov=True)
else:
print("Unknown optimizer. Please choose between (adam, sgd).")

Expand All @@ -155,7 +157,8 @@ def run():
model.train() # Set model to training mode

for batch_i, (_, imgs, targets) in enumerate(tqdm.tqdm(dataloader, desc=f"Training Epoch {epoch}")):
batches_done = len(dataloader) * epoch + batch_i
batches_done = len(dataloader) * (epoch - 1) + batch_i + 1
optimizer_steps_done = batches_done // model.hyperparams['subdivisions']

imgs = imgs.to(device, non_blocking=True)
targets = targets.to(device)
Expand All @@ -164,6 +167,8 @@ def run():

loss, loss_components = compute_loss(outputs, targets, model)

loss *= imgs.shape[0]

loss.backward()

###############
Expand All @@ -174,16 +179,22 @@ def run():
# Adapt learning rate
# Get learning rate defined in cfg
lr = model.hyperparams['learning_rate']
if batches_done < model.hyperparams['burn_in']:
# Burn in
lr *= (batches_done / model.hyperparams['burn_in'])
else:
# Set and parse the learning rate to the steps defined in the cfg
for threshold, value in model.hyperparams['lr_steps']:
if batches_done > threshold:
lr *= value

# Get learing rate schedule hyperparameter
# Split it into lists defining the timestep and corresponding factors
lr_thresholds, lr_factors = zip(*model.hyperparams['lr_steps'])
# As the factors are cumulative calculate the cumulative product to get the factors relative to the base learning rate
lr_factors = cumprod(lr_factors)

# Add burn in to lr schedule
lr_thresholds = [-1, model.hyperparams['burn_in']] + list(lr_thresholds)
lr_factors = [0, 1] + list(lr_factors)


# Multiply learning rate by factor based on linear interpolation of the learning rate schedule
lr *= np.interp(optimizer_steps_done, lr_thresholds, lr_factors)
# Log the learning rate
logger.scalar_summary("train/learning_rate", lr, batches_done)
logger.scalar_summary("train/learning_rate", lr, optimizer_steps_done)
# Set learning rate
for g in optimizer.param_groups:
g['lr'] = lr
Expand All @@ -193,27 +204,27 @@ def run():
# Reset gradients
optimizer.zero_grad()

# ############
# Log progress
# ############
if args.verbose:
print(AsciiTable(
[
["Type", "Value"],
["IoU loss", float(loss_components[0])],
["Object loss", float(loss_components[1])],
["Class loss", float(loss_components[2])],
["Loss", float(loss_components[3])],
["Batch loss", to_cpu(loss).item()],
]).table)

# Tensorboard logging
tensorboard_log = [
("train/iou_loss", float(loss_components[0])),
("train/obj_loss", float(loss_components[1])),
("train/class_loss", float(loss_components[2])),
("train/loss", to_cpu(loss).item())]
logger.list_of_scalars_summary(tensorboard_log, batches_done)
# ############
# Log progress
# ############
if args.verbose:
print(AsciiTable(
[
["Type", "Value"],
["IoU loss", float(loss_components[0])],
["Object loss", float(loss_components[1])],
["Class loss", float(loss_components[2])],
["Loss", float(loss_components[3])],
["Batch loss", to_cpu(loss).item()],
]).table)

# Tensorboard logging
tensorboard_log = [
("train/iou_loss", float(loss_components[0])),
("train/obj_loss", float(loss_components[1])),
("train/class_loss", float(loss_components[2])),
("train/loss", to_cpu(loss).item())]
logger.list_of_scalars_summary(tensorboard_log, optimizer_steps_done)

model.seen += imgs.size(0)

Expand Down
2 changes: 1 addition & 1 deletion pytorchyolo/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def build_targets(p, targets, model):
gi, gj = gij.T # grid xy indices

# Convert anchor indexes to int
a = t[:, 6].long()
a = t[:, 6].long().view(-1)
# Add target tensors for this yolo layer to the output lists
# Add to index list and limit index range to prevent out of bounds
indices.append((b, a, gj.clamp_(0, gain[3].long() - 1), gi.clamp_(0, gain[2].long() - 1)))
Expand Down
16 changes: 13 additions & 3 deletions pytorchyolo/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import division

import time
import random
import platform
import subprocess
import time
from functools import reduce

import tqdm
import torch
import torch.nn as nn
import torchvision
import numpy as np
import subprocess
import random
import imgaug as ia


Expand Down Expand Up @@ -396,3 +398,11 @@ def print_environment_info():
print(f"Current Commit Hash: {subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], stderr=subprocess.DEVNULL).decode('ascii').strip()}")
except (subprocess.CalledProcessError, FileNotFoundError):
print("No git or repo found")


def cumprod(lst):
"""
Returns a list where each element is the cumulative product of the input list up to that index.
Similar to NumPy cumsum."""
return reduce(lambda acc, x: acc + [acc[-1] * x], lst[1:], [lst[0]])