Skip to content

Commit

Permalink
mark OptimizerLoop.backward method protected (#9514)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
awaelchli and carmocca authored Sep 15, 2021
1 parent 23450e2 commit 200ed9e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265))
* Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266))
* Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437), [#9424](https://github.com/PyTorchLightning/pytorch-lightning/pull/9424))
* Marked `OptimizerLoop.backward` as protected ([#9514](https://github.com/PyTorchLightning/pytorch-lightning/pull/9514))


- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def on_run_end(self) -> _OUTPUTS_TYPE:
outputs, self.outputs = self.outputs, [] # free memory
return outputs

def backward(
def _backward(
self, loss: Tensor, optimizer: torch.optim.Optimizer, opt_idx: int, *args: Any, **kwargs: Any
) -> Tensor:
"""Performs the backward step.
Expand Down Expand Up @@ -337,7 +337,7 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call
return None

def backward_fn(loss: Tensor) -> Tensor:
self.backward(loss, optimizer, opt_idx)
self._backward(loss, optimizer, opt_idx)

# check if model weights are nan
if self.trainer.terminate_on_nan:
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ def test_gradient_clipping_by_norm(tmpdir, precision):
gradient_clip_val=1.0,
)

old_backward = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.backward
old_backward = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._backward

def backward(*args, **kwargs):
# test that gradient is clipped correctly
Expand All @@ -971,7 +971,7 @@ def backward(*args, **kwargs):
assert (grad_norm - 1.0).abs() < 0.01, f"Gradient norm != 1.0: {grad_norm}"
return ret_val

trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.backward = backward
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._backward = backward
trainer.fit(model)


Expand All @@ -996,7 +996,7 @@ def test_gradient_clipping_by_value(tmpdir, precision):
default_root_dir=tmpdir,
)

old_backward = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.backward
old_backward = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._backward

def backward(*args, **kwargs):
# test that gradient is clipped correctly
Expand All @@ -1009,7 +1009,7 @@ def backward(*args, **kwargs):
), f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ."
return ret_val

trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.backward = backward
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._backward = backward
trainer.fit(model)


Expand Down

0 comments on commit 200ed9e

Please sign in to comment.