diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ab66435a2935d..34072c5e43a61 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -14,15 +14,15 @@ """nn.Module with additional great features.""" +from abc import ABC +from argparse import Namespace import collections import copy import inspect import os +from pathlib import Path import re import tempfile -from abc import ABC -from argparse import Namespace -from pathlib import Path from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch @@ -35,9 +35,9 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO +from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args @@ -1252,9 +1252,6 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer.zero_grad() """ - if not isinstance(optimizer, LightningOptimizer): - # wraps into LightingOptimizer only for running step - optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) optimizer.step(closure=optimizer_closure) def optimizer_zero_grad( diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index c8e9ff8b80a2f..f0b361de6133e 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -103,6 +103,8 @@ def _on_trainer_init(self, trainer): @classmethod def to_lightning_optimizer(cls, optimizer, trainer): + if isinstance(optimizer, LightningOptimizer): + return optimizer optimizer = cls(optimizer) optimizer._on_trainer_init(trainer) return optimizer diff --git a/pytorch_lightning/plugins/native_amp.py b/pytorch_lightning/plugins/native_amp.py index 4df5d128476a4..3d64fe91388b8 100644 --- a/pytorch_lightning/plugins/native_amp.py +++ b/pytorch_lightning/plugins/native_amp.py @@ -16,6 +16,7 @@ import torch from torch.optim import Optimizer +from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin @@ -52,7 +53,10 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): # unscale gradient to allow analyze within `on_after_backward` if not self.trainer.train_loop.should_accumulate() and automatic_optimization: - self.trainer.scaler.unscale_(optimizer) + if isinstance(optimizer, LightningOptimizer): + self.trainer.scaler.unscale_(optimizer._optimizer) + else: + self.trainer.scaler.unscale_(optimizer) return closure_loss diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5a837956bc4ce..c66cc3a43d0b1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,8 +15,8 @@ """Trainer to automate the training.""" import os -import warnings from typing import Dict, Iterable, List, Optional, Union +import warnings import torch from torch.utils.data import DataLoader @@ -24,7 +24,6 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector -from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -47,6 +46,7 @@ from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin +from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin @@ -56,7 +56,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import rank_zero_warn, DeviceType +from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 68a0f4781c9a9..fe4525006ebb9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -26,7 +26,7 @@ from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum -from pytorch_lightning.utilities import TPU_AVAILABLE, AMPType, parsing +from pytorch_lightning.utilities import AMPType, parsing, TPU_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach @@ -489,6 +489,9 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_ 'native PyTorch amp and lbfgs are not compatible.' ' To request, please file a Github issue in PyTorch and tag @mcarilli') + # wraps into LightingOptimizer only for running step + optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) + # model hook model_ref.optimizer_step( self.trainer.current_epoch, @@ -831,6 +834,8 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): # backward can be called manually in the training loop if isinstance(result, torch.Tensor): + # scale loss under accumulate_grad_batches > 1 and manual_backward + result = self.scale_closure_loss(result) self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, *args, **kwargs) else: result.closure_loss = self.trainer.accelerator_backend.backward( @@ -975,3 +980,9 @@ def update_running_loss(self): # reset for next set of accumulated grads self.accumulated_loss.reset() + + def scale_closure_loss(self, loss: torch.Tensor) -> torch.Tensor: + model_ref = self.trainer.get_model() + if model_ref._running_manual_backward: + loss /= self.trainer.accumulate_grad_batches + return loss diff --git a/tests/base/boring_model.py b/tests/base/boring_model.py index 6ceffe8562372..6fdc3794d05f6 100644 --- a/tests/base/boring_model.py +++ b/tests/base/boring_model.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from pytorch_lightning import LightningModule from torch.utils.data import Dataset +from pytorch_lightning import LightningModule + class RandomDictDataset(Dataset): def __init__(self, size, length): diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index e3a597063d02e..01319365d9051 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -11,17 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pickle from argparse import ArgumentParser +import pickle from typing import Optional from unittest.mock import MagicMock, patch import pytest import torch -from torch.optim import SGD, Adam +from torch.optim import Adam, SGD from torch.utils.data import DataLoader, random_split -from pytorch_lightning import LightningDataModule, Trainer, seed_everything +from pytorch_lightning import LightningDataModule, seed_everything, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel @@ -75,16 +75,12 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, if batch_idx % 2 == 0: assert isinstance(optimizer, SGD) optimizer.step(closure=optimizer_closure) - if not enable_pl_optimizer: - optimizer.zero_grad() # update discriminator opt every 4 steps if optimizer_idx == 1: if batch_idx % 4 == 0: assert isinstance(optimizer, Adam) optimizer.step(closure=optimizer_closure) - if not enable_pl_optimizer: - optimizer.zero_grad() model = TestModel() model.training_epoch_end = None diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index a9fcf918cc699..530f20f86a3db 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -14,16 +14,18 @@ import os from unittest.mock import patch +import numpy as np import pytest import torch import torch.nn as nn from torch.optim import Adam, Optimizer import pytorch_lightning as pl -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_utils import is_overridden from tests.base.boring_model import BoringModel, RandomDataset, RandomDictDataset, RandomDictStringDataset diff --git a/tests/trainer/optimization/test_parity_automatic_optimization.py b/tests/trainer/optimization/test_parity_automatic_optimization.py new file mode 100644 index 0000000000000..4a1d6c384cd52 --- /dev/null +++ b/tests/trainer/optimization/test_parity_automatic_optimization.py @@ -0,0 +1,371 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import Callable +from copy import deepcopy +from typing import Optional +from unittest.mock import patch + +import numpy as np +import pytest +import torch +from torch.optim import Optimizer + +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.core.optimizer import LightningOptimizer +from tests.base.boring_model import BoringModel + +# TODO: +# For both automatic / manual optimization +# - Test dp, ddp, ddp2 +# - Apex +# - Random accumulated_grad_batches (bug) +# - Multiple optimizers + + +class BaseParityAutomaticOptimizationModel(BoringModel): + + def __init__(self, optimizer_cls, optimizer_is_mocked=False, accumulate_grad_batches=None): + super().__init__() + self.optimizer_cls = optimizer_cls + self.losses = [] + self.grads = [] + self.on_before_zero_grad_count = 0 + self.optimizer_is_mocked = optimizer_is_mocked + self.grad_checked = False + self.accumulate_grad_batches = accumulate_grad_batches + + def on_before_zero_grad(self, optimizer): + self.on_before_zero_grad_count += 1 + if self.layer.weight.grad is not None: + self.grads.append(self.layer.weight.grad.clone()) + + def configure_optimizers(self): + optimizer = self.optimizer_cls(self.layer.parameters(), lr=0.1) + assert isinstance(optimizer, Optimizer) + return optimizer + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.losses.append(loss.detach().item()) + return {"loss": loss} + + +class AutomaticOptimizationPurePytorchOptimizerModel(BaseParityAutomaticOptimizationModel): + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.losses.append(loss.detach().item()) + loss /= float(self.accumulate_grad_batches) + return {"loss": loss} + + def optimizer_step( + self, + epoch: int = None, + batch_idx: int = None, + optimizer: Optimizer = None, + optimizer_idx: int = None, + optimizer_closure: Optional[Callable] = None, + on_tpu: bool = None, + using_native_amp: bool = None, + using_lbfgs: bool = None, + ) -> None: + """ + Override the optimizer step to define manual optimizer steps, as we use LightningOptimizer wrapper as standard + """ + # Get the unwrapped optimizer + optimizer = optimizer._optimizer + assert not isinstance(optimizer, LightningOptimizer) + + optimizer_closure() + assert self.trainer.accumulate_grad_batches == 1 + + if should_accumulate(self.trainer, self.accumulate_grad_batches): + return + + self.grad_checked = True + assert torch.abs(self.layer.weight.grad).sum() > 0 + optimizer.step() + + self.on_before_zero_grad_count += 1 + optimizer.zero_grad() + + if not self.optimizer_is_mocked: + assert torch.abs(self.layer.weight.grad).sum() == 0 + + +class AutomaticOptimizationPurePytorchAMPOptimizerModel(BaseParityAutomaticOptimizationModel): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.scaler = torch.cuda.amp.GradScaler() + + def training_step(self, batch, batch_idx): + with torch.cuda.amp.autocast(): + output = self.layer(batch) + loss = self.loss(batch, output) + self.losses.append(loss.detach().item()) + loss /= float(self.accumulate_grad_batches) + loss = self.scaler.scale(loss) + return {"loss": loss} + + def optimizer_step( + self, + epoch: int = None, + batch_idx: int = None, + optimizer: Optimizer = None, + optimizer_idx: int = None, + optimizer_closure: Optional[Callable] = None, + on_tpu: bool = None, + using_native_amp: bool = None, + using_lbfgs: bool = None, + ) -> None: + """ + Override the optimizer step to define manual optimizer steps, as we use LightningOptimizer wrapper as standard + """ + # Get the unwrapped optimizer + optimizer = optimizer._optimizer + assert not isinstance(optimizer, LightningOptimizer) + + optimizer_closure() + assert self.trainer.accumulate_grad_batches == 1 + + if should_accumulate(self.trainer, self.accumulate_grad_batches): + return + + self.scaler.unscale_(optimizer) + self.grad_checked = True + assert torch.abs(self.layer.weight.grad).sum() > 0 + self.scaler.step(optimizer) + self.scaler.update() + self.on_before_zero_grad_count += 1 + optimizer.zero_grad() + if not self.optimizer_is_mocked: + assert torch.abs(self.layer.weight.grad).sum() == 0 + + +def should_accumulate(trainer, accumulate_grad_batches): + accumulation_done = (trainer.batch_idx + 1) == trainer.num_training_batches + is_final_batch = (trainer.batch_idx + 1) % accumulate_grad_batches == 0 + return not (accumulation_done or is_final_batch) + + +@pytest.mark.parametrize(["precision", "amp_backend", "gpus"], [ + pytest.param(32, "native", 0), + pytest.param(16, "native", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason='Requires GPU')), +]) +@pytest.mark.parametrize('accumulate_grad_batches', [1, 7]) +def test_lightning_optimizer_and_no_lightning_optimizer_equality( + tmpdir, + precision, + amp_backend, + gpus, + accumulate_grad_batches, +): + + if accumulate_grad_batches > 1: + accumulate_grad_batches = np.random.randint(1, accumulate_grad_batches) + + vanilla_model_cls = AutomaticOptimizationPurePytorchAMPOptimizerModel if precision == 16 \ + else AutomaticOptimizationPurePytorchOptimizerModel + + run_lightning_optimizer_equality( + BaseParityAutomaticOptimizationModel, + vanilla_model_cls, + precision=precision, + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=5, + accumulate_grad_batches=accumulate_grad_batches, + amp_backend=amp_backend, + gpus=gpus + ) + + +@pytest.mark.parametrize(["precision", "amp_backend", "gpus"], [ + pytest.param(32, "native", 0), +]) +@pytest.mark.parametrize('accumulate_grad_batches', [1]) +def test_lightning_optimizer_and_no_lightning_optimizer_equality_check_optim_calls( + tmpdir, + precision, + amp_backend, + gpus, + accumulate_grad_batches, +): + + vanilla_model_cls = AutomaticOptimizationPurePytorchAMPOptimizerModel if precision == 16 \ + else AutomaticOptimizationPurePytorchOptimizerModel + + with patch("torch.optim.SGD.step") as mock_sgd_step, \ + patch("torch.optim.Adam.step") as mock_adam_step, \ + patch("torch.optim.AdamW.step") as mock_adamw_step, \ + patch("torch.optim.SGD.zero_grad") as mock_sgd_zero_grad, \ + patch("torch.optim.Adam.zero_grad") as mock_adam_zero_grad, \ + patch("torch.optim.AdamW.zero_grad") as mock_adamw_zero_grad: + + max_epochs = 2 + limit_train_batches = 10 + + # Run equality test using Lightning Optimizer + run_lightning_optimizer_equality( + BaseParityAutomaticOptimizationModel, + vanilla_model_cls, + default_root_dir=tmpdir, + optimizer_is_mocked=True, + accumulate_grad_batches=accumulate_grad_batches, + max_epochs=max_epochs, + limit_train_batches=limit_train_batches, + amp_backend=amp_backend, + precision=precision, + gpus=gpus + ) + + expected_num_batches = max_epochs * limit_train_batches + assert mock_sgd_step.call_count == (expected_num_batches // accumulate_grad_batches) + assert mock_sgd_zero_grad.call_count == (expected_num_batches // accumulate_grad_batches) + assert mock_sgd_step.call_count == mock_adam_step.call_count + assert mock_sgd_step.call_count == mock_adam_step.call_count + assert mock_sgd_zero_grad.call_count == mock_adam_zero_grad.call_count + assert mock_sgd_zero_grad.call_count == mock_adamw_zero_grad.call_count + + +def run_lightning_optimizer_equality( + lightning_model_cls, + vanilla_model_cls, + optimizer_is_mocked=False, + **trainer_kwargs, +): + + trainer_kwargs = { + "limit_val_batches": 0, + **trainer_kwargs + } + expected_num_batches = trainer_kwargs["max_epochs"] * trainer_kwargs["limit_train_batches"] + accumulate_grad_batches = trainer_kwargs["accumulate_grad_batches"] + + pl_optimizer_initial_model_weights, pl_optimizer_model = train_specific_optimizer_model( + lightning_model_cls, + torch.optim.SGD, + expected_num_batches=expected_num_batches, + optimizer_is_mocked=optimizer_is_mocked, + enable_pl_optimizer=True, + **trainer_kwargs, + ) + + no_pl_optimizer_initial_model_weights, no_pl_optimizer_model = train_specific_optimizer_model( + lightning_model_cls, + torch.optim.Adam if optimizer_is_mocked else torch.optim.SGD, + expected_num_batches=expected_num_batches, + optimizer_is_mocked=optimizer_is_mocked, + enable_pl_optimizer=False, # Disable pl optimizer + **trainer_kwargs, + ) + + pure_pytorch_optimizer_initial_model_weights, pure_pytorch_optimizer_model = train_specific_optimizer_model( + vanilla_model_cls, + torch.optim.AdamW if optimizer_is_mocked else torch.optim.SGD, + expected_num_batches=expected_num_batches, + optimizer_is_mocked=optimizer_is_mocked, + replace_optimizer_step_with_pure_pytorch=True, + **trainer_kwargs, + ) + + if not optimizer_is_mocked: + + assert_model_equality( + pl_optimizer_initial_model_weights=pl_optimizer_initial_model_weights, + pl_optimizer_model=pl_optimizer_model, + no_pl_optimizer_initial_model_weights=no_pl_optimizer_initial_model_weights, + no_pl_optimizer_model=no_pl_optimizer_model, + pure_pytorch_optimizer_initial_model_weights=pure_pytorch_optimizer_initial_model_weights, + pure_pytorch_optimizer_model=pure_pytorch_optimizer_model, + expected_num_batches=expected_num_batches, + precision=trainer_kwargs["precision"] + ) + + +def assert_model_equality( + pl_optimizer_initial_model_weights, + pl_optimizer_model, + no_pl_optimizer_initial_model_weights, + no_pl_optimizer_model, + pure_pytorch_optimizer_initial_model_weights, + pure_pytorch_optimizer_model, + expected_num_batches, + precision, +): + + assert torch.equal(pl_optimizer_initial_model_weights, no_pl_optimizer_initial_model_weights) + assert torch.equal(pl_optimizer_initial_model_weights, pure_pytorch_optimizer_initial_model_weights) + assert len(pl_optimizer_model.losses) == expected_num_batches + assert pure_pytorch_optimizer_model.grad_checked + assert pure_pytorch_optimizer_model.losses == no_pl_optimizer_model.losses + assert not torch.isnan(torch.FloatTensor(no_pl_optimizer_model.losses)).any() + + assert torch.equal(torch.FloatTensor(no_pl_optimizer_model.losses), torch.FloatTensor(pl_optimizer_model.losses)) + assert no_pl_optimizer_model.on_before_zero_grad_count == pl_optimizer_model.on_before_zero_grad_count + + for pytorch_grad, no_pl_optim_grad, pl_optim_grad in zip(pure_pytorch_optimizer_model.grads, + no_pl_optimizer_model.grads, + pl_optimizer_model.grads): + assert torch.equal(no_pl_optim_grad, pl_optim_grad), 'Grad parameters are different' + assert torch.equal(pytorch_grad, no_pl_optim_grad), 'Grad parameters are different' + + for pytorch_weight, no_pl_optim_weight, pl_optim_weight in zip(pure_pytorch_optimizer_model.parameters(), + no_pl_optimizer_model.parameters(), + pl_optimizer_model.parameters()): + assert torch.equal(no_pl_optim_weight, pl_optim_weight), 'Model parameters are different' + assert torch.equal(pytorch_weight, no_pl_optim_weight), 'Model parameters are different' + + +# train function +def train_specific_optimizer_model( + model_cls, + optimizer_cls, + expected_num_batches, + enable_pl_optimizer=False, + optimizer_is_mocked=False, + replace_optimizer_step_with_pure_pytorch=False, + **trainer_kwargs, +): + + seed_everything(42) + trainer_kwargs = deepcopy(trainer_kwargs) + + model = model_cls( + optimizer_cls=optimizer_cls, + optimizer_is_mocked=optimizer_is_mocked, + accumulate_grad_batches=trainer_kwargs["accumulate_grad_batches"], + ) + + if replace_optimizer_step_with_pure_pytorch: + # When running pure vanilla training, accumulate_grad_batches should be 1. + trainer_kwargs["accumulate_grad_batches"] = 1 + trainer_kwargs["precision"] = 32 + + expected_global_step = expected_num_batches // trainer_kwargs["accumulate_grad_batches"] + + initial_weights = model.layer.weight.clone() + model.training_epoch_end = None + + trainer = Trainer( + enable_pl_optimizer=enable_pl_optimizer, + **trainer_kwargs + ) + trainer.fit(model) + + assert np.abs(trainer.global_step - expected_global_step) <= 2 + return initial_weights, model diff --git a/tests/trainer/optimization/test_parity_manual_optimization.py b/tests/trainer/optimization/test_parity_manual_optimization.py new file mode 100644 index 0000000000000..5d110b2fbdca7 --- /dev/null +++ b/tests/trainer/optimization/test_parity_manual_optimization.py @@ -0,0 +1,211 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import Callable +from copy import deepcopy +from typing import Optional +from unittest.mock import patch + +import numpy as np +import pytest +import torch +from torch.optim import Optimizer + +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.core.optimizer import LightningOptimizer +from tests.base.boring_model import BoringModel +from tests.trainer.optimization.test_parity_automatic_optimization import ( + assert_model_equality, + run_lightning_optimizer_equality, + should_accumulate, +) + +""" +TODO: +For both Manual / manual optimization + - Test dp, ddp, ddp2 + - Apex + - Random accumulated_grad_batches (bug) + - Multiple optimizers +""" + + +class BaseParityManualOptimizationModel(BoringModel): + + def __init__(self, optimizer_cls, optimizer_is_mocked=False, accumulate_grad_batches=None): + super().__init__() + self.optimizer_cls = optimizer_cls + self.losses = [] + self.grads = [] + self.on_before_zero_grad_count = 0 + self.optimizer_is_mocked = optimizer_is_mocked + self.grad_checked = False + self.accumulate_grad_batches = accumulate_grad_batches + + def on_before_zero_grad(self, optimizer): + self.on_before_zero_grad_count += 1 + if self.layer.weight.grad is not None: + self.grads.append(self.layer.weight.grad.clone()) + + def configure_optimizers(self): + optimizer = self.optimizer_cls(self.layer.parameters(), lr=0.1) + assert isinstance(optimizer, Optimizer) + return optimizer + + def training_step(self, batch, batch_idx): + opt = self.optimizers() + if not isinstance(opt, LightningOptimizer): + opt = LightningOptimizer.to_lightning_optimizer(opt, self.trainer) + output = self.layer(batch) + loss = self.loss(batch, output) + self.losses.append(loss.detach().item()) + self.manual_backward(loss, opt) + opt.step() + + +class ManualOptimizationPurePytorchOptimizerModel(BaseParityManualOptimizationModel): + + def training_step(self, batch, batch_idx): + optimizer = self.optimizers() + output = self.layer(batch) + loss = self.loss(batch, output) + self.losses.append(loss.detach().item()) + loss /= float(self.accumulate_grad_batches) + loss.backward() + + if should_accumulate(self.trainer, self.accumulate_grad_batches): + return + + self.grad_checked = True + assert torch.abs(self.layer.weight.grad).sum() > 0 + optimizer.step() + + self.on_before_zero_grad_count += 1 + optimizer.zero_grad() + + if not self.optimizer_is_mocked: + assert torch.abs(self.layer.weight.grad).sum() == 0 + + +class ManualOptimizationPurePytorchAMPOptimizerModel(BaseParityManualOptimizationModel): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.scaler = torch.cuda.amp.GradScaler() + + def training_step(self, batch, batch_idx): + optimizer = self.optimizers() + with torch.cuda.amp.autocast(): + output = self.layer(batch) + loss = self.loss(batch, output) + self.losses.append(loss.detach().item()) + loss /= float(self.accumulate_grad_batches) + loss = self.scaler.scale(loss) + loss.backward() + + if should_accumulate(self.trainer, self.accumulate_grad_batches): + return + + self.scaler.unscale_(optimizer) + self.grad_checked = True + + assert torch.abs(self.layer.weight.grad).sum() > 0 + self.scaler.step(optimizer) + self.scaler.update() + self.on_before_zero_grad_count += 1 + optimizer.zero_grad() + + if not self.optimizer_is_mocked: + assert torch.abs(self.layer.weight.grad).sum() == 0 + + +@pytest.mark.parametrize(["precision", "amp_backend", "gpus"], [ + pytest.param(32, "native", 0), + pytest.param(16, "native", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason='Requires GPU')), +]) +@pytest.mark.parametrize('accumulate_grad_batches', [1, 7]) +def test_lightning_optimizer_and_no_lightning_optimizer_equality( + tmpdir, + precision, + amp_backend, + gpus, + accumulate_grad_batches): + + if accumulate_grad_batches > 1: + accumulate_grad_batches = np.random.randint(1, accumulate_grad_batches) + + vanilla_model_cls = ManualOptimizationPurePytorchAMPOptimizerModel if precision == 16 \ + else ManualOptimizationPurePytorchOptimizerModel + + run_lightning_optimizer_equality( + BaseParityManualOptimizationModel, + vanilla_model_cls, + precision=precision, + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=5, + accumulate_grad_batches=accumulate_grad_batches, + amp_backend=amp_backend, + gpus=gpus, + automatic_optimization=False + ) + + +@pytest.mark.parametrize(["precision", "amp_backend", "gpus"], [ + pytest.param(32, "native", 0), +]) +@pytest.mark.parametrize('accumulate_grad_batches', [1]) +def test_lightning_optimizer_and_no_lightning_optimizer_equality_check_optim_calls( + tmpdir, + precision, + amp_backend, + gpus, + accumulate_grad_batches, +): + + vanilla_model_cls = ManualOptimizationPurePytorchAMPOptimizerModel if precision == 16 \ + else ManualOptimizationPurePytorchOptimizerModel + + with patch("torch.optim.SGD.step") as mock_sgd_step, \ + patch("torch.optim.Adam.step") as mock_adam_step, \ + patch("torch.optim.AdamW.step") as mock_adamw_step, \ + patch("torch.optim.SGD.zero_grad") as mock_sgd_zero_grad, \ + patch("torch.optim.Adam.zero_grad") as mock_adam_zero_grad, \ + patch("torch.optim.AdamW.zero_grad") as mock_adamw_zero_grad: + + max_epochs = 2 + limit_train_batches = 10 + + # Run equality test using Lightning Optimizer + + run_lightning_optimizer_equality( + BaseParityManualOptimizationModel, + vanilla_model_cls, + default_root_dir=tmpdir, + optimizer_is_mocked=True, + accumulate_grad_batches=accumulate_grad_batches, + max_epochs=max_epochs, + limit_train_batches=limit_train_batches, + amp_backend=amp_backend, + precision=precision, + gpus=gpus, + automatic_optimization=False + ) + + expected_num_batches = max_epochs * limit_train_batches + assert mock_sgd_step.call_count == (expected_num_batches // accumulate_grad_batches) + assert mock_sgd_zero_grad.call_count == (expected_num_batches // accumulate_grad_batches) + assert mock_sgd_step.call_count == mock_adam_step.call_count + assert mock_sgd_step.call_count == mock_adam_step.call_count + assert mock_sgd_zero_grad.call_count == mock_adam_zero_grad.call_count + assert mock_sgd_zero_grad.call_count == mock_adamw_zero_grad.call_count