Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable precision autocast for LightningModule step methods in Fabric #17439

Merged
merged 4 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Increased the minimum XLA requirement to 1.13 ([#17368](https://github.com/Lightning-AI/lightning/pull/17368))


- Enable precision autocast for LightningModule step methods in Fabric ([#17439](https://github.com/Lightning-AI/lightning/pull/17439))


### Deprecated

-
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def wrapped_forward(*args: Any, **kwargs: Any) -> Any:
def call_forward_module(*args: Any, **kwargs: Any) -> Any:
# Patch the original_module's forward so we can redirect the arguments back to the real method
self._original_module.forward = wrapped_forward
return self._forward_module(*args, **kwargs)
return self.forward(*args, **kwargs)

return call_forward_module

Expand Down
5 changes: 4 additions & 1 deletion tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch.utils.data.dataloader import DataLoader

from lightning.fabric.fabric import Fabric
from lightning.fabric.plugins import Precision
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer, is_wrapped
from tests_fabric.helpers.runif import RunIf
Expand Down Expand Up @@ -417,9 +418,10 @@ def validation_step(self, arg, kwarg=None):
def normal_method(self):
pass

precision = Mock(wraps=Precision())
original_module = LightningModule()
forward_module = DDP(original_module)
fabric_module = _FabricModule(forward_module=forward_module, precision=Mock(), original_module=original_module)
fabric_module = _FabricModule(forward_module=forward_module, precision=precision, original_module=original_module)

# Regular methods on the original_module are visible and identical on the fabric_module ...
assert fabric_module.normal_method == original_module.normal_method
Expand All @@ -441,6 +443,7 @@ def normal_method(self):
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return"
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return" # call 2nd time
assert fabric_module.validation_step("val_arg", kwarg="val_kwarg") == "validation_step_return"
precision.forward_context.assert_called()

# The forward method remains untouched/unpatched after the special methods have been called
assert original_module.forward.__name__ == "forward"
Expand Down