From 395ce2a59e27f1312d95d4970faf2e16380db0f2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 27 Apr 2021 14:12:56 +0200 Subject: [PATCH 01/30] `_fit_impl` refactor and types --- .../training_type/training_type_plugin.py | 18 +-- .../logger_connector/logger_connector.py | 3 +- pytorch_lightning/trainer/predict_loop.py | 4 +- pytorch_lightning/trainer/trainer.py | 113 +++++++++--------- pytorch_lightning/tuner/batch_size_scaling.py | 4 +- pytorch_lightning/tuner/lr_finder.py | 2 +- pytorch_lightning/tuner/tuning.py | 9 +- pytorch_lightning/utilities/types.py | 2 + 8 files changed, 82 insertions(+), 73 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 15be889c85e3e..36a62f7421b3c 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -26,6 +26,7 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT TBroadcast = TypeVar("T") @@ -37,7 +38,7 @@ class TrainingTypePlugin(Plugin, ABC): def __init__(self) -> None: self._model = None - self._results = None + self.results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None self._call_configure_sharded_model_hook = True def connect(self, model: Module) -> None: @@ -123,30 +124,21 @@ def lightning_module(self) -> 'pl.LightningModule': """Returns the pure LightningModule without potential wrappers""" return unwrap_lightning_module(self._model) - @property - def results(self) -> Any: - """ - The results of the last training/testing run will be cached here. - In distributed training, we make sure to transfer the results to the appropriate master process. - """ - # TODO: improve these docs - return self._results - @property def rpc_enabled(self) -> bool: return False def start_training(self, trainer: 'pl.Trainer') -> None: # double dispatch to initiate the training loop - self._results = trainer.run_stage() + self.results = trainer.run_stage() def start_evaluating(self, trainer: 'pl.Trainer') -> None: # double dispatch to initiate the test loop - self._results = trainer.run_stage() + self.results = trainer.run_stage() def start_predicting(self, trainer: 'pl.Trainer') -> None: # double dispatch to initiate the predicting loop - self._results = trainer.run_stage() + self.results = trainer.run_stage() def training_step(self, *args, **kwargs): return self.lightning_module.training_step(*args, **kwargs) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 796b381e95223..932e6a49dcb6b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -27,6 +27,7 @@ from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars +from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT class LoggerConnector: @@ -267,7 +268,7 @@ def prepare_eval_loop_results(self): for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): self.add_to_eval_loop_results(dl_idx, has_been_initialized) - def get_evaluate_epoch_results(self): + def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: if not self.trainer.sanity_checking: # log all the metrics as a single dict metrics_to_log = self.cached_results.get_epoch_log_metrics() diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 92ccb10a7f1c5..54c9cb1093cf2 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import torch from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.types import _PREDICT_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache @@ -84,7 +86,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx): self._predictions[dataloader_idx].append(predictions) - def on_predict_epoch_end(self): + def on_predict_epoch_end(self) -> _PREDICT_OUTPUT: self.trainer.profiler.describe() results = self._predictions diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6a3ed24135265..e8b0970c226b2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -63,6 +63,7 @@ from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed +from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT log = logging.getLogger(__name__) # warnings to ignore in trainer @@ -408,36 +409,13 @@ def __init__( # Callback system self.on_init_end() - def fit( + def _fit_impl( self, model: LightningModule, train_dataloader: Any = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, - ): - r""" - Runs the full optimization routine. - - Args: - datamodule: A instance of :class:`LightningDataModule`. - - model: Model to fit. - - train_dataloader: Either a single PyTorch DataLoader or a collection of these - (list, dict, nested lists and dicts). In the case of multiple dataloaders, please - see this :ref:`page ` - - val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped - - """ - Trainer._log_api_event("fit") - # we reuse fit for other functions. When already set, it shouldn't be modified. - if not self.state.running: - self.state = TrainerState.FITTING - if self._running_stage is None or self.tuning: - self.training = True - + ) -> Union[int, _EVALUATE_OUTPUT, _PREDICT_OUTPUT]: # set local properties on the model self.model_connector.copy_trainer_model_properties(model) @@ -546,17 +524,13 @@ def dispatch(self): self.accelerator.start_training(self) def run_stage(self): - results = None - self.profile_connector.setup() if self.evaluating: - results = self.run_evaluate() - elif self.predicting: - results = self.run_predict() - else: - self.run_train() - return results + return self.run_evaluate() + if self.predicting: + return self.run_predict() + return self.run_train() def _pre_training_routine(self): # wait for all to join if on distributed @@ -586,7 +560,6 @@ def _pre_training_routine(self): ref_model.on_pretrain_routine_end() def run_train(self) -> None: - self._pre_training_routine() if not self.is_global_zero and self.progress_bar_callback is not None: @@ -660,7 +633,7 @@ def run_train(self) -> None: self._running_stage = None raise - def run_evaluation(self, on_epoch=False): + def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: if not (self.evaluating or self.sanity_checking): rank_zero_warn( f"`trainer.run_evaluation()` was called but the running stage is set to {self._running_stage}." @@ -777,7 +750,7 @@ def track_output_for_epoch_end(self, outputs, output): outputs.append(output) return outputs - def run_evaluate(self): + def run_evaluate(self) -> _EVALUATE_OUTPUT: if not self.is_global_zero and self.progress_bar_callback is not None: self.progress_bar_callback.disable() @@ -786,9 +759,6 @@ def run_evaluate(self): with self.profiler.profile(f"run_{self._running_stage}_evaluation"): eval_loop_results = self.run_evaluation() - if len(eval_loop_results) == 0: - return 1 - # remove the tensors from the eval results for i, result in enumerate(eval_loop_results): if isinstance(result, dict): @@ -798,7 +768,7 @@ def run_evaluate(self): return eval_loop_results - def run_predict(self): + def run_predict(self) -> _PREDICT_OUTPUT: # prepare dataloaders dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() @@ -869,6 +839,45 @@ def run_sanity_check(self, ref_model): # prevents sanity check to affect random sampling in training reset_seed() + def fit( + self, + model: LightningModule, + train_dataloader: Any = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional[LightningDataModule] = None, + ) -> Optional[int]: + r""" + Runs the full optimization routine. + + Args: + model: Model to fit. + + train_dataloader: Either a single PyTorch DataLoader or a collection of these + (list, dict, nested lists and dicts). In the case of multiple dataloaders, please + see this :ref:`page ` + + val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. + If the model has a predefined val_dataloaders method this will be skipped + + datamodule: A instance of :class:`LightningDataModule`. + + Returns: + Whether the run was successful (1) or not (0) + """ + Trainer._log_api_event("fit") + + self.state = TrainerState.FITTING + self.training = True + + results = self._fit_impl( + model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule + ) + + assert self.state.stopped + self.training = False + + return results + def validate( self, model: Optional[LightningModule] = None, @@ -876,7 +885,7 @@ def validate( ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, - ): + ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the validation set. @@ -923,10 +932,10 @@ def validate( self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders) if not model_provided: - self.validated_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) + self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path) # run validate - results = self.fit(model) + results = self._fit_impl(model) assert self.state.stopped self.validating = False @@ -940,7 +949,7 @@ def test( ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, - ): + ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your test set until you want to. @@ -984,21 +993,17 @@ def test( self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) if not model_provided: - self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) + self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path) # run test - results = self.fit(model) + results = self._fit_impl(model) assert self.state.stopped self.testing = False return results - def __load_ckpt_weights( - self, - model, - ckpt_path: Optional[str] = None, - ) -> Optional[str]: + def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: if ckpt_path is None: return @@ -1039,7 +1044,7 @@ def predict( model: Optional[LightningModule] = None, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, - ): + ) -> _PREDICT_OUTPUT: r""" Separates from fit to make sure you never run on your predictions set until you want to. @@ -1075,7 +1080,7 @@ def predict( # Attach dataloaders (if given) self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) - results = self.fit(model) + results = self._fit_impl(model) assert self.state.stopped self.predicting = False @@ -1088,7 +1093,7 @@ def tune( train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, - ): + ) -> None: r""" Runs routines to tune hyperparameters before training. diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 9c5e966c14cc1..51b60c8ef2c62 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -189,7 +189,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f trainer.global_step = 0 # reset after each try try: # Try fit - trainer.fit(model, **fit_kwargs) + trainer.tuner._fit(model, **fit_kwargs) # Double in size new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') except RuntimeError as exception: @@ -218,7 +218,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, trainer.global_step = 0 # reset after each try try: # Try fit - trainer.fit(model, **fit_kwargs) + trainer.tuner._fit(model, **fit_kwargs) count += 1 if count > max_trials: break diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 2d122a3d30cfd..4eda74c6a8359 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -176,7 +176,7 @@ def lr_find( model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers) # Fit, lr & loss logged in callback - trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) + trainer.tuner._fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) # Prompt if we stopped early if trainer.global_step != num_training: diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index a7aa1ee256a5d..2a1ce3fec48ff 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from typing import Any, List, Optional, Union from torch.utils.data import DataLoader @@ -71,6 +71,13 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): self.trainer.state = TrainerState.FINISHED + def _fit(self, *args: Any, **kwargs: Any) -> Optional[int]: + """`_fit_impl` wrapper to set the proper `RunningStage`""" + self.trainer.training = True + results = self.trainer._fit_impl(*args, **kwargs) + self.trainer.tuning = True + return results + def scale_batch_size( self, model, diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index c1c40b98c71c7..62ee3034338ea 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -10,4 +10,6 @@ _METRIC = Union[Metric, torch.Tensor, int, float] STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]] EPOCH_OUTPUT = List[STEP_OUTPUT] +_EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader +_PREDICT_OUTPUT = Union[Any, List[Any]] _PARAMETERS = Iterator[torch.nn.Parameter] From 65aa6909ca7ad205387d7b4dbcb8bffedbdb1c04 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 27 Apr 2021 14:19:16 +0200 Subject: [PATCH 02/30] Fix return --- pytorch_lightning/tuner/tuning.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 2a1ce3fec48ff..1e6f43488c880 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -71,12 +71,11 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): self.trainer.state = TrainerState.FINISHED - def _fit(self, *args: Any, **kwargs: Any) -> Optional[int]: + def _fit(self, *args: Any, **kwargs: Any) -> None: """`_fit_impl` wrapper to set the proper `RunningStage`""" self.trainer.training = True - results = self.trainer._fit_impl(*args, **kwargs) + self.trainer._fit_impl(*args, **kwargs) self.trainer.tuning = True - return results def scale_batch_size( self, From b3f331437eae31754949ea6b22a0dcc94f7a14aa Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 27 Apr 2021 14:20:45 +0200 Subject: [PATCH 03/30] Remove return docstring --- pytorch_lightning/trainer/trainer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e8b0970c226b2..17ad320156cd0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -860,9 +860,6 @@ def fit( If the model has a predefined val_dataloaders method this will be skipped datamodule: A instance of :class:`LightningDataModule`. - - Returns: - Whether the run was successful (1) or not (0) """ Trainer._log_api_event("fit") From e7c3657edb28a3b2f5b911075965f7bd179bb3cf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 27 Apr 2021 14:31:09 +0200 Subject: [PATCH 04/30] Fixes --- pytorch_lightning/plugins/training_type/horovod.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 99899aed11753..8bfb6524c290e 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -104,14 +104,14 @@ def start_training(self, trainer): stack.enter_context(optimizer.skip_synchronize()) # set up training routine - self._results = trainer.run_stage() + self.results = trainer.run_stage() # Make sure all workers have finished training before returning to the user self.join() def start_evaluating(self, trainer): with ExitStack(): - self._results = trainer.run_stage() + self.results = trainer.run_stage() # Make sure all workers have finished training before returning to the user self.join() @@ -119,7 +119,7 @@ def start_evaluating(self, trainer): def start_predicting(self, trainer): with ExitStack(): # set up training routine - self._results = trainer.run_stage() + self.results = trainer.run_stage() # Make sure all workers have finished training before returning to the user self.join() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 17ad320156cd0..23cc4c364d428 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -523,7 +523,7 @@ def dispatch(self): else: self.accelerator.start_training(self) - def run_stage(self): + def run_stage(self) -> Optional[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]: self.profile_connector.setup() if self.evaluating: From 8060424d56b9c1af9c09071585021ccaf0026c65 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 27 Apr 2021 14:39:09 +0200 Subject: [PATCH 05/30] Fixes --- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/tuner/tuning.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 23cc4c364d428..d600c8c25d4d3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -523,7 +523,7 @@ def dispatch(self): else: self.accelerator.start_training(self) - def run_stage(self) -> Optional[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]: + def run_stage(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: self.profile_connector.setup() if self.evaluating: diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 1e6f43488c880..3088d0b6ba08c 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -72,7 +72,8 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): self.trainer.state = TrainerState.FINISHED def _fit(self, *args: Any, **kwargs: Any) -> None: - """`_fit_impl` wrapper to set the proper `RunningStage`""" + """`_fit_impl` wrapper to set the proper state during tuning, as this can be called multiple times""" + self.trainer.state = TrainerState.TUNING # last `_fit_impl` call might have set it to `FINISHED` self.trainer.training = True self.trainer._fit_impl(*args, **kwargs) self.trainer.tuning = True From 67d5ca81db77bb8d6116d19092d8a287f263b152 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 27 Apr 2021 15:51:07 +0200 Subject: [PATCH 06/30] Undo results change --- .../plugins/training_type/horovod.py | 6 +++--- .../training_type/training_type_plugin.py | 17 +++++++++++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 8bfb6524c290e..99899aed11753 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -104,14 +104,14 @@ def start_training(self, trainer): stack.enter_context(optimizer.skip_synchronize()) # set up training routine - self.results = trainer.run_stage() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user self.join() def start_evaluating(self, trainer): with ExitStack(): - self.results = trainer.run_stage() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user self.join() @@ -119,7 +119,7 @@ def start_evaluating(self, trainer): def start_predicting(self, trainer): with ExitStack(): # set up training routine - self.results = trainer.run_stage() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user self.join() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 36a62f7421b3c..f4cf24b9285b7 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -38,7 +38,7 @@ class TrainingTypePlugin(Plugin, ABC): def __init__(self) -> None: self._model = None - self.results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None + self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None self._call_configure_sharded_model_hook = True def connect(self, model: Module) -> None: @@ -124,21 +124,30 @@ def lightning_module(self) -> 'pl.LightningModule': """Returns the pure LightningModule without potential wrappers""" return unwrap_lightning_module(self._model) + @property + def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + """ + The results of the last training/testing run will be cached here. + In distributed training, we make sure to transfer the results to the appropriate master process. + """ + # TODO(@awaelchli): improve these docs + return self._results + @property def rpc_enabled(self) -> bool: return False def start_training(self, trainer: 'pl.Trainer') -> None: # double dispatch to initiate the training loop - self.results = trainer.run_stage() + self._results = trainer.run_stage() def start_evaluating(self, trainer: 'pl.Trainer') -> None: # double dispatch to initiate the test loop - self.results = trainer.run_stage() + self._results = trainer.run_stage() def start_predicting(self, trainer: 'pl.Trainer') -> None: # double dispatch to initiate the predicting loop - self.results = trainer.run_stage() + self._results = trainer.run_stage() def training_step(self, *args, **kwargs): return self.lightning_module.training_step(*args, **kwargs) From d7fbc5d00b1b4586c2e990b0fdc99c0a98901b31 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 28 Apr 2021 16:33:05 +0200 Subject: [PATCH 07/30] Revert changes for a separate PR --- pytorch_lightning/trainer/trainer.py | 16 ++++++++-------- pytorch_lightning/tuner/batch_size_scaling.py | 4 ++-- pytorch_lightning/tuner/lr_finder.py | 2 +- pytorch_lightning/tuner/tuning.py | 8 ++++---- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0d360bf3f7117..076b5714fb331 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -409,7 +409,7 @@ def __init__( # Callback system self.on_init_end() - def _fit_impl( + def _launch( self, model: LightningModule, train_dataloader: Any = None, @@ -857,7 +857,7 @@ def fit( self.state = TrainerState.FITTING self.training = True - results = self._fit_impl( + results = self._launch( model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule ) @@ -923,7 +923,7 @@ def validate( self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path) # run validate - results = self._fit_impl(model) + results = self._launch(model) assert self.state.stopped self.validating = False @@ -984,7 +984,7 @@ def test( self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path) # run test - results = self._fit_impl(model) + results = self._launch(model) assert self.state.stopped self.testing = False @@ -1041,7 +1041,9 @@ def predict( Args: model: The model to predict with. + dataloaders: Either a single PyTorch DataLoader or a list of them, specifying inference samples. + datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders. return_predictions: Whether to return predictions. @@ -1065,16 +1067,14 @@ def predict( self.predicting = True if dataloaders is not None and datamodule: - raise MisconfigurationException( - 'You cannot pass dataloaders to trainer.predict if you supply a datamodule.' - ) + raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') # Attach datamodule to get setup/prepare_data added to model before the call to it below self.data_connector.attach_datamodule(model, datamodule) # Attach dataloaders (if given) self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) - results = self._fit_impl(model) + results = self._launch(model) assert self.state.stopped self.predicting = False diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 51b60c8ef2c62..7e9dc524099de 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -189,7 +189,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f trainer.global_step = 0 # reset after each try try: # Try fit - trainer.tuner._fit(model, **fit_kwargs) + trainer.tuner._launch(model, **fit_kwargs) # Double in size new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') except RuntimeError as exception: @@ -218,7 +218,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, trainer.global_step = 0 # reset after each try try: # Try fit - trainer.tuner._fit(model, **fit_kwargs) + trainer.tuner._launch(model, **fit_kwargs) count += 1 if count > max_trials: break diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 4eda74c6a8359..971e3667d5ed7 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -176,7 +176,7 @@ def lr_find( model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers) # Fit, lr & loss logged in callback - trainer.tuner._fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) + trainer.tuner._launch(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) # Prompt if we stopped early if trainer.global_step != num_training: diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 3088d0b6ba08c..9d471e2c5cbca 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -71,11 +71,11 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): self.trainer.state = TrainerState.FINISHED - def _fit(self, *args: Any, **kwargs: Any) -> None: - """`_fit_impl` wrapper to set the proper state during tuning, as this can be called multiple times""" - self.trainer.state = TrainerState.TUNING # last `_fit_impl` call might have set it to `FINISHED` + def _launch(self, *args: Any, **kwargs: Any) -> None: + """`_launch` wrapper to set the proper state during tuning, as this can be called multiple times""" + self.trainer.state = TrainerState.TUNING # last `_launch` call might have set it to `FINISHED` self.trainer.training = True - self.trainer._fit_impl(*args, **kwargs) + self.trainer._launch(*args, **kwargs) self.trainer.tuning = True def scale_batch_size( From de7937eb636a4f8a2512539586c4f809c675820d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 28 Apr 2021 16:33:38 +0200 Subject: [PATCH 08/30] WIP --- .../trainer/configuration_validator.py | 30 +++++---- .../trainer/connectors/data_connector.py | 53 +++++++-------- pytorch_lightning/trainer/trainer.py | 64 +++++++++++-------- pytorch_lightning/trainer/training_loop.py | 16 +---- pytorch_lightning/tuner/batch_size_scaling.py | 13 ++-- pytorch_lightning/tuner/lr_finder.py | 9 +-- pytorch_lightning/tuner/tuning.py | 8 +-- 7 files changed, 88 insertions(+), 105 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 55b4ea7fe7692..215fd1353e3f0 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -11,19 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -class ConfigValidator(object): +class ConfigValidator: - def __init__(self, trainer): + def __init__(self, trainer: 'pl.Trainer') -> None: self.trainer = trainer - def verify_loop_configurations(self, model: LightningModule) -> None: + def verify_loop_configurations(self, model: 'pl.LightningModule') -> None: r""" Checks that the model is configured correctly before the run is started. @@ -31,19 +31,18 @@ def verify_loop_configurations(self, model: LightningModule) -> None: model: The model to check the configuration. """ - if self.trainer.state == TrainerState.FITTING: + if self.trainer.state in (TrainerState.FITTING, TrainerState.TUNING): self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, 'val') - elif self.trainer.state == TrainerState.TUNING: - self.__verify_train_loop_configuration(model) elif self.trainer.state == TrainerState.VALIDATING: self.__verify_eval_loop_configuration(model, 'val') elif self.trainer.state == TrainerState.TESTING: self.__verify_eval_loop_configuration(model, 'test') elif self.trainer.state == TrainerState.PREDICTING: self.__verify_predict_loop_configuration(model) + self.__verify_dp_batch_transfer_support(model) - def __verify_train_loop_configuration(self, model): + def __verify_train_loop_configuration(self, model: 'pl.LightningModule') -> None: # ----------------------------------- # verify model has a training step # ----------------------------------- @@ -82,14 +81,14 @@ def __verify_train_loop_configuration(self, model): going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches() has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad - if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization: + if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization: raise MisconfigurationException( 'When overriding `LightningModule` optimizer_step or optimizer_zero_grad,' ' `accumulate_grad_batches` in `Trainer` should be 1.' ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' ) - def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -> None: + def __verify_eval_loop_configuration(self, model: 'pl.LightningModule', stage: str) -> None: loader_name = f'{stage}_dataloader' step_name = 'validation_step' if stage == 'val' else 'test_step' @@ -101,8 +100,15 @@ def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) - if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop') - def __verify_predict_loop_configuration(self, model: LightningModule) -> None: - + def __verify_predict_loop_configuration(self, model: 'pl.LightningModule') -> None: has_predict_dataloader = is_overridden('predict_dataloader', model) if not has_predict_dataloader: raise MisconfigurationException('Dataloader not found for `Trainer.predict`') + + def __verify_dp_batch_transfer_support(self, model: 'pl.LightningModule') -> None: + """Raise Misconfiguration exception since these hooks are not supported in DP mode""" + # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. + batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') + for hook in batch_transfer_hooks: + if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model): + raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.') diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 5d2f141dc64a8..e4e8de96e4153 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader -from pytorch_lightning.core.datamodule import LightningDataModule +import pytorch_lightning as pl from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -78,42 +78,33 @@ def can_prepare_data(self): else: return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data - def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): - # if a datamodule comes in as the second arg, then fix it for the user - if isinstance(train_dataloader, LightningDataModule): - datamodule = train_dataloader - train_dataloader = None - - self.__enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule) - + def attach_data( + self, + model: 'pl.LightningModule', + train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional['pl.LightningDataModule'] = None + ) -> None: # set up the passed in dataloaders (if needed) - self.attach_dataloaders(model, train_dataloader, val_dataloaders) - self.attach_datamodule(model, datamodule) - self._validate_data_hooks(model) - - def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders - if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: - raise MisconfigurationException( - 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' - ) - - def _validate_data_hooks(self, model): - # Raise Misconfiguration exception since these hooks are not supported in DP mode - # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. - batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') - for hook in batch_transfer_hooks: - if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model): - raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.') + self.attach_dataloaders( + model, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders, + test_dataloaders=test_dataloaders, + predict_dataloaders=predict_dataloaders, + ) + self.attach_datamodule(model, datamodule=datamodule) def attach_dataloaders( self, - model, + model: 'pl.LightningModule', train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - ): + ) -> None: # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations if train_dataloader is not None: @@ -128,7 +119,9 @@ def attach_dataloaders( if predict_dataloaders is not None: model.predict_dataloader = _PatchDataLoader(predict_dataloaders) - def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = None) -> None: + def attach_datamodule( + self, model: 'pl.LightningModule', datamodule: Optional['pl.LightningDataModule'] = None + ) -> None: # We use datamodule if it's been provided, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 076b5714fb331..3d4546c60ecd2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -57,7 +57,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import DeviceType, rank_zero_warn +from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_warn from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach @@ -409,21 +409,18 @@ def __init__( # Callback system self.on_init_end() - def _launch( - self, - model: LightningModule, - train_dataloader: Any = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - datamodule: Optional[LightningDataModule] = None, - ) -> Union[int, _EVALUATE_OUTPUT, _PREDICT_OUTPUT]: + def _launch(self, model: LightningModule) -> Union[int, _EVALUATE_OUTPUT, _PREDICT_OUTPUT]: # set local properties on the model self.model_connector.copy_trainer_model_properties(model) - # ---------------------------- - # LINK DATA - # ---------------------------- - # setup data, etc... - self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) + # clean hparams + if hasattr(model, "hparams"): + parsing.clean_namespace(model.hparams) + + self.config_validator.verify_loop_configurations(model) + + # attach model log function to callback + self.callback_connector.attach_model_logging_functions(model) # hook self.data_connector.prepare_data(model) @@ -857,10 +854,23 @@ def fit( self.state = TrainerState.FITTING self.training = True - results = self._launch( + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(train_dataloader, LightningDataModule): + datamodule = train_dataloader + train_dataloader = None + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: + raise MisconfigurationException( + 'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`' + ) + + # links data to the trainer + self.data_connector.attach_data( model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule ) + results = self._launch(model) + assert self.state.stopped self.training = False @@ -914,10 +924,8 @@ def validate( model_provided = model is not None model = model or self.lightning_module - # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule) - # Attach dataloaders (if given) - self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders) + # links data to the trainer + self.data_connector.attach_data(model, val_dataloaders=val_dataloaders, datamodule=datamodule) if not model_provided: self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path) @@ -975,10 +983,8 @@ def test( model_provided = model is not None model = model or self.lightning_module - # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule) - # Attach dataloaders (if given) - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + # links data to the trainer + self.data_connector.attach_data(model, test_dataloaders=test_dataloaders, datamodule=datamodule) if not model_provided: self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path) @@ -1069,10 +1075,8 @@ def predict( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') - # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule) - # Attach dataloaders (if given) - self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) + # links data to the trainer + self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) results = self._launch(model) @@ -1092,8 +1096,6 @@ def tune( Runs routines to tune hyperparameters before training. Args: - datamodule: A instance of :class:`LightningDataModule`. - model: Model to tune. train_dataloader: A Pytorch DataLoader with training samples. If the model has @@ -1102,11 +1104,17 @@ def tune( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped + datamodule: A instance of :class:`LightningDataModule`. """ Trainer._log_api_event("tune") self.state = TrainerState.TUNING self.tuning = True + # links data to the trainer + self.data_connector.attach_data( + model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule + ) + self.tuner.tune(model, train_dataloader, val_dataloaders, datamodule) assert self.state.stopped diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9284c75879270..57bbb3d7362eb 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -25,7 +25,7 @@ from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing +from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType from pytorch_lightning.utilities.distributed import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters @@ -92,20 +92,6 @@ def on_train_start(self): # hook self.trainer.call_hook("on_train_start") - def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): - # clean hparams - if hasattr(model, "hparams"): - parsing.clean_namespace(model.hparams) - - # links data to the trainer - self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) - - # check that model is configured correctly - self.trainer.config_validator.verify_loop_configurations(model) - - # attach model log function to callback - self.trainer.callback_connector.attach_model_logging_functions(model) - def on_train_end(self): if self._teardown_already_run: return diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 7e9dc524099de..53c158b8855e8 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -35,7 +35,6 @@ def scale_batch_size( init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', - **fit_kwargs ): r""" Will iteratively try to find the largest batch size for a given model @@ -115,9 +114,9 @@ def scale_batch_size( # Initially we just double in size until an OOM is encountered new_size = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val if mode == 'power': - new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs) + new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials) elif mode == 'binsearch': - new_size = _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs) + new_size = _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials) else: raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch') @@ -181,7 +180,7 @@ def __scale_batch_restore_params(trainer): del trainer.__dumped_params -def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): +def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials): """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): @@ -189,7 +188,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f trainer.global_step = 0 # reset after each try try: # Try fit - trainer.tuner._launch(model, **fit_kwargs) + trainer.tuner._launch(model) # Double in size new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') except RuntimeError as exception: @@ -207,7 +206,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f return new_size -def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): +def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials): """ Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search """ @@ -218,7 +217,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, trainer.global_step = 0 # reset after each try try: # Try fit - trainer.tuner._launch(model, **fit_kwargs) + trainer.tuner._launch(model) count += 1 if count > max_trials: break diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 971e3667d5ed7..b2b2595294c8f 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -15,16 +15,14 @@ import logging import os from functools import wraps -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, Sequence import numpy as np import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.data import DataLoader from pytorch_lightning.callbacks import Callback -from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import DeviceType, rank_zero_warn @@ -65,14 +63,11 @@ def _determine_lr_attr_name(trainer, model: LightningModule) -> str: def lr_find( trainer, model: LightningModule, - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, - datamodule: Optional[LightningDataModule] = None, update_attr: bool = False, ): r""" @@ -176,7 +171,7 @@ def lr_find( model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers) # Fit, lr & loss logged in callback - trainer.tuner._launch(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) + trainer.tuner._launch(model) # Prompt if we stopped early if trainer.global_step != num_training: diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 9d471e2c5cbca..74d3e9ac3be55 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -124,7 +124,7 @@ def scale_batch_size( or datamodule. """ - self.setup_trainer(model, **fit_kwargs) + #self.setup_trainer(model, **fit_kwargs) return scale_batch_size( self.trainer, model, @@ -133,7 +133,6 @@ def scale_batch_size( init_val, max_trials, batch_arg_name, - **fit_kwargs, ) def lr_find( @@ -149,18 +148,15 @@ def lr_find( datamodule: Optional[LightningDataModule] = None, update_attr: bool = False, ): - self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) + #self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) return lr_find( self.trainer, model, - train_dataloader, - val_dataloaders, min_lr, max_lr, num_training, mode, early_stop_threshold, - datamodule, update_attr, ) From 405082b7c083779955f1d30a2315308f7ca9e690 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 28 Apr 2021 19:08:34 +0200 Subject: [PATCH 09/30] Progress --- pytorch_lightning/trainer/trainer.py | 15 +- pytorch_lightning/tuner/auto_gpu_select.py | 8 +- pytorch_lightning/tuner/batch_size_scaling.py | 38 ++- pytorch_lightning/tuner/lr_finder.py | 315 ++++++++---------- pytorch_lightning/tuner/tuning.py | 141 +++----- 5 files changed, 227 insertions(+), 290 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 84a3f122fa717..5eb28cef0275c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -56,6 +56,7 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin +from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_warn from pytorch_lightning.utilities.debugging import InternalDebugger @@ -1091,7 +1092,9 @@ def tune( train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, - ) -> None: + scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, + lr_find_kwargs: Optional[Dict[str, Any]] = None, + ) -> Optional[Union[int, _LRFinder]]: r""" Runs routines to tune hyperparameters before training. @@ -1105,6 +1108,10 @@ def tune( If the model has a predefined val_dataloaders method this will be skipped datamodule: A instance of :class:`LightningDataModule`. + + scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find` + + lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size` """ Trainer._log_api_event("tune") self.state = TrainerState.TUNING @@ -1115,11 +1122,15 @@ def tune( model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule ) - self.tuner.tune(model, train_dataloader, val_dataloaders, datamodule) + result = self.tuner.tune( + model, scale_batch_size_kwargs=scale_batch_size_kwargs or {}, lr_find_kwargs=lr_find_kwargs or {} + ) assert self.state.stopped self.tuning = False + return result + def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" state = self._setup_state diff --git a/pytorch_lightning/tuner/auto_gpu_select.py b/pytorch_lightning/tuner/auto_gpu_select.py index 3bd1ce52b52f4..8e0b5ad68b689 100644 --- a/pytorch_lightning/tuner/auto_gpu_select.py +++ b/pytorch_lightning/tuner/auto_gpu_select.py @@ -17,11 +17,11 @@ def pick_multiple_gpus(nb): - ''' + """ Raises: MisconfigurationException: If ``gpus`` is set to 0, when ``auto_select_gpus=True``. - ''' + """ if nb == 0: raise MisconfigurationException( r"auto_select_gpus=True, gpus=0 is not a valid configuration.\ @@ -38,11 +38,11 @@ def pick_multiple_gpus(nb): def pick_single_gpu(exclude_gpus: list): - ''' + """ Raises: RuntimeError: If you try to allocate a GPU, when no GPUs are available. - ''' + """ for i in range(torch.cuda.device_count()): if i in exclude_gpus: continue diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 53c158b8855e8..4b25587ab707e 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -15,7 +15,7 @@ import os from typing import Optional, Tuple -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -28,20 +28,21 @@ def scale_batch_size( - trainer, - model: LightningModule, + trainer: 'pl.Trainer', + model: 'pl.LightningModule', mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', -): +) -> Optional[int]: r""" Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. Args: trainer: The Trainer + model: Model to fit. mode: string setting the search mode. Either `power` or `binsearch`. @@ -52,7 +53,7 @@ def scale_batch_size( batch size that failed. steps_per_trial: number of steps to run with a given batch size. - Idealy 1 should be enough to test if a OOM error occurs, + Ideally 1 should be enough to test if a OOM error occurs, however in practise a few are needed init_val: initial batch size to start the search with @@ -69,9 +70,6 @@ def scale_batch_size( - ``model.datamodule`` - ``trainer.datamodule`` (the datamodule passed to the tune method) - **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader - or datamodule. - Raises: MisconfigurationException: If field ``batch_arg_name`` is not found in ``model`` and ``model.hparams``, or @@ -112,7 +110,7 @@ def scale_batch_size( trainer.progress_bar_callback.disable() # Initially we just double in size until an OOM is encountered - new_size = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val if mode == 'power': new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials) elif mode == 'binsearch': @@ -138,7 +136,7 @@ def scale_batch_size( return new_size -def __scale_batch_dump_params(trainer): +def __scale_batch_dump_params(trainer: 'pl.Trainer') -> None: # Prevent going into infinite loop trainer.__dumped_params = { 'auto_lr_find': trainer.auto_lr_find, @@ -154,7 +152,7 @@ def __scale_batch_dump_params(trainer): } -def __scale_batch_reset_params(trainer, model, steps_per_trial): +def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule', steps_per_trial: int) -> None: trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times trainer.current_epoch = 0 @@ -167,7 +165,7 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial): trainer.model = model # required for saving -def __scale_batch_restore_params(trainer): +def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None: trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] trainer.current_epoch = trainer.__dumped_params['current_epoch'] trainer.max_steps = trainer.__dumped_params['max_steps'] @@ -180,9 +178,11 @@ def __scale_batch_restore_params(trainer): del trainer.__dumped_params -def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials): - """ Batch scaling mode where the size is doubled at each iteration until an - OOM error is encountered. """ +# TODO: new_size argument +def _run_power_scaling( + trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int +) -> int: + """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): garbage_collection_cuda() trainer.global_step = 0 # reset after each try @@ -199,14 +199,16 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials): new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed') break else: - raise # some other error not memory related + raise exception # some other error not memory related if not changed: break return new_size -def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials): +def _run_binsearch_scaling( + trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int +) -> int: """ Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search """ @@ -251,7 +253,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials) def _adjust_batch_size( - trainer, + trainer: 'pl.Trainer', batch_arg_name: str = 'batch_size', factor: float = 1.0, value: Optional[int] = None, diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index b2b2595294c8f..b93856e44361f 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -15,15 +15,15 @@ import logging import os from functools import wraps -from typing import Callable, Sequence +from typing import Callable, Optional, Sequence import numpy as np import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler +import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -40,7 +40,7 @@ log = logging.getLogger(__name__) -def _determine_lr_attr_name(trainer, model: LightningModule) -> str: +def _determine_lr_attr_name(trainer: 'pl.Trainer', model: 'pl.LightningModule') -> str: if isinstance(trainer.auto_lr_find, str): if not lightning_hasattr(model, trainer.auto_lr_find): raise MisconfigurationException( @@ -60,171 +60,6 @@ def _determine_lr_attr_name(trainer, model: LightningModule) -> str: ) -def lr_find( - trainer, - model: LightningModule, - min_lr: float = 1e-8, - max_lr: float = 1, - num_training: int = 100, - mode: str = 'exponential', - early_stop_threshold: float = 4.0, - update_attr: bool = False, -): - r""" - ``lr_find`` enables the user to do a range test of good initial learning rates, - to reduce the amount of guesswork in picking a good starting learning rate. - - Args: - model: Model to do range testing for - - train_dataloader: A PyTorch - ``DataLoader`` with training samples. If the model has - a predefined train_dataloader method, this will be skipped. - - min_lr: minimum learning rate to investigate - - max_lr: maximum learning rate to investigate - - num_training: number of learning rates to test - - mode: Search strategy to update learning rate after each batch: - - - ``'exponential'`` (default): Will increase the learning rate exponentially. - - ``'linear'``: Will increase the learning rate linearly. - - early_stop_threshold: threshold for stopping the search. If the - loss at any point is larger than early_stop_threshold*best_loss - then the search is stopped. To disable, set to None. - - datamodule: An optional ``LightningDataModule`` which holds the training - and validation dataloader(s). Note that the ``train_dataloader`` and - ``val_dataloaders`` parameters cannot be used at the same time as - this parameter, or a ``MisconfigurationException`` will be raised. - - update_attr: Whether to update the learning rate attribute or not. - - Raises: - MisconfigurationException: - If learning rate/lr in ``model`` or ``model.hparams`` isn't overriden when ``auto_lr_find=True``, or - if you are using `more than one optimizer` with learning rate finder. - - Example:: - - # Setup model and trainer - model = MyModelClass(hparams) - trainer = pl.Trainer() - - # Run lr finder - lr_finder = trainer.tuner.lr_find(model, ...) - - # Inspect results - fig = lr_finder.plot(); fig.show() - suggested_lr = lr_finder.suggestion() - - # Overwrite lr and create new model - hparams.lr = suggested_lr - model = MyModelClass(hparams) - - # Ready to train with new learning rate - trainer.fit(model) - - """ - if trainer.fast_dev_run: - rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) - return - - # Determine lr attr - if update_attr: - lr_attr_name = _determine_lr_attr_name(trainer, model) - - save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') - - __lr_finder_dump_params(trainer, model) - - # Prevent going into infinite loop - trainer.auto_lr_find = False - - # Initialize lr finder object (stores results) - lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) - - # Use special lr logger callback - trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] - - # No logging - trainer.logger = DummyLogger() - - # Max step set to number of iterations - trainer.max_steps = num_training - - # Disable standard progress bar for fit - if trainer.progress_bar_callback: - trainer.progress_bar_callback.disable() - - # Required for saving the model - trainer.optimizers, trainer.schedulers = [], [], - trainer.model = model - - # Dump model checkpoint - trainer.save_checkpoint(str(save_path)) - - # Configure optimizer and scheduler - model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers) - - # Fit, lr & loss logged in callback - trainer.tuner._launch(model) - - # Prompt if we stopped early - if trainer.global_step != num_training: - log.info(f'LR finder stopped early after {trainer.global_step} steps due to diverging loss.') - - # Transfer results from callback to lr finder object - lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses}) - lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose - - # Reset model state - if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) - fs = get_filesystem(str(save_path)) - if fs.exists(save_path): - fs.rm(save_path) - - # Finish by resetting variables so trainer is ready to fit model - __lr_finder_restore_params(trainer, model) - if trainer.progress_bar_callback: - trainer.progress_bar_callback.enable() - - # Update lr attr if required - if update_attr: - lr = lr_finder.suggestion() - - # TODO: log lr.results to self.logger - lightning_setattr(model, lr_attr_name, lr) - log.info(f'Learning rate set to {lr}') - - return lr_finder - - -def __lr_finder_dump_params(trainer, model): - # Prevent going into infinite loop - trainer.__dumped_params = { - 'auto_lr_find': trainer.auto_lr_find, - 'callbacks': trainer.callbacks, - 'logger': trainer.logger, - 'max_steps': trainer.max_steps, - 'checkpoint_callback': trainer.checkpoint_callback, - 'configure_optimizers': model.configure_optimizers, - } - - -def __lr_finder_restore_params(trainer, model): - trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] - trainer.logger = trainer.__dumped_params['logger'] - trainer.callbacks = trainer.__dumped_params['callbacks'] - trainer.max_steps = trainer.__dumped_params['max_steps'] - model.configure_optimizers = trainer.__dumped_params['configure_optimizers'] - del trainer.__dumped_params - - class _LRFinder(object): """ LR finder object. This object stores the results of Trainer.lr_find(). @@ -359,6 +194,143 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1): self._optimal_idx = None +def lr_find( + trainer: 'pl.Trainer', + model: 'pl.LightningModule', + min_lr: float = 1e-8, + max_lr: float = 1, + num_training: int = 100, + mode: str = 'exponential', + early_stop_threshold: float = 4.0, + update_attr: bool = False, +) -> Optional[_LRFinder]: + r""" + ``lr_find`` enables the user to do a range test of good initial learning rates, + to reduce the amount of guesswork in picking a good starting learning rate. + + Args: + trainer: The Trainer + + model: Model to do range testing for + + min_lr: minimum learning rate to investigate + + max_lr: maximum learning rate to investigate + + num_training: number of learning rates to test + + mode: Search strategy to update learning rate after each batch: + + - ``'exponential'`` (default): Will increase the learning rate exponentially. + - ``'linear'``: Will increase the learning rate linearly. + + early_stop_threshold: threshold for stopping the search. If the + loss at any point is larger than early_stop_threshold*best_loss + then the search is stopped. To disable, set to None. + + update_attr: Whether to update the learning rate attribute or not. + + Raises: + MisconfigurationException: + If learning rate/lr in ``model`` or ``model.hparams`` isn't overriden when ``auto_lr_find=True``, or + if you are using `more than one optimizer` with learning rate finder. + """ + if trainer.fast_dev_run: + rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) + return + + # Determine lr attr + if update_attr: + lr_attr_name = _determine_lr_attr_name(trainer, model) + + save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') + + __lr_finder_dump_params(trainer, model) + + # Prevent going into infinite loop + trainer.auto_lr_find = False + + # Initialize lr finder object (stores results) + lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) + + # Use special lr logger callback + trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] + + # No logging + trainer.logger = DummyLogger() + + # Max step set to number of iterations + trainer.max_steps = num_training + + # Disable standard progress bar for fit + if trainer.progress_bar_callback: + trainer.progress_bar_callback.disable() + + # Required for saving the model + trainer.optimizers, trainer.schedulers = [], [], + trainer.model = model + + # Dump model checkpoint + trainer.save_checkpoint(str(save_path)) + + # Configure optimizer and scheduler + model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers) + + # Fit, lr & loss logged in callback + trainer.tuner._launch(model) + + # Prompt if we stopped early + if trainer.global_step != num_training: + log.info(f'LR finder stopped early after {trainer.global_step} steps due to diverging loss.') + + # Transfer results from callback to lr finder object + lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses}) + lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose + + # Reset model state + if trainer.is_global_zero: + trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) + fs = get_filesystem(str(save_path)) + if fs.exists(save_path): + fs.rm(save_path) + + # Finish by resetting variables so trainer is ready to fit model + __lr_finder_restore_params(trainer, model) + if trainer.progress_bar_callback: + trainer.progress_bar_callback.enable() + + # Update lr attr if required + if update_attr: + lr = lr_finder.suggestion() + + # TODO: log lr.results to self.logger + lightning_setattr(model, lr_attr_name, lr) + log.info(f'Learning rate set to {lr}') + + return lr_finder + + +def __lr_finder_dump_params(trainer, model): + # Prevent going into infinite loop + trainer.__dumped_params = { + 'auto_lr_find': trainer.auto_lr_find, + 'callbacks': trainer.callbacks, + 'logger': trainer.logger, + 'max_steps': trainer.max_steps, + 'checkpoint_callback': trainer.checkpoint_callback, + 'configure_optimizers': model.configure_optimizers, + } + + +def __lr_finder_restore_params(trainer, model): + trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] + trainer.logger = trainer.__dumped_params['logger'] + trainer.callbacks = trainer.__dumped_params['callbacks'] + trainer.max_steps = trainer.__dumped_params['max_steps'] + model.configure_optimizers = trainer.__dumped_params['configure_optimizers'] + del trainer.__dumped_params + + class _LRCallback(Callback): """ Special callback used by the learning rate finder. This callbacks log the learning rate before each batch and log the corresponding loss after @@ -434,9 +406,10 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data class _LinearLR(_LRScheduler): - """Linearly increases the learning rate between two boundaries - over a number of iterations. - Arguments: + """ + Linearly increases the learning rate between two boundaries over a number of iterations. + + Args: optimizer: wrapped optimizer. diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 74d3e9ac3be55..ae7a242fe9bba 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -11,66 +11,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from torch.utils.data import DataLoader -from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size -from pytorch_lightning.tuner.lr_finder import lr_find +from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find class Tuner: - def __init__(self, trainer): + def __init__(self, trainer: 'pl.Trainer') -> None: self.trainer = trainer - def on_trainer_init(self, auto_lr_find, auto_scale_batch_size): + def on_trainer_init(self, auto_lr_find: Union[str, bool], auto_scale_batch_size: Union[str, bool]) -> None: self.trainer.auto_lr_find = auto_lr_find self.trainer.auto_scale_batch_size = auto_scale_batch_size - def setup_trainer( + def tune( self, - model: LightningModule, - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - datamodule: LightningDataModule = None, - ): - self.trainer.model_connector.copy_trainer_model_properties(model) - # setup data, etc... - self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) - # hook - self.trainer.data_connector.prepare_data(model) + model: 'pl.LightningModule', + scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, + lr_find_kwargs: Optional[Dict[str, Any]] = None, + ) -> Optional[Union[int, _LRFinder]]: + scale_batch_size_kwargs = scale_batch_size_kwargs or {} + lr_find_kwargs = lr_find_kwargs or {} + result = None - def tune(self, model, train_dataloader, val_dataloaders, datamodule): # Run auto batch size scaling if self.trainer.auto_scale_batch_size: - if isinstance(self.trainer.auto_scale_batch_size, bool): - self.trainer.auto_scale_batch_size = 'power' - self.scale_batch_size( - model, - mode=self.trainer.auto_scale_batch_size, - train_dataloader=train_dataloader, - val_dataloaders=val_dataloaders, - datamodule=datamodule, - ) + result = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs) # Run learning rate finder: if self.trainer.auto_lr_find: - self.lr_find( - model, - update_attr=True, - train_dataloader=train_dataloader, - val_dataloaders=val_dataloaders, - datamodule=datamodule, - ) + lr_find_kwargs.setdefault('update_attr', True) + lr_find(self.trainer, model, **lr_find_kwargs) self.trainer.state = TrainerState.FINISHED + return result + def _launch(self, *args: Any, **kwargs: Any) -> None: """`_launch` wrapper to set the proper state during tuning, as this can be called multiple times""" self.trainer.state = TrainerState.TUNING # last `_launch` call might have set it to `FINISHED` @@ -80,7 +62,7 @@ def _launch(self, *args: Any, **kwargs: Any) -> None: def scale_batch_size( self, - model, + model: 'pl.LightningModule', mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, @@ -88,56 +70,23 @@ def scale_batch_size( batch_arg_name: str = 'batch_size', **fit_kwargs ): - r""" - Will iteratively try to find the largest batch size for a given model - that does not give an out of memory (OOM) error. - - Args: - model: Model to fit. - - mode: string setting the search mode. Either `power` or `binsearch`. - If mode is `power` we keep multiplying the batch size by 2, until - we get an OOM error. If mode is 'binsearch', we will initially - also keep multiplying by 2 and after encountering an OOM error - do a binary search between the last successful batch size and the - batch size that failed. - - steps_per_trial: number of steps to run with a given batch size. - Idealy 1 should be enough to test if a OOM error occurs, - however in practise a few are needed - - init_val: initial batch size to start the search with - - max_trials: max number of increase in batch size done before - algorithm is terminated - - batch_arg_name: name of the attribute that stores the batch size. - It is expected that the user has provided a model or datamodule that has a hyperparameter - with that name. We will look for this attribute name in the following places - - - ``model`` - - ``model.hparams`` - - ``model.datamodule`` - - ``trainer.datamodule`` (the datamodule passed to the tune method) - - **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader - or datamodule. - - """ - #self.setup_trainer(model, **fit_kwargs) - return scale_batch_size( - self.trainer, + # TODO: deprecate + self.trainer.auto_lr_find = True + return self.trainer.tune( model, - mode, - steps_per_trial, - init_val, - max_trials, - batch_arg_name, + **fit_kwargs, + scale_batch_size_kwargs={ + 'mode': mode, + 'steps_per_trial': steps_per_trial, + 'init_val': init_val, + 'max_trials': max_trials, + 'batch_arg_name': batch_arg_name, + } ) def lr_find( self, - model: LightningModule, + model: 'pl.LightningModule', train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, @@ -145,20 +94,22 @@ def lr_find( num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, - datamodule: Optional[LightningDataModule] = None, + datamodule: Optional['pl.LightningDataModule'] = None, update_attr: bool = False, ): - #self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) - return lr_find( - self.trainer, + # TODO: deprecate + self.trainer.auto_scale_batch_size = True + return self.trainer.tune( model, - min_lr, - max_lr, - num_training, - mode, - early_stop_threshold, - update_attr, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders, + datamodule=datamodule, + lr_find_kwargs={ + 'min_lr': min_lr, + 'max_lr': max_lr, + 'num_training': num_training, + 'mode': mode, + 'early_stop_threshold': early_stop_threshold, + 'update_attr': update_attr + } ) - - def pick_multiple_gpus(self, num_gpus: int): - return pick_multiple_gpus(num_gpus) From c2553887fe8ff6e1c6b3181a9598d69b4d1f0e19 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 28 Apr 2021 19:50:53 +0200 Subject: [PATCH 10/30] Deprecation messages --- pytorch_lightning/tuner/batch_size_scaling.py | 1 - pytorch_lightning/tuner/tuning.py | 11 +++++++++-- tests/deprecated_api/test_remove_1-5.py | 14 ++++++++++++++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 4b25587ab707e..98de18906cfba 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -178,7 +178,6 @@ def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None: del trainer.__dumped_params -# TODO: new_size argument def _run_power_scaling( trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int ) -> int: diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index ae7a242fe9bba..55160467929eb 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -19,6 +19,7 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find +from pytorch_lightning.utilities import rank_zero_deprecation class Tuner: @@ -70,7 +71,10 @@ def scale_batch_size( batch_arg_name: str = 'batch_size', **fit_kwargs ): - # TODO: deprecate + rank_zero_deprecation( + "`Tuner.scale_batch_size()` is deprecated in v1.3 and will be removed in v1.5." + " Please use `trainer.tune(scale_batch_size_kwargs={...})` instead." + ) self.trainer.auto_lr_find = True return self.trainer.tune( model, @@ -97,7 +101,10 @@ def lr_find( datamodule: Optional['pl.LightningDataModule'] = None, update_attr: bool = False, ): - # TODO: deprecate + rank_zero_deprecation( + "`Tuner.lr_find()` is deprecated in v1.3 and will be removed in v1.5." + " Please use `trainer.tune(lr_finder_kwargs={...})` instead." + ) self.trainer.auto_scale_batch_size = True return self.trainer.tune( model, diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 6516fbcc18639..0c076dd84e083 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -294,3 +294,17 @@ def test_v1_5_0_trainer_logging_mixin(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False) with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): trainer.metrics_to_scalars({}) + + +def test_v1_5_0_tuner_scale_batch_size(): + trainer = Trainer(fast_dev_run=True) + model = BoringModel() + with pytest.deprecated_call(match=r"scale_batch_size\(\)` is deprecated in v1.3 and will be removed in v1.5"): + trainer.tuner.scale_batch_size(model) + + +def test_v1_5_0_tuner_lr_find(): + trainer = Trainer(fast_dev_run=True) + model = BoringModel() + with pytest.deprecated_call(match=r"lr_find\(\)` is deprecated in v1.3 and will be removed in v1.5"): + trainer.tuner.lr_find(model) From d2a54d6edf5e9e8631dc36b2b4792acbbf3b44fe Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 28 Apr 2021 23:42:13 +0200 Subject: [PATCH 11/30] Fixes --- pytorch_lightning/tuner/batch_size_scaling.py | 2 +- pytorch_lightning/tuner/tuning.py | 18 +++++++++++------- tests/trainer/test_trainer_tricks.py | 15 +++------------ tests/tuner/test_lr_finder.py | 17 ++++------------- 4 files changed, 19 insertions(+), 33 deletions(-) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 98de18906cfba..1ce318d9da965 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -198,7 +198,7 @@ def _run_power_scaling( new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed') break else: - raise exception # some other error not memory related + raise # some other error not memory related if not changed: break diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 55160467929eb..656727cd8204c 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -48,7 +48,7 @@ def tune( # Run learning rate finder: if self.trainer.auto_lr_find: lr_find_kwargs.setdefault('update_attr', True) - lr_find(self.trainer, model, **lr_find_kwargs) + result = lr_find(self.trainer, model, **lr_find_kwargs) self.trainer.state = TrainerState.FINISHED @@ -70,13 +70,13 @@ def scale_batch_size( max_trials: int = 25, batch_arg_name: str = 'batch_size', **fit_kwargs - ): + ) -> Optional[int]: rank_zero_deprecation( "`Tuner.scale_batch_size()` is deprecated in v1.3 and will be removed in v1.5." " Please use `trainer.tune(scale_batch_size_kwargs={...})` instead." ) - self.trainer.auto_lr_find = True - return self.trainer.tune( + self.trainer.auto_scale_batch_size = True + result = self.trainer.tune( model, **fit_kwargs, scale_batch_size_kwargs={ @@ -87,6 +87,8 @@ def scale_batch_size( 'batch_arg_name': batch_arg_name, } ) + self.trainer.auto_scale_batch_size = False + return result def lr_find( self, @@ -100,13 +102,13 @@ def lr_find( early_stop_threshold: float = 4.0, datamodule: Optional['pl.LightningDataModule'] = None, update_attr: bool = False, - ): + ) -> Optional[_LRFinder]: rank_zero_deprecation( "`Tuner.lr_find()` is deprecated in v1.3 and will be removed in v1.5." " Please use `trainer.tune(lr_finder_kwargs={...})` instead." ) - self.trainer.auto_scale_batch_size = True - return self.trainer.tune( + self.trainer.auto_lr_find = True + result = self.trainer.tune( model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, @@ -120,3 +122,5 @@ def lr_find( 'update_attr': update_attr } ) + self.trainer.auto_lr_find = False + return result diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 7206d225ab5cd..a7659e3ff1f72 100644 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -212,20 +212,11 @@ def test_trainer_reset_correctly(tmpdir): 'limit_train_batches', 'current_epoch', ] - - attributes_before = {} - for ca in changed_attributes: - attributes_before[ca] = getattr(trainer, ca) - + expected = {ca: getattr(trainer, ca) for ca in changed_attributes} trainer.tuner.scale_batch_size(model, max_trials=5) + actual = {ca: getattr(trainer, ca) for ca in changed_attributes} - attributes_after = {} - for ca in changed_attributes: - attributes_after[ca] = getattr(trainer, ca) - - for key in changed_attributes: - assert attributes_before[key] == attributes_after[key], \ - f'Attribute {key} was not reset correctly after learning rate finder' + assert actual == expected @RunIf(min_gpus=1) diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index 9834c1c8ad09b..e6b530752407f 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -17,7 +17,7 @@ import pytest import torch -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate from tests.helpers import BoringModel @@ -79,20 +79,11 @@ def test_trainer_reset_correctly(tmpdir): changed_attributes = [ 'callbacks', 'logger', 'max_steps', 'auto_lr_find', 'accumulate_grad_batches', 'checkpoint_callback' ] - attributes_before = {} - for ca in changed_attributes: - attributes_before[ca] = getattr(trainer, ca) - + expected = {ca: getattr(trainer, ca) for ca in changed_attributes} _ = trainer.tuner.lr_find(model, num_training=5) + actual = {ca: getattr(trainer, ca) for ca in changed_attributes} - attributes_after = {} - for ca in changed_attributes: - attributes_after[ca] = getattr(trainer, ca) - - for key in changed_attributes: - assert attributes_before[key] == attributes_after[key], \ - f'Attribute {key} was not reset correctly after learning rate finder' - + assert actual == expected assert model.trainer == trainer From 8641504f1eca0a0a2cf65d82b06c8382974a88ba Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 00:11:20 +0200 Subject: [PATCH 12/30] Move `copy_trainer_model_properties` --- pytorch_lightning/trainer/connectors/data_connector.py | 2 ++ pytorch_lightning/trainer/connectors/model_connector.py | 5 ----- pytorch_lightning/trainer/predict_loop.py | 4 ---- pytorch_lightning/trainer/trainer.py | 3 --- 4 files changed, 2 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index e4e8de96e4153..d5cc17a040320 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -96,6 +96,8 @@ def attach_data( predict_dataloaders=predict_dataloaders, ) self.attach_datamodule(model, datamodule=datamodule) + # set local properties on the model + self.trainer.model_connector.copy_trainer_model_properties(model) def attach_dataloaders( self, diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 23f8d36a7ba83..d4bdedd31e0f4 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -11,11 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Root module for all distributed operations in Lightning. -Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU. - -""" from weakref import proxy diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 4815987e26240..9674921f103ae 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -77,10 +77,6 @@ def on_predict_model_eval(self): model_ref.on_predict_model_eval() def setup(self, model, max_batches, dataloaders): - - # copy properties for forward overrides - self.trainer.model_connector.copy_trainer_model_properties(model) - # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f76c899250047..fc792923d79d8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -411,9 +411,6 @@ def __init__( self.on_init_end() def _launch(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: - # set local properties on the model - self.model_connector.copy_trainer_model_properties(model) - # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) From d8b0bf57258a6c1e62020620c2dc8daaf09fa27a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 00:40:49 +0200 Subject: [PATCH 13/30] Code cleaning in preparation for 7258 --- .../trainer/configuration_validator.py | 30 +- .../trainer/connectors/data_connector.py | 18 +- .../trainer/connectors/model_connector.py | 5 - pytorch_lightning/trainer/predict_loop.py | 6 +- pytorch_lightning/trainer/trainer.py | 5 +- pytorch_lightning/tuner/auto_gpu_select.py | 8 +- pytorch_lightning/tuner/batch_size_scaling.py | 34 +- pytorch_lightning/tuner/lr_finder.py | 290 +++++++++--------- pytorch_lightning/tuner/tuning.py | 21 +- tests/trainer/test_trainer_tricks.py | 212 ------------- tests/tuner/test_lr_finder.py | 17 +- tests/tuner/test_scale_batch_size.py | 217 +++++++++++++ 12 files changed, 427 insertions(+), 436 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 55b4ea7fe7692..215fd1353e3f0 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -11,19 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -class ConfigValidator(object): +class ConfigValidator: - def __init__(self, trainer): + def __init__(self, trainer: 'pl.Trainer') -> None: self.trainer = trainer - def verify_loop_configurations(self, model: LightningModule) -> None: + def verify_loop_configurations(self, model: 'pl.LightningModule') -> None: r""" Checks that the model is configured correctly before the run is started. @@ -31,19 +31,18 @@ def verify_loop_configurations(self, model: LightningModule) -> None: model: The model to check the configuration. """ - if self.trainer.state == TrainerState.FITTING: + if self.trainer.state in (TrainerState.FITTING, TrainerState.TUNING): self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, 'val') - elif self.trainer.state == TrainerState.TUNING: - self.__verify_train_loop_configuration(model) elif self.trainer.state == TrainerState.VALIDATING: self.__verify_eval_loop_configuration(model, 'val') elif self.trainer.state == TrainerState.TESTING: self.__verify_eval_loop_configuration(model, 'test') elif self.trainer.state == TrainerState.PREDICTING: self.__verify_predict_loop_configuration(model) + self.__verify_dp_batch_transfer_support(model) - def __verify_train_loop_configuration(self, model): + def __verify_train_loop_configuration(self, model: 'pl.LightningModule') -> None: # ----------------------------------- # verify model has a training step # ----------------------------------- @@ -82,14 +81,14 @@ def __verify_train_loop_configuration(self, model): going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches() has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad - if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization: + if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization: raise MisconfigurationException( 'When overriding `LightningModule` optimizer_step or optimizer_zero_grad,' ' `accumulate_grad_batches` in `Trainer` should be 1.' ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' ) - def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -> None: + def __verify_eval_loop_configuration(self, model: 'pl.LightningModule', stage: str) -> None: loader_name = f'{stage}_dataloader' step_name = 'validation_step' if stage == 'val' else 'test_step' @@ -101,8 +100,15 @@ def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) - if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop') - def __verify_predict_loop_configuration(self, model: LightningModule) -> None: - + def __verify_predict_loop_configuration(self, model: 'pl.LightningModule') -> None: has_predict_dataloader = is_overridden('predict_dataloader', model) if not has_predict_dataloader: raise MisconfigurationException('Dataloader not found for `Trainer.predict`') + + def __verify_dp_batch_transfer_support(self, model: 'pl.LightningModule') -> None: + """Raise Misconfiguration exception since these hooks are not supported in DP mode""" + # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. + batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') + for hook in batch_transfer_hooks: + if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model): + raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.') diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 5d2f141dc64a8..fd6c9ea32891c 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -16,6 +16,7 @@ from torch.utils.data import DataLoader +import pytorch_lightning as pl from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -89,7 +90,6 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): # set up the passed in dataloaders (if needed) self.attach_dataloaders(model, train_dataloader, val_dataloaders) self.attach_datamodule(model, datamodule) - self._validate_data_hooks(model) def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): # If you supply a datamodule you can't supply train_dataloader or val_dataloaders @@ -98,22 +98,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' ) - def _validate_data_hooks(self, model): - # Raise Misconfiguration exception since these hooks are not supported in DP mode - # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. - batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') - for hook in batch_transfer_hooks: - if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model): - raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.') - def attach_dataloaders( self, - model, + model: 'pl.LightningModule', train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - ): + ) -> None: # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations if train_dataloader is not None: @@ -128,7 +120,9 @@ def attach_dataloaders( if predict_dataloaders is not None: model.predict_dataloader = _PatchDataLoader(predict_dataloaders) - def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = None) -> None: + def attach_datamodule( + self, model: 'pl.LightningModule', datamodule: Optional['pl.LightningDataModule'] = None + ) -> None: # We use datamodule if it's been provided, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 23f8d36a7ba83..d4bdedd31e0f4 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -11,11 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Root module for all distributed operations in Lightning. -Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU. - -""" from weakref import proxy diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 4815987e26240..fb1ad3b054c9e 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -76,11 +76,7 @@ def on_predict_model_eval(self): model_ref = self.trainer.lightning_module model_ref.on_predict_model_eval() - def setup(self, model, max_batches, dataloaders): - - # copy properties for forward overrides - self.trainer.model_connector.copy_trainer_model_properties(model) - + def setup(self, max_batches, dataloaders): # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1177c5f4ace7e..a2a7da13985f8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -775,7 +775,7 @@ def run_predict(self) -> Optional[_PREDICT_OUTPUT]: return [] # set up the eval loop - self.predict_loop.setup(self.lightning_module, max_batches, dataloaders) + self.predict_loop.setup(max_batches, dataloaders) # call hook self.predict_loop.on_predict_start() @@ -1086,8 +1086,6 @@ def tune( Runs routines to tune hyperparameters before training. Args: - datamodule: A instance of :class:`LightningDataModule`. - model: Model to tune. train_dataloader: A Pytorch DataLoader with training samples. If the model has @@ -1096,6 +1094,7 @@ def tune( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped + datamodule: A instance of :class:`LightningDataModule`. """ Trainer._log_api_event("tune") self.state = TrainerState.TUNING diff --git a/pytorch_lightning/tuner/auto_gpu_select.py b/pytorch_lightning/tuner/auto_gpu_select.py index 3bd1ce52b52f4..8e0b5ad68b689 100644 --- a/pytorch_lightning/tuner/auto_gpu_select.py +++ b/pytorch_lightning/tuner/auto_gpu_select.py @@ -17,11 +17,11 @@ def pick_multiple_gpus(nb): - ''' + """ Raises: MisconfigurationException: If ``gpus`` is set to 0, when ``auto_select_gpus=True``. - ''' + """ if nb == 0: raise MisconfigurationException( r"auto_select_gpus=True, gpus=0 is not a valid configuration.\ @@ -38,11 +38,11 @@ def pick_multiple_gpus(nb): def pick_single_gpu(exclude_gpus: list): - ''' + """ Raises: RuntimeError: If you try to allocate a GPU, when no GPUs are available. - ''' + """ for i in range(torch.cuda.device_count()): if i in exclude_gpus: continue diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 7e9dc524099de..45b0ac426e803 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -15,7 +15,7 @@ import os from typing import Optional, Tuple -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -28,21 +28,22 @@ def scale_batch_size( - trainer, - model: LightningModule, + trainer: 'pl.Trainer', + model: 'pl.LightningModule', mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', **fit_kwargs -): +) -> Optional[int]: r""" Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. Args: trainer: The Trainer + model: Model to fit. mode: string setting the search mode. Either `power` or `binsearch`. @@ -53,7 +54,7 @@ def scale_batch_size( batch size that failed. steps_per_trial: number of steps to run with a given batch size. - Idealy 1 should be enough to test if a OOM error occurs, + Ideally 1 should be enough to test if a OOM error occurs, however in practise a few are needed init_val: initial batch size to start the search with @@ -113,7 +114,7 @@ def scale_batch_size( trainer.progress_bar_callback.disable() # Initially we just double in size until an OOM is encountered - new_size = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val if mode == 'power': new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs) elif mode == 'binsearch': @@ -139,7 +140,7 @@ def scale_batch_size( return new_size -def __scale_batch_dump_params(trainer): +def __scale_batch_dump_params(trainer: 'pl.Trainer') -> None: # Prevent going into infinite loop trainer.__dumped_params = { 'auto_lr_find': trainer.auto_lr_find, @@ -155,7 +156,7 @@ def __scale_batch_dump_params(trainer): } -def __scale_batch_reset_params(trainer, model, steps_per_trial): +def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule', steps_per_trial: int) -> None: trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times trainer.current_epoch = 0 @@ -168,7 +169,7 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial): trainer.model = model # required for saving -def __scale_batch_restore_params(trainer): +def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None: trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] trainer.current_epoch = trainer.__dumped_params['current_epoch'] trainer.max_steps = trainer.__dumped_params['max_steps'] @@ -181,9 +182,11 @@ def __scale_batch_restore_params(trainer): del trainer.__dumped_params -def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): - """ Batch scaling mode where the size is doubled at each iteration until an - OOM error is encountered. """ +def _run_power_scaling( + trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int, + **fit_kwargs +) -> int: + """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): garbage_collection_cuda() trainer.global_step = 0 # reset after each try @@ -207,7 +210,10 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f return new_size -def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): +def _run_binsearch_scaling( + trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int, + **fit_kwargs +) -> int: """ Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search """ @@ -252,7 +258,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, def _adjust_batch_size( - trainer, + trainer: 'pl.Trainer', batch_arg_name: str = 'batch_size', factor: float = 1.0, value: Optional[int] = None, diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 14f21da856145..df51637dc9520 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -23,9 +23,8 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader +import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -42,7 +41,7 @@ log = logging.getLogger(__name__) -def _determine_lr_attr_name(trainer, model: LightningModule) -> str: +def _determine_lr_attr_name(trainer: 'pl.Trainer', model: 'pl.LightningModule') -> str: if isinstance(trainer.auto_lr_find, str): if not lightning_hasattr(model, trainer.auto_lr_find): raise MisconfigurationException( @@ -62,9 +61,143 @@ def _determine_lr_attr_name(trainer, model: LightningModule) -> str: ) +class _LRFinder(object): + """ LR finder object. This object stores the results of Trainer.lr_find(). + + Args: + mode: either `linear` or `exponential`, how to increase lr after each step + + lr_min: lr to start search from + + lr_max: lr to stop search + + num_training: number of steps to take between lr_min and lr_max + + Example:: + # Run lr finder + lr_finder = trainer.lr_find(model) + + # Results stored in + lr_finder.results + + # Plot using + lr_finder.plot() + + # Get suggestion + lr = lr_finder.suggestion() + """ + + def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): + assert mode in ('linear', 'exponential'), \ + 'mode should be either `linear` or `exponential`' + + self.mode = mode + self.lr_min = lr_min + self.lr_max = lr_max + self.num_training = num_training + + self.results = {} + self._total_batch_idx = 0 # for debug purpose + + def _exchange_scheduler(self, configure_optimizers: Callable): + """ Decorate configure_optimizers methods such that it returns the users + originally specified optimizer together with a new scheduler that + that takes care of the learning rate search. + """ + + @wraps(configure_optimizers) + def func(): + # Decide the structure of the output from configure_optimizers + # Same logic as method `init_optimizers` in trainer/optimizers.py + optim_conf = configure_optimizers() + if isinstance(optim_conf, Optimizer): + optimizers = [optim_conf] + elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \ + and isinstance(optim_conf[0], list): + optimizers, _ = optim_conf + elif isinstance(optim_conf, dict): + optimizers = [optim_conf["optimizer"]] + elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): + optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] + elif isinstance(optim_conf, (list, tuple)): + optimizers = [optim_conf] + + if len(optimizers) != 1: + raise MisconfigurationException( + f'`model.configure_optimizers()` returned {len(optimizers)}, but' + ' learning rate finder only works with single optimizer' + ) + + optimizer = optimizers[0] + + new_lrs = [self.lr_min] * len(optimizer.param_groups) + for param_group, new_lr in zip(optimizer.param_groups, new_lrs): + param_group["lr"] = new_lr + param_group["initial_lr"] = new_lr + + args = (optimizer, self.lr_max, self.num_training) + scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args) + + return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}] + + return func + + def plot(self, suggest: bool = False, show: bool = False): + """ Plot results from lr_find run + Args: + suggest: if True, will mark suggested lr to use with a red point + + show: if True, will show figure + """ + import matplotlib.pyplot as plt + + lrs = self.results["lr"] + losses = self.results["loss"] + + fig, ax = plt.subplots() + + # Plot loss as a function of the learning rate + ax.plot(lrs, losses) + if self.mode == 'exponential': + ax.set_xscale("log") + ax.set_xlabel("Learning rate") + ax.set_ylabel("Loss") + + if suggest: + _ = self.suggestion() + if self._optimal_idx: + ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx], markersize=10, marker='o', color='red') + + if show: + plt.show() + + return fig + + def suggestion(self, skip_begin: int = 10, skip_end: int = 1): + """ This will propose a suggestion for choice of initial learning rate + as the point with the steepest negative gradient. + + Returns: + lr: suggested initial learning rate to use + skip_begin: how many samples to skip in the beginning. Prevent too naive estimates + skip_end: how many samples to skip in the end. Prevent too optimistic estimates + + """ + try: + loss = np.array(self.results["loss"][skip_begin:-skip_end]) + loss = loss[np.isfinite(loss)] + min_grad = np.gradient(loss).argmin() + self._optimal_idx = min_grad + skip_begin + return self.results["lr"][self._optimal_idx] + # todo: specify the possible exception + except Exception: + log.exception('Failed to compute suggesting for `lr`. There might not be enough points.') + self._optimal_idx = None + + def lr_find( - trainer, - model: LightningModule, + trainer: 'pl.Trainer', + model: 'pl.LightningModule', train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, @@ -72,14 +205,16 @@ def lr_find( num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, - datamodule: Optional[LightningDataModule] = None, + datamodule: Optional['pl.LightningDataModule'] = None, update_attr: bool = False, -): +) -> Optional[_LRFinder]: r""" ``lr_find`` enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. Args: + trainer: The Trainer + model: Model to do range testing for train_dataloader: A PyTorch @@ -232,140 +367,6 @@ def __lr_finder_restore_params(trainer, model): del trainer.__dumped_params -class _LRFinder(object): - """ LR finder object. This object stores the results of Trainer.lr_find(). - - Args: - mode: either `linear` or `exponential`, how to increase lr after each step - - lr_min: lr to start search from - - lr_max: lr to stop search - - num_training: number of steps to take between lr_min and lr_max - - Example:: - # Run lr finder - lr_finder = trainer.lr_find(model) - - # Results stored in - lr_finder.results - - # Plot using - lr_finder.plot() - - # Get suggestion - lr = lr_finder.suggestion() - """ - - def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): - assert mode in ('linear', 'exponential'), \ - 'mode should be either `linear` or `exponential`' - - self.mode = mode - self.lr_min = lr_min - self.lr_max = lr_max - self.num_training = num_training - - self.results = {} - self._total_batch_idx = 0 # for debug purpose - - def _exchange_scheduler(self, configure_optimizers: Callable): - """ Decorate configure_optimizers methods such that it returns the users - originally specified optimizer together with a new scheduler that - that takes care of the learning rate search. - """ - - @wraps(configure_optimizers) - def func(): - # Decide the structure of the output from configure_optimizers - # Same logic as method `init_optimizers` in trainer/optimizers.py - optim_conf = configure_optimizers() - if isinstance(optim_conf, Optimizer): - optimizers = [optim_conf] - elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \ - and isinstance(optim_conf[0], list): - optimizers, _ = optim_conf - elif isinstance(optim_conf, dict): - optimizers = [optim_conf["optimizer"]] - elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): - optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] - elif isinstance(optim_conf, (list, tuple)): - optimizers = [optim_conf] - - if len(optimizers) != 1: - raise MisconfigurationException( - f'`model.configure_optimizers()` returned {len(optimizers)}, but' - ' learning rate finder only works with single optimizer' - ) - - optimizer = optimizers[0] - - new_lrs = [self.lr_min] * len(optimizer.param_groups) - for param_group, new_lr in zip(optimizer.param_groups, new_lrs): - param_group["lr"] = new_lr - param_group["initial_lr"] = new_lr - - args = (optimizer, self.lr_max, self.num_training) - scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args) - - return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}] - - return func - - def plot(self, suggest: bool = False, show: bool = False): - """ Plot results from lr_find run - Args: - suggest: if True, will mark suggested lr to use with a red point - - show: if True, will show figure - """ - import matplotlib.pyplot as plt - - lrs = self.results["lr"] - losses = self.results["loss"] - - fig, ax = plt.subplots() - - # Plot loss as a function of the learning rate - ax.plot(lrs, losses) - if self.mode == 'exponential': - ax.set_xscale("log") - ax.set_xlabel("Learning rate") - ax.set_ylabel("Loss") - - if suggest: - _ = self.suggestion() - if self._optimal_idx: - ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx], markersize=10, marker='o', color='red') - - if show: - plt.show() - - return fig - - def suggestion(self, skip_begin: int = 10, skip_end: int = 1): - """ This will propose a suggestion for choice of initial learning rate - as the point with the steepest negative gradient. - - Returns: - lr: suggested initial learning rate to use - skip_begin: how many samples to skip in the beginning. Prevent too naive estimates - skip_end: how many samples to skip in the end. Prevent too optimistic estimates - - """ - try: - loss = np.array(self.results["loss"][skip_begin:-skip_end]) - loss = loss[np.isfinite(loss)] - min_grad = np.gradient(loss).argmin() - self._optimal_idx = min_grad + skip_begin - return self.results["lr"][self._optimal_idx] - # todo: specify the possible exception - except Exception: - log.exception('Failed to compute suggesting for `lr`. There might not be enough points.') - self._optimal_idx = None - - class _LRCallback(Callback): """ Special callback used by the learning rate finder. This callbacks log the learning rate before each batch and log the corresponding loss after @@ -441,9 +442,10 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data class _LinearLR(_LRScheduler): - """Linearly increases the learning rate between two boundaries - over a number of iterations. - Arguments: + """ + Linearly increases the learning rate between two boundaries over a number of iterations. + + Args: optimizer: wrapped optimizer. diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 9d471e2c5cbca..9822008f07a4f 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -16,20 +16,20 @@ from torch.utils.data import DataLoader +import pytorch_lightning as pl from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size -from pytorch_lightning.tuner.lr_finder import lr_find +from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find class Tuner: - def __init__(self, trainer): + def __init__(self, trainer: 'pl.Trainer') -> None: self.trainer = trainer - def on_trainer_init(self, auto_lr_find, auto_scale_batch_size): + def on_trainer_init(self, auto_lr_find: Union[str, bool], auto_scale_batch_size: Union[str, bool]) -> None: self.trainer.auto_lr_find = auto_lr_find self.trainer.auto_scale_batch_size = auto_scale_batch_size @@ -80,14 +80,14 @@ def _launch(self, *args: Any, **kwargs: Any) -> None: def scale_batch_size( self, - model, + model: 'pl.LightningModule', mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', **fit_kwargs - ): + ) -> Optional[int]: r""" Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. @@ -138,7 +138,7 @@ def scale_batch_size( def lr_find( self, - model: LightningModule, + model: 'pl.LightningModule', train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, @@ -146,9 +146,9 @@ def lr_find( num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, - datamodule: Optional[LightningDataModule] = None, + datamodule: Optional['pl.LightningDataModule'] = None, update_attr: bool = False, - ): + ) -> Optional[_LRFinder]: self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) return lr_find( self.trainer, @@ -163,6 +163,3 @@ def lr_find( datamodule, update_attr, ) - - def pick_multiple_gpus(self, num_gpus: int): - return pick_multiple_gpus(num_gpus) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 7206d225ab5cd..85aa7aa937740 100644 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -11,21 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from copy import deepcopy - -import pytest import torch from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -import tests.helpers.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate -from tests.helpers import BoringModel -from tests.helpers.datamodules import MNISTDataModule -from tests.helpers.runif import RunIf def test_num_training_batches(tmpdir): @@ -166,205 +156,3 @@ def test_overfit_batch_limits(tmpdir): loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(model, split) assert loader_num_batches[0] == 10 - - -def test_model_reset_correctly(tmpdir): - """ Check that model weights are correctly reset after scaling batch size. """ - tutils.reset_seed() - - model = EvalModelTemplate() - - # logger file to get meta - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - ) - - before_state_dict = deepcopy(model.state_dict()) - - trainer.tuner.scale_batch_size(model, max_trials=5) - - after_state_dict = model.state_dict() - - for key in before_state_dict.keys(): - assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), \ - 'Model was not reset correctly after scaling batch size' - - -def test_trainer_reset_correctly(tmpdir): - """ Check that all trainer parameters are reset correctly after scaling batch size. """ - tutils.reset_seed() - - model = EvalModelTemplate() - - # logger file to get meta - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - ) - - changed_attributes = [ - 'max_steps', - 'weights_summary', - 'logger', - 'callbacks', - 'checkpoint_callback', - 'limit_train_batches', - 'current_epoch', - ] - - attributes_before = {} - for ca in changed_attributes: - attributes_before[ca] = getattr(trainer, ca) - - trainer.tuner.scale_batch_size(model, max_trials=5) - - attributes_after = {} - for ca in changed_attributes: - attributes_after[ca] = getattr(trainer, ca) - - for key in changed_attributes: - assert attributes_before[key] == attributes_after[key], \ - f'Attribute {key} was not reset correctly after learning rate finder' - - -@RunIf(min_gpus=1) -@pytest.mark.parametrize('scale_arg', ['power', 'binsearch', True]) -def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg): - """ Test possible values for 'batch size auto scaling' Trainer argument. """ - tutils.reset_seed() - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - before_batch_size = hparams.get('batch_size') - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - auto_scale_batch_size=scale_arg, - gpus=1, - ) - trainer.tune(model) - after_batch_size = model.batch_size - assert before_batch_size != after_batch_size, \ - 'Batch size was not altered after running auto scaling of batch size' - - assert not os.path.exists(tmpdir / 'scale_batch_size_temp_model.ckpt') - - -@RunIf(min_gpus=1) -@pytest.mark.parametrize('use_hparams', [True, False]) -def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams): - """ Test that new batch size gets written to the correct hyperparameter attribute. """ - tutils.reset_seed() - - hparams = EvalModelTemplate.get_default_hparams() - before_batch_size = hparams.get('batch_size') - - class HparamsEvalModelTemplate(EvalModelTemplate): - - def dataloader(self, *args, **kwargs): - # artificially set batch_size so we can get a dataloader - # remove it immediately after, because we want only self.hparams.batch_size - setattr(self, "batch_size", before_batch_size) - dataloader = super().dataloader(*args, **kwargs) - del self.batch_size - return dataloader - - datamodule_model = MNISTDataModule(data_dir=tmpdir, batch_size=111) # this datamodule should get ignored! - datamodule_fit = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) - - model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate - model = model_class(**hparams) - model.datamodule = datamodule_model # unused when another module gets passed to .tune() / .fit() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - auto_scale_batch_size=True, - gpus=1, - ) - trainer.tune(model, datamodule_fit) - after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size - assert trainer.datamodule == datamodule_fit - assert before_batch_size != after_batch_size - assert after_batch_size <= len(trainer.train_dataloader.dataset) - assert datamodule_fit.batch_size == after_batch_size - # should be left unchanged, since it was not passed to .tune() - assert datamodule_model.batch_size == 111 - - -def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir): - """ Test for a warning when model.batch_size and model.hparams.batch_size both present. """ - - class TestModel(BoringModel): - - def __init__(self, batch_size=1): - super().__init__() - # now we have model.batch_size and model.hparams.batch_size - self.batch_size = 1 - self.save_hyperparameters() - - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, max_steps=1, max_epochs=1000, auto_scale_batch_size=True) - expected_message = "Field `model.batch_size` and `model.hparams.batch_size` are mutually exclusive!" - with pytest.warns(UserWarning, match=expected_message): - trainer.tune(model) - - -@pytest.mark.parametrize('scale_method', ['power', 'binsearch']) -def test_call_to_trainer_method(tmpdir, scale_method): - """ Test that calling the trainer method itself works. """ - tutils.reset_seed() - - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - - before_batch_size = hparams.get('batch_size') - # logger file to get meta - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - ) - - after_batch_size = trainer.tuner.scale_batch_size(model, mode=scale_method, max_trials=5) - model.batch_size = after_batch_size - trainer.fit(model) - - assert before_batch_size != after_batch_size, \ - 'Batch size was not altered after running auto scaling of batch size' - - -def test_error_on_dataloader_passed_to_fit(tmpdir): - """Verify that when the auto scale batch size feature raises an error - if a train dataloader is passed to fit """ - - # only train passed to fit - model = EvalModelTemplate() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - auto_scale_batch_size='power', - ) - fit_options = dict(train_dataloader=model.dataloader(train=True)) - - with pytest.raises(MisconfigurationException): - trainer.tune(model, **fit_options) - - -@RunIf(min_gpus=1, amp_native=True) -def test_auto_scale_batch_size_with_amp(tmpdir): - model = EvalModelTemplate() - batch_size_before = model.batch_size - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=1, - auto_scale_batch_size=True, - gpus=1, - precision=16, - ) - trainer.tune(model) - batch_size_after = model.batch_size - assert trainer.amp_backend == AMPType.NATIVE - assert trainer.scaler is not None - assert batch_size_after != batch_size_before diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index 9834c1c8ad09b..e6b530752407f 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -17,7 +17,7 @@ import pytest import torch -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate from tests.helpers import BoringModel @@ -79,20 +79,11 @@ def test_trainer_reset_correctly(tmpdir): changed_attributes = [ 'callbacks', 'logger', 'max_steps', 'auto_lr_find', 'accumulate_grad_batches', 'checkpoint_callback' ] - attributes_before = {} - for ca in changed_attributes: - attributes_before[ca] = getattr(trainer, ca) - + expected = {ca: getattr(trainer, ca) for ca in changed_attributes} _ = trainer.tuner.lr_find(model, num_training=5) + actual = {ca: getattr(trainer, ca) for ca in changed_attributes} - attributes_after = {} - for ca in changed_attributes: - attributes_after[ca] = getattr(trainer, ca) - - for key in changed_attributes: - assert attributes_before[key] == attributes_after[key], \ - f'Attribute {key} was not reset correctly after learning rate finder' - + assert actual == expected assert model.trainer == trainer diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index ad7fc57092f32..e61cafec568ef 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -11,12 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from copy import deepcopy + import pytest +import torch from torch.utils.data import DataLoader +import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.tuner.tuning import Tuner +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import EvalModelTemplate from tests.helpers import BoringDataModule, BoringModel +from tests.helpers.datamodules import MNISTDataModule +from tests.helpers.runif import RunIf class BatchSizeDataModule(BoringDataModule): @@ -63,3 +73,210 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod assert model.batch_size == 16 if datamodule is not None and hasattr(datamodule, "batch_size"): assert datamodule.batch_size == 16 + + +def test_model_reset_correctly(tmpdir): + """ Check that model weights are correctly reset after scaling batch size. """ + tutils.reset_seed() + + model = EvalModelTemplate() + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + ) + + before_state_dict = deepcopy(model.state_dict()) + + trainer.tuner.scale_batch_size(model, max_trials=5) + + after_state_dict = model.state_dict() + + for key in before_state_dict.keys(): + assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), \ + 'Model was not reset correctly after scaling batch size' + + +def test_trainer_reset_correctly(tmpdir): + """ Check that all trainer parameters are reset correctly after scaling batch size. """ + tutils.reset_seed() + + model = EvalModelTemplate() + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + ) + + changed_attributes = [ + 'max_steps', + 'weights_summary', + 'logger', + 'callbacks', + 'checkpoint_callback', + 'limit_train_batches', + 'current_epoch', + ] + expected = {ca: getattr(trainer, ca) for ca in changed_attributes} + trainer.tuner.scale_batch_size(model, max_trials=5) + actual = {ca: getattr(trainer, ca) for ca in changed_attributes} + + assert actual == expected + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize('scale_arg', ['power', 'binsearch', True]) +def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg): + """ Test possible values for 'batch size auto scaling' Trainer argument. """ + tutils.reset_seed() + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + before_batch_size = hparams.get('batch_size') + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + auto_scale_batch_size=scale_arg, + gpus=1, + ) + trainer.tune(model) + after_batch_size = model.batch_size + assert before_batch_size != after_batch_size, \ + 'Batch size was not altered after running auto scaling of batch size' + + assert not os.path.exists(tmpdir / 'scale_batch_size_temp_model.ckpt') + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize('use_hparams', [True, False]) +def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams): + """ Test that new batch size gets written to the correct hyperparameter attribute. """ + tutils.reset_seed() + + hparams = EvalModelTemplate.get_default_hparams() + before_batch_size = hparams.get('batch_size') + + class HparamsEvalModelTemplate(EvalModelTemplate): + + def dataloader(self, *args, **kwargs): + # artificially set batch_size so we can get a dataloader + # remove it immediately after, because we want only self.hparams.batch_size + setattr(self, "batch_size", before_batch_size) + dataloader = super().dataloader(*args, **kwargs) + del self.batch_size + return dataloader + + datamodule_model = MNISTDataModule(data_dir=tmpdir, batch_size=111) # this datamodule should get ignored! + datamodule_fit = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) + + model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate + model = model_class(**hparams) + model.datamodule = datamodule_model # unused when another module gets passed to .tune() / .fit() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + auto_scale_batch_size=True, + gpus=1, + ) + trainer.tune(model, datamodule_fit) + after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size + assert trainer.datamodule == datamodule_fit + assert before_batch_size != after_batch_size + assert after_batch_size <= len(trainer.train_dataloader.dataset) + assert datamodule_fit.batch_size == after_batch_size + # should be left unchanged, since it was not passed to .tune() + assert datamodule_model.batch_size == 111 + + +def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir): + """ Test for a warning when model.batch_size and model.hparams.batch_size both present. """ + + class TestModel(BoringModel): + + def __init__(self, batch_size=1): + super().__init__() + # now we have model.batch_size and model.hparams.batch_size + self.batch_size = 1 + self.save_hyperparameters() + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, max_steps=1, max_epochs=1000, auto_scale_batch_size=True) + expected_message = "Field `model.batch_size` and `model.hparams.batch_size` are mutually exclusive!" + with pytest.warns(UserWarning, match=expected_message): + trainer.tune(model) + + +@pytest.mark.parametrize('scale_method', ['power', 'binsearch']) +def test_call_to_trainer_method(tmpdir, scale_method): + """ Test that calling the trainer method itself works. """ + tutils.reset_seed() + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + before_batch_size = hparams.get('batch_size') + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + ) + + after_batch_size = trainer.tuner.scale_batch_size(model, mode=scale_method, max_trials=5) + model.batch_size = after_batch_size + trainer.fit(model) + + assert before_batch_size != after_batch_size, \ + 'Batch size was not altered after running auto scaling of batch size' + + +def test_error_on_dataloader_passed_to_fit(tmpdir): + """Verify that when the auto scale batch size feature raises an error + if a train dataloader is passed to fit """ + + # only train passed to fit + model = EvalModelTemplate() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + auto_scale_batch_size='power', + ) + fit_options = dict(train_dataloader=model.dataloader(train=True)) + + with pytest.raises(MisconfigurationException): + trainer.tune(model, **fit_options) + + +@RunIf(min_gpus=1, amp_native=True) +def test_auto_scale_batch_size_with_amp(tmpdir): + model = EvalModelTemplate() + batch_size_before = model.batch_size + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + auto_scale_batch_size=True, + gpus=1, + precision=16, + ) + trainer.tune(model) + batch_size_after = model.batch_size + assert trainer.amp_backend == AMPType.NATIVE + assert trainer.scaler is not None + assert batch_size_after != batch_size_before + + +def test_scale_batch_size_no_trials(tmpdir): + """Check the result is correct even when no trials are run""" + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=1, + auto_scale_batch_size='power', + ) + model = BatchSizeModel(batch_size=2) + result = trainer.tuner.scale_batch_size(model, max_trials=0) + assert result == 2 From 1eeedac22126fe8c131c90a442caa95100ca311e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 00:45:59 +0200 Subject: [PATCH 14/30] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee81e4f6514a3..b2777e49ada3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -380,6 +380,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed resetting device after `fitting/evaluating/predicting` ([#7188](https://github.com/PyTorchLightning/pytorch-lightning/pull/7188)) +- Fixed bug where `trainer.tuner.scale_batch_size(max_trials=0)` would not return the correct batch size result ([#7262](https://github.com/PyTorchLightning/pytorch-lightning/pull/7262)) + + - Fixed metrics not being properly logged with `precision=16` and `manual_optimization` ([#7228](https://github.com/PyTorchLightning/pytorch-lightning/pull/7228)) From 5c62b3ad22f7b3ab9247cf28d205be267b4ac103 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 00:54:53 +0200 Subject: [PATCH 15/30] Fix test --- tests/trainer/test_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index cb98c660d2dee..e97e79df97b4a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1977,6 +1977,7 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None: def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(BoringModel()) with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"): trainer.validate() From d30467d77895cae57bc81b96fdb411e5cf17d81f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 01:08:10 +0200 Subject: [PATCH 16/30] Update CHANGELOG --- CHANGELOG.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b2777e49ada3a..b06c6f6275c2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -133,6 +133,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `BasePredictionWriter` callback to implement prediction saving ([#7127](https://github.com/PyTorchLightning/pytorch-lightning/pull/7127)) +- Added `trainer.tune(scale_batch_size_kwargs, lr_find_kwargs)` arguments to configure the tuning algorithms ([#7258](https://github.com/PyTorchLightning/pytorch-lightning/pull/7258)) + + - Added `tpu_distributed` check for TPU Spawn barrier ([#7241](https://github.com/PyTorchLightning/pytorch-lightning/pull/7241)) @@ -171,6 +174,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed default setting for communication of multi-node training using `DDPShardedPlugin` ([#6937](https://github.com/PyTorchLightning/pytorch-lightning/pull/6937)) +- `trainer.tune()` now returns the tuning result ([#7258](https://github.com/PyTorchLightning/pytorch-lightning/pull/7258)) + + - `LightningModule.from_datasets()` now accepts `IterableDataset` instances as training datasets. ([#7503](https://github.com/PyTorchLightning/pytorch-lightning/pull/7503)) @@ -203,6 +209,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `@auto_move_data` in favor of `trainer.predict` ([#6993](https://github.com/PyTorchLightning/pytorch-lightning/pull/6993)) +- Deprecated `trainer.tuner.{lr_find,scale_batch_size}` in favor of `trainer.tune()` ([#7258](https://github.com/PyTorchLightning/pytorch-lightning/pull/7258)) + + - Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), [#6540](https://github.com/PyTorchLightning/pytorch-lightning/pull/6540), @@ -311,6 +320,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) +- Fixed `trainer.tuner.{lr_find,scale_batch_size}` not setting the `Trainer` state properly ([#7258](https://github.com/PyTorchLightning/pytorch-lightning/pull/7258)) + + - Fixed bug where `BaseFinetuning.flatten_modules()` was duplicating leaf node parameters ([#6879](https://github.com/PyTorchLightning/pytorch-lightning/pull/6879)) From 08e7a54b9e62835f19b8c322dbed15554001bb4a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 01:17:17 +0200 Subject: [PATCH 17/30] Update docs --- docs/source/advanced/lr_finder.rst | 9 ++++----- docs/source/advanced/training_tricks.rst | 9 ++++----- docs/source/api_references.rst | 12 ------------ pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/tuner/tuning.py | 2 +- 5 files changed, 10 insertions(+), 24 deletions(-) diff --git a/docs/source/advanced/lr_finder.rst b/docs/source/advanced/lr_finder.rst index 9a0749b36ad4a..8b62186d54198 100644 --- a/docs/source/advanced/lr_finder.rst +++ b/docs/source/advanced/lr_finder.rst @@ -73,17 +73,16 @@ If your model is using an arbitrary value instead of ``self.lr`` or ``self.learn trainer.tune(model) -If you want to inspect the results of the learning rate finder or just play around -with the parameters of the algorithm, this can be done by invoking the ``lr_find`` -method of the trainer. A typical example of this would look like +You can also inspect the results of the learning rate finder or just play around +with the parameters of the algorithm. A typical example of this would look like: .. code-block:: python model = MyModelClass(hparams) - trainer = Trainer() + trainer = Trainer(auto_lr_find=True) # Run learning rate finder - lr_finder = trainer.tuner.lr_find(model) + lr_finder = trainer.tune(model, lr_find_kwargs={...}) # Results can be found in lr_finder.results diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index c3b232b41c13c..0bd3f1ce1cffa 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -112,18 +112,17 @@ search for batch sizes larger than the size of the training dataset. to `.fit()`. The scaling algorithm has a number of parameters that the user can control by -invoking the trainer method `.scale_batch_size` themself (see description below). +invoking passing the ``scale_batch_size`` argument: .. code-block:: python # Use default in trainer construction - trainer = Trainer() - tuner = Tuner(trainer) + trainer = Trainer(auto_scale_batch_size=True) # Invoke method - new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here) + new_batch_size = trainer.tune(model, scale_batch_size_kwargs={...}) - # Override old batch size + # Override old batch size (this is done automatically) model.hparams.batch_size = new_batch_size # Fit as normal diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index c954db735c282..84e1f47f570d9 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -150,18 +150,6 @@ Trainer API trainer -Tuner API ---------- - -.. currentmodule:: pytorch_lightning.tuner - -.. autosummary:: - :toctree: api - :nosignatures: - - batch_size_scaling - lr_finder - Utilities API ------------- diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7c90d423af9a5..8cd7b6dff231e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1115,7 +1115,7 @@ def tune( model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule ) - result = self.tuner.tune( + result = self.tuner._tune( model, scale_batch_size_kwargs=scale_batch_size_kwargs or {}, lr_find_kwargs=lr_find_kwargs or {} ) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 656727cd8204c..793aebedca09b 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -31,7 +31,7 @@ def on_trainer_init(self, auto_lr_find: Union[str, bool], auto_scale_batch_size: self.trainer.auto_lr_find = auto_lr_find self.trainer.auto_scale_batch_size = auto_scale_batch_size - def tune( + def _tune( self, model: 'pl.LightningModule', scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, From 236b9f4ec42e85a8fbafdd2c81e736529c8b0c66 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 01:44:40 +0200 Subject: [PATCH 18/30] Fix test? --- pytorch_lightning/trainer/trainer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8cd7b6dff231e..32804f7125108 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1110,6 +1110,16 @@ def tune( self.state = TrainerState.TUNING self.tuning = True + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(train_dataloader, LightningDataModule): + datamodule = train_dataloader + train_dataloader = None + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: + raise MisconfigurationException( + 'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`' + ) + # links data to the trainer self.data_connector.attach_data( model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule From e2d90c3d17f97c5e844f5f858de9a29087e2d60a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 14:45:43 +0200 Subject: [PATCH 19/30] Fix docs --- pytorch_lightning/tuner/lr_finder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index b93856e44361f..4f804f1632261 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -61,7 +61,7 @@ def _determine_lr_attr_name(trainer: 'pl.Trainer', model: 'pl.LightningModule') class _LRFinder(object): - """ LR finder object. This object stores the results of Trainer.lr_find(). + """ LR finder object. This object stores the results of lr_find(). Args: mode: either `linear` or `exponential`, how to increase lr after each step @@ -73,8 +73,10 @@ class _LRFinder(object): num_training: number of steps to take between lr_min and lr_max Example:: + trainer = Trainer(auto_lr_find=True) + # Run lr finder - lr_finder = trainer.lr_find(model) + lr_finder = trainer.tune(model) # Results stored in lr_finder.results From 869c857dbce583bab31a26576e6d51c434e72f65 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 14:57:00 +0200 Subject: [PATCH 20/30] Dict return for trainer.tune --- docs/source/advanced/lr_finder.rst | 3 +- docs/source/advanced/training_tricks.rst | 3 +- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/tuner/lr_finder.py | 3 +- pytorch_lightning/tuner/tuning.py | 13 +++++---- tests/tuner/test_lr_finder.py | 35 +++++++++++++----------- tests/tuner/test_scale_batch_size.py | 3 +- 7 files changed, 35 insertions(+), 27 deletions(-) diff --git a/docs/source/advanced/lr_finder.rst b/docs/source/advanced/lr_finder.rst index 8b62186d54198..26efdf64817e7 100644 --- a/docs/source/advanced/lr_finder.rst +++ b/docs/source/advanced/lr_finder.rst @@ -82,7 +82,8 @@ with the parameters of the algorithm. A typical example of this would look like: trainer = Trainer(auto_lr_find=True) # Run learning rate finder - lr_finder = trainer.tune(model, lr_find_kwargs={...}) + result = trainer.tune(model, lr_find_kwargs={...}) + lr_finder = result['lr_find'] # Results can be found in lr_finder.results diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index 0bd3f1ce1cffa..09620e8c478d8 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -120,7 +120,8 @@ invoking passing the ``scale_batch_size`` argument: trainer = Trainer(auto_scale_batch_size=True) # Invoke method - new_batch_size = trainer.tune(model, scale_batch_size_kwargs={...}) + result = trainer.tune(model, scale_batch_size_kwargs={...}) + new_batch_size = result['new_batch_size'] # Override old batch size (this is done automatically) model.hparams.batch_size = new_batch_size diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 32804f7125108..87f75c5108c09 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1087,7 +1087,7 @@ def tune( datamodule: Optional[LightningDataModule] = None, scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, lr_find_kwargs: Optional[Dict[str, Any]] = None, - ) -> Optional[Union[int, _LRFinder]]: + ) -> Dict[str, Optional[Union[int, _LRFinder]]]: r""" Runs routines to tune hyperparameters before training. diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 4f804f1632261..d8c501678b38a 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -76,7 +76,8 @@ class _LRFinder(object): trainer = Trainer(auto_lr_find=True) # Run lr finder - lr_finder = trainer.tune(model) + result = trainer.tune(model) + lr_finder = result['lr_find'] # Results stored in lr_finder.results diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 793aebedca09b..ce4183b206c9b 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -36,19 +36,20 @@ def _tune( model: 'pl.LightningModule', scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, lr_find_kwargs: Optional[Dict[str, Any]] = None, - ) -> Optional[Union[int, _LRFinder]]: + ) -> Dict[str, Optional[Union[int, _LRFinder]]]: scale_batch_size_kwargs = scale_batch_size_kwargs or {} lr_find_kwargs = lr_find_kwargs or {} - result = None + # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added + result = {} # Run auto batch size scaling if self.trainer.auto_scale_batch_size: - result = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs) + result['scale_batch_size'] = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs) # Run learning rate finder: if self.trainer.auto_lr_find: lr_find_kwargs.setdefault('update_attr', True) - result = lr_find(self.trainer, model, **lr_find_kwargs) + result['lr_find'] = lr_find(self.trainer, model, **lr_find_kwargs) self.trainer.state = TrainerState.FINISHED @@ -88,7 +89,7 @@ def scale_batch_size( } ) self.trainer.auto_scale_batch_size = False - return result + return result['scale_batch_size'] def lr_find( self, @@ -123,4 +124,4 @@ def lr_find( } ) self.trainer.auto_lr_find = False - return result + return result['lr_find'] diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index e6b530752407f..c4cb5b1aa0cd0 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -54,7 +54,7 @@ def test_model_reset_correctly(tmpdir): before_state_dict = deepcopy(model.state_dict()) - _ = trainer.tuner.lr_find(model, num_training=5) + trainer.tuner.lr_find(model, num_training=5) after_state_dict = model.state_dict() @@ -80,7 +80,7 @@ def test_trainer_reset_correctly(tmpdir): 'callbacks', 'logger', 'max_steps', 'auto_lr_find', 'accumulate_grad_batches', 'checkpoint_callback' ] expected = {ca: getattr(trainer, ca) for ca in changed_attributes} - _ = trainer.tuner.lr_find(model, num_training=5) + trainer.tuner.lr_find(model, num_training=5) actual = {ca: getattr(trainer, ca) for ca in changed_attributes} assert actual == expected @@ -159,7 +159,8 @@ def test_call_to_trainer_method(tmpdir, optimizer): max_epochs=2, ) - lrfinder = trainer.tuner.lr_find(model, mode='linear') + result = trainer.tuner.lr_find(model, mode='linear') + lrfinder = result['lr_find'] after_lr = lrfinder.suggestion() model.learning_rate = after_lr trainer.tune(model) @@ -182,7 +183,8 @@ def test_datamodule_parameter(tmpdir): max_epochs=2, ) - lrfinder = trainer.tuner.lr_find(model, datamodule=dm) + result = trainer.tuner.lr_find(model, datamodule=dm) + lrfinder = result['lr_find'] after_lr = lrfinder.suggestion() model.lr = after_lr @@ -204,7 +206,8 @@ def test_accumulation_and_early_stopping(tmpdir): accumulate_grad_batches=2, ) - lrfinder = trainer.tuner.lr_find(model, early_stop_threshold=None) + result = trainer.tuner.lr_find(model, early_stop_threshold=None) + lrfinder = result['lr_find'] after_lr = lrfinder.suggestion() expected_num_lrs = 100 @@ -230,7 +233,8 @@ def test_suggestion_parameters_work(tmpdir): max_epochs=3, ) - lrfinder = trainer.tuner.lr_find(model, datamodule=dm) + result = trainer.tuner.lr_find(model, datamodule=dm) + lrfinder = result['lr_find'] lr1 = lrfinder.suggestion(skip_begin=10) # default lr2 = lrfinder.suggestion(skip_begin=150) # way too high, should have an impact @@ -249,7 +253,8 @@ def test_suggestion_with_non_finite_values(tmpdir): max_epochs=3, ) - lrfinder = trainer.tuner.lr_find(model) + result = trainer.tuner.lr_find(model) + lrfinder = result['lr_find'] before_lr = lrfinder.suggestion() lrfinder.results['loss'][-1] = float('nan') after_lr = lrfinder.suggestion() @@ -278,12 +283,10 @@ def __init__(self, learning_rate=0.1, batch_size=2): before_lr = model.hparams.learning_rate # logger file to get meta - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=3, - ) - bs = trainer.tuner.scale_batch_size(model) - lr = trainer.tuner.lr_find(model).suggestion() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=3, auto_lr_find=True, auto_scale_batch_size=True) + result = trainer.tune(model) + bs = result['scale_batch_size'] + lr = result['lr_find'].suggestion() assert lr != before_lr assert isinstance(bs, int) @@ -303,13 +306,13 @@ def __init__(self, learning_rate=0.1): lr_min = 1e-8 lr_max = 1.0 - lr_finder = trainer.tuner.lr_find( + result = trainer.tuner.lr_find( model, max_lr=lr_min, min_lr=lr_max, num_training=3, ) - lr_candidates = lr_finder.results["lr"] + lr_candidates = result['lr_find'].results["lr"] assert all([lr_min <= lr <= lr_max for lr in lr_candidates]) @@ -329,7 +332,7 @@ def training_step_end(self, outputs): model = TestModel() trainer = Trainer(default_root_dir=tmpdir) num_training = 3 - _ = trainer.tuner.lr_find( + trainer.tuner.lr_find( model=model, num_training=num_training, ) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index e61cafec568ef..c6deab8699a04 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -223,7 +223,8 @@ def test_call_to_trainer_method(tmpdir, scale_method): max_epochs=1, ) - after_batch_size = trainer.tuner.scale_batch_size(model, mode=scale_method, max_trials=5) + result = trainer.tuner.scale_batch_size(model, mode=scale_method, max_trials=5) + after_batch_size = result['scale_batch_size'] model.batch_size = after_batch_size trainer.fit(model) From 668e68e7dda1615592d86a5577b57bdd589bb551 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 15:19:47 +0200 Subject: [PATCH 21/30] Undo some changes --- tests/tuner/test_lr_finder.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index c4cb5b1aa0cd0..e1a7d6edcfee3 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -159,8 +159,7 @@ def test_call_to_trainer_method(tmpdir, optimizer): max_epochs=2, ) - result = trainer.tuner.lr_find(model, mode='linear') - lrfinder = result['lr_find'] + lrfinder = trainer.tuner.lr_find(model, mode='linear') after_lr = lrfinder.suggestion() model.learning_rate = after_lr trainer.tune(model) @@ -183,8 +182,7 @@ def test_datamodule_parameter(tmpdir): max_epochs=2, ) - result = trainer.tuner.lr_find(model, datamodule=dm) - lrfinder = result['lr_find'] + lrfinder = trainer.tuner.lr_find(model, datamodule=dm) after_lr = lrfinder.suggestion() model.lr = after_lr @@ -206,8 +204,7 @@ def test_accumulation_and_early_stopping(tmpdir): accumulate_grad_batches=2, ) - result = trainer.tuner.lr_find(model, early_stop_threshold=None) - lrfinder = result['lr_find'] + lrfinder = trainer.tuner.lr_find(model, early_stop_threshold=None) after_lr = lrfinder.suggestion() expected_num_lrs = 100 From 617e9e59e8c7cb96da7c30a7b8548f3c0638326d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 15:22:26 +0200 Subject: [PATCH 22/30] Undo some changes --- tests/tuner/test_lr_finder.py | 10 ++++------ tests/tuner/test_scale_batch_size.py | 3 +-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index e1a7d6edcfee3..ff6df389cb488 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -230,8 +230,7 @@ def test_suggestion_parameters_work(tmpdir): max_epochs=3, ) - result = trainer.tuner.lr_find(model, datamodule=dm) - lrfinder = result['lr_find'] + lrfinder = trainer.tuner.lr_find(model, datamodule=dm) lr1 = lrfinder.suggestion(skip_begin=10) # default lr2 = lrfinder.suggestion(skip_begin=150) # way too high, should have an impact @@ -250,8 +249,7 @@ def test_suggestion_with_non_finite_values(tmpdir): max_epochs=3, ) - result = trainer.tuner.lr_find(model) - lrfinder = result['lr_find'] + lrfinder = trainer.tuner.lr_find(model) before_lr = lrfinder.suggestion() lrfinder.results['loss'][-1] = float('nan') after_lr = lrfinder.suggestion() @@ -303,13 +301,13 @@ def __init__(self, learning_rate=0.1): lr_min = 1e-8 lr_max = 1.0 - result = trainer.tuner.lr_find( + lrfinder = trainer.tuner.lr_find( model, max_lr=lr_min, min_lr=lr_max, num_training=3, ) - lr_candidates = result['lr_find'].results["lr"] + lr_candidates = lrfinder.results["lr"] assert all([lr_min <= lr <= lr_max for lr in lr_candidates]) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index c6deab8699a04..e61cafec568ef 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -223,8 +223,7 @@ def test_call_to_trainer_method(tmpdir, scale_method): max_epochs=1, ) - result = trainer.tuner.scale_batch_size(model, mode=scale_method, max_trials=5) - after_batch_size = result['scale_batch_size'] + after_batch_size = trainer.tuner.scale_batch_size(model, mode=scale_method, max_trials=5) model.batch_size = after_batch_size trainer.fit(model) From d60eb201710f51f8c4dcc79ee57c2933fa240a2d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 15:25:09 +0200 Subject: [PATCH 23/30] Undo some changes --- tests/tuner/test_lr_finder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index ff6df389cb488..641196eda466f 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -301,13 +301,13 @@ def __init__(self, learning_rate=0.1): lr_min = 1e-8 lr_max = 1.0 - lrfinder = trainer.tuner.lr_find( + lr_finder = trainer.tuner.lr_find( model, max_lr=lr_min, min_lr=lr_max, num_training=3, ) - lr_candidates = lrfinder.results["lr"] + lr_candidates = lr_finder.results["lr"] assert all([lr_min <= lr <= lr_max for lr in lr_candidates]) From 45023295d3e57ee1a23d895f514784297f138753 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 29 Apr 2021 17:03:02 +0200 Subject: [PATCH 24/30] Apply suggestions from code review Co-authored-by: Nicki Skafte --- docs/source/advanced/training_tricks.rst | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index 09620e8c478d8..77d425be9530e 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -112,7 +112,7 @@ search for batch sizes larger than the size of the training dataset. to `.fit()`. The scaling algorithm has a number of parameters that the user can control by -invoking passing the ``scale_batch_size`` argument: +passing the ``scale_batch_size_kwargs`` argument to ``trainer.tune``: .. code-block:: python diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 87f75c5108c09..68b1d8d73ab58 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1126,7 +1126,7 @@ def tune( ) result = self.tuner._tune( - model, scale_batch_size_kwargs=scale_batch_size_kwargs or {}, lr_find_kwargs=lr_find_kwargs or {} + model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs ) assert self.state.stopped From f2ccf216d203c05c199368e52e1d644e5fccac7e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 30 Apr 2021 13:40:17 +0200 Subject: [PATCH 25/30] Undo deprecation --- CHANGELOG.md | 3 - docs/source/advanced/lr_finder.rst | 8 +- docs/source/advanced/training_tricks.rst | 8 +- docs/source/api_references.rst | 12 +++ pytorch_lightning/tuner/batch_size_scaling.py | 42 +------- pytorch_lightning/tuner/lr_finder.py | 37 +------ pytorch_lightning/tuner/tuning.py | 97 ++++++++++++++++--- tests/deprecated_api/test_remove_1-5.py | 14 --- 8 files changed, 108 insertions(+), 113 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 875ce795d8189..79f33123b9b4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -212,9 +212,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `@auto_move_data` in favor of `trainer.predict` ([#6993](https://github.com/PyTorchLightning/pytorch-lightning/pull/6993)) -- Deprecated `trainer.tuner.{lr_find,scale_batch_size}` in favor of `trainer.tune()` ([#7258](https://github.com/PyTorchLightning/pytorch-lightning/pull/7258)) - - - Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), [#6540](https://github.com/PyTorchLightning/pytorch-lightning/pull/6540), diff --git a/docs/source/advanced/lr_finder.rst b/docs/source/advanced/lr_finder.rst index 26efdf64817e7..ad91b3d356872 100644 --- a/docs/source/advanced/lr_finder.rst +++ b/docs/source/advanced/lr_finder.rst @@ -74,16 +74,16 @@ If your model is using an arbitrary value instead of ``self.lr`` or ``self.learn You can also inspect the results of the learning rate finder or just play around -with the parameters of the algorithm. A typical example of this would look like: +with the parameters of the algorithm. This can be done by invoking the +:meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find` method .A typical example of this would look like: .. code-block:: python model = MyModelClass(hparams) - trainer = Trainer(auto_lr_find=True) + trainer = Trainer() # Run learning rate finder - result = trainer.tune(model, lr_find_kwargs={...}) - lr_finder = result['lr_find'] + lr_finder = trainer.tuner.lr_find(model) # Results can be found in lr_finder.results diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index 77d425be9530e..3845d71a3968f 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -112,16 +112,16 @@ search for batch sizes larger than the size of the training dataset. to `.fit()`. The scaling algorithm has a number of parameters that the user can control by -passing the ``scale_batch_size_kwargs`` argument to ``trainer.tune``: +invoking the :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size` method: .. code-block:: python # Use default in trainer construction - trainer = Trainer(auto_scale_batch_size=True) + trainer = Trainer() + tuner = Tuner(trainer) # Invoke method - result = trainer.tune(model, scale_batch_size_kwargs={...}) - new_batch_size = result['new_batch_size'] + new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here) # Override old batch size (this is done automatically) model.hparams.batch_size = new_batch_size diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index 84e1f47f570d9..642b11b5bdad4 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -150,6 +150,18 @@ Trainer API trainer +Tuner API +--------- + +.. currentmodule:: pytorch_lightning.tuner.tuning + +.. autosummary:: + :toctree: api + :nosignatures: + :template: classtemplate.rst + + Tuner + Utilities API ------------- diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 1ce318d9da965..fe196f760be05 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -36,47 +36,7 @@ def scale_batch_size( max_trials: int = 25, batch_arg_name: str = 'batch_size', ) -> Optional[int]: - r""" - Will iteratively try to find the largest batch size for a given model - that does not give an out of memory (OOM) error. - - Args: - trainer: The Trainer - - model: Model to fit. - - mode: string setting the search mode. Either `power` or `binsearch`. - If mode is `power` we keep multiplying the batch size by 2, until - we get an OOM error. If mode is 'binsearch', we will initially - also keep multiplying by 2 and after encountering an OOM error - do a binary search between the last successful batch size and the - batch size that failed. - - steps_per_trial: number of steps to run with a given batch size. - Ideally 1 should be enough to test if a OOM error occurs, - however in practise a few are needed - - init_val: initial batch size to start the search with - - max_trials: max number of increase in batch size done before - algorithm is terminated - - batch_arg_name: name of the attribute that stores the batch size. - It is expected that the user has provided a model or datamodule that has a hyperparameter - with that name. We will look for this attribute name in the following places - - - ``model`` - - ``model.hparams`` - - ``model.datamodule`` - - ``trainer.datamodule`` (the datamodule passed to the tune method) - - Raises: - MisconfigurationException: - If field ``batch_arg_name`` is not found in ``model`` and ``model.hparams``, or - if batch scaling feature is used with dataloaders passed directly to ``.fit()``. - ValueError: - If mode in method ``scale_batch_size`` is neither ``power`` nor ``binsearch``. - """ + """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`""" if trainer.fast_dev_run: rank_zero_warn('Skipping batch size scaler since fast_dev_run is enabled.', UserWarning) return diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index d8c501678b38a..83af1fa802942 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -73,11 +73,8 @@ class _LRFinder(object): num_training: number of steps to take between lr_min and lr_max Example:: - trainer = Trainer(auto_lr_find=True) - # Run lr finder - result = trainer.tune(model) - lr_finder = result['lr_find'] + lr_finder = trainer.lr_find(model) # Results stored in lr_finder.results @@ -207,37 +204,7 @@ def lr_find( early_stop_threshold: float = 4.0, update_attr: bool = False, ) -> Optional[_LRFinder]: - r""" - ``lr_find`` enables the user to do a range test of good initial learning rates, - to reduce the amount of guesswork in picking a good starting learning rate. - - Args: - trainer: The Trainer - - model: Model to do range testing for - - min_lr: minimum learning rate to investigate - - max_lr: maximum learning rate to investigate - - num_training: number of learning rates to test - - mode: Search strategy to update learning rate after each batch: - - - ``'exponential'`` (default): Will increase the learning rate exponentially. - - ``'linear'``: Will increase the learning rate linearly. - - early_stop_threshold: threshold for stopping the search. If the - loss at any point is larger than early_stop_threshold*best_loss - then the search is stopped. To disable, set to None. - - update_attr: Whether to update the learning rate attribute or not. - - Raises: - MisconfigurationException: - If learning rate/lr in ``model`` or ``model.hparams`` isn't overriden when ``auto_lr_find=True``, or - if you are using `more than one optimizer` with learning rate finder. - """ + """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" if trainer.fast_dev_run: rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) return diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index ce4183b206c9b..994d87085b493 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -19,10 +19,10 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find -from pytorch_lightning.utilities import rank_zero_deprecation class Tuner: + """Tuner class to tune your model""" def __init__(self, trainer: 'pl.Trainer') -> None: self.trainer = trainer @@ -65,21 +65,61 @@ def _launch(self, *args: Any, **kwargs: Any) -> None: def scale_batch_size( self, model: 'pl.LightningModule', + train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional['pl.LightningDataModule'] = None, mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', - **fit_kwargs ) -> Optional[int]: - rank_zero_deprecation( - "`Tuner.scale_batch_size()` is deprecated in v1.3 and will be removed in v1.5." - " Please use `trainer.tune(scale_batch_size_kwargs={...})` instead." - ) + """ + Iteratively try to find the largest batch size for a given model + that does not give an out of memory (OOM) error. + + Args: + model: Model to tune. + + train_dataloader: A Pytorch DataLoader with training samples. If the model has + a predefined train_dataloader method this will be skipped. + + val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. + If the model has a predefined val_dataloaders method this will be skipped + + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + + mode: string setting the search mode. Either ``power`` or ``binsearch``. + If mode is ``power`` we keep multiplying the batch size by 2, until + we get an OOM error. If mode is ``binsearch``, we will initially + also keep multiplying by 2 and after encountering an OOM error + do a binary search between the last successful batch size and the + batch size that failed. + + steps_per_trial: number of steps to run with a given batch size. + Ideally 1 should be enough to test if a OOM error occurs, + however in practise a few are needed + + init_val: initial batch size to start the search with + + max_trials: max number of increase in batch size done before + algorithm is terminated + + batch_arg_name: name of the attribute that stores the batch size. + It is expected that the user has provided a model or datamodule that has a hyperparameter + with that name. We will look for this attribute name in the following places + + - ``model`` + - ``model.hparams`` + - ``model.datamodule`` + - ``trainer.datamodule`` (the datamodule passed to the tune method) + """ self.trainer.auto_scale_batch_size = True result = self.trainer.tune( model, - **fit_kwargs, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders, + datamodule=datamodule, scale_batch_size_kwargs={ 'mode': mode, 'steps_per_trial': steps_per_trial, @@ -96,18 +136,51 @@ def lr_find( model: 'pl.LightningModule', train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional['pl.LightningDataModule'] = None, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, - datamodule: Optional['pl.LightningDataModule'] = None, update_attr: bool = False, ) -> Optional[_LRFinder]: - rank_zero_deprecation( - "`Tuner.lr_find()` is deprecated in v1.3 and will be removed in v1.5." - " Please use `trainer.tune(lr_finder_kwargs={...})` instead." - ) + """ + Enables the user to do a range test of good initial learning rates, + to reduce the amount of guesswork in picking a good starting learning rate. + + Args: + model: Model to tune. + + train_dataloader: A Pytorch DataLoader with training samples. If the model has + a predefined train_dataloader method this will be skipped. + + val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. + If the model has a predefined val_dataloaders method this will be skipped + + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + + min_lr: minimum learning rate to investigate + + max_lr: maximum learning rate to investigate + + num_training: number of learning rates to test + + mode: Search strategy to update learning rate after each batch: + + - ``'exponential'`` (default): Will increase the learning rate exponentially. + - ``'linear'``: Will increase the learning rate linearly. + + early_stop_threshold: threshold for stopping the search. If the + loss at any point is larger than early_stop_threshold*best_loss + then the search is stopped. To disable, set to None. + + update_attr: Whether to update the learning rate attribute or not. + + Raises: + MisconfigurationException: + If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden when ``auto_lr_find=True``, + or if you are using more than one optimizer. + """ self.trainer.auto_lr_find = True result = self.trainer.tune( model, diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 0c076dd84e083..6516fbcc18639 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -294,17 +294,3 @@ def test_v1_5_0_trainer_logging_mixin(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False) with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): trainer.metrics_to_scalars({}) - - -def test_v1_5_0_tuner_scale_batch_size(): - trainer = Trainer(fast_dev_run=True) - model = BoringModel() - with pytest.deprecated_call(match=r"scale_batch_size\(\)` is deprecated in v1.3 and will be removed in v1.5"): - trainer.tuner.scale_batch_size(model) - - -def test_v1_5_0_tuner_lr_find(): - trainer = Trainer(fast_dev_run=True) - model = BoringModel() - with pytest.deprecated_call(match=r"lr_find\(\)` is deprecated in v1.3 and will be removed in v1.5"): - trainer.tuner.lr_find(model) From 9cbc15434116b6e4e67ced2cdc7ed0efa6b39bd2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 30 Apr 2021 13:41:01 +0200 Subject: [PATCH 26/30] Fix docs --- pytorch_lightning/core/datamodule.py | 2 +- pytorch_lightning/trainer/trainer.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index a1f1c02ef498d..9550ceae4a9cc 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -319,7 +319,7 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): Args: args: The parser or namespace to take arguments from. Only known arguments will be - parsed and passed to the :class:`LightningDataModule`. + parsed and passed to the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. **kwargs: Additional keyword arguments that may override ones in the parser or namespace. These must be valid DataModule arguments. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 68b1d8d73ab58..019fd2a9984be 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -843,7 +843,7 @@ def fit( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped - datamodule: A instance of :class:`LightningDataModule`. + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ Trainer._log_api_event("fit") @@ -893,7 +893,7 @@ def validate( verbose: If True, prints the validation results. - datamodule: A instance of :class:`LightningDataModule`. + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. Returns: The dictionary with final validation results returned by validation_epoch_end. @@ -956,7 +956,7 @@ def test( verbose: If True, prints the test results. - datamodule: A instance of :class:`LightningDataModule`. + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. Returns: Returns a list of dictionaries, one for each test dataloader containing their respective metrics. @@ -1100,7 +1100,7 @@ def tune( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped - datamodule: A instance of :class:`LightningDataModule`. + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find` @@ -1125,9 +1125,7 @@ def tune( model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule ) - result = self.tuner._tune( - model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs - ) + result = self.tuner._tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs) assert self.state.stopped self.tuning = False From 46a4e925184bad9a8d53ab4b3bd2f28c3fe5e6b3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 30 Apr 2021 13:43:52 +0200 Subject: [PATCH 27/30] Fix docs --- docs/source/advanced/training_tricks.rst | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index 3845d71a3968f..25a96ea481153 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -1,6 +1,7 @@ .. testsetup:: * from pytorch_lightning.trainer.trainer import Trainer + from pytorch_lightning.core.lightning import LightningModule .. _training_tricks: @@ -142,10 +143,6 @@ The algorithm in short works by: 3. The found batch size is saved to either `model.batch_size` or `model.hparams.batch_size` 4. Restore the initial state of model and trainer -.. autoclass:: pytorch_lightning.tuner.tuning.Tuner - :noindex: - :members: scale_batch_size - .. warning:: Batch size finder is not supported for DDP yet, it is coming soon. From 175e2cd1b09a494ca3fb67d06eb6fded6cb86bc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 30 Apr 2021 14:01:14 +0200 Subject: [PATCH 28/30] Apply suggestions from code review --- docs/source/advanced/lr_finder.rst | 2 +- docs/source/advanced/training_tricks.rst | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/advanced/lr_finder.rst b/docs/source/advanced/lr_finder.rst index ad91b3d356872..fe2c82c661872 100644 --- a/docs/source/advanced/lr_finder.rst +++ b/docs/source/advanced/lr_finder.rst @@ -75,7 +75,7 @@ If your model is using an arbitrary value instead of ``self.lr`` or ``self.learn You can also inspect the results of the learning rate finder or just play around with the parameters of the algorithm. This can be done by invoking the -:meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find` method .A typical example of this would look like: +:meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find` method. A typical example of this would look like: .. code-block:: python diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index 25a96ea481153..dd16f7c914107 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -1,7 +1,6 @@ .. testsetup:: * from pytorch_lightning.trainer.trainer import Trainer - from pytorch_lightning.core.lightning import LightningModule .. _training_tricks: From 23f2da2760e328dd55117df42fbb7ace4a98e526 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 30 Apr 2021 14:42:55 +0200 Subject: [PATCH 29/30] Apply suggestions from code review Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> --- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/trainer/test_trainer.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 019fd2a9984be..f897d72354a26 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1102,9 +1102,9 @@ def tune( datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. - scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find` + scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size` - lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size` + lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find` """ Trainer._log_api_event("tune") self.state = TrainerState.TUNING diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d7da3f58b6956..b3473b5847f4e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1977,7 +1977,8 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None: def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(BoringModel()) + model = BoringModel() + trainer.fit(model) with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"): trainer.validate() From 841381c080a398ce15c2e984a73bb28d04dfb711 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 30 Apr 2021 14:48:28 +0200 Subject: [PATCH 30/30] Update scale_batch_size mode docstring --- pytorch_lightning/tuner/tuning.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 155a003ae26c7..8e3862b195cd6 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -89,12 +89,11 @@ def scale_batch_size( datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. - mode: string setting the search mode. Either ``power`` or ``binsearch``. - If mode is ``power`` we keep multiplying the batch size by 2, until - we get an OOM error. If mode is ``binsearch``, we will initially - also keep multiplying by 2 and after encountering an OOM error - do a binary search between the last successful batch size and the - batch size that failed. + mode: Search strategy to update the batch size: + + - ``'power'`` (default): Keep multiplying the batch size by 2, until we get an OOM error. + - ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error + do a binary search between the last successful batch size and the batch size that failed. steps_per_trial: number of steps to run with a given batch size. Ideally 1 should be enough to test if a OOM error occurs,