Skip to content

Commit

Permalink
Remove warning on no_backward_sync with XLA strategy (#17761)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Jun 7, 2023
1 parent 420eb6f commit f3c49b8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed computing the next version folder in `CSVLogger` ([#17139](https://github.com/Lightning-AI/lightning/pull/17139))


- Removed false positive warning when using `fabric.no_backward_sync` with XLA strategies ([#17761](https://github.com/Lightning-AI/lightning/pull/17761))


## [2.0.1.post0] - 2023-04-11

No changes
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Gener
"You need to set up the model first before you can call `self.no_backward_sync()`:"
" `model = self.setup(model, ...)`"
)
if not enabled or isinstance(self._strategy, SingleDeviceStrategy):
if not enabled or isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)):
context = nullcontext()
elif self._strategy._backward_sync_control is None:
rank_zero_warn(
Expand Down
7 changes: 6 additions & 1 deletion tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,14 +637,19 @@ def test_no_backward_sync():
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
# same for XLA
fabric._strategy = Mock(spec=XLAStrategy, _backward_sync_control=MagicMock())
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()

# pretend that the strategy supports skipping backward sync
fabric._strategy = Mock(_backward_sync_control=MagicMock())
# disabling the context manager makes it a no-op
with fabric.no_backward_sync(model, enabled=False):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
# when enabld, the wrapped module gets passed down
# when enabled, the wrapped module gets passed down
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module)
Expand Down

0 comments on commit f3c49b8

Please sign in to comment.