From 788f6864d935238232e75994f1cfb3a5e42fb459 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 18 Oct 2021 02:23:51 +0200 Subject: [PATCH 01/22] Fix `LightningOptimizer` step and toggling logic (#9958) --- pytorch_lightning/core/optimizer.py | 33 ++++++++++----------------- pytorch_lightning/profiler/pytorch.py | 24 +++++++++---------- tests/profiler/test_profiler.py | 2 +- 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index ba81644b9bd9a..dd074450f0897 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -93,14 +93,6 @@ def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx): optimizer = trainer.lightning_optimizers[opt_idx] return optimizer - def _toggle_model(self): - model_ref = self._trainer.lightning_module - model_ref.toggle_optimizer(self, self._optimizer_idx) - - def _untoggle_model(self): - model_ref = self._trainer.lightning_module - model_ref.untoggle_optimizer(self) - @contextmanager def toggle_model(self, sync_grad: bool = True): """This function is just a helper for advanced users. @@ -116,16 +108,12 @@ def toggle_model(self, sync_grad: bool = True): # local import here to avoid circular import from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior + lightning_module = self._trainer.lightning_module + with _block_parallel_sync_behavior(self._trainer, block=(not sync_grad)): - self._toggle_model() + lightning_module.toggle_optimizer(self, self._optimizer_idx) yield - self._untoggle_model() - - def __optimizer_step(self, closure: Callable, profiler_name: str = None, **kwargs): - trainer = self._trainer - - with trainer.profiler.profile(profiler_name): - trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs) + lightning_module.untoggle_optimizer(self._optimizer_idx) def step(self, closure: Optional[Callable] = None, **kwargs): """Call this directly from your training_step when doing optimizations manually. By using this we can @@ -193,14 +181,17 @@ def closure_dis(): opt_dis.step(closure=closure_dis) """ if closure is None: - profiler_name = f"closure_{self._optimizer_idx}" closure = do_nothing_closure + profiler_action = "optimizer_step_without_closure" + elif not callable(closure): + raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable") else: - if not callable(closure): - raise MisconfigurationException("When closure is provided, it should be a function") - profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" + profiler_action = "optimizer_step_with_closure" + profiler_action += f"_{self._optimizer_idx}" - self.__optimizer_step(closure=closure, profiler_name=profiler_name, **kwargs) + trainer = self._trainer + with trainer.profiler.profile(profiler_action): + trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) self._total_optimizer_step_calls += 1 def __repr__(self): diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 8bdbadffec15b..58f4a18895498 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -117,11 +117,11 @@ def pre_step(self, current_action: str) -> None: def reset(self): # handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise. - self._num_optimizer_step_and_closure = 0 + self._num_optimizer_step_with_closure = 0 self._num_validation_step = 0 self._num_test_step = 0 self._num_predict_step = 0 - self._optimizer_step_and_closure_reached_end = False + self._optimizer_step_with_closure_reached_end = False self._validation_step_reached_end = False self._test_step_reached_end = False self._predict_step_reached_end = False @@ -132,13 +132,13 @@ def reset(self): @property def is_training(self) -> bool: return self._current_action is not None and ( - self._current_action.startswith("optimizer_step_and_closure_") or self._current_action == "training_step" + self._current_action.startswith("optimizer_step_with_closure_") or self._current_action == "training_step" ) @property def num_step(self) -> int: if self.is_training: - return self._num_optimizer_step_and_closure + return self._num_optimizer_step_with_closure if self._current_action == "validation_step": return self._num_validation_step if self._current_action == "test_step": @@ -149,10 +149,10 @@ def num_step(self) -> int: def _step(self) -> None: if self.is_training: - self._num_optimizer_step_and_closure += 1 + self._num_optimizer_step_with_closure += 1 elif self._current_action == "validation_step": if self._start_action_name == "on_fit_start": - if self._num_optimizer_step_and_closure > 0: + if self._num_optimizer_step_with_closure > 0: self._num_validation_step += 1 else: self._num_validation_step += 1 @@ -164,7 +164,7 @@ def _step(self) -> None: @property def has_finished(self) -> bool: if self.is_training: - return self._optimizer_step_and_closure_reached_end + return self._optimizer_step_with_closure_reached_end if self._current_action == "validation_step": return self._validation_step_reached_end if self._current_action == "test_step": @@ -182,7 +182,7 @@ def __call__(self, num_step: int) -> "ProfilerAction": action = self._schedule(max(self.num_step, 0)) if action == ProfilerAction.RECORD_AND_SAVE: if self.is_training: - self._optimizer_step_and_closure_reached_end = True + self._optimizer_step_with_closure_reached_end = True elif self._current_action == "validation_step": self._validation_step_reached_end = True elif self._current_action == "test_step": @@ -202,9 +202,9 @@ class PyTorchProfiler(BaseProfiler): "test_step", "predict_step", } - RECORD_FUNCTION_PREFIX = "optimizer_step_and_closure_" + RECORD_FUNCTION_PREFIX = "optimizer_step_with_closure_" STEP_FUNCTIONS = {"training_step", "validation_step", "test_step", "predict_step"} - STEP_FUNCTION_PREFIX = "optimizer_step_and_closure_" + STEP_FUNCTION_PREFIX = "optimizer_step_with_closure_" AVAILABLE_SORT_KEYS = { "cpu_time", "cuda_time", @@ -383,8 +383,8 @@ def start(self, action_name: str) -> None: self._register.__enter__() if self._lightning_module is not None: - # when the model is used in automatic optimization, - # we use `optimizer_step_and_closure` to step the model. + # when the model is used in automatic optimization, we use `optimizer_step_with_closure` to step the model. + # this profiler event is generated in the `LightningOptimizer.step` method if self._lightning_module.automatic_optimization and "training_step" in self.STEP_FUNCTIONS: self.STEP_FUNCTIONS.remove("training_step") diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index e00c6707dfa1c..9e22c3f8ec9f7 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -285,7 +285,7 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler): files = [file for file in files if file.endswith(".json")] assert len(files) == 2, files local_rank = trainer.local_rank - assert any(f"{local_rank}-optimizer_step_and_closure_" in f for f in files) + assert any(f"{local_rank}-optimizer_step_with_closure_" in f for f in files) assert any(f"{local_rank}-validation_step" in f for f in files) From 01b304ec574ecd06b9ad73e64cf37747d7e713b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 18 Oct 2021 03:10:48 +0200 Subject: [PATCH 02/22] Update accelerator connector messages after the addition of strategy (#9937) --- .../connectors/accelerator_connector.py | 59 ++++++++----------- .../test_accelerator_connector.py | 8 +-- tests/accelerators/test_tpu.py | 1 - 3 files changed, 30 insertions(+), 38 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index c25cf47c24ba4..53f95ae4c8a14 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -244,12 +244,13 @@ def _validate_accelerator_and_devices(self) -> None: raise MisconfigurationException( f"You passed `devices={self.devices}` but haven't specified" " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping," - f" got `accelerator={self.distributed_backend}`." + f" got `accelerator={self.distributed_backend!r}`." ) def _validate_accelerator_type(self) -> None: if self._accelerator_type and self._accelerator_type != self._device_type: - raise MisconfigurationException( + # internal error: should not happen. + raise ValueError( f"Mismatch between the requested accelerator type ({self._accelerator_type})" f" and assigned device type ({self._device_type})." ) @@ -259,25 +260,16 @@ def _warn_if_devices_flag_ignored(self) -> None: if self.devices is None: return devices_warning = f"The flag `devices={self.devices}` will be ignored, as you have set" - if self.distributed_backend == "auto": + if self.distributed_backend in ("auto", DeviceType.TPU): if self.tpu_cores is not None: rank_zero_warn(f"{devices_warning} `tpu_cores={self.tpu_cores}`") - elif self.ipus is not None: - rank_zero_warn(f"{devices_warning} `ipus={self.ipus}`") - elif self.gpus is not None: - rank_zero_warn(f"{devices_warning} `gpus={self.gpus}`") - elif self.num_processes != 1: - rank_zero_warn(f"{devices_warning} `num_processes={self.num_processes}`") - elif self.distributed_backend == DeviceType.TPU: - if self.tpu_cores is not None: - rank_zero_warn(f"{devices_warning} `tpu_cores={self.tpu_cores}`") - elif self.distributed_backend == DeviceType.IPU: + elif self.distributed_backend in ("auto", DeviceType.IPU): if self.ipus is not None: rank_zero_warn(f"{devices_warning} `ipus={self.ipus}`") - elif self.distributed_backend == DeviceType.GPU: + elif self.distributed_backend in ("auto", DeviceType.GPU): if self.gpus is not None: rank_zero_warn(f"{devices_warning} `gpus={self.gpus}`") - elif self.distributed_backend == DeviceType.CPU: + elif self.distributed_backend in ("auto", DeviceType.CPU): if self.num_processes != 1: rank_zero_warn(f"{devices_warning} `num_processes={self.num_processes}`") @@ -298,26 +290,27 @@ def _handle_accelerator_and_distributed_backend( ) -> None: if distributed_backend is not None: rank_zero_deprecation( - f"`Trainer(distributed_backend={distributed_backend})` has been deprecated and will be removed in v1.5." - f" Use `Trainer(strategy={distributed_backend})` instead." + f"`Trainer(distributed_backend={distributed_backend!r})` " + "has been deprecated and will be removed in v1.5." + f" Use `Trainer(strategy={distributed_backend!r})` instead." ) if self.strategy is not None: raise MisconfigurationException( - f"You have passed `Trainer(strategy={self.strategy})` but have" - f" also passed `Trainer(distributed_backend={distributed_backend})`." - f"HINT: Use just `Trainer(strategy={self.strategy})` instead." + f"You have passed `Trainer(strategy={self.strategy!r})` but have" + f" also passed `Trainer(distributed_backend={distributed_backend!r})`." + f" HINT: Use just `Trainer(strategy={self.strategy!r})` instead." ) if accelerator is not None and accelerator in list(DistributedType): rank_zero_deprecation( - f"Passing {accelerator} `strategy` to the `accelerator` flag in Trainer has been deprecated" - f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={accelerator})` instead." + f"Passing `Trainer(accelerator={accelerator!r})` has been deprecated" + f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={accelerator!r})` instead." ) if self.strategy is not None: raise MisconfigurationException( - f"You have passed `Trainer(strategy={self.strategy})` but have" - f" also passed `Trainer(accelerator={accelerator})`." - f"HINT: Use just `Trainer(strategy={self.strategy})` instead." + f"You have passed `Trainer(strategy={self.strategy!r})` but have" + f" also passed `Trainer(accelerator={accelerator!r})`." + f" HINT: Use just `Trainer(strategy={self.strategy!r})` instead." ) def _set_training_type_plugin(self) -> None: @@ -333,7 +326,7 @@ def handle_given_plugins(self) -> None: for plug in self.plugins: if self.strategy is not None and self._is_plugin_training_type(plug): raise MisconfigurationException( - f"You have passed `Trainer(strategy={self.strategy})`" + f"You have passed `Trainer(strategy={self.strategy!r})`" f" and you can only specify one training type plugin, but you have passed {plug} as a plugin." ) if self._is_plugin_training_type(plug): @@ -507,7 +500,7 @@ def _map_devices_to_accelerator(self, accelerator: str) -> bool: if accelerator == DeviceType.CPU: if not isinstance(self.devices, int): raise MisconfigurationException( - "The flag `devices` only supports integer for `accelerator='cpu'`," + "The flag `devices` must be an int with `accelerator='cpu'`," f" got `devices={self.devices}` instead." ) self.num_processes = self.devices @@ -816,7 +809,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): elif self.num_gpus > 1 and not _use_cpu: rank_zero_warn( "You requested multiple GPUs but did not specify a backend, e.g." - ' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.' + ' `Trainer(strategy="dp"|"ddp"|"ddp2")`. Setting `strategy="ddp_spawn"` for you.' ) self.distributed_backend = DistributedType.DDP_SPAWN @@ -833,7 +826,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): self._distrib_type = DistributedType.DDP_SPAWN if self.num_gpus > 0: rank_zero_warn( - "You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs." + "You requested one or more GPUs, but set `accelerator='ddp_cpu'`. Training will not use GPUs." ) self.parallel_device_ids = None if self.num_processes is None: @@ -859,7 +852,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1): if self._distrib_type in (DistributedType.DP, DistributedType.DDP2): rank_zero_warn( - f"{self._distrib_type} is not supported on CPUs, hence setting the distributed type to `ddp`." + f"{self._distrib_type.value!r} is not supported on CPUs, hence setting `strategy='ddp'`." ) self._distrib_type = DistributedType.DDP else: @@ -887,8 +880,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): if self.num_nodes > 1 and not using_valid_distributed: # throw error to force user to choose a supported distributed type such as ddp or ddp2 raise MisconfigurationException( - "Your chosen distributed type does not support num_nodes > 1. " - "Please set accelerator=ddp or accelerator=ddp2." + "Your chosen strategy does not support `num_nodes > 1`. Please set `strategy=('ddp'|'ddp2')`." ) def _set_horovod_backend(self): @@ -910,7 +902,8 @@ def check_interactive_compatibility(self): if _IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible(): raise MisconfigurationException( - f"Selected distributed backend {self._distrib_type} is not compatible with an interactive" + f"`Trainer(strategy={self._distrib_type.value!r})` or" + f" `Trainer(accelerator={self._distrib_type.value!r})` is not compatible with an interactive" " environment. Run your code as a script, or choose one of the compatible backends:" f" {', '.join(DistributedType.interactive_compatible_types())}." " In case you are spawning processes yourself, make sure to include the Trainer" diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 014e577d7c97c..d39eb82cc95cb 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -447,10 +447,10 @@ def on_fit_start(self, trainer, pl_module): @mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True) @mock.patch("torch.cuda.device_count", return_value=2) def test_ipython_incompatible_backend_error(*_): - with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"): + with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): Trainer(accelerator="ddp", gpus=2) - with pytest.raises(MisconfigurationException, match="backend ddp2 is not compatible"): + with pytest.raises(MisconfigurationException, match=r"strategy='ddp2'\)`.*is not compatible"): Trainer(accelerator="ddp2", gpus=2) @@ -615,14 +615,14 @@ def test_set_devices_if_none_gpu(): def test_devices_with_cpu_only_supports_integer(): - with pytest.raises(MisconfigurationException, match="The flag `devices` only supports integer"): + with pytest.raises(MisconfigurationException, match="The flag `devices` must be an int"): Trainer(accelerator="cpu", devices="1,3") @pytest.mark.parametrize("training_type", ["ddp2", "dp"]) def test_unsupported_distrib_types_on_cpu(training_type): - with pytest.warns(UserWarning, match="is not supported on CPUs, hence setting the distributed type to `ddp`."): + with pytest.warns(UserWarning, match="is not supported on CPUs, hence setting `strategy='ddp"): trainer = Trainer(accelerator=training_type, num_processes=2) assert trainer._distrib_type == DistributedType.DDP diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index df5444ac776a6..622ead614b1c7 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -222,7 +222,6 @@ def on_train_end(self, trainer, pl_module): @RunIf(tpu=True) def test_ddp_cpu_not_supported_on_tpus(): - with pytest.raises(MisconfigurationException, match="`accelerator='ddp_cpu'` is not supported on TPU machines"): Trainer(accelerator="ddp_cpu") From 0b6f679bc567987a80c123e0e56f00d90b4c673b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 11:34:48 +0200 Subject: [PATCH 03/22] update type for setup --- pytorch_lightning/lite/lite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 76d5a6607c8e2..1d17d44e9765a 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -301,8 +301,8 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> None: def _setup_model_and_optimizers( self, model: nn.Module, - optimizers: Union[Optimizer, List[Optimizer]], - ) -> Tuple[_LiteModule, Union[_LiteOptimizer, List[_LiteOptimizer]]]: + optimizers: List[Optimizer], + ) -> Tuple[_LiteModule, List[_LiteOptimizer]]: # Let accelerator/plugin wrap and connect the models and optimizers [model], optimizers = self._strategy.setup_models_and_optimizers([model], optimizers) model = _LiteModule(module=model, accelerator=self._accelerator) From 7a9151637cea5a4f1ac64f13cbd1e16b28332a53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 11:43:11 +0200 Subject: [PATCH 04/22] loop customization docs (#9609) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos MocholĂ­ Co-authored-by: thomas chaton Co-authored-by: edenlightning <66261195+edenlightning@users.noreply.github.com> --- docs/source/advanced/sequences.rst | 2 +- docs/source/api_references.rst | 65 +++ docs/source/extensions/loops.rst | 403 ++++++++++++++++++ docs/source/extensions/loops_advanced.rst | 41 ++ docs/source/index.rst | 2 +- docs/source/starter/new-project.rst | 5 + pytorch_lightning/loops/base.py | 2 +- .../loops/optimization/__init__.py | 1 + 8 files changed, 518 insertions(+), 3 deletions(-) create mode 100644 docs/source/extensions/loops.rst create mode 100644 docs/source/extensions/loops_advanced.rst diff --git a/docs/source/advanced/sequences.rst b/docs/source/advanced/sequences.rst index 8e50de49933eb..2d8d770cbb850 100644 --- a/docs/source/advanced/sequences.rst +++ b/docs/source/advanced/sequences.rst @@ -1,6 +1,6 @@ Sequential Data -================ +=============== Truncated Backpropagation Through Time -------------------------------------- diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index df70b2b0a3944..7bc4d8b460e8d 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -67,6 +67,71 @@ Loggers API test_tube wandb +Loop API +-------- + +Base Classes +^^^^^^^^^^^^ + +.. currentmodule:: pytorch_lightning.loops + +.. autosummary:: + :toctree: api + :nosignatures: + :template: classtemplate.rst + + ~base.Loop + ~dataloader.dataloader_loop.DataLoaderLoop + + +Default Loop Implementations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Training +"""""""" + +.. currentmodule:: pytorch_lightning.loops + +.. autosummary:: + :toctree: api + :nosignatures: + :template: classtemplate.rst + + FitLoop + ~epoch.TrainingEpochLoop + ~batch.TrainingBatchLoop + ~optimization.OptimizerLoop + ~optimization.ManualOptimization + + +Validation and Testing +"""""""""""""""""""""" + +.. currentmodule:: pytorch_lightning.loops + +.. autosummary:: + :toctree: api + :nosignatures: + :template: classtemplate.rst + + ~dataloader.EvaluationLoop + ~epoch.EvaluationEpochLoop + + +Prediction +"""""""""" + +.. currentmodule:: pytorch_lightning.loops + +.. autosummary:: + :toctree: api + :nosignatures: + :template: classtemplate.rst + + ~dataloader.PredictionLoop + ~epoch.PredictionEpochLoop + + Plugins API ----------- diff --git a/docs/source/extensions/loops.rst b/docs/source/extensions/loops.rst new file mode 100644 index 0000000000000..b83a64d2f6b74 --- /dev/null +++ b/docs/source/extensions/loops.rst @@ -0,0 +1,403 @@ +.. _loop_customization: + +Loops +===== + +Loops let advanced users swap out the default gradient descent optimization loop at the core of Lightning with a different optimization paradigm. + +The Lightning Trainer is built on top of the standard gradient descent optimization loop which works for 90%+ of machine learning use cases: + +.. code-block:: python + + for i, batch in enumerate(dataloader): + x, y = batch + y_hat = model(x) + loss = loss_function(y_hat, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + +However, some new research use cases such as meta-learning, active learning, recommendation systems, etc., require a different loop structure. +For example here is a simple loop that guides the weight updates with a loss from a special validation split: + +.. code-block:: python + + for i, batch in enumerate(train_dataloader): + x, y = batch + y_hat = model(x) + loss = loss_function(y_hat, y) + optimizer.zero_grad() + loss.backward() + + val_loss = 0 + for i, val_batch in enumerate(val_dataloader): + x, y = val_batch + y_hat = model(x) + val_loss += loss_function(y_hat, y) + + scale_gradients(model, 1 / val_loss) + optimizer.step() + + +With Lightning Loops, you can customize to non-standard gradient descent optimizations to get the same loop above: + +.. code-block:: python + + trainer = Trainer() + trainer.fit_loop.epoch_loop = MyGradientDescentLoop() + +Think of this as swapping out the engine in a car! + +Understanding the default Trainer loop +-------------------------------------- + +The Lightning :class:`~pytorch_lightning.trainer.trainer.Trainer` automates the standard optimization loop which every PyTorch user is familiar with: + +.. code-block:: python + + for i, batch in enumerate(dataloader): + x, y = batch + y_hat = model(x) + loss = loss_function(y_hat, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + +The core research logic is simply shifted to the :class:`~pytorch_lightning.core.lightning.LightningModule`: + +.. code-block:: python + + for i, batch in enumerate(dataloader): + # x, y = batch moved to training_step + # y_hat = model(x) moved to training_step + # loss = loss_function(y_hat, y) moved to training_step + loss = lightning_module.training_step(batch, i) + + # Lighting handles automatically: + optimizer.zero_grad() + loss.backward() + optimizer.step() + +Under the hood, the above loop is implemented using the :class:`~pytorch_lightning.loops.base.Loop` API like so: + +.. code-block:: python + + class DefaultLoop(Loop): + def advance(self, batch, i): + loss = lightning_module.training_step(batch, i) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def run(self, dataloader): + for i, batch in enumerate(dataloader): + self.advance(batch, i) + +Defining a loop within a class interface instead of hard-coding a raw Python for/while loop has several benefits: + +1. You can have full control over the data flow through loops. +2. You can add new loops and nest as many of them as you want. +3. If needed, the state of a loop can be :ref:`saved and resumed `. +4. New hooks can be injected at any point. + +.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/epoch-loop-steps.gif + :alt: Animation showing how to convert a standard training loop to a Lightning loop + + +.. _override default loops: + +Overriding the default loops +---------------------------- + +The fastest way to get started with loops, is to override functionality of an existing loop. +Lightning has 4 main loops it uses: :class:`~pytorch_lightning.loops.fit_loop.FitLoop` for training and validating, +:class:`~pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop` for testing, +:class:`~pytorch_lightning.loops.dataloader.prediction_loop.PredictionLoop` for predicting. + +For simple changes that don't require a custom loop, you can modify each of these loops. + +Each loop has a series of methods that can be modified. +For example with the :class:`~pytorch_lightning.loops.fit_loop.FitLoop`: + +.. code-block:: + + from pytorch_lightning.loops import FitLoop + + class MyLoop(FitLoop): + + def advance(): + ... + + def on_advance_end(self) + ... + + def on_run_end(self): + ... + +A full list with all built-in loops and subloops can be found :ref:`here `. + +To add your own modifications to a loop, simply subclass an existing loop class and override what you need. +Here is a simple example how to add a new hook: + +.. code-block:: python + + from pytorch_lightning.loops import FitLoop + + + class CustomFitLoop(FitLoop): + def advance(self): + # ... whatever code before + + # pass anything you want to the hook + self.trainer.call_hook("my_new_hook", *args, **kwargs) + + # ... whatever code after + +Now simply attach the correct loop in the trainer directly: + +.. code-block:: python + + trainer = Trainer(...) + trainer.fit_loop = CustomFitLoop() + + # fit() now uses the new FitLoop! + trainer.fit(...) + + # the equivalent for validate(), test(), predict() + val_loop = CustomValLoop() + trainer = Trainer() + trainer.validate_loop = val_loop + trainer.validate(model) + +Now your code is FULLY flexible and you can still leverage ALL the best parts of Lightning! + +.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/replace-fit-loop.gif + :alt: Animation showing how to replace a loop on the Trainer + +Creating a new loop from scratch +-------------------------------- + +You can also go wild and implement a full loop from scratch by sub-classing the :class:`~pytorch_lightning.loops.base.Loop` base class. +You will need to override a minimum of two things: + +.. code-block:: + + from pytorch_lightning.loop import Loop + + class MyFancyLoop(Loop): + + @property + def done(self): + # provide condition to stop the loop + + def advance(self): + # access your dataloader/s in whatever way you want + # do your fancy optimization things + # call the lightning module methods at your leisure + +Finally, attach it into the :class:`~pytorch_lightning.trainer.trainer.Trainer`: + +.. code-block:: python + + trainer = Trainer(...) + trainer.fit_loop = MyFancyLoop() + + # fit() now uses your fancy loop! + trainer.fit(...) + +Now you have full control over the Trainer. +But beware: The power of loop customization comes with great responsibility. +We recommend that you familiarize yourself with :ref:`overriding the default loops ` first before you start building a new loop from the ground up. + +Loop API +-------- +Here is the full API of methods available in the Loop base class. + +The :class:`~pytorch_lightning.loops.base.Loop` class is the base for all loops in Lighting just like the :class:`~pytorch_lightning.core.lightning.LightningModule` is the base for all models. +It defines a public interface that each loop implementation must follow, the key ones are: + +Properties +^^^^^^^^^^ + +done +~~~~ + +.. autoattribute:: pytorch_lightning.loops.base.Loop.done + :noindex: + +skip (optional) +~~~~~~~~~~~~~~~ + +.. autoattribute:: pytorch_lightning.loops.base.Loop.skip + :noindex: + +Methods +^^^^^^^ + +reset (optional) +~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.loops.base.Loop.reset + :noindex: + +advance +~~~~~~~ + +.. automethod:: pytorch_lightning.loops.base.Loop.advance + :noindex: + +run (optional) +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.loops.base.Loop.run + :noindex: + + +Subloops +-------- + +When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method: + +.. code-block:: python + + # Step 1: create your loop + my_epoch_loop = MyEpochLoop() + + # Step 2: use connect() + trainer.fit_loop.connect(epoch_loop=my_epoch_loop) + + # Trainer runs the fit loop with your new epoch loop! + trainer.fit(model) + +More about the built-in loops and how they are composed is explained in the next section. + +.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/connect-epoch-loop.gif + :alt: Animation showing how to connect a custom subloop + +.. _loop structure: + +Built-in Loops +-------------- + +The training loop in Lightning is called *fit loop* and is actually a combination of several loops. +Here is what the structure would look like in plain Python: + +.. code-block:: python + + # FitLoop + for epoch in range(max_epochs): + + # TrainingEpochLoop + for batch_idx, batch in enumerate(train_dataloader): + + # TrainingBatchLoop + for split_batch in tbptt_split(batch): + + # OptimizerLoop + for optimizer_idx, opt in enumerate(optimizers): + + loss = lightning_module.training_step(batch, batch_idx, optimizer_idx) + ... + + # ValidationEpochLoop + for batch_idx, batch in enumerate(val_dataloader): + lightning_module.validation_step(batch, batch_idx, optimizer_idx) + ... + + +Each of these :code:`for`-loops represents a class implementing the :class:`~pytorch_lightning.loops.base.Loop` interface. + + +.. list-table:: Trainer entry points and associated loops + :widths: 25 75 + :header-rows: 1 + + * - Built-in loop + - Description + * - :class:`~pytorch_lightning.loops.fit_loop.FitLoop` + - The :class:`~pytorch_lightning.loops.fit_loop.FitLoop` is the top-level loop where training starts. + It simply counts the epochs and iterates from one to the next by calling :code:`TrainingEpochLoop.run()` in its :code:`advance()` method. + * - :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` + - The :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` is the one that iterates over the dataloader that the user returns in their :meth:`~pytorch_lightning.core.lightning.LightningModule.train_dataloader` method. + Its main responsibilities are calling the :code:`*_epoch_start` and :code:`*_epoch_end` hooks, accumulating outputs if the user request them in one of these hooks, and running validation at the requested interval. + The validation is carried out by yet another loop, :class:`~pytorch_lightning.loops.epoch.validation_epoch_loop.ValidationEpochLoop`. + + In the :code:`run()` method, the training epoch loop could in theory simply call the :code:`LightningModule.training_step` already and perform the optimization. + However, Lightning has built-in support for automatic optimization with multiple optimizers and on top of that also supports :doc:`truncated back-propagation through time <../advanced/sequences>`. + For this reason there are actually two more loops nested under :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop`. + * - :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop` + - The responsibility of the :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop` is to split a batch given by the :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` along the time-dimension and iterate over the list of splits. + It also keeps track of the hidden state *hiddens* returned by the training step. + By default, when truncated back-propagation through time (TBPTT) is turned off, this loop does not do anything except redirect the call to the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop`. + Read more about :doc:`TBPTT <../advanced/sequences>`. + * - :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` + - The :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` iterates over one or multiple optimizers and for each one it calls the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method with the batch, the current batch index and the optimizer index if multiple optimizers are requested. + It is the leaf node in the tree of loops and performs the actual optimization (forward, zero grad, backward, optimizer step). + * - :class:`~pytorch_lightning.loops.optimization.manual_loop.ManualOptimization` + - Substitutes the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` in case of :ref:`manual_optimization` and implements the manual optimization step. + + +Available Loops in Lightning Flash +---------------------------------- + +`Active Learning `__ is a machine learning practice in which the user interacts with the learner in order to provide new labels when required. + +You can find a real use case in `Lightning Flash `_. + +Flash implements the :code:`ActiveLearningLoop` that you can use together with the :code:`ActiveLearningDataModule` to label new data on the fly. +To run the following demo, install Flash and `BaaL `__ first: + +.. code-block:: bash + + pip install lightning-flash baal + +.. code-block:: python + + import torch + + import flash + from flash.core.classification import Probabilities + from flash.core.data.utils import download_data + from flash.image import ImageClassificationData, ImageClassifier + from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop + + # 1. Create the DataModule + download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") + + # Implement the research use-case where we mask labels from labelled dataset. + datamodule = ActiveLearningDataModule( + ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2), + val_split=0.1, + ) + + # 2. Build the task + head = torch.nn.Sequential( + torch.nn.Dropout(p=0.1), + torch.nn.Linear(512, datamodule.num_classes), + ) + model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, serializer=Probabilities()) + + # 3.1 Create the trainer + trainer = flash.Trainer(max_epochs=3) + + # 3.2 Create the active learning loop and connect it to the trainer + active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1) + active_learning_loop.connect(trainer.fit_loop) + trainer.fit_loop = active_learning_loop + + # 3.3 Finetune + trainer.finetune(model, datamodule=datamodule, strategy="freeze") + + # 4. Predict what's on a few images! ants or bees? + predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") + print(predictions) + + # 5. Save the model! + trainer.save_checkpoint("image_classification_model.pt") + +Here is the `runnable example `_ and the `code for the active learning loop `_. + +Advanced Topics and Examples +---------------------------- + +Next: :doc:`Advanced loop features and examples <../extensions/loops_advanced>` diff --git a/docs/source/extensions/loops_advanced.rst b/docs/source/extensions/loops_advanced.rst new file mode 100644 index 0000000000000..6cf8ceb72b98b --- /dev/null +++ b/docs/source/extensions/loops_advanced.rst @@ -0,0 +1,41 @@ +:orphan: + +Loops (Advanced) +================ + +.. _persisting loop state: + +Persisting the state of loops +----------------------------- + +.. note:: + + This is an experimental feature and is not activated by default. + Set the environment variable `PL_FAULT_TOLERANT_TRAINING = 1` to enable saving the progress of loops. + Read more about :doc:`fault-tolerant training <../advanced/fault_tolerant_training>`. + +A powerful property of the class-based loop interface is that it can own an internal state. +Loop instances can save their state to the checkpoint through corresponding hooks and if implemented accordingly, resume the state of exectuion at the appropriate place. +This design is particularly interesting for fault-tolerant training which is an experimental feature released in Lightning v1.5. + +The two hooks :class:`~pytorch_lightning.loops.base.Loop.on_save_checkpoint` and :class:`~pytorch_lightning.loops.base.Loop.on_load_checkpoint` function very similarly to how LightningModules and Callbacks save and load state. + +.. code-block:: python + + def on_save_checkpoint(self): + state_dict["iteration"] = self.iteration + return state_dict + + + def on_load_checkpoint(self, state_dict): + self.iteration = state_dict["iteration"] + +When the Trainer is restarting from a checkpoint (e.g., through :code:`Trainer(resume_from_checkpoint=...)`), the loop exposes a boolean attribute :attr:`~pytorch_lightning.loops.base.Loop.restarting`. +Based around the value of this variable, the user can write the loop in such a way that it can restart from an arbitrary point given the state loaded from the checkpoint. +For example, the implementation of the :meth:`~pytorch_lightning.loops.base.Loop.reset` method could look like this given our previous example: + +.. code-block:: python + + def reset(self): + if not self.restarting: + self.iteration = 0 diff --git a/docs/source/index.rst b/docs/source/index.rst index f0eb5c05af4d1..ea3e606d72849 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -85,7 +85,7 @@ PyTorch Lightning extensions/logging extensions/metrics extensions/plugins - + extensions/loops .. toctree:: :maxdepth: 1 diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index 88213637d44a9..0626aa09db871 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -272,6 +272,11 @@ Turn off automatic optimization and you control the train loop! self.manual_backward(loss_b) opt_b.step() +Loop customization +================== + +If you need even more flexibility, you can fully customize the training loop to its core. +Learn more about loops :doc:`here <../extensions/loops>`. Predict or Deploy ================= diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 1a19c753b0e2b..ef53df92c6bb0 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -35,7 +35,7 @@ class Loop(ABC, Generic[T]): This class implements the following loop structure: - .. codeblock:: python + .. code-block:: python on_run_start() diff --git a/pytorch_lightning/loops/optimization/__init__.py b/pytorch_lightning/loops/optimization/__init__.py index 17e96c49d30da..07249b6a130c1 100644 --- a/pytorch_lightning/loops/optimization/__init__.py +++ b/pytorch_lightning/loops/optimization/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401 from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401 From 6ed19461e3e0dd9dba5c64377d90e641b70a4a13 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Oct 2021 10:38:15 +0000 Subject: [PATCH 05/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_examples/lite_examples/gan/gan_example.py | 3 +-- pl_examples/lite_examples/gan/models.py | 4 +-- .../lite_examples/simple/mnist_example.py | 5 ++-- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/lite/lite.py | 27 +++++++++++-------- pytorch_lightning/lite/wrappers.py | 3 ++- .../plugins/precision/native_amp.py | 4 +-- .../plugins/precision/precision_plugin.py | 2 +- .../plugins/training_type/ddp_spawn.py | 2 +- .../plugins/training_type/sharded.py | 2 +- .../training_type/training_type_plugin.py | 9 ++++--- 11 files changed, 36 insertions(+), 27 deletions(-) diff --git a/pl_examples/lite_examples/gan/gan_example.py b/pl_examples/lite_examples/gan/gan_example.py index 67d15eda98126..9cf4a91ed3ff3 100644 --- a/pl_examples/lite_examples/gan/gan_example.py +++ b/pl_examples/lite_examples/gan/gan_example.py @@ -6,7 +6,6 @@ python -m torch.distributed.run --nproc_per_node=2 gan_example.py """ -from __future__ import print_function import argparse import os @@ -25,7 +24,7 @@ from pl_examples.lite_examples.gan.models import Discriminator, Generator, weights_init from pytorch_lightning import seed_everything from pytorch_lightning.lite import LightningLite -from pytorch_lightning.lite.wrappers import _LiteOptimizer, _LiteModule +from pytorch_lightning.lite.wrappers import _LiteModule, _LiteOptimizer parser = argparse.ArgumentParser() parser.add_argument("--workers", type=int, help="number of data loading workers", default=0) diff --git a/pl_examples/lite_examples/gan/models.py b/pl_examples/lite_examples/gan/models.py index 76f1608bfc5a1..5ccdec18aebc2 100644 --- a/pl_examples/lite_examples/gan/models.py +++ b/pl_examples/lite_examples/gan/models.py @@ -18,7 +18,7 @@ def weights_init(m): class Generator(nn.Module): def __init__(self): - super(Generator, self).__init__() + super().__init__() self.main = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), @@ -49,7 +49,7 @@ def forward(self, input): class Discriminator(nn.Module): def __init__(self): - super(Discriminator, self).__init__() + super().__init__() self.main = nn.Sequential( # input is (nc) x 64 x 64 nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), diff --git a/pl_examples/lite_examples/simple/mnist_example.py b/pl_examples/lite_examples/simple/mnist_example.py index c0a1931280891..1d9cb715c4137 100644 --- a/pl_examples/lite_examples/simple/mnist_example.py +++ b/pl_examples/lite_examples/simple/mnist_example.py @@ -1,11 +1,12 @@ import argparse + import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim +from torch.optim.lr_scheduler import StepLR from torch.utils.data import DistributedSampler from torchvision import datasets, transforms -from torch.optim.lr_scheduler import StepLR from pytorch_lightning import seed_everything from pytorch_lightning.lite import LightningLite @@ -13,7 +14,7 @@ class Net(nn.Module): def __init__(self): - super(Net, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout(0.25) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index edd80e2747813..691b117185b0d 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -315,7 +315,7 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: return closure_loss def run_backward(self, tensor: Tensor, model, *args, **kwargs) -> None: - """Lightning-independent backward logic""" + """Lightning-independent backward logic.""" self.precision_plugin.run_backward(tensor, model, *args, **kwargs) def optimizer_step( diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 1d17d44e9765a..aa8ea606cfc5e 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -12,26 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from collections import Callable from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Optional, Sequence, Union, List, Dict, Tuple, Generator +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler, RandomSampler, Sampler +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler from pytorch_lightning import Trainer from pytorch_lightning.accelerators import Accelerator, TPUAccelerator -from pytorch_lightning.lite.wrappers import _LiteOptimizer, _LiteModule, _LiteDataLoader -from pytorch_lightning.plugins import PLUGIN_INPUT, DDPSpawnPlugin, TrainingTypePlugin, DeepSpeedPlugin +from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin -from pytorch_lightning.utilities import move_data_to_device, DistributedType, DeviceType +from pytorch_lightning.utilities import DeviceType, DistributedType, move_data_to_device from pytorch_lightning.utilities.data import has_iterable_dataset from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -102,7 +102,10 @@ def __init__( @property def device(self) -> torch.device: - """The current device this process runs on. Use this to create tensors directly on the device if needed.""" + """The current device this process runs on. + + Use this to create tensors directly on the device if needed. + """ return self._accelerator.root_device @property @@ -233,8 +236,8 @@ def backward(self, tensor: Tensor, *args: Any, **kwargs: Any) -> None: def cast(self) -> Generator[None, None, None]: """A context manager to automatically convert operations for the chosen precision. - Use this only if the `forward` method of your model does not cover all operations you wish to run with - the chosen precision setting. + Use this only if the `forward` method of your model does not cover all operations you wish to run with the + chosen precision setting. """ with self._accelerator.forward_context(): yield @@ -255,8 +258,10 @@ def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tens return move_data_to_device(obj, device=self.device) def print(self, *args: Any, **kwargs: Any) -> None: - """Print something only on the first process. Arguments passed to this method are forwarded to the - Python built-in :func:`print` function.""" + """Print something only on the first process. + + Arguments passed to this method are forwarded to the Python built-in :func:`print` function. + """ if self.local_rank == 0: print(*args, **kwargs) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 94b77e8834d58..7a332952ee15e 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -14,7 +14,8 @@ from typing import Any, Callable, Optional import torch -from torch import nn as nn, Tensor +from torch import nn as nn +from torch import Tensor from torch.optim import Optimizer from torch.utils.data import DataLoader diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 27f0fcbccf4c5..97aa134f5f99b 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -100,7 +100,7 @@ def pre_optimizer_step( return False def post_optimizer_step(self, optimizer: "Optimizer", optimizer_idx: int) -> None: - """Updates the GradScaler""" + """Updates the GradScaler.""" self.run_post_optimizer_step(optimizer) def run_pre_optimizer_step(self, optimizer: "Optimizer") -> None: @@ -119,7 +119,7 @@ def autocast_context_manager(self) -> torch.cuda.amp.autocast: @contextmanager def forward_context(self) -> Generator[None, None, None]: - """Enable autocast context""" + """Enable autocast context.""" with self.autocast_context_manager(): yield diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index e06f7ac7ccc99..fd7ee03f7154d 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -158,7 +158,7 @@ def post_dispatch(self) -> None: @contextlib.contextmanager def forward_context(self) -> Generator: - """A contextmanager for managing model forward/training_step/evaluation_step/predict_step""" + """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" yield @contextlib.contextmanager diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index bfa1c51867fb7..07fa4e455df26 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -16,7 +16,7 @@ import re from functools import partial from multiprocessing.queues import SimpleQueue -from typing import Any, Dict, List, Optional, Union, Callable +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 670edcfb096cf..2b194fe462cf4 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, Optional, Tuple, List, Union +from typing import Dict, Generator, List, Optional, Tuple, Union import torch from torch.nn import Module diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 230c341f5213d..092a46c79531f 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Sequence, Union, Tuple, List +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -62,7 +62,10 @@ def setup(self) -> None: """Called by the accelerator to finish setup.""" def setup_dataloader(self, dataloader: DataLoader) -> DataLoader: - """Called by the accelerator. The plugin wraps and modifies the dataloader as needed.""" + """Called by the accelerator. + + The plugin wraps and modifies the dataloader as needed. + """ return dataloader def setup_models_and_optimizers( @@ -306,7 +309,7 @@ def remove_checkpoint(self, filepath: _PATH) -> None: @contextlib.contextmanager def forward_context(self) -> Generator: - """A contextmanager for managing model forward/training_step/evaluation_step/predict_step""" + """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" yield @contextlib.contextmanager From 8c76cf5ae1db078dcd201b08de788ac6769c894d Mon Sep 17 00:00:00 2001 From: Elad Segal Date: Mon, 18 Oct 2021 13:54:26 +0300 Subject: [PATCH 06/22] reset val dataloader for binsearch (#9975) --- CHANGELOG.md | 2 ++ pytorch_lightning/tuner/batch_size_scaling.py | 1 + tests/tuner/test_scale_batch_size.py | 5 +++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37945300774d5..f809e66c6b7ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -546,6 +546,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed issue with non-init dataclass fields in `apply_to_collection` ([#9963](https://github.com/PyTorchLightning/pytorch-lightning/issues/9963)) +- Reset `val_dataloader` in `tuner/batch_size_scaling` for binsearch ([#9975](https://github.com/PyTorchLightning/pytorch-lightning/pull/9975)) + ## [1.4.9] - 2021-09-30 diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index d3fd0822aa39f..42f9ce084a43c 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -205,6 +205,7 @@ def _run_binsearch_scaling( if changed: # Force the train dataloader to reset as the batch size has changed trainer.reset_train_dataloader(model) + trainer.reset_val_dataloader(model) else: break diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 5e4d1af1277c7..9dbb24d9edf30 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -274,10 +274,11 @@ def __init__(self): trainer.tuner.scale_batch_size(model, mode="ThisModeDoesNotExist") -def test_dataloader_reset_with_scale_batch_size(tmpdir): +@pytest.mark.parametrize("scale_method", ["power", "binsearch"]) +def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method): """Test that train and val dataloaders are reset at every update in scale batch size.""" model = BatchSizeModel(batch_size=16) - scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4} + scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4, "mode": scale_method} trainer = Trainer(max_epochs=2, auto_scale_batch_size=True) new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"] From f0d9452f177b6af7f621a2701c8166666bb75c73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 13:01:57 +0200 Subject: [PATCH 07/22] remove unused setup method --- pytorch_lightning/plugins/training_type/ddp.py | 6 ------ .../plugins/training_type/training_type_plugin.py | 7 ------- 2 files changed, 13 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 6e10ec93495f8..7fc53b2b05b3d 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -191,12 +191,6 @@ def setup_model(self, model: Module) -> Module: ) return model - def setup_dataloader(self, dataloader: DataLoader) -> DataLoader: - kwargs = self.distributed_sampler_kwargs - sampler = DistributedSampler(dataloader.dataset, **kwargs) - # dataloader = replace_sampler(dataloader, sampler) - return dataloader - def _call_children_scripts(self): # bookkeeping of spawned processes self._check_can_spawn_children() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 092a46c79531f..64a1ed6b32c2e 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -61,13 +61,6 @@ def setup_environment(self) -> None: def setup(self) -> None: """Called by the accelerator to finish setup.""" - def setup_dataloader(self, dataloader: DataLoader) -> DataLoader: - """Called by the accelerator. - - The plugin wraps and modifies the dataloader as needed. - """ - return dataloader - def setup_models_and_optimizers( self, models: List[Module], optimizers: List[Optimizer] ) -> Tuple[List[Module], List[Optimizer]]: From 06d04095e52d3ad3678f36f9ceec3dd726f98b67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 13:16:39 +0200 Subject: [PATCH 08/22] remove unused forward_context() usages --- pytorch_lightning/accelerators/accelerator.py | 13 ++++--------- pytorch_lightning/lite/lite.py | 2 +- .../plugins/training_type/training_type_plugin.py | 5 ----- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 691b117185b0d..2771a3813bce0 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -211,7 +211,7 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details """ - with self.precision_plugin.forward_context(), self.training_type_plugin.forward_context(): + with self.precision_plugin.forward_context(): return self.training_type_plugin.training_step(*step_kwargs.values()) def post_training_step(self) -> None: @@ -231,7 +231,7 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details """ - with self.precision_plugin.forward_context(), self.training_type_plugin.forward_context(): + with self.precision_plugin.forward_context(): return self.training_type_plugin.validation_step(*step_kwargs.values()) def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: @@ -239,7 +239,7 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details """ - with self.precision_plugin.forward_context(), self.training_type_plugin.forward_context(): + with self.precision_plugin.forward_context(): return self.training_type_plugin.test_step(*step_kwargs.values()) def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: @@ -247,7 +247,7 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details """ - with self.precision_plugin.forward_context(), self.training_type_plugin.forward_context(): + with self.precision_plugin.forward_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: @@ -709,8 +709,3 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = "`on_train_batch_start` logic is implemented directly in the `TrainingTypePlugin` implementations." ) return self.training_type_plugin.on_train_batch_start(batch, batch_idx) - - @contextlib.contextmanager - def forward_context(self): - with self.precision_plugin.forward_context(), self.training_type_plugin.forward_context(): - yield diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index aa8ea606cfc5e..c613d6744ab89 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -239,7 +239,7 @@ def cast(self) -> Generator[None, None, None]: Use this only if the `forward` method of your model does not cover all operations you wish to run with the chosen precision setting. """ - with self._accelerator.forward_context(): + with self._precision_plugin.forward_context(): yield def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 64a1ed6b32c2e..e0bc057cf4b41 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -300,11 +300,6 @@ def remove_checkpoint(self, filepath: _PATH) -> None: if self.should_rank_save_checkpoint: return self.checkpoint_io.remove_checkpoint(filepath) - @contextlib.contextmanager - def forward_context(self) -> Generator: - """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" - yield - @contextlib.contextmanager def model_sharded_context(self) -> Generator: """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to From e03499aea8f8fe6bb7d7513ac47088dabc13377e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 13:20:51 +0200 Subject: [PATCH 09/22] refactor forward context --- pytorch_lightning/plugins/precision/double.py | 33 +------------------ .../plugins/precision/native_amp.py | 24 -------------- .../plugins/precision/precision_plugin.py | 22 ++++++++----- 3 files changed, 14 insertions(+), 65 deletions(-) diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index 179daf9e91db8..2b104f321ad38 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -91,38 +91,7 @@ def connect( return super().connect(model, optimizers, lr_schedulers) - @contextmanager - def train_step_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type. - - See: :meth:`torch.set_default_tensor_type` - """ - torch.set_default_tensor_type(torch.DoubleTensor) - yield - torch.set_default_tensor_type(torch.FloatTensor) - - @contextmanager - def val_step_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type. - - See: :meth:`torch.set_default_tensor_type` - """ - torch.set_default_tensor_type(torch.DoubleTensor) - yield - torch.set_default_tensor_type(torch.FloatTensor) - - @contextmanager - def test_step_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type. - - See: :meth:`torch.set_default_tensor_type` - """ - torch.set_default_tensor_type(torch.DoubleTensor) - yield - torch.set_default_tensor_type(torch.FloatTensor) - - @contextmanager - def predict_step_context(self) -> Generator[None, None, None]: + def forward_context(self) -> Generator[None, None, None]: """A context manager to change the default tensor type. See: :meth:`torch.set_default_tensor_type` diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 97aa134f5f99b..92c6d32d8ce58 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -123,30 +123,6 @@ def forward_context(self) -> Generator[None, None, None]: with self.autocast_context_manager(): yield - @contextmanager - def train_step_context(self) -> Generator[None, None, None]: - """Enable autocast context.""" - with self.autocast_context_manager(): - yield - - @contextmanager - def val_step_context(self) -> Generator[None, None, None]: - """Enable autocast context.""" - with self.autocast_context_manager(): - yield - - @contextmanager - def test_step_context(self) -> Generator[None, None, None]: - """Enable autocast context.""" - with self.autocast_context_manager(): - yield - - @contextmanager - def predict_step_context(self) -> Generator[None, None, None]: - """Enable autocast context.""" - with self.autocast_context_manager(): - yield - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if "native_amp_scaling_state" in checkpoint and not self.is_bfloat16: self.scaler.load_state_dict(checkpoint["native_amp_scaling_state"]) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index fd7ee03f7154d..6192d51c86f27 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -157,26 +157,30 @@ def post_dispatch(self) -> None: """Hook to do something after the training/evaluation/prediction finishes.""" @contextlib.contextmanager - def forward_context(self) -> Generator: + def forward_context(self) -> Generator[None, None, None]: """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" yield @contextlib.contextmanager - def train_step_context(self) -> Generator: + def train_step_context(self) -> Generator[None, None, None]: """A contextmanager for the training step.""" - yield + with self.forward_context(): + yield @contextlib.contextmanager - def val_step_context(self) -> Generator: + def val_step_context(self) -> Generator[None, None, None]: """A contextmanager for the validation step.""" - yield + with self.forward_context(): + yield @contextlib.contextmanager - def test_step_context(self) -> Generator: + def test_step_context(self) -> Generator[None, None, None]: """A contextmanager for the test step.""" - yield + with self.forward_context(): + yield @contextlib.contextmanager - def predict_step_context(self) -> Generator: + def predict_step_context(self) -> Generator[None, None, None]: """A contextmanager for the predict step.""" - yield + with self.forward_context(): + yield From 2daa95d3d43c8c1d20311912a3a91368d61ca818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 13:22:26 +0200 Subject: [PATCH 10/22] revert accelerator changes of context managers --- pytorch_lightning/accelerators/accelerator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 2771a3813bce0..96f1aea73dd14 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -211,7 +211,7 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details """ - with self.precision_plugin.forward_context(): + with self.precision_plugin.train_step_context(): return self.training_type_plugin.training_step(*step_kwargs.values()) def post_training_step(self) -> None: @@ -231,7 +231,7 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details """ - with self.precision_plugin.forward_context(): + with self.precision_plugin.val_step_context(): return self.training_type_plugin.validation_step(*step_kwargs.values()) def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: @@ -239,7 +239,7 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details """ - with self.precision_plugin.forward_context(): + with self.precision_plugin.test_step_context(): return self.training_type_plugin.test_step(*step_kwargs.values()) def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: @@ -247,7 +247,7 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details """ - with self.precision_plugin.forward_context(): + with self.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: From 8b5eb80ae64f0ef8bc74b6666d116c338c063535 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 13:24:25 +0200 Subject: [PATCH 11/22] fix model wrapper call to forward context --- pytorch_lightning/lite/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 7a332952ee15e..130b495604c4a 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -77,7 +77,7 @@ def module(self) -> nn.Module: return self._module def forward(self, *args: Any, **kwargs: Any) -> Any: - with self._accelerator.forward_context(): + with self._accelerator.precision_plugin.forward_context(): output = self.module.forward(*args, **kwargs) output = apply_to_collection(output, function=lambda t: t.to(torch.get_default_dtype()), dtype=Tensor) From 0284098ed4ccc6433e6dba53856d968597ba3efa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 14:00:32 +0200 Subject: [PATCH 12/22] Update pytorch_lightning/lite/wrappers.py Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- pytorch_lightning/lite/wrappers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 130b495604c4a..ad2804469428b 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -89,8 +89,7 @@ def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> N super().__init__(**dl_kwargs) self._device = device - # TODO: how to type this *angry face" - def __iter__(self): # type: ignore + def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]] iterator = super().__iter__() if self._device is None: return iterator From 09eb2066995ac416277bdaedd52de1cf9382f433 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 14:02:02 +0200 Subject: [PATCH 13/22] add missing imports for type --- pytorch_lightning/lite/wrappers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index ad2804469428b..4d50f9d19c912 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union, Iterator, Generator import torch from torch import nn as nn @@ -89,7 +89,7 @@ def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> N super().__init__(**dl_kwargs) self._device = device - def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]] + def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: iterator = super().__iter__() if self._device is None: return iterator From c69a79c86fae7d56b4f6dbe410b912f6793c595d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 18 Oct 2021 14:02:16 +0200 Subject: [PATCH 14/22] Fix `self.log(on_epoch=True)` on_batch_start (#9780) --- CHANGELOG.md | 4 ++ .../loops/batch/training_batch_loop.py | 49 ++----------------- .../loops/epoch/evaluation_epoch_loop.py | 4 +- .../loops/epoch/training_epoch_loop.py | 41 +++++++++++++--- .../logger_connector/logger_connector.py | 10 ++-- tests/loops/test_evaluation_loop_flow.py | 8 +-- tests/loops/test_training_loop_flow_scalar.py | 30 ++++-------- .../logging_/test_train_loop_logging.py | 12 +++++ 8 files changed, 72 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f809e66c6b7ad..43ebc6464ac51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -527,11 +527,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691)) +- Fixed `self.log(on_epoch=True)` for the `on_batch_start` and `on_train_batch_start` hooks ([#9780](https://github.com/PyTorchLightning/pytorch-lightning/pull/9780)) + + - Fixed restoring training state during `trainer.fit` only ([#9413](https://github.com/PyTorchLightning/pytorch-lightning/pull/9413)) - Fixed DeepSpeed and Lightning both calling the scheduler ([#9788](https://github.com/PyTorchLightning/pytorch-lightning/pull/9788)) + - Fixed missing arguments when saving hyperparameters from the parent class but not from the child class ([#9800](https://github.com/PyTorchLightning/pytorch-lightning/pull/9800)) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 93e156070d3d1..c1d800c42d853 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -23,9 +23,6 @@ from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop from pytorch_lightning.loops.utilities import _get_active_optimizers from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import AttributeDict -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.warnings import WarningCache _OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] @@ -43,7 +40,6 @@ def __init__(self) -> None: self.manual_loop = ManualOptimization() self._outputs: _OUTPUTS_TYPE = [] - self._warning_cache: WarningCache = WarningCache() self._remaining_splits: Optional[List[Any]] = None @property @@ -59,42 +55,6 @@ def connect( if manual_loop is not None: self.manual_loop = manual_loop - def run(self, batch: Any, batch_idx: int) -> AttributeDict: - """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks. - - Args: - batch: the current batch to run the train step on - batch_idx: the index of the current batch - """ - if batch is None: - self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") - return AttributeDict(signal=0, outputs=[]) - - # hook - self.trainer.logger_connector.on_batch_start() - response = self.trainer.call_hook("on_batch_start") - if response == -1: - return AttributeDict(signal=-1) - - # hook - # TODO: Update this in v1.7 (deprecation: #9816) - model_fx = self.trainer.lightning_module.on_train_batch_start - extra_kwargs = ( - {"dataloader_idx": 0} - if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True) - else {} - ) - response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) - if response == -1: - return AttributeDict(signal=-1) - - self.trainer.fit_loop.epoch_loop.batch_progress.increment_started() - - super().run(batch, batch_idx) - - output, self._outputs = AttributeDict(signal=0, outputs=self._outputs), None # free memory - return output - def reset(self) -> None: """Resets the loop state.""" self._outputs = [] @@ -117,11 +77,10 @@ def advance(self, batch, batch_idx): batch_idx: the index of the current batch """ void(batch) - split_idx, split_batch = self._remaining_splits.pop(0) - self.split_idx = split_idx + self.split_idx, split_batch = self._remaining_splits.pop(0) # let logger connector extract current batch size - self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) + self.trainer.logger_connector.on_train_split_start(self.split_idx, split_batch) # choose which loop will run the optimization if self.trainer.lightning_module.automatic_optimization: @@ -135,10 +94,12 @@ def advance(self, batch, batch_idx): # then `advance` doesn't finish and an empty dict is returned self._outputs.append(outputs) - def on_run_end(self) -> None: + def on_run_end(self) -> _OUTPUTS_TYPE: self.optimizer_loop._hiddens = None # this is not necessary as the manual loop runs for only 1 iteration, but just in case self.manual_loop._hiddens = None + output, self._outputs = self._outputs, None # free memory + return output def teardown(self) -> None: # release memory diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 3e1b88a2d41c3..d666cc2ad0d59 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -233,10 +233,10 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: Raises: AssertionError: If the number of dataloaders is None (has not yet been set). """ - self.trainer.logger_connector.on_batch_start() + self.trainer.logger_connector.on_batch_start(batch_idx) assert self._num_dataloaders is not None - self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self._num_dataloaders) + self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self._num_dataloaders) if self.trainer.testing: self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index fe3a2dc7431cc..4cc8eaa811231 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -28,6 +28,7 @@ from pytorch_lightning.utilities.fetching import AbstractDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.warnings import WarningCache _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] @@ -57,6 +58,7 @@ def __init__(self, min_steps: int, max_steps: int): self._results = ResultCollection(training=True) self._outputs: _OUTPUTS_TYPE = [] + self._warning_cache = WarningCache() self._dataloader_iter: Optional[Iterator] = None # caches the loaded dataloader state until dataloader objects are available self._dataloader_state_dict: Dict[str, Any] = {} @@ -151,14 +153,37 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.batch_progress.increment_ready() - with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.batch_loop.run(batch, batch_idx) + if batch is None: + self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") + batch_output = [] + else: + # hook + self.trainer.logger_connector.on_batch_start(batch_idx) + response = self.trainer.call_hook("on_batch_start") + if response == -1: + self.batch_progress.increment_processed() + raise StopIteration + + # TODO: Update this in v1.7 (deprecation: #9816) + model_fx = self.trainer.lightning_module.on_train_batch_start + extra_kwargs = ( + {"dataloader_idx": 0} + if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True) + else {} + ) - self.batch_progress.increment_processed() + # hook + response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) + if response == -1: + self.batch_progress.increment_processed() + raise StopIteration - # when returning -1 from train_step, we end epoch early - if batch_output.signal == -1: - raise StopIteration + self.batch_progress.increment_started() + + with self.trainer.profiler.profile("run_training_batch"): + batch_output = self.batch_loop.run(batch, batch_idx) + + self.batch_progress.increment_processed() # update non-plateau LR schedulers # update epoch-interval ones only when we are at the end of training epoch @@ -167,7 +192,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.update_lr_schedulers("epoch", update_plateau_schedulers=False) batch_end_outputs = self._prepare_outputs_training_batch_end( - batch_output.outputs, + batch_output, automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization, num_optimizers=len(self.trainer.optimizers), ) @@ -186,7 +211,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.batch_progress.increment_completed() if is_overridden("training_epoch_end", self.trainer.lightning_module): - self._outputs.append(batch_output.outputs) + self._outputs.append(batch_output) # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 21684d6831a65..cb01e7edbc97a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -138,7 +138,7 @@ def _increment_eval_log_step(self) -> None: elif self.trainer.state.stage is RunningStage.TESTING: self._test_log_step += 1 - def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None: + def on_evaluation_batch_start(self, batch: Any, dataloader_idx: int, num_dataloaders: int) -> None: model = self.trainer.lightning_module # set dataloader_idx only if multiple ones model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None @@ -146,7 +146,6 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: # track batch_size assert self.trainer._results is not None self.trainer._results.extract_batch_size(batch) - self._batch_idx = batch_idx def update_eval_step_metrics(self) -> None: if self.trainer.sanity_checking: @@ -213,14 +212,12 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]: Train metric updates """ - def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: + def on_train_split_start(self, split_idx: int, split_batch: Any) -> None: assert self.trainer._results is not None # when the user requests `dataloader_iter`, we can't track the batch_size # and this is left to user responsibility. if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher): self.trainer._results.extract_batch_size(split_batch) - - self._batch_idx = batch_idx self._split_idx = split_idx def update_train_step_metrics(self) -> None: @@ -267,7 +264,8 @@ def _log_gpus_metrics(self) -> None: def on_epoch_start(self) -> None: self._epoch_end_reached = False - def on_batch_start(self) -> None: + def on_batch_start(self, batch_idx: int) -> None: + self._batch_idx = batch_idx self._epoch_end_reached = False def epoch_end_reached(self) -> None: diff --git a/tests/loops/test_evaluation_loop_flow.py b/tests/loops/test_evaluation_loop_flow.py index 5a9d0a737350c..0fe90557b3530 100644 --- a/tests/loops/test_evaluation_loop_flow.py +++ b/tests/loops/test_evaluation_loop_flow.py @@ -64,10 +64,8 @@ def backward(self, loss, optimizer, optimizer_idx): # simulate training manually trainer.state.stage = RunningStage.TRAINING batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - assert out.signal == 0 + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - train_step_out = out.outputs assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out["loss"], torch.Tensor) @@ -129,10 +127,8 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - assert out.signal == 0 + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - train_step_out = out.outputs assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out["loss"], torch.Tensor) diff --git a/tests/loops/test_training_loop_flow_scalar.py b/tests/loops/test_training_loop_flow_scalar.py index 0501cbdf529db..f7f539efef8cd 100644 --- a/tests/loops/test_training_loop_flow_scalar.py +++ b/tests/loops/test_training_loop_flow_scalar.py @@ -147,10 +147,8 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - assert out.signal == 0 + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - train_step_out = out.outputs assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out["loss"], torch.Tensor) @@ -221,10 +219,8 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - assert out.signal == 0 + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - train_step_out = out.outputs assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] assert isinstance(train_step_out["loss"], torch.Tensor) @@ -311,8 +307,7 @@ def training_step(self, batch, batch_idx): for batch_idx, batch in enumerate(model.train_dataloader()): out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) if not batch_idx % 2: - assert out.outputs == [] - assert out.signal == 0 + assert out == [] def test_training_step_none_batches(tmpdir): @@ -321,7 +316,6 @@ def test_training_step_none_batches(tmpdir): class TestModel(BoringModel): def __init__(self): super().__init__() - self.counter = 0 def collate_none_when_even(self, batch): @@ -333,12 +327,17 @@ def collate_none_when_even(self, batch): return result def train_dataloader(self): - return DataLoader(RandomDataset(32, 64), collate_fn=self.collate_none_when_even) + return DataLoader(RandomDataset(32, 4), collate_fn=self.collate_none_when_even) + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + if batch_idx % 2 == 0: + assert outputs == [] + else: + assert outputs model = TestModel() trainer = Trainer( default_root_dir=tmpdir, - limit_train_batches=4, limit_val_batches=1, max_epochs=4, enable_model_summary=False, @@ -348,12 +347,3 @@ def train_dataloader(self): with pytest.warns(UserWarning, match=r".*train_dataloader yielded None.*"): trainer.fit(model) - - trainer.state.stage = RunningStage.TRAINING - - # manually check a few batches - for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) - if not batch_idx % 2: - assert out.outputs == [] - assert out.signal == 0 diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 5e9db17b2de62..f7f7190adb9bd 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -276,11 +276,21 @@ def on_train_epoch_start(self, _, pl_module): pl_module, "on_train_epoch_start", on_steps=self.choices, on_epochs=[True], prob_bars=self.choices ) + def on_batch_start(self, _, pl_module, *__): + self.make_logging( + pl_module, "on_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + ) + def on_batch_end(self, _, pl_module): self.make_logging( pl_module, "on_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices ) + def on_train_batch_start(self, _, pl_module, *__): + self.make_logging( + pl_module, "on_train_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + ) + def on_train_batch_end(self, _, pl_module, *__): self.make_logging( pl_module, "on_train_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices @@ -323,7 +333,9 @@ def training_step(self, batch, batch_idx): "on_train_start": 1, "on_epoch_start": 1, "on_train_epoch_start": 1, + "on_train_batch_start": 2, "on_train_batch_end": 2, + "on_batch_start": 2, "on_batch_end": 2, "on_train_epoch_end": 1, "on_epoch_end": 1, From 1b6db9bcd96e89387b03871a58410d1f3e186d8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 14:03:07 +0200 Subject: [PATCH 15/22] add override ignore type for __iter__ --- pytorch_lightning/lite/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 4d50f9d19c912..0f1233f1b5f92 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -89,7 +89,7 @@ def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> N super().__init__(**dl_kwargs) self._device = device - def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: + def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: # type: ignore[override] iterator = super().__iter__() if self._device is None: return iterator From 53c1748a1a0a0a7ae9d1b1ad6f9a6ef7ebff3da4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Oct 2021 12:04:29 +0000 Subject: [PATCH 16/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/lite/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 0f1233f1b5f92..94cd8ecb89e5d 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union, Iterator, Generator +from typing import Any, Callable, Generator, Iterator, Optional, Union import torch from torch import nn as nn From 0684e5295ff05c51807433aebe3f6d5bcb863d83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 18 Oct 2021 14:05:41 +0200 Subject: [PATCH 17/22] Remove deprecated `DataModule.dims` usage in tests (#9948) --- tests/helpers/boring_model.py | 3 --- tests/helpers/datamodules.py | 5 ----- 2 files changed, 8 deletions(-) diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index 4036d34663a9f..d51fb44bff0d2 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -158,18 +158,15 @@ def prepare_data(self): def setup(self, stage: Optional[str] = None): if stage == "fit" or stage is None: self.random_train = Subset(self.random_full, indices=range(64)) - self.dims = self.random_train[0].shape if stage in ("fit", "validate") or stage is None: self.random_val = Subset(self.random_full, indices=range(64, 64 * 2)) if stage == "test" or stage is None: self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3)) - self.dims = getattr(self, "dims", self.random_test[0].shape) if stage == "predict" or stage is None: self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) - self.dims = getattr(self, "dims", self.random_predict[0].shape) def train_dataloader(self): return DataLoader(self.random_train) diff --git a/tests/helpers/datamodules.py b/tests/helpers/datamodules.py index 08fa3c6d214fd..0cb178a749a09 100644 --- a/tests/helpers/datamodules.py +++ b/tests/helpers/datamodules.py @@ -40,11 +40,6 @@ def __init__(self, data_dir: str = "./", batch_size: int = 32, use_trials: bool # TrialMNIST is a constrained MNIST dataset self.dataset_cls = TrialMNIST if use_trials else MNIST - # self.dims is returned when you call dm.size() - # Setting default dims here because we know them. - # Could optionally be assigned dynamically in dm.setup() - self.dims = (1, 28, 28) - def prepare_data(self): # download only self.dataset_cls(self.data_dir, train=True, download=True) From e0470cc2444d8fd6cd3fb91dc12dbbe9af8ce66f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 18 Oct 2021 14:10:47 +0200 Subject: [PATCH 18/22] Update `resume_from_checkpoint` docs (#9952) --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 3 --- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index da6a81e8add44..2bde85de052ca 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -59,9 +59,6 @@ def resume_start(self) -> None: 1. from HPC weights if found 2. from `resume_from_checkpoint` file if provided 3. don't restore - - Raises: - FileNotFoundError: If the path to the checkpoint file is provided but the file does not exist. """ self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path checkpoint_path = self.resume_checkpoint_path diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e6d8ccde91d71..2990792502c0b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -357,7 +357,7 @@ def __init__( you can set ``replace_sampler_ddp=False`` and add your own distributed sampler. resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is - no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, + no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch. strategy: Supports different training strategies with aliases From ae12a4d42e0e303e54f50b159fd5f77a938937b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 14:32:35 +0200 Subject: [PATCH 19/22] update setup logic for ddp --- pytorch_lightning/lite/lite.py | 2 +- pytorch_lightning/plugins/training_type/ddp.py | 11 ++--------- pytorch_lightning/plugins/training_type/ddp_spawn.py | 8 +++++--- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index c613d6744ab89..0db786b4a93db 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -156,11 +156,11 @@ def setup( """ # wrap all objects passed in and return them in the same order optimizers = [optimizers] if isinstance(optimizers, Optimizer) else optimizers - model, optimizers = self._setup_model_and_optimizers(model, optimizers) if move_to_device: model = self.to_device(model) + model, optimizers = self._setup_model_and_optimizers(model, optimizers) optimizers = optimizers[0] if len(optimizers) == 1 else optimizers return model, optimizers diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 7fc53b2b05b3d..2bcb032e3adcf 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -184,12 +184,7 @@ def setup_environment(self) -> None: self.setup_distributed() def setup_model(self, model: Module) -> Module: - model = DistributedDataParallel( - module=model.to(self.root_device), - device_ids=self.determine_ddp_device_ids(), - **self._ddp_kwargs, - ) - return model + return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) def _call_children_scripts(self): # bookkeeping of spawned processes @@ -365,9 +360,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): def configure_ddp(self) -> None: self.pre_configure_ddp() - self._model = DistributedDataParallel( - LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs - ) + self._model = self.setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() def determine_ddp_device_ids(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 07fa4e455df26..7e5433cb4beba 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -22,6 +22,7 @@ import torch import torch.distributed import torch.multiprocessing as mp +from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl @@ -148,6 +149,9 @@ def setup(self) -> None: smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() + def setup_model(self, model: Module) -> Module: + return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) + def set_world_ranks(self, process_idx: int = 0) -> None: self._local_rank = process_idx if self.cluster_environment is None: @@ -259,9 +263,7 @@ def _register_ddp_hooks(self) -> None: def configure_ddp(self) -> None: self.pre_configure_ddp() - self._model = DistributedDataParallel( - LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs - ) + self._model = self.setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() def determine_ddp_device_ids(self): From 3f355d0eb774bd591f6d8b5d48c6d7f58453da5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 18 Oct 2021 14:43:06 +0200 Subject: [PATCH 20/22] Remove manual tracking of optimizer steps (#9957) --- pytorch_lightning/core/optimizer.py | 2 -- tests/accelerators/test_tpu.py | 15 ++++++------ tests/core/test_lightning_optimizer.py | 1 - .../optimization/test_manual_optimization.py | 24 +++++++++++-------- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index dd074450f0897..65e0d7dc7b8ab 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -46,7 +46,6 @@ def __init__(self, optimizer: Optimizer): self._optimizer = optimizer self._trainer = None self._optimizer_idx = None - self._total_optimizer_step_calls = 0 @property def optimizer(self): @@ -192,7 +191,6 @@ def closure_dis(): trainer = self._trainer with trainer.profiler.profile(profiler_action): trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) - self._total_optimizer_step_calls += 1 def __repr__(self): groups = [ diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index 622ead614b1c7..25743d5b3bc3d 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -13,6 +13,7 @@ # limitations under the License import collections from copy import deepcopy +from unittest.mock import patch import pytest import torch @@ -21,7 +22,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator -from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.utilities import find_shared_parameters from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -189,16 +189,18 @@ def on_train_batch_end(self, outputs, batch, batch_idx): assert torch.all(self.layer.weight.grad == 0) self.count += 1 + def on_train_start(self): + opt = self.optimizers() + self.opt_step_patch = patch.object(opt, "step", wraps=opt.step) + self.opt_step_mock = self.opt_step_patch.start() + def on_train_end(self): assert self.called["training_step"] == 5 assert self.called["on_train_batch_start"] == 5 assert self.called["on_train_batch_end"] == 5 - class TestManualOptimizationCallack(Callback): - def on_train_end(self, trainer, pl_module): - - opt = pl_module.optimizers() - assert opt._total_optimizer_step_calls == 3 + self.opt_step_patch.stop() + assert self.opt_step_mock.call_count == 3 model = ManualOptimizationModel() model_copy = deepcopy(model) @@ -212,7 +214,6 @@ def on_train_end(self, trainer, pl_module): limit_test_batches=0, limit_val_batches=0, tpu_cores=8, - callbacks=[TestManualOptimizationCallack()], ) trainer.fit(model) diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index f4f1287c122d8..05de6f44b9e44 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -161,7 +161,6 @@ def test_state(tmpdir): "zero_grad", "__setstate__", "add_param_group", - "_total_optimizer_step_calls", ] for k, v in lightning_optimizer.__dict__.items(): diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 1fff9c5a4715c..0be43ee8f670e 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -23,7 +23,6 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.callbacks import Callback from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -706,14 +705,6 @@ def configure_optimizers(self): mock_adam_step.assert_has_calls(expected_calls) -class TestManualOptimizationDDPCallack(Callback): - def on_train_end(self, trainer, pl_module): - - opt_a, opt_b = pl_module.optimizers() - assert opt_a._total_optimizer_step_calls == 4 - assert opt_b._total_optimizer_step_calls == 2 - - class TesManualOptimizationDDPModel(BoringModel): def __init__(self): super().__init__() @@ -787,6 +778,20 @@ def configure_optimizers(self): optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001) return [optimizer_gen, optimizer_dis] + def on_train_start(self): + # this is done here instead of in the calling function due to `spawn` + sgd, adam = self.optimizers() + self.sgd_step_patch = patch.object(sgd, "step", wraps=sgd.step) + self.sgd_step_mock = self.sgd_step_patch.start() + self.adam_step_patch = patch.object(adam, "step", wraps=adam.step) + self.adam_step_mock = self.adam_step_patch.start() + + def on_train_end(self): + self.sgd_step_patch.stop() + assert self.sgd_step_mock.call_count == 4 + self.adam_step_patch.stop() + assert self.adam_step_mock.call_count == 2 + def train_manual_optimization(tmpdir, strategy, model_cls=TesManualOptimizationDDPModel): @@ -806,7 +811,6 @@ def train_manual_optimization(tmpdir, strategy, model_cls=TesManualOptimizationD log_every_n_steps=1, gpus=2, strategy=strategy, - callbacks=[TestManualOptimizationDDPCallack()], ) trainer.fit(model) From 9e4a9d0c758a829dc892fcafae77ebc9ad86ee3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 14:50:24 +0200 Subject: [PATCH 21/22] remove proxy changes from connector --- pytorch_lightning/lite/lite.py | 2 +- pytorch_lightning/trainer/connectors/accelerator_connector.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 0db786b4a93db..8029ed9a25dce 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -93,7 +93,7 @@ def __init__( amp_level=None, plugins=plugins, ) - self._accelerator = self._accelerator_connector.select_accelerator() + self._accelerator = self._accelerator_connector.accelerator self._strategy = self._accelerator.training_type_plugin self._precision_plugin = self._accelerator.precision_plugin diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 7ebc183e624e8..53f95ae4c8a14 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -763,8 +763,8 @@ def select_accelerator(self) -> Accelerator: # that we first select training_type_plugin, then precision_plugin accelerator = acc_cls(training_type_plugin=self.training_type_plugin, precision_plugin=self.precision_plugin) # transfer ownership of the plugins to the accelerator - self._training_type_plugin = self.training_type_plugin - self._precision_plugin = self.precision_plugin + self._training_type_plugin = proxy(self.training_type_plugin) + self._precision_plugin = proxy(self.precision_plugin) return accelerator From 10d0b41977be99ed0d71d5c5eba2bce19b21f149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 14:58:19 +0200 Subject: [PATCH 22/22] Introduce `PrecisionPlugin.forward_context()` (#9988) Co-authored-by: thomas chaton --- CHANGELOG.md | 4 +++ pytorch_lightning/plugins/precision/double.py | 32 +------------------ .../plugins/precision/native_amp.py | 20 +----------- .../plugins/precision/precision_plugin.py | 25 ++++++++++----- 4 files changed, 23 insertions(+), 58 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 43ebc6464ac51..a10c5cb41cc18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -194,6 +194,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597)) +- LightningLite: + * Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988)) + + ### Changed - Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)). diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index 179daf9e91db8..5e9e8bd43b820 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -92,37 +92,7 @@ def connect( return super().connect(model, optimizers, lr_schedulers) @contextmanager - def train_step_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type. - - See: :meth:`torch.set_default_tensor_type` - """ - torch.set_default_tensor_type(torch.DoubleTensor) - yield - torch.set_default_tensor_type(torch.FloatTensor) - - @contextmanager - def val_step_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type. - - See: :meth:`torch.set_default_tensor_type` - """ - torch.set_default_tensor_type(torch.DoubleTensor) - yield - torch.set_default_tensor_type(torch.FloatTensor) - - @contextmanager - def test_step_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type. - - See: :meth:`torch.set_default_tensor_type` - """ - torch.set_default_tensor_type(torch.DoubleTensor) - yield - torch.set_default_tensor_type(torch.FloatTensor) - - @contextmanager - def predict_step_context(self) -> Generator[None, None, None]: + def forward_context(self) -> Generator[None, None, None]: """A context manager to change the default tensor type. See: :meth:`torch.set_default_tensor_type` diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 8f93b63588c19..50c527f5f407d 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -102,25 +102,7 @@ def autocast_context_manager(self) -> torch.cuda.amp.autocast: return torch.cuda.amp.autocast() @contextmanager - def train_step_context(self) -> Generator[None, None, None]: - """Enable autocast context.""" - with self.autocast_context_manager(): - yield - - @contextmanager - def val_step_context(self) -> Generator[None, None, None]: - """Enable autocast context.""" - with self.autocast_context_manager(): - yield - - @contextmanager - def test_step_context(self) -> Generator[None, None, None]: - """Enable autocast context.""" - with self.autocast_context_manager(): - yield - - @contextmanager - def predict_step_context(self) -> Generator[None, None, None]: + def forward_context(self) -> Generator[None, None, None]: """Enable autocast context.""" with self.autocast_context_manager(): yield diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 5138bb660b9cd..c81a474faad34 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -143,21 +143,30 @@ def post_dispatch(self) -> None: """Hook to do something after the training/evaluation/prediction finishes.""" @contextlib.contextmanager - def train_step_context(self) -> Generator: - """A contextmanager for the training step.""" + def forward_context(self) -> Generator[None, None, None]: + """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" yield @contextlib.contextmanager - def val_step_context(self) -> Generator: + def train_step_context(self) -> Generator[None, None, None]: + """A contextmanager for the training step.""" + with self.forward_context(): + yield + + @contextlib.contextmanager + def val_step_context(self) -> Generator[None, None, None]: """A contextmanager for the validation step.""" - yield + with self.forward_context(): + yield @contextlib.contextmanager - def test_step_context(self) -> Generator: + def test_step_context(self) -> Generator[None, None, None]: """A contextmanager for the test step.""" - yield + with self.forward_context(): + yield @contextlib.contextmanager - def predict_step_context(self) -> Generator: + def predict_step_context(self) -> Generator[None, None, None]: """A contextmanager for the predict step.""" - yield + with self.forward_context(): + yield