diff --git a/CHANGELOG.md b/CHANGELOG.md index 703a2c7eec19f..5f8f7a08b089b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) +- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + - Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) @@ -26,9 +32,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) +- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + ### Deprecated +- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + ### Removed - Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index ea9cb03d18366..06191dcff6d80 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -20,6 +20,7 @@ from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.enums import AMPType, LightningEnum @@ -80,8 +81,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None: def start_training(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_training(trainer) - def start_testing(self, trainer: 'Trainer') -> None: - self.training_type_plugin.start_testing(trainer) + def start_evaluating(self, trainer: 'Trainer') -> None: + self.training_type_plugin.start_evaluating(trainer) def start_predicting(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_predicting(trainer) @@ -323,7 +324,7 @@ def setup_optimizers(self, trainer: 'Trainer') -> None: trainer: the Trainer, these optimizers should be connected to model: the model to be optimized by the created optimizers """ - if trainer.testing: + if trainer.state not in (TrainerState.FITTING, TrainerState.TUNING): return optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers( trainer=trainer, model=self.lightning_module @@ -417,7 +418,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I @property def results(self) -> Any: """ - The results of the last training/testing run will be cached within the training type plugin. + The results of the last run will be cached within the training type plugin. In distributed training, we make sure to transfer the results to the appropriate master process. """ return self.training_type_plugin.results diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ff5cfeaf1bb96..38ccce648502a 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -137,12 +137,13 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.patience = callback_state['patience'] def on_validation_end(self, trainer, pl_module): - if trainer.running_sanity_check: + from pytorch_lightning.trainer.states import TrainerState + if trainer.state != TrainerState.FITTING or trainer.sanity_checking: return - self._run_early_stopping_check(trainer, pl_module) + self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer, pl_module): + def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met and if so tells the trainer to stop the training. diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d552560191a35..4233933af1b1a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -213,12 +213,14 @@ def save_checkpoint(self, trainer, pl_module): epoch = trainer.current_epoch global_step = trainer.global_step + from pytorch_lightning.trainer.states import TrainerState if ( trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.state != TrainerState.FITTING # don't save anything during non-fit + or trainer.sanity_checking # don't save anything during sanity check or self.save_top_k == 0 # no models are saved or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch - or trainer.running_sanity_check # don't save anything during sanity check or self._last_global_step_saved == global_step # already saved at the last step ): return diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 2f133eaccf512..c382e67b21a64 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -380,7 +380,6 @@ def init_test_tqdm(self) -> tqdm: def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) self.main_progress_bar = tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): @@ -412,7 +411,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) - if not trainer.running_sanity_check: + if trainer.sanity_checking: + reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) + else: self._update_bar(self.main_progress_bar) # fill up remaining self.val_progress_bar = self.init_validation_tqdm() reset(self.val_progress_bar, self.total_val_batches) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c13f6a226e83b..4c839f3a6c906 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -25,7 +25,7 @@ from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import ScriptModule, Tensor @@ -44,8 +44,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args -if TYPE_CHECKING: - from pytorch_lightning.trainer.states import RunningStage log = logging.getLogger(__name__) @@ -69,7 +67,6 @@ class LightningModule( "on_gpu", "current_epoch", "global_step", - "running_stage", "global_rank", "local_rank", "logger", @@ -172,10 +169,6 @@ def automatic_optimization(self) -> bool: """ return self._automatic_optimization - @property - def running_stage(self) -> Optional["RunningStage"]: - return self.trainer._running_stage if self.trainer else None - @automatic_optimization.setter def automatic_optimization(self, automatic_optimization: bool) -> None: self._automatic_optimization = automatic_optimization diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index c0b691bb07cb8..170cdc4600bb4 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -18,7 +18,6 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.warnings import WarningCache @@ -43,9 +42,9 @@ def __init__(self, pl_module: LightningModule): self.module = pl_module def forward(self, *inputs, **kwargs): - running_stage = self.module.running_stage + trainer = self.module.trainer - if running_stage == RunningStage.TRAINING: + if trainer and trainer.training: output = self.module.training_step(*inputs, **kwargs) # In manual_optimization, we need to prevent DDP reducer as @@ -53,18 +52,18 @@ def forward(self, *inputs, **kwargs): # `require_backward_grad_sync` will be reset in the # ddp_plugin ``post_training_step`` hook if not self.module.automatic_optimization: - self.module.trainer.model.require_backward_grad_sync = False + trainer.model.require_backward_grad_sync = False warn_if_output_is_none(output, "training_step") - elif running_stage == RunningStage.TESTING: + elif trainer and trainer.testing: output = self.module.test_step(*inputs, **kwargs) warn_if_output_is_none(output, "test_step") - elif running_stage == RunningStage.EVALUATING: + elif trainer and (trainer.sanity_checking or trainer.validating): output = self.module.validation_step(*inputs, **kwargs) warn_if_output_is_none(output, "validation_step") - elif running_stage == RunningStage.PREDICTING: + elif trainer and trainer.predicting: output = self.module.predict(*inputs, **kwargs) warn_if_output_is_none(output, "predict") diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index d699dcb690d88..3dace06cbf825 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -27,6 +27,7 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -103,7 +104,7 @@ def start_training(self, trainer): # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] - def start_testing(self, trainer): + def start_evaluating(self, trainer): mp.spawn(self.new_process, **self.mp_spawn_kwargs) def start_predicting(self, trainer): @@ -152,7 +153,7 @@ def new_process(self, process_idx, trainer, mp_queue): self.barrier() - results = trainer.train_or_test_or_predict() + results = trainer.run_stage() # persist info in ddp_spawn self.transfer_distrib_spawn_state_on_fit_end(results) @@ -204,7 +205,6 @@ def on_save(self, checkpoint: dict) -> dict: return checkpoint def transfer_distrib_spawn_state_on_fit_end(self, results): - # TODO: is there a better way than accessing callback through model -> trainer -> callback? checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -213,8 +213,11 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # save the last weights last_path = None - # TODO: is there a better way than accessing trainer through model -> trainer? - if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: + if ( + self.lightning_module.trainer.state == TrainerState.FITTING + and best_model_path is not None + and len(best_model_path) > 0 + ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) atomic_save(self.on_save(self.lightning_module.state_dict()), last_path) @@ -224,14 +227,13 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): self.mp_queue.put(results) def __recover_child_process_weights(self, best_path, last_path): - # TODO: is there a better way than accessing callback through model -> trainer -> callback? # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path # todo, pass also best score # load last weights - if last_path is not None and not self.lightning_module.trainer.testing: + if last_path is not None and self.lightning_module.trainer.state == TrainerState.FITTING: ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 3f9eccce7073c..a481c0c2e206b 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -213,7 +213,7 @@ def init_deepspeed(self): precision = self.lightning_module.trainer.accelerator.precision model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) - if self.lightning_module.trainer.training: + if self.lightning_module.trainer and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) else: self._initialize_deepspeed_inference(model) @@ -249,8 +249,7 @@ def _initialize_deepspeed_train(self, model): ) # set optimizer for save/load, but deepspeed manages the specific optimizer logic - trainer = self.lightning_module.trainer - trainer.optimizers = [optimizer] + self.lightning_module.trainer.optimizers = [optimizer] self.model = model def _initialize_deepspeed_inference(self, model): diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 8fe52190fd7bb..2fe3906cb01d0 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -101,9 +101,9 @@ def start_training(self, trainer): # Make sure all workers have finished training before returning to the user hvd.join() - def start_testing(self, trainer): + def start_evaluating(self, trainer): with ExitStack(): - self._results = trainer.run_test() + self._results = trainer.run_evaluate() # Make sure all workers have finished training before returning to the user hvd.join() diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 329c82b5ed7f8..8fd75555ecd14 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -24,7 +24,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.distributed import LightningDistributedModule from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin -from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -208,7 +208,7 @@ def _skip_init_connections(self): Returns: Whether to skip initialization """ - return torch_distrib.is_initialized() and self.lightning_module.running_stage == RunningStage.TESTING + return torch_distrib.is_initialized() and self.lightning_module.trainer.state != TrainerState.FITTING def init_model_parallel_groups(self): num_model_parallel = 1 # TODO currently no support for vertical model parallel @@ -231,7 +231,7 @@ def _infer_check_num_gpus(self): return self.world_size def handle_transferred_pipe_module(self) -> None: - if not self.lightning_module.running_stage == RunningStage.TESTING: + if self.lightning_module.trainer.state == TrainerState.FITTING: torch_distrib.barrier() # Ensure we await main process initialization # Add trainer/configure_optimizers to the pipe model for access in all worker processes rpc_pipe.PipeModel.trainer = self.lightning_module.trainer @@ -243,7 +243,7 @@ def init_pipe_module(self) -> None: # Create pipe_module model = self.lightning_module self._find_and_init_pipe_module(model) - if not self.lightning_module.running_stage == RunningStage.TESTING: + if self.lightning_module.trainer.state == TrainerState.FITTING: torch_distrib.barrier() # Ensure we join main process initialization model.sequential_module.foreach_worker(register_optimizers, include_self=True) @@ -333,9 +333,9 @@ def start_training(self, trainer) -> None: if self.main_rpc_process: super().start_training(trainer) - def start_testing(self, trainer) -> None: + def start_evaluating(self, trainer) -> None: if self.main_rpc_process: - super().start_testing(trainer) + super().start_evaluating(trainer) class LightningPipeModule(nn.Module): diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 0f86a01128dc3..7536ef9b1d856 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -16,6 +16,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import is_lightning_optimizer from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only if _FAIRSCALE_AVAILABLE: @@ -48,8 +49,7 @@ def _reinit_optimizers_with_oss(self): trainer.convert_to_lightning_optimizers() def _wrap_optimizers(self): - trainer = self.model.trainer - if trainer.testing is True: + if self.model.trainer.state != TrainerState.FITTING: return self._reinit_optimizers_with_oss() diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 512b40fa9093b..7aadf797e160a 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -15,6 +15,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only if _FAIRSCALE_AVAILABLE: @@ -44,8 +45,7 @@ def _reinit_optimizers_with_oss(self): trainer.optimizers = optimizers def _wrap_optimizers(self): - trainer = self.model.trainer - if trainer.testing: + if self.model.trainer.state != TrainerState.FITTING: return self._reinit_optimizers_with_oss() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 9639a17e637bb..efada181ca9a6 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -23,6 +23,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -111,7 +112,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: trainer.save_checkpoint = self.save_checkpoint self.barrier() - results = trainer.train_or_test_or_predict() + results = trainer.run_stage() self.__save_end_of_training_weights(self.lightning_module) self.transfer_distrib_spawn_state_on_fit_end(results) @@ -130,7 +131,6 @@ def barrier(self, name: Optional[str] = None) -> None: rendezvous(f"pl.Trainer.{name}") def transfer_distrib_spawn_state_on_fit_end(self, results): - # TODO: is there a better way than accessing callback through model -> trainer -> callback? best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path if self.mp_queue is not None: @@ -138,8 +138,11 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # save the last weights last_path = None - # TODO: is there a better way than accessing trainer through model -> trainer? - if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: + if ( + self.lightning_module.trainer.state == TrainerState.FITTING + and best_model_path is not None + and len(best_model_path) > 0 + ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) self.save(self.lightning_module.state_dict(), last_path) @@ -241,7 +244,7 @@ def post_dispatch(self) -> None: # todo, pass also bets score # load last weights - if last_path and not self.lightning_module.trainer.testing: + if last_path and model.trainer.state == TrainerState.FITTING: ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt) @@ -254,8 +257,7 @@ def __load_weights_on_main_process(self) -> None: model = self.lightning_module # load weights if not interrupted - # TODO: check for trainer reference - if on_colab_kaggle() and not model.trainer.testing: + if on_colab_kaggle() and model.trainer.state == TrainerState.FITTING: self.load_spawn_weights(model) self._model = model @@ -279,7 +281,7 @@ def start_training(self, trainer) -> None: self._close_logger(trainer) xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) - def start_testing(self, trainer) -> None: + def start_evaluating(self, trainer) -> None: self._close_logger(trainer) xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index de799b394fe69..7783f066dbc61 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -121,9 +121,9 @@ def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop self._results = trainer.run_train() - def start_testing(self, trainer: 'Trainer') -> None: + def start_evaluating(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop - self._results = trainer.run_test() + self._results = trainer.run_evaluate() def start_predicting(self, trainer: 'Trainer') -> None: # double dispatch to initiate the predicting loop diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 9cb22f39b7228..1bf38048ee159 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -24,18 +24,16 @@ def __init__(self, trainer): def verify_loop_configurations(self, model: LightningModule): r""" - Checks that the model is configured correctly before training or testing is started. + Checks that the model is configured correctly before the run is started. Args: model: The model to check the configuration. """ - if not self.trainer.testing: + if self.trainer.training: self.__verify_train_loop_configuration(model) - self.__verify_eval_loop_configuration(model, 'validation') - else: - # check test loop configuration - self.__verify_eval_loop_configuration(model, 'test') + elif self.trainer.evaluating: + self.__verify_eval_loop_configuration(model) def __verify_train_loop_configuration(self, model): # ----------------------------------- @@ -83,18 +81,16 @@ def __verify_train_loop_configuration(self, model): ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' ) - def __verify_eval_loop_configuration(self, model, eval_loop_name): - step_name = f'{eval_loop_name}_step' + def __verify_eval_loop_configuration(self, model): + stage = "val" if self.trainer.validating else "test" - # map the dataloader name - loader_name = f'{eval_loop_name}_dataloader' - if eval_loop_name == 'validation': - loader_name = 'val_dataloader' + loader_name = f'{stage}_dataloader' + step_name = f'{stage}_step' has_loader = is_overridden(loader_name, model) has_step = is_overridden(step_name, model) if has_loader and not has_step: - rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop') + rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop') if has_step and not has_loader: - rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop') + rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop') diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 97484e5f473fd..9e08cf031175f 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -113,8 +113,7 @@ def attach_dataloaders( model.predict_dataloader = _PatchDataLoader(predict_dataloaders) def attach_datamodule(self, model, datamodule: Optional[LightningDataModule]) -> None: - - # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it + # We use datamodule if it's been provided, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) # If we have a datamodule, attach necessary hooks + dataloaders diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index a547144c8a6f3..223216846758f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -17,7 +17,7 @@ import torch from pytorch_lightning.core.step_result import Result -from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import DistributedType, LightningEnum @@ -222,9 +222,8 @@ class EpochResultStore: ``` """ - def __init__(self, trainer, stage): + def __init__(self, trainer) -> None: self.trainer = trainer - self._stage = stage self.reset() def __getitem__(self, key: str) -> Any: @@ -309,7 +308,6 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: callback_metrics = {} batch_pbar_metrics = {} batch_log_metrics = {} - is_train = self._stage in RunningStage.TRAINING if not self._has_batch_loop_finished: # get pbar @@ -317,8 +315,7 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: logger_connector.add_progress_bar_metrics(batch_pbar_metrics) batch_log_metrics = self.get_latest_batch_log_metrics() - if is_train: - # Only log and add to callback epoch step during evaluation, test. + if self.trainer.training: logger_connector._logged_metrics.update(batch_log_metrics) callback_metrics.update(batch_pbar_metrics) callback_metrics.update(batch_log_metrics) @@ -339,7 +336,9 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]: callback_metrics.update(epoch_log_metrics) callback_metrics.update(forked_metrics) - if not is_train and self.trainer.testing: + # TODO(carmocca): when we implement flushing the logger connector metrics after + # the trainer.state changes, this should check trainer.evaluating instead + if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): logger_connector.evaluation_callback_metrics.update(callback_metrics) # update callback_metrics @@ -484,4 +483,4 @@ def __call__( return result def __repr__(self): - return f"{self.__class__.__name__}(stage={self._stage}, internals={self._internals})" + return f"{self.__class__.__name__}(internals={self._internals})" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 45cdecfdc8515..2c6a0d613e648 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -24,7 +24,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder -from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.utilities import DeviceType, flatten_dict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -40,8 +40,8 @@ def __init__(self, trainer, log_gpu_memory: Optional[str] = None): self._logged_metrics = MetricsHolder() self._progress_bar_metrics = MetricsHolder(to_float=True) self.eval_loop_results = [] - self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in RunningStage} - self._cached_results[None] = EpochResultStore(trainer, None) + self._cached_results = {stage: EpochResultStore(trainer) for stage in RunningStage} + self._cached_results[None] = EpochResultStore(trainer) self._callback_hook_validator = CallbackHookNameValidator() @property @@ -287,7 +287,7 @@ def prepare_eval_loop_results(self): self.add_to_eval_loop_results(dl_idx, has_been_initialized) def get_evaluate_epoch_results(self): - if not self.trainer.running_sanity_check: + if not self.trainer.sanity_checking: # log all the metrics as a single dict metrics_to_log = self.cached_results.get_epoch_log_metrics() if len(metrics_to_log) > 0: @@ -295,11 +295,16 @@ def get_evaluate_epoch_results(self): self.prepare_eval_loop_results() - # log results of test - if self.trainer.testing and self.trainer.is_global_zero and self.trainer.verbose_test: + # log results of evaluation + if ( + self.trainer.state != TrainerState.FITTING + and self.trainer.evaluating + and self.trainer.is_global_zero + and self.trainer.verbose_evaluate + ): print('-' * 80) for result_idx, results in enumerate(self.eval_loop_results): - print(f'DATALOADER:{result_idx} TEST RESULTS') + print(f'DATALOADER:{result_idx} {self.trainer._running_stage.upper()} RESULTS') pprint({ k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v for k, v in results.items() @@ -330,7 +335,7 @@ def _track_callback_metrics(self, eval_results): flat['checkpoint_on'] = flat['val_loss'] flat['early_stop_on'] = flat['val_loss'] self.trainer.logger_connector.callback_metrics.update(flat) - if self.trainer.testing: + if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): self.trainer.logger_connector.evaluation_callback_metrics.update(flat) else: # with a scalar return, auto set it to "val_loss" for callbacks @@ -345,7 +350,7 @@ def _track_callback_metrics(self, eval_results): flat['early_stop_on'] = flat['val_loss'] self.trainer.logger_connector.callback_metrics.update(flat) - if self.trainer.testing: + if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): self.trainer.logger_connector.evaluation_callback_metrics.update(flat) def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics): @@ -363,14 +368,14 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric callback_metrics.update(log_metrics) callback_metrics.update(prog_bar_metrics) self.trainer.logger_connector.callback_metrics.update(callback_metrics) - if self.trainer.testing: + if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics) if len(dataloader_result_metrics) > 0: self.eval_loop_results.append(dataloader_result_metrics) def __process_eval_epoch_end_results_and_log_legacy(self, eval_results): - if self.trainer.running_sanity_check: + if self.trainer.sanity_checking: return if eval_results is not None and len(eval_results) > 0: diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 4a0c565d78be0..cdaab6248f006 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -35,5 +35,4 @@ def copy_trainer_model_properties(self, model): m._device_type = str(self.trainer._device_type) m._distrib_type = str(self.trainer._distrib_type) m.use_amp = self.trainer.amp_backend is not None - m.testing = self.trainer.testing m.precision = self.trainer.precision diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 1a1a992758dc8..a50603bb58dbf 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -51,7 +51,7 @@ def update_learning_rates(self, interval: str, monitor_metrics=None): ) if monitor_val is None: if lr_scheduler.get('strict', True): - avail_metrics = self.trainer.logger_connector.callback_metrics.keys() + avail_metrics = list(self.trainer.logger_connector.callback_metrics.keys()) raise MisconfigurationException( f'ReduceLROnPlateau conditioned on metric {monitor_key}' f' which is not available. Available metrics are: {avail_metrics}.' diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 06a3da750032c..95bd8b3f8cc44 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -372,8 +372,7 @@ def reset_test_dataloader(self, model) -> None: has_loader = is_overridden('test_dataloader', model) has_step = is_overridden('test_step', model) if has_loader and has_step: - self.num_test_batches, self.test_dataloaders =\ - self._reset_eval_dataloader(model, 'test') + self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader(model, 'test') def reset_predict_dataloader(self, model) -> None: """Resets the predict dataloader and determines the number of batches. @@ -383,8 +382,7 @@ def reset_predict_dataloader(self, model) -> None: """ has_loader = is_overridden('predict_dataloader', model) if has_loader: - self.num_predict_batches, self.predict_dataloaders =\ - self._reset_eval_dataloader(model, 'predict') + self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict') def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: """Handles downloading data in the GPU or TPU case. diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 46cfc545c889d..70db8b36814ca 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -14,7 +14,6 @@ from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector -from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn @@ -22,7 +21,6 @@ class DeprecatedDistDeviceAttributes: _distrib_type: DistributedType _device_type: DeviceType - _running_stage: RunningStage num_gpus: int accelerator_connector: AcceleratorConnector @@ -138,6 +136,7 @@ class DeprecatedTrainerAttributes: accelerator: Accelerator lightning_module = LightningModule + sanity_checking: bool @property def accelerator_backend(self) -> Accelerator: @@ -153,3 +152,11 @@ def get_model(self) -> LightningModule: " and will be removed in v1.4.", DeprecationWarning ) return self.lightning_module + + @property + def running_sanity_check(self) -> bool: + rank_zero_warn( + "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking`" + " and will be removed in v1.5.", DeprecationWarning + ) + return self.sanity_checking diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 24c8be9dc9b37..d5047ce57858a 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -32,21 +32,20 @@ def __init__(self, trainer): self.num_dataloaders = None def on_trainer_init(self): - self.trainer.num_val_batches = [] self.trainer.num_sanity_val_batches = [] self.trainer.num_test_batches = [] + self.trainer.num_val_batches = [] self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None - self.trainer.running_sanity_check = False - # when .test() is called, it sets this + # .validate() and .test() set this when they load a checkpoint + self.trainer.validated_ckpt_path = None self.trainer.tested_ckpt_path = None - # when true, prints test results - self.trainer.verbose_test = True + # when true, print evaluation results in .validate() and .test() + self.trainer.verbose_evaluate = True - def get_evaluation_dataloaders(self, max_batches): - # select dataloaders + def get_evaluation_dataloaders(self): model = self.trainer.lightning_module # select dataloaders @@ -54,20 +53,20 @@ def get_evaluation_dataloaders(self, max_batches): self.trainer.reset_test_dataloader(model) dataloaders = self.trainer.test_dataloaders - new_max_batches = self.trainer.num_test_batches + max_batches = self.trainer.num_test_batches else: # val - in_sanity_check = self.trainer.running_sanity_check - should_reload_every_epoch = self.trainer.reload_dataloaders_every_epoch - if (self.trainer.val_dataloaders is None or should_reload_every_epoch) and not in_sanity_check: + if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) - + if self.trainer.sanity_checking: + self.trainer.num_sanity_val_batches = [ + min(self.trainer.num_sanity_val_steps, val_batches) + for val_batches in self.trainer.num_val_batches + ] + max_batches = self.trainer.num_sanity_val_batches + else: + max_batches = self.trainer.num_val_batches dataloaders = self.trainer.val_dataloaders - new_max_batches = self.trainer.num_val_batches - - if max_batches is None: - max_batches = new_max_batches - return dataloaders, max_batches def should_skip_evaluation(self, max_batches): @@ -154,7 +153,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx): model_ref = self.trainer.lightning_module model_ref._results = Result() - if self.testing: + if self.trainer.testing: model_ref._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator.test_step(args) @@ -323,7 +322,7 @@ def on_evaluation_epoch_end(self, *args, **kwargs): self.trainer.call_hook('on_epoch_end') def log_evaluation_step_metrics(self, output, batch_idx): - if self.trainer.running_sanity_check: + if self.trainer.sanity_checking: return step_log_metrics = {} diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 6b801cc7f5dea..40507a1bc03f4 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -27,9 +27,8 @@ def on_trainer_init(self): self.trainer.num_predict_batches = [] def get_predict_dataloaders(self, max_batches): - # select dataloaders - model = self.trainer.lightning_module - self.trainer.reset_predict_dataloader(model) + self.trainer.reset_predict_dataloader(self.trainer.lightning_module) + dataloaders = self.trainer.predict_dataloaders if max_batches is None: max_batches = self.trainer.num_predict_batches diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index c061c6ef28d4c..8cbd53d93f37f 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -31,7 +31,7 @@ from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn from pytorch_lightning.utilities.argparse import ( add_argparse_args, @@ -48,6 +48,7 @@ class TrainerProperties(ABC): _default_root_dir: str _lightning_optimizers = None _progress_bar_callback: ProgressBarBase + _running_stage: Optional[RunningStage] = None _state: TrainerState _weights_save_path: str @@ -168,6 +169,14 @@ def progress_bar_metrics(self, x: dict) -> None: def state(self) -> TrainerState: return self._state + @state.setter + def state(self, state: TrainerState) -> None: + self._state = state + + @property + def interrupted(self) -> bool: + return self._state == TrainerState.INTERRUPTED + @property def is_global_zero(self) -> bool: return self.global_rank == 0 @@ -412,6 +421,76 @@ def distributed_sampler_kwargs(self) -> Optional[dict]: if isinstance(self.training_type_plugin, ParallelPlugin): return self.training_type_plugin.distributed_sampler_kwargs + @property + def training(self) -> bool: + return self._running_stage == RunningStage.TRAINING + + @training.setter + def training(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TRAINING + elif self.training: + self._running_stage = None + + @property + def testing(self) -> bool: + return self._running_stage == RunningStage.TESTING + + @testing.setter + def testing(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TESTING + elif self.testing: + self._running_stage = None + + @property + def predicting(self) -> bool: + return self._running_stage == RunningStage.PREDICTING + + @predicting.setter + def predicting(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.PREDICTING + elif self.predicting: + self._running_stage = None + + @property + def tuning(self) -> bool: + return self._running_stage == RunningStage.TUNING + + @tuning.setter + def tuning(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TUNING + elif self.tuning: + self._running_stage = None + + @property + def validating(self) -> bool: + return self._running_stage == RunningStage.VALIDATING + + @validating.setter + def validating(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.VALIDATING + elif self.validating: + self._running_stage = None + + @property + def evaluating(self) -> bool: + return self._running_stage and self._running_stage.evaluating + + @property + def sanity_checking(self) -> bool: + return self._running_stage == RunningStage.SANITY_CHECKING + + @sanity_checking.setter + def sanity_checking(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.SANITY_CHECKING + elif self.sanity_checking: + self._running_stage = None + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 1758cb41ee780..d0c2ded659f67 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -12,72 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import wraps -from typing import Callable, Optional - -import pytorch_lightning from pytorch_lightning.utilities import LightningEnum class TrainerState(LightningEnum): - """ State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer` - to indicate what is currently or was executed. + """ State for the :class:`~pytorch_lightning.trainer.trainer.Trainer` + to indicate what is currently or was executed. It follows the user-called + functions such as `trainer.fit()` and `trainer.test(). >>> # you can compare the type with a string - >>> TrainerState.RUNNING == 'RUNNING' + >>> TrainerState.FITTING == 'FITTING' True >>> # which is case insensitive >>> TrainerState.FINISHED == 'finished' True """ - INITIALIZING = 'INITIALIZING' - RUNNING = 'RUNNING' + INITIALIZING = 'INITIALIZING' # trainer creation + FITTING = 'FITTING' # trainer.fit() + VALIDATING = 'VALIDATING' # trainer.validate() + TESTING = 'TESTING' # trainer.test() + PREDICTING = 'PREDICTING' # trainer.predict() + TUNING = 'TUNING' # trainer.tune() FINISHED = 'FINISHED' INTERRUPTED = 'INTERRUPTED' + @property + def stopped(self) -> bool: + return self in (self.FINISHED, self.INTERRUPTED) + + @property + def running(self) -> bool: + return self in (self.FITTING, self.VALIDATING, self.TESTING, self.PREDICTING, self.TUNING) + class RunningStage(LightningEnum): - """Type of train phase. + """Current running stage. + + This stage complements :class:`TrainerState` for example to indicate that + `RunningStage.VALIDATING` will be set both during `TrainerState.FITTING` + and `TrainerState.VALIDATING`. It follows the internal code logic. >>> # you can match the Enum with string >>> RunningStage.TRAINING == 'train' True """ TRAINING = 'train' - EVALUATING = 'eval' + SANITY_CHECKING = 'sanity_check' + VALIDATING = 'validation' TESTING = 'test' PREDICTING = 'predict' TUNING = 'tune' - -def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable: - """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods - which changes state to `entering` before the function execution and `exiting` - after the function is executed. If `None` is passed to `entering`, the state is not changed. - If `None` is passed to `exiting`, the state is restored to the state before function execution. - If `INTERRUPTED` state is set inside a run function, the state remains `INTERRUPTED`. - """ - - def wrapper(fn) -> Callable: - - @wraps(fn) - def wrapped_fn(self, *args, **kwargs): - if not isinstance(self, pytorch_lightning.Trainer): - return fn(self, *args, **kwargs) - - state_before = self._state - if entering is not None: - self._state = entering - result = fn(self, *args, **kwargs) - - # The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted - # we retain INTERRUPTED state - if self._state == TrainerState.INTERRUPTED: - return result - - self._state = exiting if exiting is not None else state_before - return result - - return wrapped_fn - - return wrapper + @property + def evaluating(self) -> bool: + return self in (self.VALIDATING, self.TESTING) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 73e830f839dbe..cc1964f07039b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -16,6 +16,7 @@ import warnings from itertools import count from pathlib import Path +from traceback import print_exc from typing import Any, Dict, Iterable, List, Optional, Union import torch @@ -52,7 +53,7 @@ from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import RunningStage, TrainerState +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner @@ -290,7 +291,6 @@ def __init__( """ super().__init__() - self._running_stage = None distributed_backend = distributed_backend or accelerator @@ -419,13 +419,11 @@ def fit( If the model has a predefined val_dataloaders method this will be skipped """ - # bookkeeping - self._state = TrainerState.RUNNING - - # bookkeeping - # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified. + # we reuse fit for other functions. When already set, it shouldn't be modified. + if not self.state.running: + self.state = TrainerState.FITTING if self._running_stage is None: - self._running_stage = RunningStage.TRAINING + self.training = True # set local properties on the model self.model_connector.copy_trainer_model_properties(model) @@ -459,21 +457,17 @@ def fit( # | || # trainer.dispatch || LIGHTNING # | || - # start_training or start_testing or start_predicting call || FLOW + # start_training or start_evaluating or start_predicting call || FLOW # from `accelerator` || # | || DIRECTION - # run_train or run_test or run_predict call || + # run_train or run_evaluate or run_predict call || # from `trainer` || # | || # results \/ # This is used to guide readers to the core loops: train, test, predict. # `run_predict` is the simplest to understand, use `Go to Definition` to read it :) - # Search for `start_training` or `start_testing` or `start_predicting` in + # Search for `start_training` or `start_evaluating` or `start_predicting` in # `pytorch_lightning/plugins/training_type` folder to find accelerator dispatch functions. - self.accelerator.train_loop = self.run_train - self.accelerator.validation_loop = self.run_evaluation - self.accelerator.test_loop = self.run_evaluation - self.accelerator.predict_loop = self.run_predict # ---------------------------- # TRAIN @@ -484,7 +478,7 @@ def fit( # plugin will setup fitting (e.g. ddp will launch child processes) self.pre_dispatch() - # dispath `start_training` or `start_testing` or `start_predicting` + # dispatch `start_training` or `start_evaluating` or `start_predicting` self.dispatch() # plugin will finalized fitting (e.g. ddp_spawn will load trained model) @@ -501,13 +495,12 @@ def fit( if self.is_function_implemented('teardown'): model.teardown('fit') - # return 1 when finished - # used for testing or when we need to know that training succeeded - if self._state != TrainerState.INTERRUPTED: - self._state = TrainerState.FINISHED - + if self.state != TrainerState.INTERRUPTED: + self.state = TrainerState.FINISHED self._running_stage = None + # return 1 when finished + # used for testing or when we need to know that training succeeded return self.accelerator.results or 1 def pre_dispatch(self): @@ -518,25 +511,21 @@ def post_dispatch(self): self.accelerator.teardown() def dispatch(self): - if self.testing: - self.accelerator.start_testing(self) - + if self.evaluating: + self.accelerator.start_evaluating(self) elif self.predicting: self.accelerator.start_predicting(self) - else: self.accelerator.start_training(self) - def train_or_test_or_predict(self): - if self.testing: - results = self.run_test() - + def run_stage(self): + results = None + if self.evaluating: + results = self.run_evaluate() elif self.predicting: results = self.run_predict() - else: - results = self.run_train() - + self.run_train() return results def _pre_training_routine(self): @@ -571,7 +560,7 @@ def _pre_training_routine(self): if self.is_function_implemented("on_pretrain_routine_end"): ref_model.on_pretrain_routine_end() - def run_train(self): + def run_train(self) -> None: self._pre_training_routine() @@ -580,9 +569,6 @@ def run_train(self): self.run_sanity_check(self.lightning_module) - # set stage for logging - self._running_stage = RunningStage.TRAINING - self.checkpoint_connector.has_trained = False # enable train mode @@ -632,27 +618,32 @@ def run_train(self): except KeyboardInterrupt: rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') - - # user could press ctrl+c many times... only shutdown once + # user could press Ctrl+c many times... only shutdown once if not self.interrupted: - self.interrupted = True - self._state = TrainerState.INTERRUPTED + self.state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() + except (RuntimeError, AssertionError): + # if an exception is raised, the finally block is executed and can hide the actual exception + # that was initially raised if `on_train_end` also raises an exception. we want to avoid that + # for assertions and other runtime errors so we aren't misled while debugging + print_exc() finally: # hook self.train_loop.on_train_end() - def run_evaluation(self, max_batches=None, on_epoch=False): + def run_evaluation(self, on_epoch=False): + if not (self.evaluating or self.sanity_checking): + rank_zero_warn( + f"`trainer.run_evaluation()` was called but the running stage is set to {self._running_stage}." + " This should not happen normally. Setting it to `RunningStage.VALIDATING`", RuntimeWarning + ) + self.validating = True - # used to know if we are logging for val, test + reset cached results - self._running_stage = RunningStage.TESTING if self.testing else RunningStage.EVALUATING + # reset cached results self.logger_connector.reset() - # bookkeeping - self.evaluation_loop.testing = self.testing - # prepare dataloaders - dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) + dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() # check if we want to skip this evaluation if self.evaluation_loop.should_skip_evaluation(max_batches): @@ -748,13 +739,13 @@ def track_output_for_epoch_end(self, outputs, output): outputs.append(output) return outputs - def run_test(self): + def run_evaluate(self): if not self.is_global_zero and self.progress_bar_callback is not None: self.progress_bar_callback.disable() - # only load test dataloader for testing - # self.reset_test_dataloader(ref_model) - with self.profiler.profile("run_test_evaluation"): + assert self.evaluating + + with self.profiler.profile(f"run_{self._running_stage}_evaluation"): eval_loop_results, _ = self.run_evaluation() if len(eval_loop_results) == 0: @@ -815,17 +806,14 @@ def run_sanity_check(self, ref_model): # run tiny validation (if validation defined) # to make sure program won't crash during val if should_sanity_check: - self.reset_val_dataloader(ref_model) - self.num_sanity_val_batches = [ - min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches - ] + stage = self._running_stage + self.sanity_checking = True # hook and callback - self.running_sanity_check = True self.on_sanity_check_start() # run eval step - _, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches) + _, eval_results = self.run_evaluation() # allow no returns from eval if eval_results is not None and len(eval_results) > 0: @@ -837,7 +825,8 @@ def run_sanity_check(self, ref_model): self.logger_connector.callback_metrics = callback_metrics self.on_sanity_check_end() - self.running_sanity_check = False + + self._running_stage = stage def test( self, @@ -848,21 +837,20 @@ def test( datamodule: Optional[LightningDataModule] = None, ): r""" - - Separates from fit to make sure you never run on your test set until you want to. + Perform one evaluation epoch over the test set. It's separated from + fit to make sure you never run on your test set until you want to. Args: ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the weights from the last epoch to test. Default to ``best``. - + If ``None``, use the current weights of the model. Default to ``best``. datamodule: A instance of :class:`LightningDataModule`. model: The model to test. - test_dataloaders: Either a single - Pytorch Dataloader or a list of them, specifying validation samples. + test_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying test samples. - verbose: If True, prints the test results + verbose: If True, prints the test results. Returns: Returns a list of dictionaries, one for each test dataloader containing their respective metrics. @@ -870,35 +858,45 @@ def test( # -------------------- # SETUP HOOK # -------------------- - self.verbose_test = verbose + self.verbose_evaluate = verbose - self._running_stage = RunningStage.TESTING + self.state = TrainerState.TESTING + self.testing = True - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' ) - # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model or self.lightning_module, datamodule) + model_provided = model is not None + model = model or self.lightning_module - if model is not None: - results = self.__test_given_model(model, test_dataloaders) - else: - results = self.__test_using_best_weights(ckpt_path, test_dataloaders) + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model, datamodule) + results = ( + self.__evaluate_given_model(model, dataloaders=test_dataloaders) + if model_provided else + self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders) + ) self.teardown('test') - self._running_stage = None - return results - def __test_using_best_weights(self, ckpt_path, test_dataloaders): - model = self.lightning_module + assert self.state.stopped + self.testing = False + return results + + def __evaluate_using_weights( + self, + model, + ckpt_path: Optional[str] = None, + dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None + ): # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: raise MisconfigurationException( - 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.' + 'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.' ) # load best weights @@ -909,8 +907,8 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): if len(ckpt_path) == 0: rank_zero_warn( - f'.test() found no path for the best weights, {ckpt_path}. Please ' - f'specify a path for a checkpoint .test(ckpt_path=PATH)' + f'`.test()` found no path for the best weights, {ckpt_path}. Please' + ' specify a path for a checkpoint `.test(ckpt_path=PATH)`' ) return {} @@ -920,32 +918,34 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): model.load_state_dict(ckpt['state_dict']) # attach dataloaders - if test_dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + if dataloaders is not None: + self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) + + if self.validating: + self.validated_ckpt_path = ckpt_path + else: + self.tested_ckpt_path = ckpt_path - # run tests - self.tested_ckpt_path = ckpt_path + # run test results = self.fit(model) # teardown - if self.is_function_implemented('teardown'): - model_ref = self.lightning_module - model_ref.teardown('test') + if self.is_function_implemented('teardown', model=model): + model.teardown('test') return results - def __test_given_model(self, model, test_dataloaders): - + def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None): # attach data - if test_dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + if dataloaders is not None: + self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) # run test # sets up testing so we short circuit to eval results = self.fit(model) # teardown - if self.is_function_implemented('teardown'): + if self.is_function_implemented('teardown', model=model): model.teardown('test') return results @@ -981,7 +981,8 @@ def predict( model = model or self.lightning_module - self._running_stage = RunningStage.PREDICTING + self.state = TrainerState.PREDICTING + self.predicting = True if dataloaders and datamodule: raise MisconfigurationException( @@ -998,7 +999,9 @@ def predict( self.model = model results = self.fit(model) - self._running_stage = None + + assert self.state.stopped + self.predicting = False return results @@ -1024,15 +1027,23 @@ def tune( If the model has a predefined val_dataloaders method this will be skipped """ + self.state = TrainerState.TUNING + self.tuning = True + self.tuner.tune(model, train_dataloader, val_dataloaders, datamodule) + assert self.state.stopped + self.tuning = False + def call_setup_hook(self, model): # call setup after the ddp process has connected - stage_name = 'test' if self.testing else 'fit' + stage_name = 'test' if self.evaluating else 'fit' + if self.datamodule is not None: - called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit + called = getattr(self.datamodule, f'has_setup_{stage_name}') if not called: self.datamodule.setup(stage_name) + self.setup(model, stage_name) model.setup(stage_name) @@ -1081,58 +1092,3 @@ def call_hook(self, hook_name, *args, **kwargs): if not skip: self._cache_logged_metrics() return output - - @property - def training(self) -> bool: - return self._running_stage == RunningStage.TRAINING - - @training.setter - def training(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.TRAINING - elif self.training: - self._running_stage = None - - @property - def testing(self) -> bool: - return self._running_stage == RunningStage.TESTING - - @testing.setter - def testing(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.TESTING - elif self.testing: - self._running_stage = None - - @property - def predicting(self) -> bool: - return self._running_stage == RunningStage.PREDICTING - - @predicting.setter - def predicting(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.PREDICTING - elif self.predicting: - self._running_stage = None - - @property - def tuning(self) -> bool: - return self._running_stage == RunningStage.TUNING - - @tuning.setter - def tuning(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.TUNING - elif self.tuning: - self._running_stage = None - - @property - def evaluating(self) -> bool: - return self._running_stage == RunningStage.EVALUATING - - @evaluating.setter - def evaluating(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.EVALUATING - elif self.evaluating: - self._running_stage = None diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 97814bb912fbe..57ae9557139e0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -23,7 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import ParallelPlugin -from pytorch_lightning.trainer.states import RunningStage, TrainerState +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing from pytorch_lightning.utilities.distributed import rank_zero_info @@ -62,7 +62,6 @@ def on_trainer_init( ): self.trainer.global_step = 0 self.trainer.current_epoch = 0 - self.trainer.interrupted = False self.trainer.should_stop = False self.trainer._state = TrainerState.INITIALIZING @@ -123,7 +122,6 @@ def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): def on_train_end(self): if self._teardown_already_run: return - self._teardown_already_run = True # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates @@ -148,6 +146,9 @@ def on_train_end(self): # give accelerators a chance to finish self.trainer.accelerator.on_train_end() + # reset bookkeeping + self.trainer._running_stage = None + def check_checkpoint_callback(self, should_update, is_last=False): # TODO bake this logic into the ModelCheckpoint callback if should_update and self.trainer.checkpoint_connector.has_trained: @@ -477,7 +478,6 @@ def run_training_epoch(self): train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 - should_check_val = False val_loop_called = False for batch_idx, (batch, is_last_batch) in train_dataloader: @@ -513,12 +513,11 @@ def run_training_epoch(self): # ----------------------------------------- should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) if should_check_val: + self.trainer.validating = True self.trainer.run_evaluation() + self.trainer.training = True val_loop_called = True - # reset stage to train - self.trainer._running_stage = RunningStage.TRAINING - # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- @@ -572,10 +571,9 @@ def run_training_epoch(self): self.check_early_stopping_callback(True) if should_check_val: + self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) - - # reset stage to train - self.trainer._running_stage = RunningStage.TRAINING + self.trainer.training = True # increment the global step once # progress global step according to grads progress diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 06475547b03f2..78810141b1369 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -18,6 +18,7 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size from pytorch_lightning.tuner.lr_finder import lr_find @@ -55,6 +56,8 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): if self.trainer.auto_lr_find: self.lr_find(model, update_attr=True) + self.trainer.state = TrainerState.FINISHED + def scale_batch_size( self, model, diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 65cf4472d156c..56833fd03735a 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -139,7 +139,7 @@ def track_lr_schedulers_update( @enabled_only def track_eval_loss_history(self, batch_idx, dataloader_idx, output): loss_dict = { - 'sanity_check': self.trainer.running_sanity_check, + 'sanity_check': self.trainer.sanity_checking, 'dataloader_idx': dataloader_idx, 'batch_idx': batch_idx, 'epoch': self.trainer.current_epoch, diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 8d01841f3636c..8a25ecc9f983b 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -19,7 +19,7 @@ @mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_trainer_callback_system(torch_save, tmpdir): +def test_trainer_callback_system(_, tmpdir): """Test the callback system.""" model = BoringModel() diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index f16d8afd9cffd..e4171a8520353 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -218,31 +218,32 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal assert progress_bar.test_batches_seen == progress_bar.total_test_batches -@pytest.mark.parametrize(['limit_val_batches', 'expected'], [ - pytest.param(0, 0), - pytest.param(5, 7), -]) -def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches, expected): +@pytest.mark.parametrize('limit_val_batches', (0, 5)) +def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches): """ Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument. """ class CurrentProgressBar(ProgressBar): + val_pbar_total = 0 + sanity_pbar_total = 0 - def __init__(self): - super().__init__() - self.val_progress_bar_total = 0 + def on_sanity_check_end(self, *args): + self.sanity_pbar_total = self.val_progress_bar.total + super().on_sanity_check_end(*args) - def on_validation_epoch_end(self, trainer, pl_module): - self.val_progress_bar_total += trainer.progress_bar_callback.val_progress_bar.total + def on_validation_epoch_end(self, *args): + self.val_pbar_total = self.val_progress_bar.total + super().on_validation_epoch_end(*args) model = BoringModel() progress_bar = CurrentProgressBar() + num_sanity_val_steps = 2 trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - num_sanity_val_steps=2, + num_sanity_val_steps=num_sanity_val_steps, limit_train_batches=1, limit_val_batches=limit_val_batches, callbacks=[progress_bar], @@ -250,7 +251,9 @@ def on_validation_epoch_end(self, trainer, pl_module): checkpoint_callback=False, ) trainer.fit(model) - assert trainer.progress_bar_callback.val_progress_bar_total == expected + + assert progress_bar.sanity_pbar_total == min(num_sanity_val_steps, limit_val_batches) + assert progress_bar.val_pbar_total == limit_val_batches def test_progress_bar_default_value(tmpdir): @@ -426,7 +429,7 @@ def test_progress_bar_print(tqdm_write, tmpdir): @mock.patch('builtins.print') @mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): - """ Test that printing in LightningModule goes through built-in print functin when progress bar is disabled. """ + """ Test that printing in LightningModule goes through built-in print function when progress bar is disabled. """ model = PrintModel() bar = ProgressBar() trainer = Trainer( diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index cb1d461414603..6385e02af33a6 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -68,3 +68,9 @@ def on_save_checkpoint(self, *args): trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()] with no_warning_call(DeprecationWarning): trainer.save_checkpoint(filepath) + + +def test_v1_5_0_running_sanity_check(): + trainer = Trainer() + with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): + assert not trainer.running_sanity_check diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index f35fccc8735d6..7a43b2d0832f9 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -450,7 +450,7 @@ def on_train_start(self): # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model - new_trainer._running_stage = RunningStage.EVALUATING + new_trainer._running_stage = RunningStage.VALIDATING dataloader = self.train_dataloader() tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 08da907d6f16f..5008ec798f7ec 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -20,7 +20,13 @@ LightningParallelModule, LightningDistributedModule, ]) -def test_lightning_wrapper_module_methods(wrapper_class): +@pytest.mark.parametrize("stage", [ + ("training", "training_step"), + ("testing", "test_step"), + ("validating", "validation_step"), + ("predicting", "predict"), +]) +def test_lightning_wrapper_module_methods(wrapper_class, stage): """ Test that the LightningWrapper redirects .forward() to the LightningModule methods. """ pl_module = MagicMock() wrapped_module = wrapper_class(pl_module) @@ -28,52 +34,62 @@ def test_lightning_wrapper_module_methods(wrapper_class): batch = torch.rand(5) batch_idx = 3 - pl_module.running_stage = RunningStage.TRAINING - wrapped_module(batch, batch_idx) - pl_module.training_step.assert_called_with(batch, batch_idx) + prop, step = stage + pl_module.trainer.sanity_checking = False + for p in ("training", "testing", "validating", "predicting"): + setattr(pl_module.trainer, p, p == prop) - pl_module.running_stage = RunningStage.TESTING wrapped_module(batch, batch_idx) - pl_module.test_step.assert_called_with(batch, batch_idx) - pl_module.running_stage = RunningStage.EVALUATING - wrapped_module(batch, batch_idx) - pl_module.validation_step.assert_called_with(batch, batch_idx) - - pl_module.running_stage = RunningStage.PREDICTING - wrapped_module(batch) - pl_module.predict.assert_called_with(batch) + getattr(pl_module, step).assert_called_with(batch, batch_idx) @pytest.mark.parametrize("wrapper_class", [ LightningParallelModule, LightningDistributedModule, ]) -def test_lightning_wrapper_module_warn_none_output(wrapper_class): +@pytest.mark.parametrize("stage", [ + ("training", "training_step"), + ("testing", "test_step"), + ("validating", "validation_step"), +]) +def test_lightning_wrapper_module_warn_none_output(wrapper_class, stage): """ Test that the LightningWrapper module warns about forgotten return statement. """ warning_cache.clear() pl_module = MagicMock() + + prop, step = stage + pl_module.trainer.sanity_checking = False + for p in ("training", "testing", "validating", "predicting"): + setattr(pl_module.trainer, p, p == prop) + wrapped_module = wrapper_class(pl_module) - pl_module.training_step.return_value = None - pl_module.validation_step.return_value = None - pl_module.test_step.return_value = None + getattr(pl_module, step).return_value = None - with pytest.warns(UserWarning, match="Your training_step returned None"): - pl_module.running_stage = RunningStage.TRAINING + with pytest.warns(UserWarning, match=f"Your {step} returned None"): wrapped_module() - with pytest.warns(UserWarning, match="Your test_step returned None"): - pl_module.running_stage = RunningStage.TESTING - wrapped_module() - with pytest.warns(UserWarning, match="Your validation_step returned None"): - pl_module.running_stage = RunningStage.EVALUATING - wrapped_module() +@pytest.mark.parametrize("wrapper_class", [ + LightningParallelModule, + LightningDistributedModule, +]) +def test_lightning_wrapper_module_no_warn(wrapper_class): + warning_cache.clear() + pl_module = MagicMock() + + pl_module.trainer.sanity_checking = False + pl_module.trainer.training = False + pl_module.trainer.testing = False + pl_module.trainer.validating = False + pl_module.trainer.predicting = False + + wrapped_module = wrapper_class(pl_module) with pytest.warns(None) as record: - pl_module.running_stage = None wrapped_module() + pl_module.assert_called() assert not record diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 221951e788284..09c5b58d363d9 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -20,8 +20,8 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, - auto_scale_batch_size=True if tuner_alg == 'batch size scaler' else False, - auto_lr_find=True if tuner_alg == 'learning rate finder' else False, + auto_scale_batch_size=(tuner_alg == 'batch size scaler'), + auto_lr_find=(tuner_alg == 'learning rate finder'), fast_dev_run=True ) expected_message = f'Skipping {tuner_alg} since fast_dev_run is enabled.' diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index af9ae06e8445b..7f3e9f6287cd8 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -18,7 +18,8 @@ @pytest.mark.parametrize('max_epochs', [1, 2, 3]) -def test_val_check_interval_1(tmpdir, max_epochs): +@pytest.mark.parametrize('denominator', [1, 3, 4]) +def test_val_check_interval(tmpdir, max_epochs, denominator): class TestModel(BoringModel): @@ -31,71 +32,16 @@ def on_train_epoch_start(self) -> None: self.train_epoch_calls += 1 def on_validation_epoch_start(self) -> None: - if not self.trainer.running_sanity_check: + if not self.trainer.sanity_checking: self.val_epoch_calls += 1 model = TestModel() trainer = Trainer( max_epochs=max_epochs, - val_check_interval=1.0, + val_check_interval=1 / denominator, logger=False, ) trainer.fit(model) - assert model.val_epoch_calls == max_epochs - - -@pytest.mark.parametrize('max_epochs', [1, 2, 3]) -def test_val_check_interval_quarter(tmpdir, max_epochs): - - class TestModel(BoringModel): - - def __init__(self): - super().__init__() - self.train_epoch_calls = 0 - self.val_epoch_calls = 0 - - def on_train_epoch_start(self) -> None: - self.train_epoch_calls += 1 - - def on_validation_epoch_start(self) -> None: - if not self.trainer.running_sanity_check: - self.val_epoch_calls += 1 - - model = TestModel() - trainer = Trainer( - max_epochs=max_epochs, - val_check_interval=0.25, - logger=False, - ) - trainer.fit(model) - - assert model.val_epoch_calls == max_epochs * 4 - - -@pytest.mark.parametrize('max_epochs', [1, 2, 3]) -def test_val_check_interval_third(tmpdir, max_epochs): - - class TestModel(BoringModel): - - def __init__(self): - super().__init__() - self.train_epoch_calls = 0 - self.val_epoch_calls = 0 - - def on_train_epoch_start(self) -> None: - self.train_epoch_calls += 1 - - def on_validation_epoch_start(self) -> None: - if not self.trainer.running_sanity_check: - self.val_epoch_calls += 1 - - model = TestModel() - trainer = Trainer( - max_epochs=max_epochs, - val_check_interval=0.33, - logger=False, - ) - trainer.fit(model) - - assert model.val_epoch_calls == max_epochs * 3 + assert model.train_epoch_calls == max_epochs + assert model.val_epoch_calls == max_epochs * denominator diff --git a/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py index 87cab653de6aa..2aac7354c38f6 100644 --- a/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py @@ -14,6 +14,8 @@ """ Tests to ensure that the training loop works with a dict """ +import pytest + from pytorch_lightning import Trainer from pytorch_lightning.core.lightning import LightningModule from tests.helpers.deterministic_model import DeterministicModel @@ -44,7 +46,8 @@ def backward(self, loss, optimizer, optimizer_idx): # out are the results of the full loop # eval_results are output of _evaluate - out, eval_results = trainer.run_evaluation() + with pytest.warns(RuntimeWarning, match="the running stage is set to None"): + out, eval_results = trainer.run_evaluation() assert len(out) == 1 assert len(eval_results) == 0 diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 69bd3411570c3..5530779b4f77d 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -600,24 +600,20 @@ def test_error_on_zero_len_dataloader(tmpdir): @RunIf(skip_windows=True) -@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +@pytest.mark.parametrize('ckpt_path', (None, 'best', 'specific')) +@pytest.mark.parametrize('stage', ('train', 'test', 'val')) @patch('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count', return_value=4) -def test_warning_with_few_workers(mock, tmpdir, ckpt_path): +def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): """ Test that error is raised if dataloader with only a few workers is used """ - model = EvalModelTemplate() + model = BoringModel() - # logger file to get meta - train_dl = model.dataloader(train=True) + train_dl = model.train_dataloader() train_dl.num_workers = 0 - val_dl = model.dataloader(train=False) + val_dl = model.val_dataloader() val_dl.num_workers = 0 - train_dl = model.dataloader(train=False) - train_dl.num_workers = 0 - - fit_options = dict(train_dataloader=train_dl, val_dataloaders=val_dl) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -625,30 +621,22 @@ def test_warning_with_few_workers(mock, tmpdir, ckpt_path): limit_train_batches=0.2, ) - # fit model with pytest.warns( - UserWarning, match='The dataloader, train dataloader, does not have many workers which may be a bottleneck.' + UserWarning, + match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' ): - trainer.fit(model, **fit_options) - - with pytest.warns( - UserWarning, match='The dataloader, val dataloader 0, does not have many workers which may be a bottleneck.' - ): - trainer.fit(model, **fit_options) - - if ckpt_path == 'specific': - ckpt_path = trainer.checkpoint_callback.best_model_path - test_options = dict(test_dataloaders=train_dl, ckpt_path=ckpt_path) - with pytest.warns( - UserWarning, match='The dataloader, test dataloader 0, does not have many workers which may be a bottleneck.' - ): - trainer.test(**test_options) + if stage == 'test': + ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path + trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path) + else: + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) @RunIf(skip_windows=True) -@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +@pytest.mark.parametrize('ckpt_path', (None, 'best', 'specific')) +@pytest.mark.parametrize('stage', ('train', 'test', 'val')) @patch('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count', return_value=4) -def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path): +def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): """ Test that error is raised if dataloader with only a few workers is used """ model = EvalModelTemplate() @@ -658,10 +646,6 @@ def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path): model.test_step = model.test_step__multiple_dataloaders model.test_epoch_end = model.test_epoch_end__multiple_dataloaders - # logger file to get meta - train_dl = model.dataloader(train=True) - train_dl.num_workers = 0 - val_dl = model.dataloader(train=False) val_dl.num_workers = 0 @@ -672,7 +656,6 @@ def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path): val_multi_dl = [val_dl, val_dl] test_multi_dl = [train_dl, train_dl] - fit_options = dict(train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -680,24 +663,15 @@ def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path): limit_train_batches=0.2, ) - # fit model - with pytest.warns( - UserWarning, match='The dataloader, train dataloader, does not have many workers which may be a bottleneck.' - ): - trainer.fit(model, **fit_options) - - with pytest.warns( - UserWarning, match='The dataloader, val dataloader 0, does not have many workers which may be a bottleneck.' - ): - trainer.fit(model, **fit_options) - - if ckpt_path == 'specific': - ckpt_path = trainer.checkpoint_callback.best_model_path - test_options = dict(test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) with pytest.warns( - UserWarning, match='The dataloader, test dataloader 0, does not have many workers which may be a bottleneck.' + UserWarning, + match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' ): - trainer.test(**test_options) + if stage == 'test': + ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path + trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) + else: + trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) def test_warning_with_iterable_dataset_and_len(tmpdir): diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index 4e067fe22feb6..bedaef6d1ffb8 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -14,99 +14,8 @@ import pytest from pytorch_lightning import Callback, Trainer -from pytorch_lightning.trainer.states import trainer_state, TrainerState -from tests.base import EvalModelTemplate - - -class StateSnapshotCallback(Callback): - """ Allows to shapshot the state inside a particular trainer method. """ - - def __init__(self, snapshot_method: str): - super().__init__() - assert snapshot_method in ['on_batch_start', 'on_test_batch_start'] - self.snapshot_method = snapshot_method - self.trainer_state = None - - def on_batch_start(self, trainer, pl_module): - if self.snapshot_method == 'on_batch_start': - self.trainer_state = trainer.state - - def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - if self.snapshot_method == 'on_test_batch_start': - self.trainer_state = trainer.state - - -def test_state_decorator_nothing_passed(tmpdir): - """ Test that state is not changed if nothing is passed to a decorator""" - - @trainer_state() - def test_method(self): - return self.state - - trainer = Trainer(default_root_dir=tmpdir) - - snapshot_state = test_method(trainer) - - assert snapshot_state == TrainerState.INITIALIZING - assert trainer.state == TrainerState.INITIALIZING - - -def test_state_decorator_entering_only(tmpdir): - """ Tests that state is set to entering inside a run function and restored to the previous value after. """ - - @trainer_state(entering=TrainerState.RUNNING) - def test_method(self): - return self.state - - trainer = Trainer(default_root_dir=tmpdir) - - snapshot_state = test_method(trainer) - - assert snapshot_state == TrainerState.RUNNING - assert trainer.state == TrainerState.INITIALIZING - - -def test_state_decorator_exiting_only(tmpdir): - """ Tests that state is not changed inside a run function and set to `exiting` after. """ - - @trainer_state(exiting=TrainerState.FINISHED) - def test_method(self): - return self.state - - trainer = Trainer(default_root_dir=tmpdir) - - snapshot_state = test_method(trainer) - - assert snapshot_state == TrainerState.INITIALIZING - assert trainer.state == TrainerState.FINISHED - - -def test_state_decorator_entering_and_exiting(tmpdir): - """ Tests that state is set to `entering` inside a run function and set ot `exiting` after. """ - - @trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED) - def test_method(self): - return self.state - - trainer = Trainer(default_root_dir=tmpdir) - - snapshot_state = test_method(trainer) - - assert snapshot_state == TrainerState.RUNNING - assert trainer.state == TrainerState.FINISHED - - -def test_state_decorator_interrupt(tmpdir): - """ Tests that state remains `INTERRUPTED` is its set in run function. """ - - @trainer_state(exiting=TrainerState.FINISHED) - def test_method(self): - self._state = TrainerState.INTERRUPTED - - trainer = Trainer(default_root_dir=tmpdir) - - test_method(trainer) - assert trainer.state == TrainerState.INTERRUPTED +from pytorch_lightning.trainer.states import TrainerState +from tests.helpers import BoringModel def test_initialize_state(tmpdir): @@ -121,71 +30,41 @@ def test_initialize_state(tmpdir): pytest.param(dict(max_steps=1), id='Single-Step'), ] ) -def test_running_state_during_fit(tmpdir, extra_params): - """ Tests that state is set to RUNNING during fit """ +def test_trainer_state_while_running(tmpdir, extra_params): + trainer = Trainer(default_root_dir=tmpdir, **extra_params, auto_lr_find=True) - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - - snapshot_callback = StateSnapshotCallback(snapshot_method='on_batch_start') + class TestModel(BoringModel): + def __init__(self, expected_state): + super().__init__() + self.expected_state = expected_state + self.lr = 0.1 - trainer = Trainer(callbacks=[snapshot_callback], default_root_dir=tmpdir, **extra_params) + def on_batch_start(self, *_): + assert self.trainer.state == self.expected_state - trainer.fit(model) + def on_train_batch_start(self, *_): + assert self.trainer.training - assert snapshot_callback.trainer_state == TrainerState.RUNNING + def on_sanity_check_start(self, *_): + assert self.trainer.sanity_checking + def on_validation_batch_start(self, *_): + assert self.trainer.validating or self.trainer.sanity_checking -@pytest.mark.parametrize( - "extra_params", [ - pytest.param(dict(fast_dev_run=True), id='Fast-Run'), - pytest.param(dict(max_steps=1), id='Single-Step'), - ] -) -def test_finished_state_after_fit(tmpdir, extra_params): - """ Tests that state is FINISHED after fit """ - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) + def on_test_batch_start(self, *_): + assert self.trainer.testing - trainer = Trainer(default_root_dir=tmpdir, **extra_params) + model = TestModel(TrainerState.TUNING) + trainer.tune(model) + assert trainer.state == TrainerState.FINISHED + model = TestModel(TrainerState.FITTING) trainer.fit(model) + assert trainer.state == TrainerState.FINISHED - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - - -def test_running_state_during_test(tmpdir): - """ Tests that state is set to RUNNING during test """ - - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - - snapshot_callback = StateSnapshotCallback(snapshot_method='on_test_batch_start') - - trainer = Trainer( - callbacks=[snapshot_callback], - default_root_dir=tmpdir, - fast_dev_run=True, - ) - - trainer.test(model) - - assert snapshot_callback.trainer_state == TrainerState.RUNNING - - -def test_finished_state_after_test(tmpdir): - """ Tests that state is FINISHED after fit """ - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - ) - + model = TestModel(TrainerState.TESTING) trainer.test(model) - - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED @pytest.mark.parametrize( @@ -196,19 +75,13 @@ def test_finished_state_after_test(tmpdir): ) def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params): """ Tests that state is set to INTERRUPTED on KeyboardInterrupt """ - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) + model = BoringModel() class InterruptCallback(Callback): - - def __init__(self): - super().__init__() - def on_batch_start(self, trainer, pl_module): raise KeyboardInterrupt trainer = Trainer(callbacks=[InterruptCallback()], default_root_dir=tmpdir, **extra_params) trainer.fit(model) - assert trainer.state == TrainerState.INTERRUPTED diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1cd979c863d37..3e090fb44943e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -352,10 +352,11 @@ def mock_save_function(filepath, *args): monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last, - verbose=1 + verbose=True ) checkpoint_callback.save_function = mock_save_function trainer = Trainer() + trainer.state = TrainerState.FITTING # emulate callback's calls during the training for i, loss in enumerate(losses): @@ -600,7 +601,7 @@ def test_benchmark_option(tmpdir): @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) -def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): +def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k): hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams)