Skip to content

Commit

Permalink
move callback system to TrainerCallbackHookMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
hadim committed Feb 19, 2020
1 parent b9a04f5 commit bf6b963
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 53 deletions.
16 changes: 0 additions & 16 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,29 +166,17 @@ def on_validation_end(self):

def _do_check_save(self, filepath, current, epoch):
# remove kth
<<<<<<< HEAD
if len(self.best_k_models) == self.save_top_k:
=======
if len(self.best_k_models.keys()) == self.save_top_k:
>>>>>>> fix logic error
delpath = self.kth_best_model
self.best_k_models.pop(self.kth_best_model)
self._del_model(delpath)

self.best_k_models[filepath] = current
<<<<<<< HEAD
if len(self.best_k_models) == self.save_top_k:
# monitor dict has reached k elements
_op = max if self.mode == 'min' else min
self.kth_best_model = _op(self.best_k_models,
key=self.best_k_models.get)
=======
if len(self.best_k_models.keys()) == self.save_top_k:
# monitor dict has reached k elements
_op = max if self.mode == 'min' else min
self.kth_best_model = _op(self.best_k_models,
key=self.best_k_models.get)
>>>>>>> fix logic error
self.kth_value = self.best_k_models[self.kth_best_model]

_op = min if self.mode == 'min' else max
Expand All @@ -199,8 +187,4 @@ def _do_check_save(self, filepath, current, epoch):
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
<<<<<<< HEAD
self._save_model(filepath)
=======
self._save_model(filepath)
>>>>>>> fix logic error
83 changes: 83 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
from abc import ABC

from pytorch_lightning.callbacks import Callback


class TrainerCallbackHookMixin(ABC):

def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.callbacks: list[Callback] = []

def on_init_begin(self):
"""Called when the trainer initialization begins."""
for callback in self.callbacks:
callback.set_trainer(self)
callback.on_init_begin()

def on_init_end(self):
"""Called when the trainer initialization ends."""
for callback in self.callbacks:
callback.on_init_end()

def on_fit_begin(self):
"""Called when the fit begins."""
for callback in self.callbacks:
callback.on_fit_begin()

def on_fit_end(self):
"""Called when the fit ends."""
for callback in self.callbacks:
callback.on_fit_end()

def on_epoch_begin(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_epoch_begin()

def on_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_epoch_end()

def on_train_begin(self):
"""Called when the train begins."""
for callback in self.callbacks:
callback.on_train_begin()

def on_train_end(self):
"""Called when the train ends."""
for callback in self.callbacks:
callback.on_train_end()

def on_batch_begin(self):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_batch_begin()

def on_batch_end(self):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_batch_end()

def on_validation_begin(self):
"""Called when the validation loop begins."""
for callback in self.callbacks:
callback.on_validation_begin()

def on_validation_end(self):
"""Called when the validation loop ends."""
for callback in self.callbacks:
callback.on_validation_end()

def on_test_begin(self):
"""Called when the test begins."""
for callback in self.callbacks:
callback.on_test_begin()

def on_test_end(self):
"""Called when the test ends."""
for callback in self.callbacks:
callback.on_test_end()
25 changes: 14 additions & 11 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,12 @@ def __init__(self):
self.get_test_dataloaders = None
self.get_val_dataloaders = None
self.use_tpu = None
self.callbacks = []

# Callback system
self.on_validation_begin = None
self.on_validation_end = None
self.on_test_begin = None
self.on_test_end = None

@abstractmethod
def copy_trainer_model_properties(self, model):
Expand Down Expand Up @@ -295,11 +300,10 @@ def run_evaluation(self, test=False):
raise MisconfigurationException(m)

# Validation/Test begin callbacks
for callback in self.callbacks:
if test:
callback.on_test_begin()
else:
callback.on_validation_begin()
if test:
self.on_test_begin()
else:
self.on_validation_begin()

# hook
model = self.get_model()
Expand Down Expand Up @@ -362,11 +366,10 @@ def run_evaluation(self, test=False):
self.checkpoint_callback.on_validation_end()

# Validation/Test end callbacks
for callback in self.callbacks:
if test:
callback.on_test_end()
else:
callback.on_validation_end()
if test:
self.on_test_end()
else:
self.on_validation_end()

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
# make dataloader_idx arg in validation_step optional
Expand Down
18 changes: 7 additions & 11 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.profiler import Profiler, PassThroughProfiler

Expand Down Expand Up @@ -57,6 +58,7 @@ class Trainer(TrainerIOMixin,
TrainerEvaluationLoopMixin,
TrainerTrainLoopMixin,
TrainerCallbackConfigMixin,
TrainerCallbackHookMixin
):

def __init__(
Expand Down Expand Up @@ -596,11 +598,9 @@ def on_train_end(self):
"""

# Init start callbacks
# Init callbacks
self.callbacks = callbacks
for callback in self.callbacks:
callback.set_trainer(self)
callback.on_init_begin()
self.on_init_begin()

# Transfer params
# Backward compatibility
Expand Down Expand Up @@ -784,9 +784,7 @@ def on_train_end(self):
use_amp = True
self.init_amp(use_amp)

# Init end callbacks
for callback in self.callbacks:
callback.on_init_end()
self.on_init_end()

@property
def slurm_job_id(self):
Expand Down Expand Up @@ -918,8 +916,7 @@ def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader
_set_dataloader(model, test_dataloader, 'test_dataloader')

# Training begin callbacks
for callback in self.callbacks:
callback.on_fit_begin()
self.on_fit_begin()

# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp2:
Expand Down Expand Up @@ -961,8 +958,7 @@ def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader
self.run_pretrain_routine(model)

# Training end callbacks
for callback in self.callbacks:
callback.on_fit_end()
self.on_fit_end()

# return 1 when finished
# used for testing or when we need to know that training succeeded
Expand Down
27 changes: 15 additions & 12 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ def __init__(self):
self.batch_idx = None
self.precision = None
self.callbacks = []
self.max_steps = None

# Callback system
self.on_train_begin = None
self.on_train_end = None
self.on_batch_begin = None
self.on_batch_end = None
self.on_epoch_begin = None
self.on_epoch_end = None

@property
def max_nb_epochs(self):
Expand Down Expand Up @@ -308,8 +317,7 @@ def process_output(self, output, train):
def train(self):

# Train begin callbacks
for callback in self.callbacks:
callback.on_train_begin()
self.on_train_begin()

warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
Expand Down Expand Up @@ -404,17 +412,15 @@ def train(self):
model.on_train_end()

# Train end callbacks
for callback in self.callbacks:
callback.on_train_end()
self.on_train_end()

if self.logger is not None:
self.logger.finalize("success")

def run_training_epoch(self):

# Epoch begin callbacks
for callback in self.callbacks:
callback.on_epoch_begin()
self.on_epoch_begin()

# before epoch hook
if self.is_function_implemented('on_epoch_start'):
Expand Down Expand Up @@ -502,8 +508,7 @@ def run_training_epoch(self):
model.on_epoch_end()

# Epoch begin callbacks
for callback in self.callbacks:
callback.on_epoch_end()
self.on_epoch_end()

def run_training_batch(self, batch, batch_idx):
# track grad norms
Expand All @@ -519,8 +524,7 @@ def run_training_batch(self, batch, batch_idx):
return 0, grad_norm_dic, {}

# Batch begin callbacks
for callback in self.callbacks:
callback.on_batch_begin()
self.on_batch_begin()

# hook
if self.is_function_implemented('on_batch_start'):
Expand Down Expand Up @@ -634,8 +638,7 @@ def optimizer_closure():
model.on_batch_end()

# Batch end callbacks
for callback in self.callbacks:
callback.on_batch_end()
self.on_batch_end()

# update progress bar
self.main_progress_bar.update(1)
Expand Down
3 changes: 0 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,6 @@ class CurrentTestModel(
trainer.test()


<<<<<<< HEAD
def test_train_dataloaders_passed_to_fit(tmpdir):
""" Verify that train dataloader can be passed to fit """
tutils.reset_seed()
Expand Down Expand Up @@ -697,7 +696,6 @@ def test_trainer_min_steps_and_epochs(tmpdir):
assert trainer.global_step >= math.floor(num_train_samples * 1.5) and \
trainer.current_epoch > 0, "Model did not train for at least min_steps"

=======
def test_callbacks():
"""Test callbacks mechanics."""
tutils.reset_seed()
Expand Down Expand Up @@ -808,7 +806,6 @@ def on_test_end(self):

assert test_callback.test_begin_called
assert test_callback.test_end_called
>>>>>>> add callbacks arguments to the trainer + associated tests

# if __name__ == '__main__':
# pytest.main([__file__])

0 comments on commit bf6b963

Please sign in to comment.