Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add train.py and val.py callbacks #4220

Merged
merged 9 commits into from
Jul 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
check_requirements, print_mutation, set_logging, one_cycle, colorstr
check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods
from utils.downloads import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness
from utils.loggers import Loggers
from utils.callbacks import Callbacks

LOGGER = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
Expand All @@ -52,6 +53,7 @@
def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt,
device,
callbacks=Callbacks()
):
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
Expand All @@ -77,12 +79,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# Loggers
if RANK in [-1, 0]:
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER).start() # loggers dict
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
if loggers.wandb:
data_dict = loggers.wandb.data_dict
if resume:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp

# Register actions
for k in methods(loggers):
callbacks.register_action(k, callback=getattr(loggers, k))

# Config
plots = not evolve # create plots
cuda = device.type != 'cpu'
Expand Down Expand Up @@ -215,13 +221,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device))
if plots:
plot_labels(labels, names, save_dir, loggers)
plot_labels(labels, names, save_dir)

# Anchors
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
model.half().float() # pre-reduce anchor precision

callbacks.on_pretrain_routine_end()

# DDP mode
if cuda and RANK != -1:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
Expand Down Expand Up @@ -329,8 +337,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots)

callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots)
# end batch ------------------------------------------------------------------------------------------------

# Scheduler
Expand All @@ -339,7 +346,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

if RANK in [-1, 0]:
# mAP
loggers.on_train_epoch_end(epoch)
callbacks.on_train_epoch_end(epoch=epoch)
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
if not noval or final_epoch: # Calculate mAP
Expand All @@ -353,14 +360,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
loggers=loggers,
callbacks=callbacks,
compute_loss=compute_loss)

# Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness:
best_fitness = fi
loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi)
callbacks.on_fit_epoch_end(mloss, results, lr, epoch, best_fitness, fi)

# Save model
if (not nosave) or (final_epoch and not evolve): # if save
Expand All @@ -377,7 +384,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if best_fitness == fi:
torch.save(ckpt, best)
del ckpt
loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi)
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)

# end epoch ----------------------------------------------------------------------------------------------------
# end training -----------------------------------------------------------------------------------------------------
Expand All @@ -400,7 +407,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
loggers.on_train_end(last, best, plots)
callbacks.on_train_end(last, best, plots, epoch)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")

torch.cuda.empty_cache()
return results
Expand Down Expand Up @@ -448,6 +456,7 @@ def parse_opt(known=False):


def main(opt):
# Checks
set_logging(RANK)
if RANK in [-1, 0]:
print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
Expand Down
176 changes: 176 additions & 0 deletions utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/usr/bin/env python

class Callbacks:
""""
Handles all registered callbacks for YOLOv5 Hooks
"""

_callbacks = {
'on_pretrain_routine_start': [],
'on_pretrain_routine_end': [],

'on_train_start': [],
'on_train_epoch_start': [],
'on_train_batch_start': [],
'optimizer_step': [],
'on_before_zero_grad': [],
'on_train_batch_end': [],
'on_train_epoch_end': [],

'on_val_start': [],
'on_val_batch_start': [],
'on_val_image_end': [],
'on_val_batch_end': [],
'on_val_end': [],

'on_fit_epoch_end': [], # fit = train + val
'on_model_save': [],
'on_train_end': [],

'teardown': [],
}

def __init__(self):
return

def register_action(self, hook, name='', callback=None):
"""
Register a new action to a callback hook

Args:
hook The callback hook name to register the action to
name The name of the action
callback The callback to fire
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
assert callable(callback), f"callback '{callback}' is not callable"
self._callbacks[hook].append({'name': name, 'callback': callback})

def get_registered_actions(self, hook=None):
""""
Returns all the registered actions by callback hook

Args:
hook The name of the hook to check, defaults to all
"""
if hook:
return self._callbacks[hook]
else:
return self._callbacks

@staticmethod
def run_callbacks(register, *args, **kwargs):
"""
Loop through the registered actions and fire all callbacks
"""
for logger in register:
# print(f"Running callbacks.{logger['callback'].__name__}()")
logger['callback'](*args, **kwargs)

def on_pretrain_routine_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each pretraining routine
"""
self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs)

def on_pretrain_routine_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each pretraining routine
"""
self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs)

def on_train_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training
"""
self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs)

def on_train_epoch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training epoch
"""
self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs)

def on_train_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training batch
"""
self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs)

def optimizer_step(self, *args, **kwargs):
"""
Fires all registered callbacks on each optimizer step
"""
self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs)

def on_before_zero_grad(self, *args, **kwargs):
"""
Fires all registered callbacks before zero grad
"""
self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs)

def on_train_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training batch
"""
self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs)

def on_train_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training epoch
"""
self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs)

def on_val_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of the validation
"""
self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs)

def on_val_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each validation batch
"""
self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs)

def on_val_image_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each val image
"""
self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs)

def on_val_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each validation batch
"""
self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs)

def on_val_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of the validation
"""
self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs)

def on_fit_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each fit (train+val) epoch
"""
self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs)

def on_model_save(self, *args, **kwargs):
"""
Fires all registered callbacks after each model save
"""
self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs)

def on_train_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of training
"""
self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs)

def teardown(self, *args, **kwargs):
"""
Fires all registered callbacks before teardown
"""
self.run_callbacks(self._callbacks['teardown'], *args, **kwargs)
5 changes: 5 additions & 0 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def handler(*args, **kwargs):
return handler


def methods(instance):
# Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]


def set_logging(rank=-1, verbose=True):
logging.basicConfig(
format="%(message)s",
Expand Down
Loading