diff --git a/CHANGELOG.md b/CHANGELOG.md index c2b50c1d41f5b..d1eed9848cdbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raised an error if the `batch_size` cannot be inferred from the current batch if it contained a string or was a custom batch object ([#10541](https://github.com/PyTorchLightning/pytorch-lightning/pull/10541)) +- Moved optimizer related logics from `Accelerator` to `TrainingTypePlugin` ([#10596](https://github.com/PyTorchLightning/pytorch-lightning/pull/10596)) + + - Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8ccc2d86edd9e..b50a19221325e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -13,21 +13,14 @@ # limitations under the License. import contextlib from abc import abstractmethod -from typing import Any, Callable, Dict, Generator, List, Optional, Union +from typing import Any, Dict, Generator, Optional, Union import torch -from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn import Module -from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import rank_zero_deprecation -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device -from pytorch_lightning.utilities.enums import AMPType, LightningEnum from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -62,10 +55,6 @@ def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_pl if precision_plugin is not None: self.training_type_plugin._precision_plugin = precision_plugin - self.optimizers: List = [] - self.lr_schedulers: List = [] - self.optimizer_frequencies: List = [] - def setup_environment(self) -> None: """Setup any processes or distributed connections. @@ -80,28 +69,18 @@ def setup(self, trainer: "pl.Trainer") -> None: Args: trainer: the trainer instance """ - self.setup_training_type_plugin() - if not self.training_type_plugin.setup_optimizers_in_pre_dispatch: - self.setup_optimizers(trainer) - self.setup_precision_plugin() + self.training_type_plugin.setup(trainer) def pre_dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something before the training/evaluation/prediction starts.""" - self._move_optimizer_state() + self.training_type_plugin._move_optimizer_state() self.training_type_plugin.pre_dispatch() if self.training_type_plugin.setup_optimizers_in_pre_dispatch: - self.setup_optimizers(trainer) + self.training_type_plugin.setup_optimizers(trainer) self.training_type_plugin.precision_plugin.pre_dispatch() - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: - """Moves the state of the optimizers to the GPU if needed.""" - device = device or self.root_device - for opt in self.optimizers: - for p, v in opt.state.items(): - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device) - def dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.dispatch(trainer) @@ -177,115 +156,12 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: with self.training_type_plugin.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) - def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: - """Forwards backward-calls to the precision plugin. - - Args: - closure_loss: a tensor holding the loss value to backpropagate - """ - self.training_type_plugin.pre_backward(closure_loss) - closure_loss = self.training_type_plugin.precision_plugin.pre_backward(self.lightning_module, closure_loss) - - self.training_type_plugin.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) - - closure_loss = self.training_type_plugin.precision_plugin.post_backward(self.lightning_module, closure_loss) - self.training_type_plugin.post_backward(closure_loss) - - return closure_loss - - def optimizer_step( - self, - optimizer: Optimizer, - opt_idx: int, - closure: Callable[[], Any], - model: Optional[Union["pl.LightningModule", Module]] = None, - **kwargs: Any, - ) -> None: - """performs the actual optimizer step. - - Args: - optimizer: the optimizer performing the step - opt_idx: index of the current optimizer - closure: closure calculating the loss value - model: reference to the model, optionally defining optimizer step related hooks - **kwargs: Any extra arguments to ``optimizer.step`` - """ - model = model or self.lightning_module - self.training_type_plugin.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) - - def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: - """Zeros all model parameter's gradients.""" - model_ref = self.lightning_module - model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) - - def setup_optimizers(self, trainer: "pl.Trainer") -> None: - """Creates optimizers and schedulers. - - Args: - trainer: the Trainer, these optimizers should be connected to - """ - if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): - return - optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers( - trainer=trainer, model=self.lightning_module - ) - self.optimizers = optimizers - self.lr_schedulers = lr_schedulers - self.optimizer_frequencies = optimizer_frequencies - - def setup_training_type_plugin(self) -> None: - """Attaches the training type plugin to the accelerator.""" - self.training_type_plugin.setup() - - def setup_precision_plugin(self) -> None: - """Attaches the precision plugin to the accelerator.""" - model, optimizers, schedulers = self.training_type_plugin.precision_plugin.connect( - self.model, self.optimizers, self.lr_schedulers - ) - self.model = model - self.optimizers = optimizers - self.lr_schedulers = schedulers - - @property - def amp_backend(self) -> Optional[LightningEnum]: - if isinstance(self.training_type_plugin.precision_plugin, ApexMixedPrecisionPlugin): - return AMPType.APEX - if isinstance(self.training_type_plugin.precision_plugin, NativeMixedPrecisionPlugin): - return AMPType.NATIVE - return None - - @property - def precision(self) -> Union[str, int]: - """The type of precision being used with this accelerator. - - .. deprecated:: - This property been deprecated and will be removed soon. - Use ``training_type_plugin.precision_plugin.precision`` instead. - """ - rank_zero_deprecation( - f"`{self.__class__.__name__}.precision` has been deprecated and will be removed soon" - f" Use `training_type_plugin.precision_plugin.precision` instead." - ) - return self.training_type_plugin.precision_plugin.precision - - @property - def scaler(self) -> Optional["GradScaler"]: - return getattr(self.training_type_plugin.precision_plugin, "scaler", None) - - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: - """Returns state of an optimizer. - - Allows for syncing/collating optimizer state from processes in custom plugins. - """ - return getattr(self.training_type_plugin, "optimizer_state", lambda x: x.state_dict())(optimizer) - @contextlib.contextmanager def model_sharded_context(self) -> Generator[None, None, None]: """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to. shard the model instantly - useful for extremely large models. Can save memory and initialization time. - Returns: Model parallel context. """ diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 8b18676effb79..7d5786102d0b3 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -29,8 +29,10 @@ def setup(self, trainer: "pl.Trainer") -> None: MisconfigurationException: If the selected device is not CPU. """ - if "cpu" not in str(self.root_device): - raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.") + if "cpu" not in str(self.training_type_plugin.root_device): + raise MisconfigurationException( + f"Device should be CPU, got {self.training_type_plugin.root_device} instead." + ) return super().setup(trainer) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 62af5f27dcc1c..49d0770e54ff0 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -37,9 +37,11 @@ def setup_environment(self) -> None: If the selected device is not GPU. """ super().setup_environment() - if "cuda" not in str(self.root_device): - raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") - torch.cuda.set_device(self.root_device) + if "cuda" not in str(self.training_type_plugin.root_device): + raise MisconfigurationException( + f"Device should be GPU, got {self.training_type_plugin.root_device} instead" + ) + torch.cuda.set_device(self.training_type_plugin.root_device) def setup(self, trainer: "pl.Trainer") -> None: self.set_nvidia_flags(trainer.local_rank) @@ -77,7 +79,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: def teardown(self) -> None: super().teardown() - self._move_optimizer_state(torch.device("cpu")) + self.training_type_plugin._move_optimizer_state(torch.device("cpu")) @staticmethod def auto_device_count() -> int: diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index 0f6bdb8270395..155dce5275a9b 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -15,25 +15,12 @@ import torch -import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities.exceptions import MisconfigurationException class IPUAccelerator(Accelerator): """Accelerator for IPUs.""" - def setup_optimizers(self, trainer: "pl.Trainer") -> None: - """ - Raises: - MisconfigurationException: - If multiple optimizers are provided. - """ - super().setup_optimizers(trainer) - - if len(self.optimizers) > 1: - raise MisconfigurationException("IPUs currently only support one optimizer.") - def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """IPU device stats aren't supported yet.""" return {} diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 673e8419ca7fb..f116ed7f0f493 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Union import torch @@ -21,7 +21,6 @@ from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin from pytorch_lightning.utilities import _XLA_AVAILABLE -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm @@ -49,14 +48,6 @@ def setup(self, trainer: "pl.Trainer") -> None: ) return super().setup(trainer) - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: - """Moves the state of the optimizers to the TPU if needed.""" - # TODO: `self.root_device` would raise error if called outside the spawn process - # while training on 8 and more cores. - for opt in self.optimizers: - for p, v in opt.state.items(): - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) - def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """Gets stats for the given TPU device. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 87fa042475697..0c26d3e8af22c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1347,7 +1347,7 @@ def training_step(...): **kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward` """ self._verify_is_manual_optimization("manual_backward") - self.trainer.accelerator.backward(loss, None, None, *args, **kwargs) + self.trainer.training_type_plugin.backward(loss, None, None, *args, **kwargs) def backward( self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index ecd62ab81715e..b3f49d393824f 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -161,4 +161,4 @@ def closure_dis(): trainer = self._trainer assert trainer is not None with trainer.profiler.profile(profiler_action): - trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) + trainer.training_type_plugin.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 0d292dba54176..fede7f5df7291 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -112,7 +112,7 @@ def device(self) -> torch.device: Use this to create tensors directly on the device if needed. """ - return self._accelerator.root_device + return self._strategy.root_device @property def global_rank(self) -> int: diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 26a76e6ed9ccd..6a37453eec93d 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -46,6 +46,8 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer self._accelerator = accelerator + # TODO (@awaelchli) refactor to take Strategy as param + self._strategy = self._accelerator.training_type_plugin @property def optimizer(self) -> Optimizer: @@ -56,11 +58,11 @@ def state_dict(self) -> Dict[str, Tensor]: def step(self, closure: Optional[Callable] = None) -> None: closure = closure or _do_nothing_closure - self._accelerator.optimizer_step( + self._strategy.optimizer_step( self.optimizer, opt_idx=0, closure=closure, - model=self._accelerator.model, + model=self._strategy.model, ) diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index c53b1b87a1c89..301122530b4ad 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -320,7 +320,7 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call return None def backward_fn(loss: Tensor) -> None: - self.trainer.accelerator.backward(loss, optimizer, opt_idx) + self.trainer.training_type_plugin.backward(loss, optimizer, opt_idx) # check if model weights are nan if self.trainer._terminate_on_nan: @@ -402,7 +402,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, optimizer: the current optimizer opt_idx: the index of the current optimizer """ - self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + self.trainer.training_type_plugin.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) self.optim_progress.optimizer.zero_grad.increment_completed() def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult: diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index a0bbb4b9211ac..1e448a226a2a1 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -47,9 +47,9 @@ def main_params(self, optimizer: Optimizer) -> _PARAMETERS: def dispatch(self, trainer: "pl.Trainer") -> None: if not self._connected: - accelerator = trainer.accelerator - _, accelerator.optimizers = amp.initialize( - trainer.lightning_module, accelerator.optimizers, opt_level=self.amp_level + strategy = trainer.training_type_plugin + _, strategy.optimizers = amp.initialize( + trainer.lightning_module, strategy.optimizers, opt_level=self.amp_level ) self._connected = True return super().dispatch(trainer) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index e09100b77207d..fad41f12302f2 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -126,11 +126,12 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self): return True - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() + super().setup(trainer) def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index b02b4bdefaa1b..c12068c025860 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -16,6 +16,7 @@ import torch from torch.nn import DataParallel, Module +import pytorch_lightning as pl from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -61,10 +62,11 @@ def node_rank(self) -> int: def world_size(self) -> int: return 1 - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: # model needs to be moved to the device before it is wrapped self.model_to_device() self._model = self._setup_model(LightningParallelModule(self._model)) + super().setup(trainer) def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: """Moves the batch to the correct device. diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 38fa2942a7819..f50f2ff09dc52 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -157,7 +157,7 @@ def configure_ddp(self) -> None: self.model_to_device() # setup optimizers after fully sharded has wrapped the lightning module - self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer) + self.setup_optimizers(self.lightning_module.trainer) def pre_dispatch(self) -> None: if self.sync_batchnorm: diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 961d2764b8ef3..4aef238abb5db 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -19,6 +19,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler +import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -73,8 +74,9 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() + super().setup(trainer) def pre_dispatch(self): @@ -85,7 +87,7 @@ def pre_dispatch(self): def _unpack_lightning_optimizer(opt): return opt._optimizer if isinstance(opt, LightningOptimizer) else opt - optimizers = self.lightning_module.trainer.optimizers + optimizers = self.optimizers optimizers = [_unpack_lightning_optimizer(opt) for opt in optimizers] # Horovod: scale the learning rate by the number of workers to account for @@ -106,7 +108,7 @@ def _unpack_lightning_optimizer(opt): for optimizer in optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) - self.lightning_module.trainer.accelerator.optimizers = self._wrap_optimizers(optimizers) + self.optimizers = self._wrap_optimizers(optimizers) def start_training(self, trainer): with ExitStack() as stack: diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 8f8f082280156..b072ea8437ea8 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -111,7 +111,7 @@ def __init__( options["autoReport.directory"] = self.autoreport_dir os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps(options) - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: # set the `accumulate_grad_batches` property as early as possible self._handle_gradient_accumulation_steps() @@ -121,6 +121,16 @@ def setup(self) -> None: self._update_dataloader_original = pl.trainer.data_loading._update_dataloader pl.trainer.data_loading._update_dataloader = self._convert_to_poptorch_loader + if not self.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) + self.setup_precision_plugin() + + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + super().setup_optimizers(trainer) + + if len(self.optimizers) > 1: + raise MisconfigurationException("IPUs currently only support one optimizer.") + def pre_dispatch(self) -> None: model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision) self.model = model @@ -315,7 +325,7 @@ def on_predict_end(self): def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: # Updates optimizer stats if LR scheduler modified the optimizer state - optimizer = self.lightning_module.trainer.optimizers[0] + optimizer = self.optimizers[0] self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer) @property diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 9f83f0261c3ec..70bc3bc16c428 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -119,7 +119,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): - self.precision_plugin.scaler = ShardedGradScaler() + self._precision_plugin.scaler = ShardedGradScaler() return super().new_process(trainer, mp_queue) @classmethod diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 12a0f625b64fc..9dde35a589e05 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -15,6 +15,7 @@ import torch +import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin @@ -69,8 +70,9 @@ def root_device(self) -> torch.device: def model_to_device(self) -> None: self._model.to(self.root_device) - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() + super().setup(trainer) @property def is_global_zero(self) -> bool: diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index e6f6a5f4b26f2..f9fa415e67090 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -14,11 +14,15 @@ import os from typing import Any, Dict, Optional +import torch + +import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _PATH @@ -50,7 +54,7 @@ def __init__( def is_distributed(self) -> bool: return False - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: shared_params = find_shared_parameters(self.model) self.model_to_device() if is_overridden("on_post_move_to_device", self.lightning_module): @@ -58,6 +62,18 @@ def setup(self) -> None: else: set_shared_parameters(self.model, shared_params) + if not self.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) + self.setup_precision_plugin() + + def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: + """Moves the state of the optimizers to the TPU if needed.""" + # TODO: `self.root_device` would raise error if called outside the spawn process + # while training on 8 and more cores. + for opt in self.optimizers: + for p, v in opt.state.items(): + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) + def model_to_device(self) -> None: self.model.to(self.root_device) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 9c8f7f18230b8..5ef8a46d7127f 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -32,7 +32,7 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters -from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -121,8 +121,19 @@ def pre_dispatch(self): if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) - def setup(self) -> None: + def setup(self, trainer: "pl.Trainer") -> None: self.create_mp_queue() + if not self.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) + self.setup_precision_plugin() + + def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: + """Moves the state of the optimizers to the TPU if needed.""" + # TODO: `self.root_device` would raise error if called outside the spawn process + # while training on 8 and more cores. + for opt in self.optimizers: + for p, v in opt.state.items(): + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) def _setup_model(self, model: Module) -> Module: return model @@ -170,8 +181,8 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: else: set_shared_parameters(self.model.module, shared_params) - trainer.accelerator.setup_optimizers(trainer) - self.precision_plugin.connect(self._model, None, None) + trainer.training_type_plugin.setup_optimizers(trainer) + trainer.precision_plugin.connect(self._model, None, None) self.barrier("pre-run-stage") diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 75ae5592a29ef..9709467ebd1a0 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,12 +13,13 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch from torch import Tensor from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader import pytorch_lightning as pl @@ -26,7 +27,8 @@ from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT @@ -43,6 +45,9 @@ def __init__( checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin() + self.optimizers: List[Optimizer] = [] + self.lr_schedulers: List[_LRScheduler] = [] + self.optimizer_frequencies: List[int] = [] @property def checkpoint_io(self) -> CheckpointIO: @@ -67,8 +72,92 @@ def setup_environment(self) -> None: environment before setup is complete. """ - def setup(self) -> None: - """Called by the accelerator to finish setup.""" + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + """Creates optimizers and schedulers. + + Args: + trainer: the Trainer, these optimizers should be connected to + """ + if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): + return + optimizers, lr_schedulers, optimizer_frequencies = self.init_optimizers( + trainer=trainer, model=self.lightning_module + ) + self.optimizers = optimizers + self.lr_schedulers = lr_schedulers + self.optimizer_frequencies = optimizer_frequencies + + def setup(self, trainer: "pl.Trainer") -> None: + """Setup plugins for the trainer fit and creates optimizers. + + Args: + trainer: the trainer instance + """ + if not self.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) + self.setup_precision_plugin() + + def setup_precision_plugin(self) -> None: + """Attaches the precision plugin to the accelerator.""" + model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers) + self.model = model + self.optimizers = optimizers + self.lr_schedulers = schedulers + + def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: + """Moves the state of the optimizers to the GPU if needed.""" + device = device or self.root_device + for opt in self.optimizers: + for p, v in opt.state.items(): + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device) + + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + """Returns state of an optimizer. + + Allows for syncing/collating optimizer state from processes in custom plugins. + """ + return optimizer.state_dict() + + def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: + """Forwards backward-calls to the precision plugin. + + Args: + closure_loss: a tensor holding the loss value to backpropagate + """ + self.pre_backward(closure_loss) + closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss) + + self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) + + closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss) + self.post_backward(closure_loss) + + return closure_loss + + def optimizer_step( + self, + optimizer: Optimizer, + opt_idx: int, + closure: Callable[[], Any], + model: Optional[Union["pl.LightningModule", Module]] = None, + **kwargs: Any, + ) -> None: + """performs the actual optimizer step. + + Args: + optimizer: the optimizer performing the step + opt_idx: index of the current optimizer + closure: closure calculating the loss value + model: reference to the model, optionally defining optimizer step related hooks + **kwargs: Any extra arguments to ``optimizer.step`` + """ + model = model or self.lightning_module + self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) + + def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: + """Zeros all model parameter's gradients.""" + model_ref = self.lightning_module + model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Setup a model and multiple optimizers together. @@ -220,7 +309,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: optimizer_states = checkpoint["optimizer_states"] - for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states): + for optimizer, opt_state in zip(self.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) def start_training(self, trainer: "pl.Trainer") -> None: diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index ba1166a019e6b..2e5e1c48c5785 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -584,7 +584,7 @@ def parallel_devices(self) -> List[Union[torch.device, int]]: @property def root_gpu(self) -> Optional[int]: return ( - self.accelerator.root_device.index + self.training_type_plugin.root_device.index if not isinstance(self.accelerator, (IPUAccelerator, TPUAccelerator)) else None ) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index ab0d3aa4288fa..92cad3b118006 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -382,7 +382,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: optimizer_states = [] for i, optimizer in enumerate(self.trainer.optimizers): # Rely on accelerator to dump optimizer state - optimizer_state = self.trainer.accelerator.optimizer_state(optimizer) + optimizer_state = self.trainer.training_type_plugin.optimizer_state(optimizer) optimizer_states.append(optimizer_state) checkpoint["optimizer_states"] = optimizer_states diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0376a0f745f6f..1941666888520 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -38,7 +38,15 @@ from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop -from pytorch_lightning.plugins import DDPSpawnPlugin, ParallelPlugin, PLUGIN_INPUT, PrecisionPlugin, TrainingTypePlugin +from pytorch_lightning.plugins import ( + ApexMixedPrecisionPlugin, + DDPSpawnPlugin, + NativeMixedPrecisionPlugin, + ParallelPlugin, + PLUGIN_INPUT, + PrecisionPlugin, + TrainingTypePlugin, +) from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.profiler import ( AdvancedProfiler, @@ -67,6 +75,7 @@ _IPU_AVAILABLE, _StrategyType, _TPU_AVAILABLE, + AMPType, device_parser, GradClipAlgorithmType, parsing, @@ -1646,7 +1655,7 @@ def lightning_module(self) -> "pl.LightningModule": @property def optimizers(self) -> List[Optimizer]: - return self.accelerator.optimizers + return self.training_type_plugin.optimizers @optimizers.setter def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: @@ -1655,35 +1664,39 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: # the `lightning_optimizers` trainer property self._lightning_optimizers = None - self.accelerator.optimizers = new_optims + self.training_type_plugin.optimizers = new_optims @property def lr_schedulers(self) -> List[LRSchedulerTypeUnion]: - return self.accelerator.lr_schedulers + return self.training_type_plugin.lr_schedulers @lr_schedulers.setter def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None: - self.accelerator.lr_schedulers = new_schedulers + self.training_type_plugin.lr_schedulers = new_schedulers @property def optimizer_frequencies(self) -> list: - return self.accelerator.optimizer_frequencies + return self.training_type_plugin.optimizer_frequencies @optimizer_frequencies.setter def optimizer_frequencies(self, new_freqs: list) -> None: - self.accelerator.optimizer_frequencies = new_freqs + self.training_type_plugin.optimizer_frequencies = new_freqs @property - def amp_backend(self) -> Optional[str]: - return self.accelerator.amp_backend + def amp_backend(self) -> Optional[AMPType]: + if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): + return AMPType.APEX + if isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): + return AMPType.NATIVE + return None @property def precision(self) -> Union[str, int]: return self.training_type_plugin.precision_plugin.precision @property - def scaler(self): - return self.accelerator.scaler + def scaler(self) -> Optional[Any]: + return getattr(self.precision_plugin, "scaler", None) @property def gpus(self) -> Optional[Union[List[int], str, int]]: diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 584e24bb71ed9..02e9aabd0fce1 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -21,9 +21,9 @@ from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import StochasticWeightAveraging from pytorch_lightning.plugins import DDPSpawnPlugin +from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.runif import RunIf @@ -101,7 +101,7 @@ def on_train_end(self, trainer, pl_module): if not isinstance(trainer.training_type_plugin, DDPSpawnPlugin): # check backward call count. the batchnorm update epoch should not backward - assert trainer.accelerator.backward.call_count == trainer.max_epochs * trainer.limit_train_batches + assert trainer.training_type_plugin.backward.call_count == trainer.max_epochs * trainer.limit_train_batches # check call counts assert self.update_parameters_calls == trainer.max_epochs - (self._swa_epoch_start - 1) @@ -131,7 +131,7 @@ def train_with_swa( num_processes=num_processes, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward): + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward): trainer.fit(model) # check the model is the expected diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index a732390e1d00a..ff9b9e2ddb8ce 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -17,6 +17,7 @@ import torch from torch.utils.data.dataloader import DataLoader +from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer @@ -154,8 +155,10 @@ def test_lite_optimizer_state_dict(): def test_lite_optimizer_steps(): """Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" optimizer = Mock() - accelerator = Mock() + strategy = Mock() + accelerator = Accelerator(None, strategy) lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator) lite_optimizer.step() - accelerator.optimizer_step.assert_called_once() - accelerator.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=accelerator.model) + strategy = accelerator.training_type_plugin + strategy.optimizer_step.assert_called_once() + strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=accelerator.model) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index dbbb4d9bdffa7..9620a1b608421 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -22,7 +22,7 @@ import torch.nn.functional as F from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.plugins.training_type import TrainingTypePlugin from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -128,7 +128,7 @@ def on_train_end(self): ) scaler_step = scaler_step_patch.start() - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 @@ -162,7 +162,7 @@ def training_epoch_end(self, outputs) -> None: enable_model_summary=False, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 @@ -189,7 +189,7 @@ def training_epoch_end(self, outputs) -> None: enable_model_summary=False, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 assert set(trainer.logged_metrics) == {"a_step", "a_epoch"} @@ -212,7 +212,7 @@ def test_multiple_optimizers_manual_native_amp(tmpdir): gpus=1, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 @@ -470,7 +470,7 @@ def log_grad_norm(self, grad_norm_dict): track_grad_norm=2, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 @@ -540,7 +540,7 @@ def configure_optimizers(self): log_every_n_steps=1, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 2 assert trainer.progress_bar_metrics["train_loss_step"] == model._losses[-1] @@ -596,7 +596,7 @@ def configure_optimizers(self): log_every_n_steps=1, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: + with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 2