Skip to content

Commit

Permalink
Callbacks [wip] (#889)
Browse files Browse the repository at this point in the history
* Add callback system + associated test

* Add trainer and pl_module args to callback methods

* typing

* typo in docstring

* Switch to on_.*_start()

* fix on_test_start

* fix the mess after rebasing
  • Loading branch information
hadim authored Feb 26, 2020
1 parent 96b058c commit be24456
Show file tree
Hide file tree
Showing 14 changed files with 407 additions and 87 deletions.
4 changes: 2 additions & 2 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ Callbacks
_save_model,
on_epoch_end,
on_train_end,
on_epoch_begin,
on_epoch_start,
check_monitor_top_k,
on_train_begin,
on_train_start,
2 changes: 1 addition & 1 deletion docs/source/loggers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Loggers
_save_model,
on_epoch_end,
on_train_end,
on_epoch_begin,
on_epoch_start,
2 changes: 2 additions & 0 deletions pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@

from .core import data_loader, LightningModule
from .trainer import Trainer
from .callbacks import Callback

__all__ = [
'Trainer',
'LightningModule',
'Callback',
'data_loader',
]
# __call__ = __all__
46 changes: 23 additions & 23 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,61 @@
import abc


_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"


class Callback(abc.ABC):
"""Abstract base class used to build new callbacks."""

def __init__(self):
self._trainer = None
def on_init_start(self, trainer, pl_module):
"""Called when the trainer initialization begins."""
assert pl_module is None

def on_init_end(self, trainer, pl_module):
"""Called when the trainer initialization ends."""
pass

@property
def trainer(self):
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
return self._trainer
def on_fit_start(self, trainer, pl_module):
"""Called when the fit begins."""
pass

def set_trainer(self, trainer):
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
`trainer.batch_idx`, `trainer.global_step` can be used."""
self._trainer = trainer
def on_fit_end(self, trainer, pl_module):
"""Called when the fit ends."""
pass

def on_epoch_begin(self):
def on_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
pass

def on_epoch_end(self):
def on_epoch_end(self, trainer, pl_module):
"""Called when the epoch ends."""
pass

def on_batch_begin(self):
def on_batch_start(self, trainer, pl_module):
"""Called when the training batch begins."""
pass

def on_batch_end(self):
def on_batch_end(self, trainer, pl_module):
"""Called when the training batch ends."""
pass

def on_train_begin(self):
def on_train_start(self, trainer, pl_module):
"""Called when the train begins."""
pass

def on_train_end(self):
def on_train_end(self, trainer, pl_module):
"""Called when the train ends."""
pass

def on_validation_begin(self):
def on_validation_start(self, trainer, pl_module):
"""Called when the validation loop begins."""
pass

def on_validation_end(self):
def on_validation_end(self, trainer, pl_module):
"""Called when the validation loop ends."""
pass

def on_test_begin(self):
def on_test_start(self, trainer, pl_module):
"""Called when the test begins."""
pass

def on_test_end(self):
def on_test_end(self, trainer, pl_module):
"""Called when the test ends."""
pass
14 changes: 7 additions & 7 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
self.monitor_op = mode_dict[mode]
self.min_delta *= 1 if self.monitor_op == np.greater else -1

self.on_train_begin()
self.on_train_start(None, None)

def check_metrics(self, logs):
monitor_val = logs.get(self.monitor)
Expand All @@ -82,14 +82,14 @@ def check_metrics(self, logs):

return True

def on_train_begin(self):
def on_train_start(self, trainer, pl_module):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf

def on_epoch_end(self):
logs = self.trainer.callback_metrics
def on_epoch_end(self, trainer, pl_module):
logs = trainer.callback_metrics
stop_training = False
if not self.check_metrics(logs):
return stop_training
Expand All @@ -101,13 +101,13 @@ def on_epoch_end(self):
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = self.trainer.current_epoch
self.stopped_epoch = trainer.current_epoch
stop_training = True
self.on_train_end()
self.on_train_end(trainer, pl_module)

return stop_training

def on_train_end(self):
def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0:
warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def __init__(self, scheduling: dict):
self.scheduling = scheduling
self.epochs = sorted(scheduling.keys())

def on_epoch_begin(self):
trainer = self.trainer
def on_epoch_start(self, trainer, pl_module):
# indexing epochs from 1 (until v0.6.x)
# In v0.8.0, ` + 1` should be removed.
epoch = trainer.current_epoch + 1
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def check_monitor_top_k(self, current):
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def on_validation_end(self):
logs = self.trainer.callback_metrics
epoch = self.trainer.current_epoch
def on_validation_end(self, trainer, pl_module):
logs = trainer.callback_metrics
epoch = trainer.current_epoch
self.epochs_since_last_check += 1

if self.save_top_k == 0:
Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ def configure_checkpoint_callback(self):
# if checkpoint callback used, then override the weights path
self.weights_save_path = self.checkpoint_callback.filepath

# link to the trainer
self.checkpoint_callback.set_trainer(self)

# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:
self.weights_save_path = self.default_save_path
Expand Down Expand Up @@ -80,6 +77,3 @@ def configure_early_stopping(self, early_stop_callback):
else:
self.early_stop_callback = early_stop_callback
self.enable_early_stop = True

if self.early_stop_callback is not None:
self.early_stop_callback.set_trainer(self)
83 changes: 83 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Callable
from abc import ABC

from pytorch_lightning.callbacks import Callback


class TrainerCallbackHookMixin(ABC):

def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.callbacks: list[Callback] = []
self.get_model: Callable = ...

def on_init_start(self):
"""Called when the trainer initialization begins."""
for callback in self.callbacks:
callback.on_init_start(self, None)

def on_init_end(self):
"""Called when the trainer initialization ends."""
for callback in self.callbacks:
callback.on_init_end(self, self.get_model())

def on_fit_start(self):
"""Called when the fit begins."""
for callback in self.callbacks:
callback.on_fit_start(self, self.get_model())

def on_fit_end(self):
"""Called when the fit ends."""
for callback in self.callbacks:
callback.on_fit_end(self, self.get_model())

def on_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_epoch_start(self, self.get_model())

def on_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_epoch_end(self, self.get_model())

def on_train_start(self):
"""Called when the train begins."""
for callback in self.callbacks:
callback.on_train_start(self, self.get_model())

def on_train_end(self):
"""Called when the train ends."""
for callback in self.callbacks:
callback.on_train_end(self, self.get_model())

def on_batch_start(self):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_batch_start(self, self.get_model())

def on_batch_end(self):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_batch_end(self, self.get_model())

def on_validation_start(self):
"""Called when the validation loop begins."""
for callback in self.callbacks:
callback.on_validation_start(self, self.get_model())

def on_validation_end(self):
"""Called when the validation loop ends."""
for callback in self.callbacks:
callback.on_validation_end(self, self.get_model())

def on_test_start(self):
"""Called when the test begins."""
for callback in self.callbacks:
callback.on_test_start(self, self.get_model())

def on_test_end(self):
"""Called when the test ends."""
for callback in self.callbacks:
callback.on_test_end(self, self.get_model())
22 changes: 21 additions & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@
"""

from typing import Callable

import sys
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -171,6 +173,12 @@ def __init__(self):
self.reload_dataloaders_every_epoch = None
self.progress_bar_refresh_rate = None

# Callback system
self.on_validation_start: Callable = ...
self.on_validation_end: Callable = ...
self.on_test_start: Callable = ...
self.on_test_end: Callable = ...

@abstractmethod
def copy_trainer_model_properties(self, model):
# this is just empty shell for code from other class
Expand Down Expand Up @@ -302,6 +310,12 @@ def run_evaluation(self, test_mode: bool = False):
" Please define and try again"
raise MisconfigurationException(m)

# Validation/Test begin callbacks
if test_mode:
self.on_test_start()
else:
self.on_validation_start()

# hook
model = self.get_model()
model.on_pre_performance_check()
Expand Down Expand Up @@ -363,7 +377,13 @@ def run_evaluation(self, test_mode: bool = False):

# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test_mode:
self.checkpoint_callback.on_validation_end()
self.checkpoint_callback.on_validation_end(self, self.get_model())

# Validation/Test end callbacks
if test_mode:
self.on_test_end()
else:
self.on_validation_end()

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
# make dataloader_idx arg in validation_step optional
Expand Down
Loading

0 comments on commit be24456

Please sign in to comment.