diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 7150138d66acc6..413189eb59b92f 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -27,6 +27,22 @@ def set_trainer(self, trainer): `trainer.batch_idx`, `trainer.global_step` can be used.""" self._trainer = trainer + def on_init_begin(self): + """Called when the trainer initialization begins.""" + pass + + def on_init_end(self): + """Called when the trainer initialization ends.""" + pass + + def on_fit_begin(self): + """Called when the fit begins.""" + pass + + def on_fit_end(self): + """Called when the fit ends.""" + pass + def on_epoch_begin(self): """Called when the epoch begins.""" pass diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py new file mode 100644 index 00000000000000..cc84e4cf5bc2e3 --- /dev/null +++ b/pytorch_lightning/trainer/callback_hook.py @@ -0,0 +1,83 @@ +import os +from abc import ABC + +from pytorch_lightning.callbacks import Callback + + +class TrainerCallbackHookMixin(ABC): + + def __init__(self): + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + self.callbacks: list[Callback] = [] + + def on_init_begin(self): + """Called when the trainer initialization begins.""" + for callback in self.callbacks: + callback.set_trainer(self) + callback.on_init_begin() + + def on_init_end(self): + """Called when the trainer initialization ends.""" + for callback in self.callbacks: + callback.on_init_end() + + def on_fit_begin(self): + """Called when the fit begins.""" + for callback in self.callbacks: + callback.on_fit_begin() + + def on_fit_end(self): + """Called when the fit ends.""" + for callback in self.callbacks: + callback.on_fit_end() + + def on_epoch_begin(self): + """Called when the epoch begins.""" + for callback in self.callbacks: + callback.on_epoch_begin() + + def on_epoch_end(self): + """Called when the epoch ends.""" + for callback in self.callbacks: + callback.on_epoch_end() + + def on_train_begin(self): + """Called when the train begins.""" + for callback in self.callbacks: + callback.on_train_begin() + + def on_train_end(self): + """Called when the train ends.""" + for callback in self.callbacks: + callback.on_train_end() + + def on_batch_begin(self): + """Called when the training batch begins.""" + for callback in self.callbacks: + callback.on_batch_begin() + + def on_batch_end(self): + """Called when the training batch ends.""" + for callback in self.callbacks: + callback.on_batch_end() + + def on_validation_begin(self): + """Called when the validation loop begins.""" + for callback in self.callbacks: + callback.on_validation_begin() + + def on_validation_end(self): + """Called when the validation loop ends.""" + for callback in self.callbacks: + callback.on_validation_end() + + def on_test_begin(self): + """Called when the test begins.""" + for callback in self.callbacks: + callback.on_test_begin() + + def on_test_end(self): + """Called when the test ends.""" + for callback in self.callbacks: + callback.on_test_end() diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f5d2b9327f9fa3..da9134ec5ad9f6 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -169,6 +169,12 @@ def __init__(self): self.get_val_dataloaders = None self.use_tpu = None + # Callback system + self.on_validation_begin = None + self.on_validation_end = None + self.on_test_begin = None + self.on_test_end = None + @abstractmethod def copy_trainer_model_properties(self, model): # this is just empty shell for code from other class @@ -293,6 +299,12 @@ def run_evaluation(self, test=False): Please define and try again''' raise MisconfigurationException(m) + # Validation/Test begin callbacks + if test: + self.on_test_begin() + else: + self.on_validation_begin() + # hook model = self.get_model() model.on_pre_performance_check() @@ -353,6 +365,12 @@ def run_evaluation(self, test=False): if self.proc_rank == 0 and self.checkpoint_callback is not None and not test: self.checkpoint_callback.on_validation_end() + # Validation/Test end callbacks + if test: + self.on_test_end() + else: + self.on_validation_end() + def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6cd87b17e04d63..b75ab5e8544f9d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -25,6 +25,7 @@ from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin +from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.utilities.debugging import MisconfigurationException from pytorch_lightning.profiler import Profiler, PassThroughProfiler @@ -57,6 +58,7 @@ class Trainer(TrainerIOMixin, TrainerEvaluationLoopMixin, TrainerTrainLoopMixin, TrainerCallbackConfigMixin, + TrainerCallbackHookMixin ): def __init__( @@ -64,6 +66,7 @@ def __init__( logger=True, checkpoint_callback=True, early_stop_callback=None, + callbacks: list = [], default_save_path=None, gradient_clip_val=0, gradient_clip=None, # backward compatible, todo: remove in v0.8.0 @@ -163,6 +166,22 @@ def __init__( trainer = Trainer(early_stop_callback=early_stop_callback) + callback (:class:`.Callback`): Add a list of callbacks. + Example:: + + from pytorch_lightning.callbacks import Callback + + class PrintCallback(Callback): + def on_train_begin(self): + print("Training is started!") + + def on_train_end(self): + print(f"Training is done. The logs are: {self.trainer.logs}") + + # a list of callbacks + callbacks = [PrintCallback()] + trainer = Trainer(callbacks=callbacks) + default_save_path (str): Default path for logs and weights when no logger/ckpt_callback passed Example:: @@ -579,6 +598,10 @@ def __init__( """ + # Init callbacks + self.callbacks = callbacks + self.on_init_begin() + # Transfer params # Backward compatibility if nb_gpu_nodes is not None: @@ -761,6 +784,8 @@ def __init__( use_amp = True self.init_amp(use_amp) + self.on_init_end() + @property def slurm_job_id(self): try: @@ -890,6 +915,9 @@ def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader _set_dataloader(model, val_dataloader, 'val_dataloader') _set_dataloader(model, test_dataloader, 'test_dataloader') + # Training begin callbacks + self.on_fit_begin() + # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp2: task = int(os.environ['SLURM_LOCALID']) @@ -929,6 +957,9 @@ def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader self.run_pretrain_routine(model) + # Training end callbacks + self.on_fit_end() + # return 1 when finished # used for testing or when we need to know that training succeeded return 1 @@ -1082,6 +1113,7 @@ def test(self, model=None): trainer = Trainer() trainer.test(model) """ + self.testing = True if model is not None: self.fit(model) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e59183aa6b94e4..e192fe15226e53 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -230,6 +230,16 @@ def __init__(self): self.profiler = None self.batch_idx = None self.precision = None + self.callbacks = [] + self.max_steps = None + + # Callback system + self.on_train_begin = None + self.on_train_end = None + self.on_batch_begin = None + self.on_batch_end = None + self.on_epoch_begin = None + self.on_epoch_end = None @property def max_nb_epochs(self): @@ -305,6 +315,10 @@ def process_output(self, output, train): pass def train(self): + + # Train begin callbacks + self.on_train_begin() + warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,' ' but will start from "0" in v0.8.0.', DeprecationWarning) model = self.get_model() @@ -375,6 +389,7 @@ def train(self): if self.max_steps and self.max_steps == self.global_step: self.main_progress_bar.close() model.on_train_end() + self.on_train_end() return # early stopping @@ -390,17 +405,23 @@ def train(self): self.main_progress_bar.close() with self.profiler.profile('on_train_end'): model.on_train_end() + self.on_train_end() return self.main_progress_bar.close() with self.profiler.profile('on_train_end'): model.on_train_end() + self.on_train_end() if self.logger is not None: self.logger.finalize("success") def run_training_epoch(self): + + # Epoch begin callbacks + self.on_epoch_begin() + # before epoch hook if self.is_function_implemented('on_epoch_start'): model = self.get_model() @@ -486,6 +507,9 @@ def run_training_epoch(self): with self.profiler.profile('on_epoch_end'): model.on_epoch_end() + # Epoch begin callbacks + self.on_epoch_end() + def run_training_batch(self, batch, batch_idx): # track grad norms grad_norm_dic = {} @@ -499,6 +523,9 @@ def run_training_batch(self, batch, batch_idx): if batch is None: return 0, grad_norm_dic, {} + # Batch begin callbacks + self.on_batch_begin() + # hook if self.is_function_implemented('on_batch_start'): model_ref = self.get_model() @@ -610,6 +637,9 @@ def optimizer_closure(): with self.profiler.profile('on_batch_end'): model.on_batch_end() + # Batch end callbacks + self.on_batch_end() + # update progress bar self.main_progress_bar.update(1) self.main_progress_bar.set_postfix(**self.training_tqdm_dict) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 283b332aa0d7fe..684b2312f9481b 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -17,9 +17,12 @@ LightningValidationStepMixin, LightningValidationMultipleDataloadersMixin, LightningTestMultipleDataloadersMixin, + LightningTestMixin, + LightningValidationMixin ) from pytorch_lightning.core.lightning import load_hparams_from_tags_csv from pytorch_lightning.trainer.logging import TrainerLoggingMixin +from pytorch_lightning.callbacks import Callback def test_no_val_module(tmpdir): @@ -694,5 +697,120 @@ def test_trainer_min_steps_and_epochs(tmpdir): trainer.current_epoch > 0, "Model did not train for at least min_steps" +def test_trainer_callback_system(tmpdir): + """Test the callback system.""" + + class CurrentTestModel( + LightningTestMixin, + LightningValidationMixin, + LightningTestModelBase, + ): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + class TestCallback(Callback): + def __init__(self): + super().__init__() + self.on_init_begin_called = False + self.on_init_end_called = False + self.on_fit_begin_called = False + self.on_fit_end_called = False + self.on_epoch_begin_called = False + self.on_epoch_end_called = False + self.on_batch_begin_called = False + self.on_batch_end_called = False + self.on_train_begin_called = False + self.on_train_end_called = False + self.on_validation_begin_called = False + self.on_validation_end_called = False + self.on_test_begin_called = False + self.on_test_end_called = False + + def on_init_begin(self): + self.on_init_begin_called = True + + def on_init_end(self): + self.on_init_end_called = True + + def on_fit_begin(self): + self.on_fit_begin_called = True + + def on_fit_end(self): + self.on_fit_end_called = True + + def on_epoch_begin(self): + self.on_epoch_begin_called = True + + def on_epoch_end(self): + self.on_epoch_end_called = True + + def on_batch_begin(self): + self.on_batch_begin_called = True + + def on_batch_end(self): + self.on_batch_end_called = True + + def on_train_begin(self): + self.on_train_begin_called = True + + def on_train_end(self): + self.on_train_end_called = True + + def on_validation_begin(self): + self.on_validation_begin_called = True + + def on_validation_end(self): + self.on_validation_end_called = True + + def on_test_begin(self): + self.on_test_begin_called = True + + def on_test_end(self): + self.on_test_end_called = True + + test_callback = TestCallback() + + trainer_options = {} + trainer_options['callbacks'] = [test_callback] + trainer_options['max_epochs'] = 1 + trainer_options['val_percent_check'] = 0.1 + trainer_options['train_percent_check'] = 0.2 + trainer_options['show_progress_bar'] = False + + assert not test_callback.on_init_begin_called + assert not test_callback.on_init_end_called + + # fit model + trainer = Trainer(**trainer_options) + + assert trainer.callbacks[0] == test_callback + assert test_callback.on_init_begin_called + assert test_callback.on_init_end_called + assert not test_callback.on_fit_begin_called + assert not test_callback.on_fit_begin_called + + trainer.fit(model) + + assert test_callback.on_fit_begin_called + assert test_callback.on_fit_end_called + assert test_callback.on_epoch_begin_called + assert test_callback.on_epoch_begin_called + assert test_callback.on_batch_begin_called + assert test_callback.on_batch_end_called + assert test_callback.on_train_begin_called + assert test_callback.on_train_end_called + assert test_callback.on_validation_begin_called + assert test_callback.on_validation_end_called + assert not test_callback.on_test_begin_called + assert not test_callback.on_test_end_called + + trainer.test() + + assert test_callback.on_test_begin_called + assert test_callback.on_test_end_called + + # if __name__ == '__main__': # pytest.main([__file__])