From 9f8864f2519875646527f370f08f63cbb809b215 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 23 Mar 2021 12:17:18 +0000 Subject: [PATCH 01/27] Add base hook for model parallel --- pytorch_lightning/accelerators/accelerator.py | 15 ++++++++++++++- pytorch_lightning/callbacks/base.py | 3 +++ pytorch_lightning/core/hooks.py | 7 +++++++ .../plugins/training_type/training_type_plugin.py | 14 +++++++++++++- pytorch_lightning/trainer/callback_hook.py | 5 +++++ pytorch_lightning/trainer/trainer.py | 8 ++++++++ tests/callbacks/test_callbacks.py | 3 +++ 7 files changed, 53 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 60e6ea88b4250..c97918e1e407e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union +import contextlib +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union import torch from torch.optim import Optimizer @@ -432,3 +433,15 @@ def results(self) -> Any: In distributed training, we make sure to transfer the results to the appropriate master process. """ return self.training_type_plugin.results + + @contextlib.contextmanager + def model_parallel_context(self) -> Generator: + """ + Provide hook to create modules in a parallel aware context. This is useful for when we'd like to + shard the model instantly, which is useful for extremely large models which can save memory and + initialization time. + + Returns: Model parallel context. + """ + with self.training_type_plugin.model_parallel_context(): + yield diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991446..d9d056bbc4fee 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -29,6 +29,9 @@ class Callback(abc.ABC): Subclass this class and override any of the relevant hooks """ + def on_model_parallel_setup(self, trainer, pl_module: LightningModule) -> None: + """Called before model parallel accelerator setup""" + def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule) -> None: """Called before accelerator is being setup""" pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9624f94652713..bbea853c6b0e7 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -334,6 +334,13 @@ def on_post_move_to_device(self): """ + def on_model_parallel_setup(self) -> None: + """ + Hook to create modules in a parallel aware context. This is useful for when using sharded plugins, + where we'd like to shard the model instantly, which is useful for extremely large models + which can save memory and initialization time. + """ + class DataHooks: """Hooks to be used for data related stuff.""" diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 89f27963caadf..a407acd4a6040 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Generator, Iterable, Optional, TYPE_CHECKING, Union import torch from torch.nn import Module @@ -192,3 +193,14 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. """ return False + + @contextlib.contextmanager + def model_parallel_context(self) -> Generator: + """ + Provide hook to create modules in a parallel aware context. This is useful for when we'd like to + shard the model instantly, which is useful for extremely large models which can save memory and + initialization time. + + Returns: Model parallel context. + """ + yield diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..849f99d4f8e09 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -38,6 +38,11 @@ def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) + def on_model_parallel_setup(self, model: LightningModule, stage: Optional[str]) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + for callback in self.callbacks: + callback.on_model_parallel_setup(self, model, stage) + def setup(self, model: LightningModule, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f7bd1757b9bc2..716a1bc3707af 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -433,6 +433,7 @@ def fit( self.accelerator.setup_environment() self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module + self.call_model_parallel_hook(model) # allow user to setup in model parallel environment # ---------------------------- # INSPECT THE CORE LOOPS @@ -1075,6 +1076,13 @@ def call_setup_hook(self, model: LightningModule) -> None: self.setup(model, stage=state) model.setup(stage=state) + def call_model_parallel_hook(self, model: LightningModule) -> None: + if not hasattr(self.lightning_module, 'is_model_parallel_setup'): + self.on_model_parallel_setup(model) + with self.accelerator.model_parallel_context(): + model.on_model_parallel_setup() + self.lightning_module.is_model_parallel_setup = True + def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state self.profiler.teardown(stage=state) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index fdefc6ae9ef1c..1490a7f53cd77 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -48,6 +48,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'fit'), + call.on_model_parallel_setup(model), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), @@ -117,6 +118,7 @@ def test_trainer_callback_hook_system_test(tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'test'), + call.on_model_parallel_setup(trainer, model), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), @@ -150,6 +152,7 @@ def test_trainer_callback_hook_system_validate(tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'validate'), + call.on_model_parallel_setup(trainer, model), call.on_validation_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), From eac5344b077504c1550ceffef859605225102a03 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 20:40:16 +0530 Subject: [PATCH 02/27] fix callback signature --- tests/callbacks/test_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 1490a7f53cd77..381de5d3e659f 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -48,7 +48,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'fit'), - call.on_model_parallel_setup(model), + call.on_model_parallel_setup(trainer, model), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), From 32df0cb9e6277e0dd9d12d8d1326e41e0dcb9041 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 25 Mar 2021 15:58:09 +0000 Subject: [PATCH 03/27] Simplify hook --- pytorch_lightning/trainer/callback_hook.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 849f99d4f8e09..782921cc06047 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -38,10 +38,10 @@ def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def on_model_parallel_setup(self, model: LightningModule, stage: Optional[str]) -> None: + def on_model_parallel_setup(self, model: LightningModule) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.on_model_parallel_setup(self, model, stage) + callback.on_model_parallel_setup(self, model) def setup(self, model: LightningModule, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 716a1bc3707af..a20b8618161a4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1077,11 +1077,9 @@ def call_setup_hook(self, model: LightningModule) -> None: model.setup(stage=state) def call_model_parallel_hook(self, model: LightningModule) -> None: - if not hasattr(self.lightning_module, 'is_model_parallel_setup'): - self.on_model_parallel_setup(model) - with self.accelerator.model_parallel_context(): - model.on_model_parallel_setup() - self.lightning_module.is_model_parallel_setup = True + self.on_model_parallel_setup(model) + with self.accelerator.model_parallel_context(): + model.on_model_parallel_setup() def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state From 282a133dd834fca4a2d419a637b445a94cad7aca Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 25 Mar 2021 16:52:57 +0000 Subject: [PATCH 04/27] Add hook logic --- pytorch_lightning/accelerators/accelerator.py | 9 +++++++++ .../plugins/training_type/training_type_plugin.py | 9 +++++++++ pytorch_lightning/trainer/trainer.py | 9 ++++++--- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c97918e1e407e..d6362bb103f5c 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -445,3 +445,12 @@ def model_parallel_context(self) -> Generator: """ with self.training_type_plugin.model_parallel_context(): yield + + @property + def call_model_parallel_setup_hook(self) -> bool: + """ + Allow model parallel hook to be called in suitable environments determined by the training type plugin. + This is useful for when we want to shard the model once within fit. + Returns: True if we want to call the model parallel setup hook. + """ + return self.training_type_plugin.call_model_parallel_setup_hook diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index a407acd4a6040..2997b14d37316 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -204,3 +204,12 @@ def model_parallel_context(self) -> Generator: Returns: Model parallel context. """ yield + + @property + def call_model_parallel_setup_hook(self) -> bool: + """ + Allow model parallel hook to be called in suitable environments determined by the training type plugin. + This is useful for when we want to shard the model once within fit. + Returns: True if we want to call the model parallel setup hook. + """ + return True diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a20b8618161a4..aacc71a823a3b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1077,9 +1077,12 @@ def call_setup_hook(self, model: LightningModule) -> None: model.setup(stage=state) def call_model_parallel_hook(self, model: LightningModule) -> None: - self.on_model_parallel_setup(model) - with self.accelerator.model_parallel_context(): - model.on_model_parallel_setup() + # Call model parallel hook if accelerator requests. In some cases + # we will not call the hook; the hook has initialized the sharded model for example. + if self.accelerator.call_model_parallel_setup_hook: + self.on_model_parallel_setup(model) + with self.accelerator.model_parallel_context(): + model.on_model_parallel_setup() def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state From 7a94e72f183d8b2855f07c5ccf10f85e93b35f2e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 23:17:49 +0530 Subject: [PATCH 05/27] add tests --- tests/accelerators/test_common.py | 65 +++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index bd8636ba839f9..ee45878863551 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -1,9 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest import torch import tests.helpers.utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.plugins import SingleDevicePlugin from tests.accelerators.test_dp import CustomClassificationModelDP +from tests.helpers.boring_model import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf @@ -44,3 +59,53 @@ def test_evaluate(tmpdir, trainer_kwargs): # make sure weights didn't change new_weights = model.layer_0.weight.clone().detach().cpu() torch.testing.assert_allclose(old_weights, new_weights) + + +def test_model_parallel_setup_called(tmpdir): + + class TestModel(BoringModel): + + def on_model_parallel_setup(self): + self.on_model_parallel_setup_called = True + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + + assert model.on_model_parallel_setup_called + + +def test_model_parallel_setup_false(tmpdir): + """Ensure ``on_model_parallel_setup`` is not called, when turned off""" + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.on_model_parallel_setup_called = False + + def on_model_parallel_setup(self): + self.on_model_parallel_setup_called = True + + class CustomPlugin(SingleDevicePlugin): + + @property + def call_model_parallel_setup_hook(self) -> bool: + return False + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + plugins=CustomPlugin(device=torch.device("cpu")) + ) + trainer.fit(model) + + assert not model.on_model_parallel_setup_called From 809148135218ce102010516e9d064785deae622d Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 23:33:44 +0530 Subject: [PATCH 06/27] add property setter --- .../plugins/training_type/training_type_plugin.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 2997b14d37316..f68510242839b 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -34,6 +34,7 @@ class TrainingTypePlugin(Plugin, ABC): def __init__(self) -> None: self._model = None self._results = None + self._call_model_parallel_setup_hook = True def connect(self, model: 'Module') -> None: """Called by the accelerator to connect the accelerator and the model with this plugin""" @@ -212,4 +213,9 @@ def call_model_parallel_setup_hook(self) -> bool: This is useful for when we want to shard the model once within fit. Returns: True if we want to call the model parallel setup hook. """ - return True + return self._call_model_parallel_setup_hook + + @call_model_parallel_setup_hook.setter + def call_model_parallel_setup_hook(self, mode: bool) -> bool: + if isinstance(mode, bool): + self._call_model_parallel_setup_hook = mode From 633fc77148c0538e92479de95bae042bbf430cea Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 23:49:49 +0530 Subject: [PATCH 07/27] add logic for being called once --- pytorch_lightning/accelerators/accelerator.py | 5 ++++ .../training_type/training_type_plugin.py | 3 +-- pytorch_lightning/trainer/trainer.py | 1 + tests/accelerators/test_common.py | 27 +++++++++++++++++++ 4 files changed, 34 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index d6362bb103f5c..6ef3895077892 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -454,3 +454,8 @@ def call_model_parallel_setup_hook(self) -> bool: Returns: True if we want to call the model parallel setup hook. """ return self.training_type_plugin.call_model_parallel_setup_hook + + @call_model_parallel_setup_hook.setter + def call_model_parallel_setup_hook(self, mode: bool) -> bool: + if isinstance(mode, bool): + self.training_type_plugin.call_model_parallel_setup_hook = mode diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index f68510242839b..8431f653955e7 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -217,5 +217,4 @@ def call_model_parallel_setup_hook(self) -> bool: @call_model_parallel_setup_hook.setter def call_model_parallel_setup_hook(self, mode: bool) -> bool: - if isinstance(mode, bool): - self._call_model_parallel_setup_hook = mode + self._call_model_parallel_setup_hook = mode diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index aacc71a823a3b..dd942e040efd3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1083,6 +1083,7 @@ def call_model_parallel_hook(self, model: LightningModule) -> None: self.on_model_parallel_setup(model) with self.accelerator.model_parallel_context(): model.on_model_parallel_setup() + self.accelerator.call_model_parallel_setup_hook = False def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index ee45878863551..c0b9efee947fb 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -109,3 +109,30 @@ def call_model_parallel_setup_hook(self) -> bool: trainer.fit(model) assert not model.on_model_parallel_setup_called + + +def test_model_parallel_setup_called_once(tmpdir): + """Ensure ``on_model_parallel_setup`` is only called once""" + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.on_model_parallel_setup_called = False + + def on_model_parallel_setup(self): + self.on_model_parallel_setup_called = True + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + + assert model.on_model_parallel_setup_called + model.on_model_parallel_setup_called = False + + assert not model.on_model_parallel_setup_called From c99a36f960edda694d18426c2319e1e2a9dadac1 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 23:53:06 +0530 Subject: [PATCH 08/27] Update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 32cf9122efe34..fc9b18ff3f0b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) +- Added `on_model_parallel_setup` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679)) + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) From 9529a22882057fcd4f43d4842d35799083d1de80 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 23:56:32 +0530 Subject: [PATCH 09/27] Fix --- pytorch_lightning/accelerators/accelerator.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 293a5e28b4c79..d6041fe1fca9f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -452,20 +452,6 @@ def model_parallel_context(self) -> Generator: with self.training_type_plugin.model_parallel_context(): yield - @property - def call_model_parallel_setup_hook(self) -> bool: - """ - Allow model parallel hook to be called in suitable environments determined by the training type plugin. - This is useful for when we want to shard the model once within fit. - Returns: True if we want to call the model parallel setup hook. - """ - return self.training_type_plugin.call_model_parallel_setup_hook - - @call_model_parallel_setup_hook.setter - def call_model_parallel_setup_hook(self, mode: bool) -> bool: - if isinstance(mode, bool): - self.training_type_plugin.call_model_parallel_setup_hook = mode - # todo: remove in v1.5 def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: """ @@ -493,3 +479,17 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: ' It will be removed in v1.5.' ) self.setup_precision_plugin(plugin) + + @property + def call_model_parallel_setup_hook(self) -> bool: + """ + Allow model parallel hook to be called in suitable environments determined by the training type plugin. + This is useful for when we want to shard the model once within fit. + Returns: True if we want to call the model parallel setup hook. + """ + return self.training_type_plugin.call_model_parallel_setup_hook + + @call_model_parallel_setup_hook.setter + def call_model_parallel_setup_hook(self, mode: bool) -> bool: + if isinstance(mode, bool): + self.training_type_plugin.call_model_parallel_setup_hook = mode From 3c1c782187923c99cf82352819525f45e31ba5a8 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 23:59:06 +0530 Subject: [PATCH 10/27] fix return type --- pytorch_lightning/accelerators/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index d6041fe1fca9f..d7b7156e4ad96 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -490,6 +490,6 @@ def call_model_parallel_setup_hook(self) -> bool: return self.training_type_plugin.call_model_parallel_setup_hook @call_model_parallel_setup_hook.setter - def call_model_parallel_setup_hook(self, mode: bool) -> bool: + def call_model_parallel_setup_hook(self, mode: bool) -> None: if isinstance(mode, bool): self.training_type_plugin.call_model_parallel_setup_hook = mode From a49ec3b4e0850f8d7501aa59f8e6b3a63e1b9bc3 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 26 Mar 2021 00:27:05 +0530 Subject: [PATCH 11/27] fix lambda callback test --- pytorch_lightning/callbacks/lambda_function.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 58324e363cd37..71bfd79ec9427 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -42,6 +42,7 @@ def __init__( self, on_before_accelerator_backend_setup: Optional[Callable] = None, setup: Optional[Callable] = None, + on_model_parallel_setup: Optional[Callable] = None, teardown: Optional[Callable] = None, on_init_start: Optional[Callable] = None, on_init_end: Optional[Callable] = None, @@ -83,6 +84,8 @@ def __init__( self.on_before_accelerator_backend_setup = on_before_accelerator_backend_setup if setup is not None: self.setup = setup + if on_model_parallel_setup is not None: + self.on_model_parallel_setup = on_model_parallel_setup if teardown is not None: self.teardown = teardown if on_init_start is not None: From 4dd55d7887ca5a1748b2e351e1d67d4e30dd8cb6 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 26 Mar 2021 00:32:14 +0530 Subject: [PATCH 12/27] Fix tests --- .../connectors/logger_connector/callback_hook_validator.py | 5 +++++ tests/trainer/logging_/test_logger_connector.py | 2 ++ 2 files changed, 7 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 534dad5199e9b..a447a38e3489c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -55,6 +55,11 @@ def _setup_log(): """Called when fit or test begins""" return None + @staticmethod + def _on_model_parallel_setup_log(): + """Called when fit or test begins""" + return None + @staticmethod def _teardown_log(): """Called at the end of fit and test""" diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index d14ed71940328..1a7c2103503e8 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -280,6 +280,7 @@ def test_call_back_validator(tmpdir): 'on_epoch_end', 'on_epoch_start', 'on_fit_end', + 'on_model_parallel_setup', 'on_fit_start', 'on_init_end', 'on_init_start', @@ -316,6 +317,7 @@ def test_call_back_validator(tmpdir): "on_before_accelerator_backend_setup", "on_fit_end", "on_fit_start", + "on_model_parallel_setup", "on_init_end", "on_init_start", "on_keyboard_interrupt", From caad43c2858c7f8b562d42a9ce88b492540cc243 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 26 Mar 2021 01:02:58 +0530 Subject: [PATCH 13/27] Apply code suggestions --- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/core/hooks.py | 2 +- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index d7b7156e4ad96..1573d0f0f3b5d 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -443,7 +443,7 @@ def results(self) -> Any: @contextlib.contextmanager def model_parallel_context(self) -> Generator: """ - Provide hook to create modules in a parallel aware context. This is useful for when we'd like to + Provide hook to create modules in a distributed aware context. This is useful for when we'd like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time. diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index ef5e26cea8484..a36e8c6fc91a7 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -312,7 +312,7 @@ def on_post_move_to_device(self): def on_model_parallel_setup(self) -> None: """ - Hook to create modules in a parallel aware context. This is useful for when using sharded plugins, + Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, where we'd like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time. """ diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 028350c0581eb..de656f8210571 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -198,7 +198,7 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: @contextlib.contextmanager def model_parallel_context(self) -> Generator: """ - Provide hook to create modules in a parallel aware context. This is useful for when we'd like to + Provide hook to create modules in a distributed aware context. This is useful for when we'd like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6198d4603d555..623066dfbf7f4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1089,8 +1089,8 @@ def call_model_parallel_hook(self, model: LightningModule) -> None: # Call model parallel hook if accelerator requests. In some cases # we will not call the hook; the hook has initialized the sharded model for example. if self.accelerator.call_model_parallel_setup_hook: - self.on_model_parallel_setup(model) with self.accelerator.model_parallel_context(): + self.on_model_parallel_setup(model) model.on_model_parallel_setup() self.accelerator.call_model_parallel_setup_hook = False From a2574bec6173cbb9d31df082b68ab17f8341d4d0 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 26 Mar 2021 01:36:58 +0530 Subject: [PATCH 14/27] add logic for setup_optimizers_predispatch --- pytorch_lightning/accelerators/accelerator.py | 10 ++++ pytorch_lightning/trainer/trainer.py | 2 +- tests/accelerators/test_common.py | 54 +++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 1573d0f0f3b5d..8ba87a7020480 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -493,3 +493,13 @@ def call_model_parallel_setup_hook(self) -> bool: def call_model_parallel_setup_hook(self, mode: bool) -> None: if isinstance(mode, bool): self.training_type_plugin.call_model_parallel_setup_hook = mode + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + """ + Override to delay setting optimizers and schedulers till after dispatch. + This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. + However this may break certain precision plugins such as APEX which require optimizers to be set. + Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. + """ + return self.training_type_plugin.setup_optimizers_in_pre_dispatch diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 623066dfbf7f4..fdb88a42f6f9a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1088,7 +1088,7 @@ def call_setup_hook(self, model: LightningModule) -> None: def call_model_parallel_hook(self, model: LightningModule) -> None: # Call model parallel hook if accelerator requests. In some cases # we will not call the hook; the hook has initialized the sharded model for example. - if self.accelerator.call_model_parallel_setup_hook: + if self.accelerator.call_model_parallel_setup_hook and self.accelerator.setup_optimizers_in_pre_dispatch: with self.accelerator.model_parallel_context(): self.on_model_parallel_setup(model) model.on_model_parallel_setup() diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index c0b9efee947fb..b087569af9691 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -65,8 +65,20 @@ def test_model_parallel_setup_called(tmpdir): class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.on_model_parallel_setup_called = False + self.layer = None + def on_model_parallel_setup(self): self.on_model_parallel_setup_called = True + self.layer = torch.nn.Linear(32, 2) + + class CustomPlugin(SingleDevicePlugin): + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + return True model = TestModel() trainer = Trainer( @@ -74,6 +86,7 @@ def on_model_parallel_setup(self): limit_train_batches=2, limit_val_batches=2, max_epochs=1, + plugins=CustomPlugin(device=torch.device("cpu")) ) trainer.fit(model) @@ -123,12 +136,19 @@ def __init__(self): def on_model_parallel_setup(self): self.on_model_parallel_setup_called = True + class CustomPlugin(SingleDevicePlugin): + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + return True + model = TestModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1, + plugins=CustomPlugin(device=torch.device("cpu")) ) trainer.fit(model) @@ -136,3 +156,37 @@ def on_model_parallel_setup(self): model.on_model_parallel_setup_called = False assert not model.on_model_parallel_setup_called + + +def test_model_parallel_setup_when_setup_optimizers_pre_dispatch_false(tmpdir): + """ + Ensure ``on_model_parallel_setup`` is not called, + when ``setup_optimizers_in_pre_dispatch`` set False. + """ + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.on_model_parallel_setup_called = False + + def on_model_parallel_setup(self): + self.on_model_parallel_setup_called = True + + class CustomPlugin(SingleDevicePlugin): + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + return False + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + plugins=CustomPlugin(device=torch.device("cpu")) + ) + trainer.fit(model) + + assert not model.on_model_parallel_setup_called From 8c2bd6a25be9f06439887ec78a63dff6097f9a81 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 26 Mar 2021 01:40:26 +0530 Subject: [PATCH 15/27] add common dummy model --- tests/accelerators/test_common.py | 41 +++++++++---------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index b087569af9691..548aca5d2f505 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -93,17 +93,18 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: assert model.on_model_parallel_setup_called -def test_model_parallel_setup_false(tmpdir): - """Ensure ``on_model_parallel_setup`` is not called, when turned off""" +class DummyModel(BoringModel): - class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.on_model_parallel_setup_called = False - def __init__(self): - super().__init__() - self.on_model_parallel_setup_called = False + def on_model_parallel_setup(self): + self.on_model_parallel_setup_called = True - def on_model_parallel_setup(self): - self.on_model_parallel_setup_called = True + +def test_model_parallel_setup_false(tmpdir): + """Ensure ``on_model_parallel_setup`` is not called, when turned off""" class CustomPlugin(SingleDevicePlugin): @@ -111,7 +112,7 @@ class CustomPlugin(SingleDevicePlugin): def call_model_parallel_setup_hook(self) -> bool: return False - model = TestModel() + model = DummyModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, @@ -127,22 +128,13 @@ def call_model_parallel_setup_hook(self) -> bool: def test_model_parallel_setup_called_once(tmpdir): """Ensure ``on_model_parallel_setup`` is only called once""" - class TestModel(BoringModel): - - def __init__(self): - super().__init__() - self.on_model_parallel_setup_called = False - - def on_model_parallel_setup(self): - self.on_model_parallel_setup_called = True - class CustomPlugin(SingleDevicePlugin): @property def setup_optimizers_in_pre_dispatch(self) -> bool: return True - model = TestModel() + model = DummyModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, @@ -164,22 +156,13 @@ def test_model_parallel_setup_when_setup_optimizers_pre_dispatch_false(tmpdir): when ``setup_optimizers_in_pre_dispatch`` set False. """ - class TestModel(BoringModel): - - def __init__(self): - super().__init__() - self.on_model_parallel_setup_called = False - - def on_model_parallel_setup(self): - self.on_model_parallel_setup_called = True - class CustomPlugin(SingleDevicePlugin): @property def setup_optimizers_in_pre_dispatch(self) -> bool: return False - model = TestModel() + model = DummyModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, From 32405694f9458f607ef1af4ba4a75dfaa425db32 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 25 Mar 2021 23:10:33 +0000 Subject: [PATCH 16/27] Swap call order --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fdb88a42f6f9a..a5e0618192673 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -436,8 +436,8 @@ def fit( self.accelerator.connect(model) self.accelerator.setup_environment() self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment - self.accelerator.setup(self, model) # note: this sets up self.lightning_module self.call_model_parallel_hook(model) # allow user to setup in model parallel environment + self.accelerator.setup(self, model) # note: this sets up self.lightning_module # ---------------------------- # INSPECT THE CORE LOOPS @@ -1088,7 +1088,7 @@ def call_setup_hook(self, model: LightningModule) -> None: def call_model_parallel_hook(self, model: LightningModule) -> None: # Call model parallel hook if accelerator requests. In some cases # we will not call the hook; the hook has initialized the sharded model for example. - if self.accelerator.call_model_parallel_setup_hook and self.accelerator.setup_optimizers_in_pre_dispatch: + if self.accelerator.call_model_parallel_setup_hook: with self.accelerator.model_parallel_context(): self.on_model_parallel_setup(model) model.on_model_parallel_setup() From 897bdbb8c487f9a1ea9faca286d7edec89ec5ccf Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 25 Mar 2021 23:27:12 +0000 Subject: [PATCH 17/27] Remove test that isn't needed anymore --- tests/accelerators/test_common.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 548aca5d2f505..829497c5fb161 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -148,28 +148,3 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: model.on_model_parallel_setup_called = False assert not model.on_model_parallel_setup_called - - -def test_model_parallel_setup_when_setup_optimizers_pre_dispatch_false(tmpdir): - """ - Ensure ``on_model_parallel_setup`` is not called, - when ``setup_optimizers_in_pre_dispatch`` set False. - """ - - class CustomPlugin(SingleDevicePlugin): - - @property - def setup_optimizers_in_pre_dispatch(self) -> bool: - return False - - model = DummyModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - plugins=CustomPlugin(device=torch.device("cpu")) - ) - trainer.fit(model) - - assert not model.on_model_parallel_setup_called From 626fc7b702718d503cfef79fcc2c49bbd7f8d71a Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 26 Mar 2021 16:42:37 +0530 Subject: [PATCH 18/27] Update tests --- tests/accelerators/test_common.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 829497c5fb161..148953e846e1e 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -74,19 +74,12 @@ def on_model_parallel_setup(self): self.on_model_parallel_setup_called = True self.layer = torch.nn.Linear(32, 2) - class CustomPlugin(SingleDevicePlugin): - - @property - def setup_optimizers_in_pre_dispatch(self) -> bool: - return True - model = TestModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1, - plugins=CustomPlugin(device=torch.device("cpu")) ) trainer.fit(model) @@ -128,19 +121,12 @@ def call_model_parallel_setup_hook(self) -> bool: def test_model_parallel_setup_called_once(tmpdir): """Ensure ``on_model_parallel_setup`` is only called once""" - class CustomPlugin(SingleDevicePlugin): - - @property - def setup_optimizers_in_pre_dispatch(self) -> bool: - return True - model = DummyModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1, - plugins=CustomPlugin(device=torch.device("cpu")) ) trainer.fit(model) From e94a7ae7ea11b270d15a9b3de61a01ffe53ca6a9 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 26 Mar 2021 15:14:25 +0000 Subject: [PATCH 19/27] Add a bit more doc --- pytorch_lightning/core/hooks.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index a36e8c6fc91a7..9ff77e494f1ef 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -315,6 +315,13 @@ def on_model_parallel_setup(self) -> None: Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, where we'd like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time. + + The accelerator manages whether to call this hook at every given stage. + For sharded plugins where model parallelism is required, the hook is usually on called once + to initialize the sharded parameters, and not called again in the same process. + + By default for accelerators/plugins that do not use model sharding techniques, + this hook is called during each fit/val/test/predict stages. """ From 202ef1ab4da611b222c6c702919d79c9d36661d2 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 29 Mar 2021 12:36:07 +0100 Subject: [PATCH 20/27] Few code review fixes --- pytorch_lightning/accelerators/accelerator.py | 3 +-- .../plugins/training_type/training_type_plugin.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8ba87a7020480..d2201ceb15a64 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -491,8 +491,7 @@ def call_model_parallel_setup_hook(self) -> bool: @call_model_parallel_setup_hook.setter def call_model_parallel_setup_hook(self, mode: bool) -> None: - if isinstance(mode, bool): - self.training_type_plugin.call_model_parallel_setup_hook = mode + self.training_type_plugin.call_model_parallel_setup_hook = mode @property def setup_optimizers_in_pre_dispatch(self) -> bool: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index de656f8210571..ced959724355f 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -216,5 +216,5 @@ def call_model_parallel_setup_hook(self) -> bool: return self._call_model_parallel_setup_hook @call_model_parallel_setup_hook.setter - def call_model_parallel_setup_hook(self, mode: bool) -> bool: + def call_model_parallel_setup_hook(self, mode: bool) -> None: self._call_model_parallel_setup_hook = mode From 0709baad638a883331aac89ee466f75e670b1977 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Mon, 29 Mar 2021 12:36:25 +0100 Subject: [PATCH 21/27] Update pytorch_lightning/accelerators/accelerator.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/accelerators/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index d2201ceb15a64..8f4c65d4b842e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -444,7 +444,7 @@ def results(self) -> Any: def model_parallel_context(self) -> Generator: """ Provide hook to create modules in a distributed aware context. This is useful for when we'd like to - shard the model instantly, which is useful for extremely large models which can save memory and + shard the model instantly - useful for extremely large models. Can save memory and initialization time. Returns: Model parallel context. From 9152d08623f4552a3d6bf5620550e9177d78bb0b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 29 Mar 2021 16:55:19 +0100 Subject: [PATCH 22/27] Change hook name --- CHANGELOG.md | 2 +- pytorch_lightning/callbacks/base.py | 4 +-- .../callbacks/lambda_function.py | 6 ++--- pytorch_lightning/core/hooks.py | 2 +- pytorch_lightning/trainer/callback_hook.py | 4 +-- .../callback_hook_validator.py | 4 +-- pytorch_lightning/trainer/trainer.py | 4 +-- tests/accelerators/test_common.py | 26 +++++++++---------- tests/callbacks/test_callbacks.py | 6 ++--- .../trainer/logging_/test_logger_connector.py | 4 +-- 10 files changed, 31 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 83960db490b10..cb7ee3fdf5bf2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,7 +64,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) -- Added `on_model_parallel_setup` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679)) +- Added `configure_sharded_model` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679)) - Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595)) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 807b178dbe4ba..768e4ebca30ee 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -29,8 +29,8 @@ class Callback(abc.ABC): Subclass this class and override any of the relevant hooks """ - def on_model_parallel_setup(self, trainer, pl_module: LightningModule) -> None: - """Called before model parallel accelerator setup""" + def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None: + """Called before configure sharded model""" def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule) -> None: """Called before accelerator is being setup""" diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 71bfd79ec9427..707b6694826b5 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -42,7 +42,7 @@ def __init__( self, on_before_accelerator_backend_setup: Optional[Callable] = None, setup: Optional[Callable] = None, - on_model_parallel_setup: Optional[Callable] = None, + configure_sharded_model: Optional[Callable] = None, teardown: Optional[Callable] = None, on_init_start: Optional[Callable] = None, on_init_end: Optional[Callable] = None, @@ -84,8 +84,8 @@ def __init__( self.on_before_accelerator_backend_setup = on_before_accelerator_backend_setup if setup is not None: self.setup = setup - if on_model_parallel_setup is not None: - self.on_model_parallel_setup = on_model_parallel_setup + if configure_sharded_model is not None: + self.configure_sharded_model = configure_sharded_model if teardown is not None: self.teardown = teardown if on_init_start is not None: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9ff77e494f1ef..b320a9b223840 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -310,7 +310,7 @@ def on_post_move_to_device(self): """ - def on_model_parallel_setup(self) -> None: + def configure_sharded_model(self) -> None: """ Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, where we'd like to shard the model instantly, which is useful for extremely large models diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index b4f70d17ba403..606f6b2e4b52b 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -38,10 +38,10 @@ def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def on_model_parallel_setup(self, model: LightningModule) -> None: + def configure_sharded_model(self, model: LightningModule) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.on_model_parallel_setup(self, model) + callback.on_configure_sharded_model(self, model) def setup(self, model: LightningModule, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index a447a38e3489c..87b730403b551 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -56,8 +56,8 @@ def _setup_log(): return None @staticmethod - def _on_model_parallel_setup_log(): - """Called when fit or test begins""" + def _on_configure_sharded_model_log(): + """Called before configure sharded model""" return None @staticmethod diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 05ab03adcc1ce..66bd2d8a2f990 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1081,8 +1081,8 @@ def call_model_parallel_hook(self, model: LightningModule) -> None: # we will not call the hook; the hook has initialized the sharded model for example. if self.accelerator.call_model_parallel_setup_hook: with self.accelerator.model_parallel_context(): - self.on_model_parallel_setup(model) - model.on_model_parallel_setup() + self.configure_sharded_model(model) + model.configure_sharded_model() self.accelerator.call_model_parallel_setup_hook = False def call_teardown_hook(self, model: LightningModule) -> None: diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 148953e846e1e..e4453194818b6 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -67,11 +67,11 @@ class TestModel(BoringModel): def __init__(self): super().__init__() - self.on_model_parallel_setup_called = False + self.configure_sharded_model_called = False self.layer = None - def on_model_parallel_setup(self): - self.on_model_parallel_setup_called = True + def configure_sharded_model(self): + self.configure_sharded_model_called = True self.layer = torch.nn.Linear(32, 2) model = TestModel() @@ -83,21 +83,21 @@ def on_model_parallel_setup(self): ) trainer.fit(model) - assert model.on_model_parallel_setup_called + assert model.configure_sharded_model_called class DummyModel(BoringModel): def __init__(self): super().__init__() - self.on_model_parallel_setup_called = False + self.configure_sharded_model_called = False - def on_model_parallel_setup(self): - self.on_model_parallel_setup_called = True + def configure_sharded_model(self): + self.configure_sharded_model_called = True def test_model_parallel_setup_false(tmpdir): - """Ensure ``on_model_parallel_setup`` is not called, when turned off""" + """Ensure ``configure_sharded_model`` is not called, when turned off""" class CustomPlugin(SingleDevicePlugin): @@ -115,11 +115,11 @@ def call_model_parallel_setup_hook(self) -> bool: ) trainer.fit(model) - assert not model.on_model_parallel_setup_called + assert not model.configure_sharded_model_called def test_model_parallel_setup_called_once(tmpdir): - """Ensure ``on_model_parallel_setup`` is only called once""" + """Ensure ``configure_sharded_model`` is only called once""" model = DummyModel() trainer = Trainer( @@ -130,7 +130,7 @@ def test_model_parallel_setup_called_once(tmpdir): ) trainer.fit(model) - assert model.on_model_parallel_setup_called - model.on_model_parallel_setup_called = False + assert model.configure_sharded_model_called + model.configure_sharded_model_called = False - assert not model.on_model_parallel_setup_called + assert not model.configure_sharded_model_called diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index b5cca6d0eaff3..a30b4fe0f609b 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -48,7 +48,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'fit'), - call.on_model_parallel_setup(trainer, model), + call.on_configure_sharded_model(trainer, model), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), @@ -120,7 +120,7 @@ def test_trainer_callback_hook_system_test(tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'test'), - call.on_model_parallel_setup(trainer, model), + call.on_configure_sharded_model(trainer, model), call.on_test_start(trainer, model), call.on_epoch_start(trainer, model), call.on_test_epoch_start(trainer, model), @@ -155,7 +155,7 @@ def test_trainer_callback_hook_system_validate(tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'validate'), - call.on_model_parallel_setup(trainer, model), + call.on_configure_sharded_model(trainer, model), call.on_validation_start(trainer, model), call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 5ca205772fd91..6f15331acaa76 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -280,7 +280,7 @@ def test_call_back_validator(tmpdir): 'on_epoch_end', 'on_epoch_start', 'on_fit_end', - 'on_model_parallel_setup', + 'on_configure_sharded_model', 'on_fit_start', 'on_init_end', 'on_init_start', @@ -317,7 +317,7 @@ def test_call_back_validator(tmpdir): "on_before_accelerator_backend_setup", "on_fit_end", "on_fit_start", - "on_model_parallel_setup", + "on_configure_sharded_model", "on_init_end", "on_init_start", "on_keyboard_interrupt", From fbfe65fc45970c942d0890929173c179531f7e69 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 29 Mar 2021 17:25:14 +0100 Subject: [PATCH 23/27] Fix test --- pytorch_lightning/callbacks/lambda_function.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 707b6694826b5..a7485814b1b17 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -42,7 +42,7 @@ def __init__( self, on_before_accelerator_backend_setup: Optional[Callable] = None, setup: Optional[Callable] = None, - configure_sharded_model: Optional[Callable] = None, + on_configure_sharded_model: Optional[Callable] = None, teardown: Optional[Callable] = None, on_init_start: Optional[Callable] = None, on_init_end: Optional[Callable] = None, @@ -84,8 +84,8 @@ def __init__( self.on_before_accelerator_backend_setup = on_before_accelerator_backend_setup if setup is not None: self.setup = setup - if configure_sharded_model is not None: - self.configure_sharded_model = configure_sharded_model + if on_configure_sharded_model is not None: + self.on_configure_sharded_model = on_configure_sharded_model if teardown is not None: self.teardown = teardown if on_init_start is not None: From bae858f6b72e311682793ebadbe6b07e78029493 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 29 Mar 2021 18:21:44 +0100 Subject: [PATCH 24/27] Test setup hook, refactor names --- pytorch_lightning/accelerators/accelerator.py | 10 ++++----- .../training_type/training_type_plugin.py | 12 +++++------ pytorch_lightning/trainer/trainer.py | 10 ++++----- tests/accelerators/test_common.py | 21 ++++++++++++++++--- 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8f4c65d4b842e..93acf54bd8271 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -481,17 +481,17 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: self.setup_precision_plugin(plugin) @property - def call_model_parallel_setup_hook(self) -> bool: + def call_configure_sharded_model_hook(self) -> bool: """ Allow model parallel hook to be called in suitable environments determined by the training type plugin. This is useful for when we want to shard the model once within fit. Returns: True if we want to call the model parallel setup hook. """ - return self.training_type_plugin.call_model_parallel_setup_hook + return self.training_type_plugin.call_configure_sharded_model_hook - @call_model_parallel_setup_hook.setter - def call_model_parallel_setup_hook(self, mode: bool) -> None: - self.training_type_plugin.call_model_parallel_setup_hook = mode + @call_configure_sharded_model_hook.setter + def call_configure_sharded_model_hook(self, mode: bool) -> None: + self.training_type_plugin.call_configure_sharded_model_hook = mode @property def setup_optimizers_in_pre_dispatch(self) -> bool: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index ced959724355f..7ac4f6bed6711 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -34,7 +34,7 @@ class TrainingTypePlugin(Plugin, ABC): def __init__(self) -> None: self._model = None self._results = None - self._call_model_parallel_setup_hook = True + self._call_configure_sharded_model_hook = True def connect(self, model: 'Module') -> None: """Called by the accelerator to connect the accelerator and the model with this plugin""" @@ -207,14 +207,14 @@ def model_parallel_context(self) -> Generator: yield @property - def call_model_parallel_setup_hook(self) -> bool: + def call_configure_sharded_model_hook(self) -> bool: """ Allow model parallel hook to be called in suitable environments determined by the training type plugin. This is useful for when we want to shard the model once within fit. Returns: True if we want to call the model parallel setup hook. """ - return self._call_model_parallel_setup_hook + return self._call_configure_sharded_model_hook - @call_model_parallel_setup_hook.setter - def call_model_parallel_setup_hook(self, mode: bool) -> None: - self._call_model_parallel_setup_hook = mode + @call_configure_sharded_model_hook.setter + def call_configure_sharded_model_hook(self, mode: bool) -> None: + self._call_configure_sharded_model_hook = mode diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 66bd2d8a2f990..24bb2bc6e92a9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -436,7 +436,7 @@ def fit( self.accelerator.connect(model) self.accelerator.setup_environment() self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment - self.call_model_parallel_hook(model) # allow user to setup in model parallel environment + self.call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module # ---------------------------- @@ -1076,14 +1076,14 @@ def call_setup_hook(self, model: LightningModule) -> None: self.setup(model, stage=state) model.setup(stage=state) - def call_model_parallel_hook(self, model: LightningModule) -> None: - # Call model parallel hook if accelerator requests. In some cases + def call_configure_sharded_model(self, model: LightningModule) -> None: + # Call configure sharded model hook if accelerator requests. In some cases # we will not call the hook; the hook has initialized the sharded model for example. - if self.accelerator.call_model_parallel_setup_hook: + if self.accelerator.call_configure_sharded_model_hook: with self.accelerator.model_parallel_context(): self.configure_sharded_model(model) model.configure_sharded_model() - self.accelerator.call_model_parallel_setup_hook = False + self.accelerator.call_configure_sharded_model_hook = False def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index e4453194818b6..2ad151d75e76c 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -96,13 +96,13 @@ def configure_sharded_model(self): self.configure_sharded_model_called = True -def test_model_parallel_setup_false(tmpdir): +def test_configure_sharded_model_false(tmpdir): """Ensure ``configure_sharded_model`` is not called, when turned off""" class CustomPlugin(SingleDevicePlugin): @property - def call_model_parallel_setup_hook(self) -> bool: + def call_configure_sharded_model_hook(self) -> bool: return False model = DummyModel() @@ -118,7 +118,22 @@ def call_model_parallel_setup_hook(self) -> bool: assert not model.configure_sharded_model_called -def test_model_parallel_setup_called_once(tmpdir): +def test_accelerator_configure_sharded_model_called_once(tmpdir): + """Ensure that the configure sharded model hook is called, and set to False after to ensure not called again.""" + + model = DummyModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + assert trainer.accelerator.call_configure_sharded_model_hook is True + trainer.fit(model) + assert trainer.accelerator.call_configure_sharded_model_hook is False + + +def test_configure_sharded_model_called_once(tmpdir): """Ensure ``configure_sharded_model`` is only called once""" model = DummyModel() From 41e9c22e77856383626b274a8ea150c4add1a7b0 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 29 Mar 2021 19:33:37 +0100 Subject: [PATCH 25/27] Swap call order of callbacks and model initialization --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 24bb2bc6e92a9..f820895ee4ff2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1081,8 +1081,8 @@ def call_configure_sharded_model(self, model: LightningModule) -> None: # we will not call the hook; the hook has initialized the sharded model for example. if self.accelerator.call_configure_sharded_model_hook: with self.accelerator.model_parallel_context(): - self.configure_sharded_model(model) model.configure_sharded_model() + self.configure_sharded_model(model) self.accelerator.call_configure_sharded_model_hook = False def call_teardown_hook(self, model: LightningModule) -> None: From 76c7376abe8993dc1297bc919b89d2a456cf82a2 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 29 Mar 2021 20:05:14 +0100 Subject: [PATCH 26/27] Change name of context manager --- pytorch_lightning/accelerators/accelerator.py | 4 ++-- .../plugins/training_type/training_type_plugin.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 93acf54bd8271..102a97e1eecee 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -441,7 +441,7 @@ def results(self) -> Any: return self.training_type_plugin.results @contextlib.contextmanager - def model_parallel_context(self) -> Generator: + def model_sharded_context(self) -> Generator: """ Provide hook to create modules in a distributed aware context. This is useful for when we'd like to shard the model instantly - useful for extremely large models. Can save memory and @@ -449,7 +449,7 @@ def model_parallel_context(self) -> Generator: Returns: Model parallel context. """ - with self.training_type_plugin.model_parallel_context(): + with self.training_type_plugin.model_sharded_context(): yield # todo: remove in v1.5 diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 7ac4f6bed6711..02a49eb22a760 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -196,7 +196,7 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: return False @contextlib.contextmanager - def model_parallel_context(self) -> Generator: + def model_sharded_context(self) -> Generator: """ Provide hook to create modules in a distributed aware context. This is useful for when we'd like to shard the model instantly, which is useful for extremely large models which can save memory and diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f820895ee4ff2..6bf8983867b25 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1080,7 +1080,7 @@ def call_configure_sharded_model(self, model: LightningModule) -> None: # Call configure sharded model hook if accelerator requests. In some cases # we will not call the hook; the hook has initialized the sharded model for example. if self.accelerator.call_configure_sharded_model_hook: - with self.accelerator.model_parallel_context(): + with self.accelerator.model_sharded_context(): model.configure_sharded_model() self.configure_sharded_model(model) self.accelerator.call_configure_sharded_model_hook = False From aa35583afb3eae9d309aec4cf01b22802e2342f7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 29 Mar 2021 20:15:19 +0100 Subject: [PATCH 27/27] add docstring --- pytorch_lightning/accelerators/accelerator.py | 6 ++++++ pytorch_lightning/plugins/training_type/tpu_spawn.py | 3 +-- .../plugins/training_type/training_type_plugin.py | 6 ++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 31736c13c6351..7d16d91e3bf82 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -481,6 +481,12 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: self.setup_precision_plugin(plugin) def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ self.training_type_plugin.save_checkpoint(checkpoint, filepath) @property diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index aee2b8914b579..ba074e7cfb206 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -300,9 +300,8 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: - trainer: PyTorch Lightning Trainer + checkpoint: dict containing model and trainer state filepath: write-target file's path - weights_only: saving model weights only """ # Todo: TypeError: 'mappingproxy' object does not support item assignment self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 0250933c9da3c..1eac88212e0fb 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -198,6 +198,12 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: return False def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ # dump states as a checkpoint dictionary object if self.is_global_zero: checkpoint = self.on_save(checkpoint)