Skip to content

Commit

Permalink
Explicitly enable grad in closure (#18268)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
(cherry picked from commit b88b8b3)
  • Loading branch information
0x404 authored and lexierule committed Aug 14, 2023
1 parent cfefd09 commit 28e0bdd
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning ([#18177](https://github.com/Lightning-AI/lightning/pull/18177))


- Ensure that the closure running inside the optimizer step has gradients enabled, even if the optimizer step has it disabled ([#18268](https://github.com/Lightning-AI/lightning/pull/18268))


## [2.0.5] - 2023-07-07

### Fixed
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/loops/optimization/automatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
self._backward_fn = backward_fn
self._zero_grad_fn = zero_grad_fn

@torch.enable_grad()
def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
step_output = self._step_fn()

Expand Down
28 changes: 28 additions & 0 deletions tests/tests_pytorch/loops/optimization/test_closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,31 @@ def step(self, closure=None):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
with pytest.raises(MisconfigurationException, match="The closure hasn't been executed"):
trainer.fit(model)


def test_closure_with_no_grad_optimizer(tmpdir):
"""Test that the closure is guaranteed to run with grad enabled.
There are certain third-party library optimizers
(such as Hugging Face Transformers' AdamW) that set `no_grad` during the `step` operation.
"""

class NoGradAdamW(torch.optim.AdamW):
@torch.no_grad()
def step(self, closure):
if closure is not None:
closure()
return super().step()

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert torch.is_grad_enabled()
return super().training_step(batch, batch_idx)

def configure_optimizers(self):
return NoGradAdamW(self.parameters(), lr=0.1)

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
model = TestModel()
trainer.fit(model)

0 comments on commit 28e0bdd

Please sign in to comment.