diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index cfeaa81debfc0..433fce6de045f 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -9,6 +9,6 @@ Callbacks _save_model, on_epoch_end, on_train_end, - on_epoch_begin, + on_epoch_start, check_monitor_top_k, - on_train_begin, \ No newline at end of file + on_train_start, \ No newline at end of file diff --git a/docs/source/loggers.rst b/docs/source/loggers.rst index c030e653fdf4b..cd9314093d1b2 100644 --- a/docs/source/loggers.rst +++ b/docs/source/loggers.rst @@ -9,4 +9,4 @@ Loggers _save_model, on_epoch_end, on_train_end, - on_epoch_begin, + on_epoch_start, diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index aab6dd6137e1d..f7857661939c0 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -29,10 +29,12 @@ from .core import data_loader, LightningModule from .trainer import Trainer + from .callbacks import Callback __all__ = [ 'Trainer', 'LightningModule', + 'Callback', 'data_loader', ] # __call__ = __all__ diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 7150138d66acc..17dda597a3d30 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -8,61 +8,61 @@ import abc -_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization" - - class Callback(abc.ABC): """Abstract base class used to build new callbacks.""" - def __init__(self): - self._trainer = None + def on_init_start(self, trainer, pl_module): + """Called when the trainer initialization begins.""" + assert pl_module is None + + def on_init_end(self, trainer, pl_module): + """Called when the trainer initialization ends.""" + pass - @property - def trainer(self): - assert self._trainer is not None, _NO_TRAINER_ERROR_MSG - return self._trainer + def on_fit_start(self, trainer, pl_module): + """Called when the fit begins.""" + pass - def set_trainer(self, trainer): - """Make a link to the trainer, so different things like `trainer.current_epoch`, - `trainer.batch_idx`, `trainer.global_step` can be used.""" - self._trainer = trainer + def on_fit_end(self, trainer, pl_module): + """Called when the fit ends.""" + pass - def on_epoch_begin(self): + def on_epoch_start(self, trainer, pl_module): """Called when the epoch begins.""" pass - def on_epoch_end(self): + def on_epoch_end(self, trainer, pl_module): """Called when the epoch ends.""" pass - def on_batch_begin(self): + def on_batch_start(self, trainer, pl_module): """Called when the training batch begins.""" pass - def on_batch_end(self): + def on_batch_end(self, trainer, pl_module): """Called when the training batch ends.""" pass - def on_train_begin(self): + def on_train_start(self, trainer, pl_module): """Called when the train begins.""" pass - def on_train_end(self): + def on_train_end(self, trainer, pl_module): """Called when the train ends.""" pass - def on_validation_begin(self): + def on_validation_start(self, trainer, pl_module): """Called when the validation loop begins.""" pass - def on_validation_end(self): + def on_validation_end(self, trainer, pl_module): """Called when the validation loop ends.""" pass - def on_test_begin(self): + def on_test_start(self, trainer, pl_module): """Called when the test begins.""" pass - def on_test_end(self): + def on_test_end(self, trainer, pl_module): """Called when the test ends.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 645eff2485445..435f0d533c500 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -64,7 +64,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: self.monitor_op = mode_dict[mode] self.min_delta *= 1 if self.monitor_op == np.greater else -1 - self.on_train_begin() + self.on_train_start(None, None) def check_metrics(self, logs): monitor_val = logs.get(self.monitor) @@ -82,14 +82,14 @@ def check_metrics(self, logs): return True - def on_train_begin(self): + def on_train_start(self, trainer, pl_module): # Allow instances to be re-used self.wait = 0 self.stopped_epoch = 0 self.best = np.Inf if self.monitor_op == np.less else -np.Inf - def on_epoch_end(self): - logs = self.trainer.callback_metrics + def on_epoch_end(self, trainer, pl_module): + logs = trainer.callback_metrics stop_training = False if not self.check_metrics(logs): return stop_training @@ -101,13 +101,13 @@ def on_epoch_end(self): else: self.wait += 1 if self.wait >= self.patience: - self.stopped_epoch = self.trainer.current_epoch + self.stopped_epoch = trainer.current_epoch stop_training = True - self.on_train_end() + self.on_train_end(trainer, pl_module) return stop_training - def on_train_end(self): + def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,' ' but will start from "0" in v0.8.0.', DeprecationWarning) diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index f4e5ad3764052..9662c0b48348c 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -44,8 +44,7 @@ def __init__(self, scheduling: dict): self.scheduling = scheduling self.epochs = sorted(scheduling.keys()) - def on_epoch_begin(self): - trainer = self.trainer + def on_epoch_start(self, trainer, pl_module): # indexing epochs from 1 (until v0.6.x) # In v0.8.0, ` + 1` should be removed. epoch = trainer.current_epoch + 1 diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a9f7d65b3d4b2..24727033ff1d3 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -117,9 +117,9 @@ def check_monitor_top_k(self, current): return True return self.monitor_op(current, self.best_k_models[self.kth_best_model]) - def on_validation_end(self): - logs = self.trainer.callback_metrics - epoch = self.trainer.current_epoch + def on_validation_end(self, trainer, pl_module): + logs = trainer.callback_metrics + epoch = trainer.current_epoch self.epochs_since_last_check += 1 if self.save_top_k == 0: diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 4e6f5a9433aee..3756b19e433c0 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -48,9 +48,6 @@ def configure_checkpoint_callback(self): # if checkpoint callback used, then override the weights path self.weights_save_path = self.checkpoint_callback.filepath - # link to the trainer - self.checkpoint_callback.set_trainer(self) - # if weights_save_path is still none here, set to current working dir if self.weights_save_path is None: self.weights_save_path = self.default_save_path @@ -80,6 +77,3 @@ def configure_early_stopping(self, early_stop_callback): else: self.early_stop_callback = early_stop_callback self.enable_early_stop = True - - if self.early_stop_callback is not None: - self.early_stop_callback.set_trainer(self) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py new file mode 100644 index 0000000000000..66bda345ec5a6 --- /dev/null +++ b/pytorch_lightning/trainer/callback_hook.py @@ -0,0 +1,83 @@ +from typing import Callable +from abc import ABC + +from pytorch_lightning.callbacks import Callback + + +class TrainerCallbackHookMixin(ABC): + + def __init__(self): + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + self.callbacks: list[Callback] = [] + self.get_model: Callable = ... + + def on_init_start(self): + """Called when the trainer initialization begins.""" + for callback in self.callbacks: + callback.on_init_start(self, None) + + def on_init_end(self): + """Called when the trainer initialization ends.""" + for callback in self.callbacks: + callback.on_init_end(self, self.get_model()) + + def on_fit_start(self): + """Called when the fit begins.""" + for callback in self.callbacks: + callback.on_fit_start(self, self.get_model()) + + def on_fit_end(self): + """Called when the fit ends.""" + for callback in self.callbacks: + callback.on_fit_end(self, self.get_model()) + + def on_epoch_start(self): + """Called when the epoch begins.""" + for callback in self.callbacks: + callback.on_epoch_start(self, self.get_model()) + + def on_epoch_end(self): + """Called when the epoch ends.""" + for callback in self.callbacks: + callback.on_epoch_end(self, self.get_model()) + + def on_train_start(self): + """Called when the train begins.""" + for callback in self.callbacks: + callback.on_train_start(self, self.get_model()) + + def on_train_end(self): + """Called when the train ends.""" + for callback in self.callbacks: + callback.on_train_end(self, self.get_model()) + + def on_batch_start(self): + """Called when the training batch begins.""" + for callback in self.callbacks: + callback.on_batch_start(self, self.get_model()) + + def on_batch_end(self): + """Called when the training batch ends.""" + for callback in self.callbacks: + callback.on_batch_end(self, self.get_model()) + + def on_validation_start(self): + """Called when the validation loop begins.""" + for callback in self.callbacks: + callback.on_validation_start(self, self.get_model()) + + def on_validation_end(self): + """Called when the validation loop ends.""" + for callback in self.callbacks: + callback.on_validation_end(self, self.get_model()) + + def on_test_start(self): + """Called when the test begins.""" + for callback in self.callbacks: + callback.on_test_start(self, self.get_model()) + + def on_test_end(self): + """Called when the test ends.""" + for callback in self.callbacks: + callback.on_test_end(self, self.get_model()) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 1bb0e8853ab9d..bca62836bfdc9 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -123,6 +123,8 @@ """ +from typing import Callable + import sys from abc import ABC, abstractmethod @@ -171,6 +173,12 @@ def __init__(self): self.reload_dataloaders_every_epoch = None self.progress_bar_refresh_rate = None + # Callback system + self.on_validation_start: Callable = ... + self.on_validation_end: Callable = ... + self.on_test_start: Callable = ... + self.on_test_end: Callable = ... + @abstractmethod def copy_trainer_model_properties(self, model): # this is just empty shell for code from other class @@ -302,6 +310,12 @@ def run_evaluation(self, test_mode: bool = False): " Please define and try again" raise MisconfigurationException(m) + # Validation/Test begin callbacks + if test_mode: + self.on_test_start() + else: + self.on_validation_start() + # hook model = self.get_model() model.on_pre_performance_check() @@ -363,7 +377,13 @@ def run_evaluation(self, test_mode: bool = False): # model checkpointing if self.proc_rank == 0 and self.checkpoint_callback is not None and not test_mode: - self.checkpoint_callback.on_validation_end() + self.checkpoint_callback.on_validation_end(self, self.get_model()) + + # Validation/Test end callbacks + if test_mode: + self.on_test_end() + else: + self.on_validation_end() def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False): # make dataloader_idx arg in validation_step optional diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 566c2e734a09a..5116df9fc0786 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -30,8 +30,10 @@ from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin +from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.utilities.debugging import MisconfigurationException from pytorch_lightning.profiler import Profiler, PassThroughProfiler +from pytorch_lightning.callbacks import Callback try: @@ -62,6 +64,7 @@ class Trainer(TrainerIOMixin, TrainerEvaluationLoopMixin, TrainerTrainLoopMixin, TrainerCallbackConfigMixin, + TrainerCallbackHookMixin ): def __init__( @@ -69,6 +72,7 @@ def __init__( logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = None, + callbacks: List[Callback] = [], default_save_path: Optional[str] = None, gradient_clip_val: float = 0, gradient_clip=None, # backward compatible, todo: remove in v0.8.0 @@ -171,6 +175,18 @@ def __init__( trainer = Trainer(early_stop_callback=early_stop_callback) + callbacks: Add a list of callbacks. + Example:: + from pytorch_lightning.callbacks import Callback + class PrintCallback(Callback): + def on_train_start(self): + print("Training is started!") + def on_train_end(self): + print(f"Training is done. The logs are: {self.trainer.logs}") + # a list of callbacks + callbacks = [PrintCallback()] + trainer = Trainer(callbacks=callbacks) + default_save_path: Default path for logs and weights when no logger/ckpt_callback passed Example:: @@ -599,6 +615,10 @@ def __init__( """ + # Init callbacks + self.callbacks = callbacks + self.on_init_start() + # benchmarking self.benchmark = benchmark if benchmark: @@ -786,6 +806,9 @@ def __init__( use_amp = True self.init_amp(use_amp) + # Callback system + self.on_init_end() + @property def slurm_job_id(self) -> int: try: @@ -914,6 +937,9 @@ def fit( # feed to .fit() """ + # Fit begin callbacks + self.on_fit_start() + # set up the passed in dataloaders (if needed) self.__set_fit_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders) @@ -957,6 +983,9 @@ def fit( self.run_pretrain_routine(model) + # Fit end callbacks + self.on_fit_end() + # return 1 when finished # used for testing or when we need to know that training succeeded return 1 @@ -1090,9 +1119,8 @@ def run_pretrain_routine(self, model: LightningModule): self.reset_val_dataloader(ref_model) # check if we should run validation during training - self.disable_validation = ((self.num_val_batches == 0 or - not self.is_overriden('validation_step')) and - not self.fast_dev_run) + self.disable_validation = self.num_val_batches == 0 or not self.is_overriden('validation_step') + self.disable_validation = self.disable_validation and not self.fast_dev_run # run tiny validation (if validation defined) # to make sure program won't crash during val @@ -1162,3 +1190,51 @@ def test(self, model: Optional[LightningModule] = None): if model is not None: self.fit(model) self.run_evaluation(test_mode=True) + + +def _set_dataloader(model, dataloader, attribute): + r''' + Check dataloaders passed to .fit() method if they are pytorch DataLoader + objects and whether or not we should overright the corresponding dataloader + in the model + + Args: + model (LightningModule): The model to check + + dataloader: If a pytorch dataloader (or a list of pytorch dataloaders) + is passed, it will be incorporate into the model as model.attribute. + If attribute alreay exist it will warn the userpass. If not a + dataloader will throw an error + + attribute (str): The attribute to save the dataloader under + + ''' + # Check if attribute comes directly from base class or + # derived in user subclass + if LightningModule.__qualname__ in getattr(model, attribute).__qualname__: + # Val and test should be list of dataloaders + dataloader = dataloader if attribute == 'train_dataloader' or \ + (attribute != 'train_dataloader' and isinstance(dataloader, list)) else [dataloader] + + # Check we are given valid dataloaders + is_dataloader = isinstance(dataloader, torch.utils.data.DataLoader) + is_dataloader_list = isinstance(dataloader, list) + valid_loaders = None + if is_dataloader_list: + valid_loaders = all(isinstance(d, torch.utils.data.DataLoader) for d in dataloader) + if is_dataloader or is_dataloader_list and valid_loaders: + + # Overwrite abstract methods + def dl(): + return dataloader + dl.__name__ = attribute + setattr(model, attribute, dl) + + elif dataloader and dataloader != [None]: + raise ValueError(f'`{attribute}` needs to be an instance of ' + '`torch.utils.data.DataLoader` or a list of ' + 'DataLoaders, instead got %r`' % dataloader) + + elif dataloader: # if default (None) is passed, do not warn the user + warnings.warn(f'Model has predefined `{attribute}`,' + f' will skip `{attribute}={dataloader}` passed to fit method.') diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 93466ae6087b2..847690968b842 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -152,6 +152,8 @@ def training_step(self, batch, batch_idx): """ +from typing import Callable + import copy import warnings from abc import ABC, abstractmethod @@ -160,6 +162,7 @@ def training_step(self, batch, batch_idx): import numpy as np from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.callbacks.base import Callback try: from apex import amp @@ -229,6 +232,16 @@ def __init__(self): self.max_steps = ... self.max_steps = ... + # Callback system + self.callbacks: list[Callback] = [] + self.max_steps = None + self.on_train_start: Callable = ... + self.on_train_end: Callable = ... + self.on_batch_start: Callable = ... + self.on_batch_end: Callable = ... + self.on_epoch_start: Callable = ... + self.on_epoch_end: Callable = ... + @property def max_nb_epochs(self): """ @@ -320,6 +333,10 @@ def has_arg(self, f_name, arg_name): def train(self): warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,' ' but will start from "0" in v0.8.0.', DeprecationWarning) + + # Train begin callbacks + self.on_train_start() + # get model model = self.get_model() try: @@ -367,7 +384,7 @@ def train(self): self.main_progress_bar.set_description(desc) # changing gradient according accumulation_scheduler - self.accumulation_scheduler.on_epoch_begin() + self.accumulation_scheduler.on_epoch_start(self, self.get_model()) # ----------------- # RUN TNG EPOCH @@ -390,20 +407,22 @@ def train(self): if self.max_steps and self.max_steps == self.global_step: self.main_progress_bar.close() model.on_train_end() + self.on_train_end() return # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True - if (self.enable_early_stop and not self.disable_validation and is_val_epoch and - ((met_min_epochs and met_min_steps) or self.fast_dev_run)): - should_stop = self.early_stop_callback.on_epoch_end() - # stop training - stop = should_stop and met_min_epochs - if stop: - self.run_training_teardown() - return + if self.enable_early_stop and not self.disable_validation and is_val_epoch: + if ((met_min_epochs and met_min_steps) or self.fast_dev_run): + should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model()) + # stop training + stop = should_stop and met_min_epochs + if stop: + self.run_training_teardown() + self.on_train_end() + return self.run_training_teardown() @@ -411,7 +430,14 @@ def train(self): log.info('Detected KeyboardInterrupt, attempting graceful shutdown...') self.run_training_teardown() + # Train end callbacks + self.on_train_end() + def run_training_epoch(self): + + # Epoch begin callbacks + self.on_epoch_start() + # before epoch hook if self.is_function_implemented('on_epoch_start'): model = self.get_model() @@ -455,8 +481,8 @@ def run_training_epoch(self): # --------------- is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - should_check_val = (not self.disable_validation and can_check_epoch and - (is_val_check_batch or early_stop_epoch)) + should_check_val = not self.disable_validation and can_check_epoch + should_check_val = should_check_val and (is_val_check_batch or early_stop_epoch) # fast_dev_run always forces val checking after train batch if self.fast_dev_run or should_check_val: @@ -498,6 +524,9 @@ def run_training_epoch(self): with self.profiler.profile('on_epoch_end'): model.on_epoch_end() + # Epoch begin callbacks + self.on_epoch_end() + def run_training_batch(self, batch, batch_idx): # track grad norms grad_norm_dic = {} @@ -511,6 +540,9 @@ def run_training_batch(self, batch, batch_idx): if batch is None: return 0, grad_norm_dic, {} + # Batch begin callbacks + self.on_batch_start() + # hook if self.is_function_implemented('on_batch_start'): model_ref = self.get_model() @@ -619,6 +651,9 @@ def optimizer_closure(): with self.profiler.profile('on_batch_end'): model.on_batch_end() + # Batch end callbacks + self.on_batch_end() + # update progress bar if batch_idx % self.progress_bar_refresh_rate == 0: self.main_progress_bar.update(self.progress_bar_refresh_rate) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index c62c6f3654549..7fa4059afc3e2 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -39,5 +39,3 @@ def configure_accumulated_gradients(self, accumulate_grad_batches): self.accumulation_scheduler = GradientAccumulationScheduler(schedule) else: raise TypeError("Gradient accumulation supports only int and dict types") - - self.accumulation_scheduler.set_trainer(self) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index f33b1f3e1017f..7850638475ad7 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -23,10 +23,13 @@ LightValStepFitSingleDataloaderMixin, LightTrainDataloader, LightTestDataloader, + LightValidationMixin, + LightTestMixin ) from pytorch_lightning.core.lightning import load_hparams_from_tags_csv from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning import Callback def test_no_val_module(tmpdir): @@ -242,13 +245,12 @@ def mock_save_function(filepath): checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() - checkpoint_callback.set_trainer(trainer) # emulate callback's calls during the training for i, loss in enumerate(losses): - checkpoint_callback._trainer.current_epoch = i - checkpoint_callback._trainer.callback_metrics = {'val_loss': loss} - checkpoint_callback.on_validation_end() + trainer.current_epoch = i + trainer.callback_metrics = {'val_loss': loss} + checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) @@ -266,13 +268,12 @@ def mock_save_function(filepath): checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=0, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() - checkpoint_callback.set_trainer(trainer) # emulate callback's calls during the training for i, loss in enumerate(losses): - checkpoint_callback._trainer.current_epoch = i - checkpoint_callback._trainer.callback_metrics = {'val_loss': loss} - checkpoint_callback.on_validation_end() + trainer.current_epoch = i + trainer.callback_metrics = {'val_loss': loss} + checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = os.listdir(save_dir) @@ -286,13 +287,12 @@ def mock_save_function(filepath): checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix') checkpoint_callback.save_function = mock_save_function trainer = Trainer() - checkpoint_callback.set_trainer(trainer) # emulate callback's calls during the training for i, loss in enumerate(losses): - checkpoint_callback._trainer.current_epoch = i - checkpoint_callback._trainer.callback_metrics = {'val_loss': loss} - checkpoint_callback.on_validation_end() + trainer.current_epoch = i + trainer.callback_metrics = {'val_loss': loss} + checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) @@ -310,13 +310,12 @@ def mock_save_function(filepath): open(f'{save_dir}/other_file.ckpt', 'a').close() checkpoint_callback.save_function = mock_save_function trainer = Trainer() - checkpoint_callback.set_trainer(trainer) # emulate callback's calls during the training for i, loss in enumerate(losses): - checkpoint_callback._trainer.current_epoch = i - checkpoint_callback._trainer.callback_metrics = {'val_loss': loss} - checkpoint_callback.on_validation_end() + trainer.current_epoch = i + trainer.callback_metrics = {'val_loss': loss} + checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) @@ -335,13 +334,12 @@ def mock_save_function(filepath): checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=4, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() - checkpoint_callback.set_trainer(trainer) # emulate callback's calls during the training for loss in losses: - checkpoint_callback._trainer.current_epoch = 0 - checkpoint_callback._trainer.callback_metrics = {'val_loss': loss} - checkpoint_callback.on_validation_end() + trainer.current_epoch = 0 + trainer.callback_metrics = {'val_loss': loss} + checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) @@ -357,13 +355,12 @@ def mock_save_function(filepath): checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=3, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() - checkpoint_callback.set_trainer(trainer) # emulate callback's calls during the training for loss in losses: - checkpoint_callback._trainer.current_epoch = 0 - checkpoint_callback._trainer.callback_metrics = {'val_loss': loss} - checkpoint_callback.on_validation_end() + trainer.current_epoch = 0 + trainer.callback_metrics = {'val_loss': loss} + checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) @@ -798,8 +795,9 @@ def test_benchmark_option(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningValidationMultipleDataloadersMixin, - LightningTestModelBase + LightValidationMultipleDataloadersMixin, + LightTrainDataloader, + TestModelBase ): pass @@ -858,5 +856,120 @@ def test_end(self, outputs): Trainer().test(model) +def test_trainer_callback_system(tmpdir): + """Test the callback system.""" + + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + class TestCallback(Callback): + def __init__(self): + super().__init__() + self.on_init_start_called = False + self.on_init_end_called = False + self.on_fit_start_called = False + self.on_fit_end_called = False + self.on_epoch_start_called = False + self.on_epoch_end_called = False + self.on_batch_start_called = False + self.on_batch_end_called = False + self.on_train_start_called = False + self.on_train_end_called = False + self.on_validation_start_called = False + self.on_validation_end_called = False + self.on_test_start_called = False + self.on_test_end_called = False + + def on_init_start(self, trainer, pl_module): + self.on_init_start_called = True + + def on_init_end(self, trainer, pl_module): + self.on_init_end_called = True + + def on_fit_start(self, trainer, pl_module): + self.on_fit_start_called = True + + def on_fit_end(self, trainer, pl_module): + self.on_fit_end_called = True + + def on_epoch_start(self, trainer, pl_module): + self.on_epoch_start_called = True + + def on_epoch_end(self, trainer, pl_module): + self.on_epoch_end_called = True + + def on_batch_start(self, trainer, pl_module): + self.on_batch_start_called = True + + def on_batch_end(self, trainer, pl_module): + self.on_batch_end_called = True + + def on_train_start(self, trainer, pl_module): + self.on_train_start_called = True + + def on_train_end(self, trainer, pl_module): + self.on_train_end_called = True + + def on_validation_start(self, trainer, pl_module): + self.on_validation_start_called = True + + def on_validation_end(self, trainer, pl_module): + self.on_validation_end_called = True + + def on_test_start(self, trainer, pl_module): + self.on_test_start_called = True + + def on_test_end(self, trainer, pl_module): + self.on_test_end_called = True + + test_callback = TestCallback() + + trainer_options = {} + trainer_options['callbacks'] = [test_callback] + trainer_options['max_epochs'] = 1 + trainer_options['val_percent_check'] = 0.1 + trainer_options['train_percent_check'] = 0.2 + trainer_options['show_progress_bar'] = False + + assert not test_callback.on_init_start_called + assert not test_callback.on_init_end_called + + # fit model + trainer = Trainer(**trainer_options) + + assert trainer.callbacks[0] == test_callback + assert test_callback.on_init_start_called + assert test_callback.on_init_end_called + assert not test_callback.on_fit_start_called + assert not test_callback.on_fit_start_called + + trainer.fit(model) + + assert test_callback.on_fit_start_called + assert test_callback.on_fit_end_called + assert test_callback.on_epoch_start_called + assert test_callback.on_epoch_start_called + assert test_callback.on_batch_start_called + assert test_callback.on_batch_end_called + assert test_callback.on_train_start_called + assert test_callback.on_train_end_called + assert test_callback.on_validation_start_called + assert test_callback.on_validation_end_called + assert not test_callback.on_test_start_called + assert not test_callback.on_test_end_called + + trainer.test() + + assert test_callback.on_test_start_called + assert test_callback.on_test_end_called + # if __name__ == '__main__': # pytest.main([__file__])