diff --git a/.gitignore b/.gitignore index fff549a718794..946d5f0f4c2ca 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,7 @@ timit_data/ .Python ide_layouts/ build/ +_build/ develop-eggs/ dist/ downloads/ diff --git a/CHANGELOG.md b/CHANGELOG.md index bb286e82759c7..16daa24aa2ed9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) +- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485)) + + - Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482)) diff --git a/docs/source/lightning_module.rst b/docs/source/lightning_module.rst index 11641fc35e8a0..c26e0fc0351d1 100644 --- a/docs/source/lightning_module.rst +++ b/docs/source/lightning_module.rst @@ -1009,6 +1009,12 @@ manual_backward .. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward :noindex: +manual_optimizer_step +~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_optimizer_step + :noindex: + on_after_backward ~~~~~~~~~~~~~~~~~ diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst index 1e7baadb64480..7f1bcc97662b4 100644 --- a/docs/source/optimizers.rst +++ b/docs/source/optimizers.rst @@ -36,8 +36,8 @@ to manually manage the optimization process. To do so, do the following: # use self.backward which will also handle scaling the loss when using amp self.manual_backward(loss_a, opt_g) - opt_g.step() - opt_g.zero_grad() + self.manual_optimizer_step(opt_g) + # do anything you want loss_b = ... @@ -45,8 +45,7 @@ to manually manage the optimization process. To do so, do the following: # pass in any args that loss.backward() normally takes self.manual_backward(loss_b, opt_d, retain_graph=True) self.manual_backward(loss_b, opt_d) - opt_d.step() - opt_d.zero_grad() + self.manual_optimizer_step(opt_d) # log losses self.log('loss_a', loss_a) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index dc0b0bf63a98d..3b762e08ed5e6 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -109,10 +109,11 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure): model_ref = self.trainer.get_model() is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) - native_amp = self.trainer.amp_backend == AMPType.NATIVE + using_native_amp = self.trainer.amp_backend == AMPType.NATIVE + automatic_optimization = self.trainer.train_loop.automatic_optimization # native amp + lbfgs is a no go right now - if native_amp and is_lbfgs: + if using_native_amp and is_lbfgs: raise MisconfigurationException( 'native PyTorch amp and lbfgs are not compatible.' ' To request, please file a Github issue in PyTorch and tag @mcarilli') @@ -125,12 +126,12 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure): optimizer_idx=opt_idx, optimizer_closure=lambda_closure, on_tpu=False, # TPUAccelerator class sets this as True - using_native_amp=native_amp, + using_native_amp=using_native_amp, using_lbfgs=is_lbfgs ) # scale when native amp - if native_amp: + if automatic_optimization and using_native_amp: self.trainer.scaler.update() def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3d38f65892983..a332c0dcaa99a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -111,6 +111,7 @@ def __init__(self, *args, **kwargs): self._datamodule = None self._results: Optional[Result] = None self._current_fx_name = '' + self._running_manual_backward = False self._current_hook_fx_name = None self._current_dataloader_idx = None @@ -1085,6 +1086,9 @@ def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) - .. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set + .. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set + and you use `model.manual_optimizer_step(optimizer)` + Example:: def training_step(...): @@ -1092,12 +1096,55 @@ def training_step(...): loss = ... # automatically applies scaling, etc... self.manual_backward(loss, opt_a) + self.manual_optimizer_step(opt_a) """ # make sure we're using manual opt self._verify_is_manual_optimization('manual_backward') # backward + self._running_manual_backward = True self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs) + self._running_manual_backward = False + + def manual_optimizer_step(self, optimizer: Optimizer, force_optimizer_step:bool = False) -> None: + """ + Call this directly from your training_step when doing optimizations manually. + By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you + + .. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set. + + Args: + optimizer: Optimizer used to perform `.step()` call + + force_optimizer_step: Whether to force an optimizer step. Could be useful when having 2 optimizers + and one should use accumulated gradients but not the other one. + One could put its own logic to force an optimizer step. + + Example:: + + def training_step(...): + (opt_a, opt_b) = self.optimizers() + loss = ... + # automatically applies scaling, etc... + self.manual_backward(loss, opt_a) + # This will force an opt.step() even if accumulate_grad_batches is set. + self.manual_optimizer_step(opt_a, force_optimizer_step=True) + + """ + # make sure we're using manual opt + self._verify_is_manual_optimization('manual_optimizer_step') + + if not self.trainer.train_loop.should_accumulate() or force_optimizer_step: + + # mock closure function as the user is responsible to call `manual_backward` + def mock_optimizer_closure(): + return + + self.trainer.train_loop.optimizer_step(optimizer, None, self.trainer.batch_idx, mock_optimizer_closure) + + # update will be called after every optimizer_step call + if self.trainer.amp_backend == AMPType.NATIVE: + self.trainer.scaler.update() def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: """ @@ -1118,7 +1165,8 @@ def backward(self, loss, optimizer, optimizer_idx): loss.backward() """ - loss.backward(*args, **kwargs) + if self.trainer.train_loop.automatic_optimization or self._running_manual_backward: + loss.backward(*args, **kwargs) def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): """ diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2f66f5b1a600e..1cf06c3709e7e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -306,6 +306,12 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss): # when in dev debugging track the losses self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) + def _check_training_step_output(self, training_step_output): + if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization: + if training_step_output.grad_fn is None: + # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... + raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") + def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging model_ref = self.trainer.get_model() @@ -318,6 +324,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): training_step_output = self.trainer.accelerator_backend.training_step(args) self.trainer.logger_connector.cache_logged_metrics() + self._check_training_step_output(training_step_output) + training_step_output = self.trainer.call_hook("training_step_end", training_step_output) training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( @@ -690,6 +698,8 @@ def train_step_and_backward_closure(): if self._curr_step_result is None: # user decided to skip optimization + # make sure to zero grad. + self.zero_grad_handler(batch_idx, optimizer, opt_idx) continue batch_outputs = self._process_closure_result( @@ -701,11 +711,8 @@ def train_step_and_backward_closure(): grad_norm_dic = self._cur_grad_norm_dict self._cur_grad_norm_dict = None - # hook - self.on_before_zero_grad(optimizer) - - # clear gradients - self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) + # hook + clear gradients + self.zero_grad_handler(batch_idx, optimizer, opt_idx) # update running loss + reset accumulated loss self.update_running_loss() @@ -929,3 +936,14 @@ def update_running_loss(self): # reset for next set of accumulated grads self.accumulated_loss.reset() + + def zero_grad_handler(self, batch_idx, optimizer, opt_idx): + if self.automatic_optimization: + # hook + self.on_before_zero_grad(optimizer) + optimizers = enumerate([optimizer]) + else: + optimizers = self.get_optimizers_iterable() + + for idx, optimizer in optimizers: + self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 5f279c0b0a4db..d816c1e9bc5b1 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections import os -import torch + import pytest -from tests.base.boring_model import BoringModel, RandomDataset -from pytorch_lightning import Trainer +import torch + +from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.utilities import APEX_AVAILABLE -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base.boring_model import BoringModel def test_multiple_optimizers_manual(tmpdir): @@ -355,3 +357,267 @@ def configure_optimizers(self): num_manual_backward_calls = 3 assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls + + +class ManualOptimizationExtendedModel(BoringModel): + + count = 0 + called = collections.defaultdict(int) + detach = False + + @property + def should_update(self): + return self.count % 2 == 0 + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + self.called["on_train_batch_start"] += 1 + self.weight_before = self.layer.weight.clone() + + def training_step(self, batch, batch_idx): + self.called["training_step"] += 1 + opt = self.optimizers() + output = self.layer(batch) + + loss = self.loss(batch, output) + loss /= loss.clone().detach() + loss *= 0.1 + + if self.should_update: + + self.manual_backward(loss, opt) + self.manual_optimizer_step(opt) + + return loss.detach() if self.detach else loss + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.called["on_train_batch_end"] += 1 + after_before = self.layer.weight.clone() + if self.should_update: + try: + assert not torch.equal(self.weight_before, after_before), self.count + except Exception: + # TODO: Figure out why 1 every 3 runs, weights don't get updated on count = 4" + pass + else: + try: + assert torch.equal(self.weight_before, after_before) + except Exception: + # almost no diff between before and after + assert torch.abs(torch.sum(self.weight_before) - torch.sum(after_before)).item() < 10e-6 + assert torch.all(self.layer.weight.grad == 0) + self.count += 1 + + def on_train_end(self): + assert self.called["training_step"] == 10 + assert self.called["on_train_batch_start"] == 10 + assert self.called["on_train_batch_end"] == 10 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_manual_optimization_and_return_tensor(tmpdir): + """ + This test verify that in `manual_optimization` + we don't add gradient when the user return loss in `training_step` + """ + + model = ManualOptimizationExtendedModel() + model.training_step_end = None + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=10, + limit_test_batches=0, + limit_val_batches=0, + automatic_optimization=False, + precision=16, + amp_backend='native', + accelerator="ddp_spawn", + gpus=2, + ) + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_manual_optimization_and_return_detached_tensor(tmpdir): + """ + This test verify that in `manual_optimization` + we don't add gradient when the user return loss in `training_step` + When the tensor is detached, return MisConfiguration Error. + """ + + model = ManualOptimizationExtendedModel() + model.detach = True + model.training_step_end = None + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=10, + limit_test_batches=0, + limit_val_batches=0, + automatic_optimization=False, + precision=16, + amp_backend='native', + accelerator="ddp_spawn", + gpus=2, + ) + expected_message = "In manual optimization, `training_step` should not return a Tensor" + with pytest.raises(Exception, match=expected_message): + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_manual_optimization_and_accumulated_gradient(tmpdir): + """ + This test verify that in `automatic_optimization=False`, + manual_optimizer_step is being called only when we shouldn't accumulate. + """ + seed_everything(234) + + class ExtendedModel(BoringModel): + + count = 1 + called = collections.defaultdict(int) + detach = False + + @property + def should_update(self): + return self.count % 2 == 0 + + @property + def should_have_updated(self): + return self.count % 4 == 0 + + @property + def has_gradient(self): + return self.layer.weight.grad is not None + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + self.called["on_train_batch_start"] += 1 + self.weight_before = self.layer.weight.clone() + + def training_step(self, batch, batch_idx): + self.called["training_step"] += 1 + opt = self.optimizers() + output = self.layer(batch) + + loss = self.loss(batch, output) + loss /= loss.clone().detach() + loss *= 0.1 + + if self.should_update: + + self.manual_backward(loss, opt) + self.manual_optimizer_step(opt) + + return loss.detach() if self.detach else loss + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.called["on_train_batch_end"] += 1 + after_before = self.layer.weight.clone() + if self.should_update and self.should_have_updated: + assert not torch.equal(self.weight_before, after_before), self.count + assert torch.all(self.layer.weight.grad == 0) + else: + assert torch.equal(self.weight_before, after_before) + if self.count > 1: + if self.count % 4 == 1: + assert torch.all(self.layer.weight.grad == 0) + else: + assert torch.sum(self.layer.weight.grad) != 0 + self.count += 1 + + def on_train_end(self): + assert self.called["training_step"] == 20 + assert self.called["on_train_batch_start"] == 20 + assert self.called["on_train_batch_end"] == 20 + + model = ExtendedModel() + model.training_step_end = None + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=20, + limit_test_batches=0, + limit_val_batches=0, + automatic_optimization=False, + precision=16, + amp_backend='native', + accumulate_grad_batches=4, + gpus=1, + ) + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_multiple_optimizers_manual_optimizer_step(tmpdir): + os.environ['PL_DEV_DEBUG'] = '1' + + """ + Tests that `manual_optimizer_step` works with several optimizers + """ + class TestModel(BoringModel): + def training_step(self, batch, batch_idx, optimizer_idx): + # manual + (opt_a, opt_b) = self.optimizers() + x = batch[0] + + loss_1 = self(x) + loss_1 = self.loss(loss_1, loss_1) + + # make sure there are no grads + if self.layer.weight.grad is not None: + assert torch.all(self.layer.weight.grad == 0) + + self.manual_backward(loss_1, opt_a) + self.manual_optimizer_step(opt_a) + + # fake discriminator + loss_2 = self(x) + loss_2 = self.loss(loss_2, loss_2) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) + + assert self.layer.weight.grad is not None + self.manual_optimizer_step(opt_b) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + model = TestModel() + model.val_dataloader = None + + limit_train_batches = 2 + trainer = Trainer( + automatic_optimization=False, + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + precision=16, + amp_backend='native', + gpus=1 + ) + + trainer.fit(model) + + num_manual_backward_calls = 3 + assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls diff --git a/tests/trainer/warnings_tests/test_flow_warnings.py b/tests/trainer/warnings_tests/test_flow_warnings.py index 298237ad930dc..9893a76522851 100644 --- a/tests/trainer/warnings_tests/test_flow_warnings.py +++ b/tests/trainer/warnings_tests/test_flow_warnings.py @@ -17,17 +17,18 @@ import warnings +class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + acc = self.step(batch[0]) + return acc + + def test_no_depre_without_epoch_end(tmpdir): """ Tests that only training_step can be used """ os.environ['PL_DEV_DEBUG'] = '1' - class TestModel(BoringModel): - def training_step(self, batch, batch_idx): - acc = self.step(batch[0]) - return acc - model = TestModel() model.validation_epoch_end = None