Skip to content

Commit

Permalink
add callback system + associated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hadim committed Feb 22, 2020
1 parent 2244f8c commit 87c6835
Show file tree
Hide file tree
Showing 6 changed files with 297 additions and 0 deletions.
16 changes: 16 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ def set_trainer(self, trainer):
`trainer.batch_idx`, `trainer.global_step` can be used."""
self._trainer = trainer

def on_init_begin(self):
"""Called when the trainer initialization begins."""
pass

def on_init_end(self):
"""Called when the trainer initialization ends."""
pass

def on_fit_begin(self):
"""Called when the fit begins."""
pass

def on_fit_end(self):
"""Called when the fit ends."""
pass

def on_epoch_begin(self):
"""Called when the epoch begins."""
pass
Expand Down
83 changes: 83 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
from abc import ABC

from pytorch_lightning.callbacks import Callback


class TrainerCallbackHookMixin(ABC):

def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.callbacks: list[Callback] = []

def on_init_begin(self):
"""Called when the trainer initialization begins."""
for callback in self.callbacks:
callback.set_trainer(self)
callback.on_init_begin()

def on_init_end(self):
"""Called when the trainer initialization ends."""
for callback in self.callbacks:
callback.on_init_end()

def on_fit_begin(self):
"""Called when the fit begins."""
for callback in self.callbacks:
callback.on_fit_begin()

def on_fit_end(self):
"""Called when the fit ends."""
for callback in self.callbacks:
callback.on_fit_end()

def on_epoch_begin(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_epoch_begin()

def on_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_epoch_end()

def on_train_begin(self):
"""Called when the train begins."""
for callback in self.callbacks:
callback.on_train_begin()

def on_train_end(self):
"""Called when the train ends."""
for callback in self.callbacks:
callback.on_train_end()

def on_batch_begin(self):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_batch_begin()

def on_batch_end(self):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_batch_end()

def on_validation_begin(self):
"""Called when the validation loop begins."""
for callback in self.callbacks:
callback.on_validation_begin()

def on_validation_end(self):
"""Called when the validation loop ends."""
for callback in self.callbacks:
callback.on_validation_end()

def on_test_begin(self):
"""Called when the test begins."""
for callback in self.callbacks:
callback.on_test_begin()

def on_test_end(self):
"""Called when the test ends."""
for callback in self.callbacks:
callback.on_test_end()
18 changes: 18 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ def __init__(self):
self.get_val_dataloaders = None
self.use_tpu = None

# Callback system
self.on_validation_begin = None
self.on_validation_end = None
self.on_test_begin = None
self.on_test_end = None

@abstractmethod
def copy_trainer_model_properties(self, model):
# this is just empty shell for code from other class
Expand Down Expand Up @@ -293,6 +299,12 @@ def run_evaluation(self, test=False):
Please define and try again'''
raise MisconfigurationException(m)

# Validation/Test begin callbacks
if test:
self.on_test_begin()
else:
self.on_validation_begin()

# hook
model = self.get_model()
model.on_pre_performance_check()
Expand Down Expand Up @@ -353,6 +365,12 @@ def run_evaluation(self, test=False):
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
self.checkpoint_callback.on_validation_end()

# Validation/Test end callbacks
if test:
self.on_test_end()
else:
self.on_validation_end()

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
# make dataloader_idx arg in validation_step optional
args = [batch, batch_idx]
Expand Down
32 changes: 32 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.profiler import Profiler, PassThroughProfiler

Expand Down Expand Up @@ -57,13 +58,15 @@ class Trainer(TrainerIOMixin,
TrainerEvaluationLoopMixin,
TrainerTrainLoopMixin,
TrainerCallbackConfigMixin,
TrainerCallbackHookMixin
):

def __init__(
self,
logger=True,
checkpoint_callback=True,
early_stop_callback=None,
callbacks: list = [],
default_save_path=None,
gradient_clip_val=0,
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
Expand Down Expand Up @@ -163,6 +166,22 @@ def __init__(
trainer = Trainer(early_stop_callback=early_stop_callback)
callback (:class:`.Callback`): Add a list of callbacks.
Example::
from pytorch_lightning.callbacks import Callback
class PrintCallback(Callback):
def on_train_begin(self):
print("Training is started!")
def on_train_end(self):
print(f"Training is done. The logs are: {self.trainer.logs}")
# a list of callbacks
callbacks = [PrintCallback()]
trainer = Trainer(callbacks=callbacks)
default_save_path (str): Default path for logs and weights when no logger/ckpt_callback passed
Example::
Expand Down Expand Up @@ -579,6 +598,10 @@ def __init__(
"""

# Init callbacks
self.callbacks = callbacks
self.on_init_begin()

# Transfer params
# Backward compatibility
if nb_gpu_nodes is not None:
Expand Down Expand Up @@ -761,6 +784,8 @@ def __init__(
use_amp = True
self.init_amp(use_amp)

self.on_init_end()

@property
def slurm_job_id(self):
try:
Expand Down Expand Up @@ -890,6 +915,9 @@ def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader
_set_dataloader(model, val_dataloader, 'val_dataloader')
_set_dataloader(model, test_dataloader, 'test_dataloader')

# Training begin callbacks
self.on_fit_begin()

# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp2:
task = int(os.environ['SLURM_LOCALID'])
Expand Down Expand Up @@ -929,6 +957,9 @@ def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader

self.run_pretrain_routine(model)

# Training end callbacks
self.on_fit_end()

# return 1 when finished
# used for testing or when we need to know that training succeeded
return 1
Expand Down Expand Up @@ -1082,6 +1113,7 @@ def test(self, model=None):
trainer = Trainer()
trainer.test(model)
"""

self.testing = True
if model is not None:
self.fit(model)
Expand Down
30 changes: 30 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,16 @@ def __init__(self):
self.profiler = None
self.batch_idx = None
self.precision = None
self.callbacks = []
self.max_steps = None

# Callback system
self.on_train_begin = None
self.on_train_end = None
self.on_batch_begin = None
self.on_batch_end = None
self.on_epoch_begin = None
self.on_epoch_end = None

@property
def max_nb_epochs(self):
Expand Down Expand Up @@ -305,6 +315,10 @@ def process_output(self, output, train):
pass

def train(self):

# Train begin callbacks
self.on_train_begin()

warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
model = self.get_model()
Expand Down Expand Up @@ -375,6 +389,7 @@ def train(self):
if self.max_steps and self.max_steps == self.global_step:
self.main_progress_bar.close()
model.on_train_end()
self.on_train_end()
return

# early stopping
Expand All @@ -390,17 +405,23 @@ def train(self):
self.main_progress_bar.close()
with self.profiler.profile('on_train_end'):
model.on_train_end()
self.on_train_end()
return

self.main_progress_bar.close()

with self.profiler.profile('on_train_end'):
model.on_train_end()
self.on_train_end()

if self.logger is not None:
self.logger.finalize("success")

def run_training_epoch(self):

# Epoch begin callbacks
self.on_epoch_begin()

# before epoch hook
if self.is_function_implemented('on_epoch_start'):
model = self.get_model()
Expand Down Expand Up @@ -486,6 +507,9 @@ def run_training_epoch(self):
with self.profiler.profile('on_epoch_end'):
model.on_epoch_end()

# Epoch begin callbacks
self.on_epoch_end()

def run_training_batch(self, batch, batch_idx):
# track grad norms
grad_norm_dic = {}
Expand All @@ -499,6 +523,9 @@ def run_training_batch(self, batch, batch_idx):
if batch is None:
return 0, grad_norm_dic, {}

# Batch begin callbacks
self.on_batch_begin()

# hook
if self.is_function_implemented('on_batch_start'):
model_ref = self.get_model()
Expand Down Expand Up @@ -610,6 +637,9 @@ def optimizer_closure():
with self.profiler.profile('on_batch_end'):
model.on_batch_end()

# Batch end callbacks
self.on_batch_end()

# update progress bar
self.main_progress_bar.update(1)
self.main_progress_bar.set_postfix(**self.training_tqdm_dict)
Expand Down
Loading

0 comments on commit 87c6835

Please sign in to comment.