forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added callbacks * Update callbacks.py * Update train.py * Update val.py * Fix CamlCase add staticmethod * Refactor logger into callbacks * Cleanup * New callback on_val_image_end() * Add curves and results images to TensorBoard Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
- Loading branch information
1 parent
ba076f0
commit 954c060
Showing
6 changed files
with
230 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
#!/usr/bin/env python | ||
|
||
class Callbacks: | ||
"""" | ||
Handles all registered callbacks for YOLOv5 Hooks | ||
""" | ||
|
||
_callbacks = { | ||
'on_pretrain_routine_start': [], | ||
'on_pretrain_routine_end': [], | ||
|
||
'on_train_start': [], | ||
'on_train_epoch_start': [], | ||
'on_train_batch_start': [], | ||
'optimizer_step': [], | ||
'on_before_zero_grad': [], | ||
'on_train_batch_end': [], | ||
'on_train_epoch_end': [], | ||
|
||
'on_val_start': [], | ||
'on_val_batch_start': [], | ||
'on_val_image_end': [], | ||
'on_val_batch_end': [], | ||
'on_val_end': [], | ||
|
||
'on_fit_epoch_end': [], # fit = train + val | ||
'on_model_save': [], | ||
'on_train_end': [], | ||
|
||
'teardown': [], | ||
} | ||
|
||
def __init__(self): | ||
return | ||
|
||
def register_action(self, hook, name='', callback=None): | ||
""" | ||
Register a new action to a callback hook | ||
Args: | ||
hook The callback hook name to register the action to | ||
name The name of the action | ||
callback The callback to fire | ||
""" | ||
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" | ||
assert callable(callback), f"callback '{callback}' is not callable" | ||
self._callbacks[hook].append({'name': name, 'callback': callback}) | ||
|
||
def get_registered_actions(self, hook=None): | ||
"""" | ||
Returns all the registered actions by callback hook | ||
Args: | ||
hook The name of the hook to check, defaults to all | ||
""" | ||
if hook: | ||
return self._callbacks[hook] | ||
else: | ||
return self._callbacks | ||
|
||
@staticmethod | ||
def run_callbacks(register, *args, **kwargs): | ||
""" | ||
Loop through the registered actions and fire all callbacks | ||
""" | ||
for logger in register: | ||
# print(f"Running callbacks.{logger['callback'].__name__}()") | ||
logger['callback'](*args, **kwargs) | ||
|
||
def on_pretrain_routine_start(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the start of each pretraining routine | ||
""" | ||
self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs) | ||
|
||
def on_pretrain_routine_end(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the end of each pretraining routine | ||
""" | ||
self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs) | ||
|
||
def on_train_start(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the start of each training | ||
""" | ||
self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs) | ||
|
||
def on_train_epoch_start(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the start of each training epoch | ||
""" | ||
self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs) | ||
|
||
def on_train_batch_start(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the start of each training batch | ||
""" | ||
self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs) | ||
|
||
def optimizer_step(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks on each optimizer step | ||
""" | ||
self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs) | ||
|
||
def on_before_zero_grad(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks before zero grad | ||
""" | ||
self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs) | ||
|
||
def on_train_batch_end(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the end of each training batch | ||
""" | ||
self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs) | ||
|
||
def on_train_epoch_end(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the end of each training epoch | ||
""" | ||
self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs) | ||
|
||
def on_val_start(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the start of the validation | ||
""" | ||
self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs) | ||
|
||
def on_val_batch_start(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the start of each validation batch | ||
""" | ||
self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs) | ||
|
||
def on_val_image_end(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the end of each val image | ||
""" | ||
self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs) | ||
|
||
def on_val_batch_end(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the end of each validation batch | ||
""" | ||
self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs) | ||
|
||
def on_val_end(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the end of the validation | ||
""" | ||
self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs) | ||
|
||
def on_fit_epoch_end(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the end of each fit (train+val) epoch | ||
""" | ||
self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs) | ||
|
||
def on_model_save(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks after each model save | ||
""" | ||
self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs) | ||
|
||
def on_train_end(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks at the end of training | ||
""" | ||
self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs) | ||
|
||
def teardown(self, *args, **kwargs): | ||
""" | ||
Fires all registered callbacks before teardown | ||
""" | ||
self.run_callbacks(self._callbacks['teardown'], *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.