Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callbacks continued #889

Merged
merged 7 commits into from
Feb 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ Callbacks
_save_model,
on_epoch_end,
on_train_end,
on_epoch_begin,
on_epoch_start,
check_monitor_top_k,
on_train_begin,
on_train_start,
2 changes: 1 addition & 1 deletion docs/source/loggers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Loggers
_save_model,
on_epoch_end,
on_train_end,
on_epoch_begin,
on_epoch_start,
2 changes: 2 additions & 0 deletions pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@

from .core import data_loader, LightningModule
from .trainer import Trainer
from .callbacks import Callback

__all__ = [
'Trainer',
'LightningModule',
'Callback',
'data_loader',
]
# __call__ = __all__
46 changes: 23 additions & 23 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,61 @@
import abc


_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"


class Callback(abc.ABC):
hadim marked this conversation as resolved.
Show resolved Hide resolved
"""Abstract base class used to build new callbacks."""

def __init__(self):
self._trainer = None
def on_init_start(self, trainer, pl_module):
"""Called when the trainer initialization begins."""
assert pl_module is None

def on_init_end(self, trainer, pl_module):
"""Called when the trainer initialization ends."""
pass

@property
def trainer(self):
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
return self._trainer
def on_fit_start(self, trainer, pl_module):
"""Called when the fit begins."""
pass

def set_trainer(self, trainer):
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
`trainer.batch_idx`, `trainer.global_step` can be used."""
self._trainer = trainer
def on_fit_end(self, trainer, pl_module):
"""Called when the fit ends."""
pass

def on_epoch_begin(self):
def on_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
pass

def on_epoch_end(self):
def on_epoch_end(self, trainer, pl_module):
"""Called when the epoch ends."""
pass

def on_batch_begin(self):
def on_batch_start(self, trainer, pl_module):
"""Called when the training batch begins."""
pass

def on_batch_end(self):
def on_batch_end(self, trainer, pl_module):
"""Called when the training batch ends."""
pass

def on_train_begin(self):
def on_train_start(self, trainer, pl_module):
"""Called when the train begins."""
pass

def on_train_end(self):
def on_train_end(self, trainer, pl_module):
"""Called when the train ends."""
pass

def on_validation_begin(self):
def on_validation_start(self, trainer, pl_module):
"""Called when the validation loop begins."""
pass

def on_validation_end(self):
def on_validation_end(self, trainer, pl_module):
"""Called when the validation loop ends."""
pass

def on_test_begin(self):
def on_test_start(self, trainer, pl_module):
"""Called when the test begins."""
pass

def on_test_end(self):
def on_test_end(self, trainer, pl_module):
"""Called when the test ends."""
pass
14 changes: 7 additions & 7 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
self.monitor_op = mode_dict[mode]
self.min_delta *= 1 if self.monitor_op == np.greater else -1

self.on_train_begin()
self.on_train_start(None, None)

def check_metrics(self, logs):
monitor_val = logs.get(self.monitor)
Expand All @@ -82,14 +82,14 @@ def check_metrics(self, logs):

return True

def on_train_begin(self):
def on_train_start(self, trainer, pl_module):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf

def on_epoch_end(self):
logs = self.trainer.callback_metrics
def on_epoch_end(self, trainer, pl_module):
logs = trainer.callback_metrics
stop_training = False
if not self.check_metrics(logs):
return stop_training
Expand All @@ -101,13 +101,13 @@ def on_epoch_end(self):
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = self.trainer.current_epoch
self.stopped_epoch = trainer.current_epoch
stop_training = True
self.on_train_end()
self.on_train_end(trainer, pl_module)

return stop_training

def on_train_end(self):
def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0:
warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def __init__(self, scheduling: dict):
self.scheduling = scheduling
self.epochs = sorted(scheduling.keys())

def on_epoch_begin(self):
trainer = self.trainer
def on_epoch_start(self, trainer, pl_module):
# indexing epochs from 1 (until v0.6.x)
# In v0.8.0, ` + 1` should be removed.
epoch = trainer.current_epoch + 1
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def check_monitor_top_k(self, current):
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def on_validation_end(self):
logs = self.trainer.callback_metrics
epoch = self.trainer.current_epoch
def on_validation_end(self, trainer, pl_module):
logs = trainer.callback_metrics
epoch = trainer.current_epoch
self.epochs_since_last_check += 1

if self.save_top_k == 0:
Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ def configure_checkpoint_callback(self):
# if checkpoint callback used, then override the weights path
self.weights_save_path = self.checkpoint_callback.filepath

# link to the trainer
self.checkpoint_callback.set_trainer(self)

# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:
self.weights_save_path = self.default_save_path
Expand Down Expand Up @@ -80,6 +77,3 @@ def configure_early_stopping(self, early_stop_callback):
else:
self.early_stop_callback = early_stop_callback
self.enable_early_stop = True

if self.early_stop_callback is not None:
self.early_stop_callback.set_trainer(self)
83 changes: 83 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Callable
from abc import ABC

from pytorch_lightning.callbacks import Callback


class TrainerCallbackHookMixin(ABC):

def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.callbacks: list[Callback] = []
self.get_model: Callable = ...

def on_init_start(self):
"""Called when the trainer initialization begins."""
for callback in self.callbacks:
callback.on_init_start(self, None)

def on_init_end(self):
"""Called when the trainer initialization ends."""
for callback in self.callbacks:
callback.on_init_end(self, self.get_model())

def on_fit_start(self):
"""Called when the fit begins."""
for callback in self.callbacks:
callback.on_fit_start(self, self.get_model())

def on_fit_end(self):
"""Called when the fit ends."""
for callback in self.callbacks:
callback.on_fit_end(self, self.get_model())

def on_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_epoch_start(self, self.get_model())

def on_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_epoch_end(self, self.get_model())

def on_train_start(self):
"""Called when the train begins."""
for callback in self.callbacks:
callback.on_train_start(self, self.get_model())

def on_train_end(self):
"""Called when the train ends."""
for callback in self.callbacks:
callback.on_train_end(self, self.get_model())

def on_batch_start(self):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_batch_start(self, self.get_model())

def on_batch_end(self):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_batch_end(self, self.get_model())

def on_validation_start(self):
"""Called when the validation loop begins."""
for callback in self.callbacks:
callback.on_validation_start(self, self.get_model())

def on_validation_end(self):
"""Called when the validation loop ends."""
for callback in self.callbacks:
callback.on_validation_end(self, self.get_model())

def on_test_start(self):
"""Called when the test begins."""
for callback in self.callbacks:
callback.on_test_start(self, self.get_model())

def on_test_end(self):
"""Called when the test ends."""
for callback in self.callbacks:
callback.on_test_end(self, self.get_model())
22 changes: 21 additions & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@
"""

from typing import Callable

import sys
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -171,6 +173,12 @@ def __init__(self):
self.reload_dataloaders_every_epoch = None
self.progress_bar_refresh_rate = None

# Callback system
self.on_validation_start: Callable = ...
self.on_validation_end: Callable = ...
self.on_test_start: Callable = ...
self.on_test_end: Callable = ...

@abstractmethod
def copy_trainer_model_properties(self, model):
# this is just empty shell for code from other class
Expand Down Expand Up @@ -302,6 +310,12 @@ def run_evaluation(self, test_mode: bool = False):
" Please define and try again"
raise MisconfigurationException(m)

# Validation/Test begin callbacks
if test_mode:
self.on_test_start()
else:
self.on_validation_start()

# hook
model = self.get_model()
model.on_pre_performance_check()
Expand Down Expand Up @@ -363,7 +377,13 @@ def run_evaluation(self, test_mode: bool = False):

# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test_mode:
self.checkpoint_callback.on_validation_end()
self.checkpoint_callback.on_validation_end(self, self.get_model())

# Validation/Test end callbacks
if test_mode:
self.on_test_end()
else:
self.on_validation_end()

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
# make dataloader_idx arg in validation_step optional
Expand Down
Loading