Skip to content

Commit fc3296c

Browse files
awaelchliBorda
authored andcommitted
Support compiling a module after it was set up by Fabric (#17529)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit a533f68)
1 parent 0b397b5 commit fc3296c

File tree

3 files changed

+82
-6
lines changed

3 files changed

+82
-6
lines changed

src/lightning/fabric/fabric.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@
4646
from lightning.fabric.utilities.seed import seed_everything
4747
from lightning.fabric.utilities.types import ReduceOp
4848
from lightning.fabric.utilities.warnings import PossibleUserWarning
49-
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer, _unwrap_objects
49+
from lightning.fabric.wrappers import (
50+
_FabricDataLoader,
51+
_FabricModule,
52+
_FabricOptimizer,
53+
_unwrap_compiled,
54+
_unwrap_objects,
55+
)
5056

5157

5258
class Fabric:
@@ -542,6 +548,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Gener
542548
enabled: Whether the context manager is enabled or not. ``True`` means skip the sync, ``False`` means do not
543549
skip.
544550
"""
551+
module = _unwrap_compiled(module)
545552
if not isinstance(module, _FabricModule):
546553
raise TypeError(
547554
"You need to set up the model first before you can call `self.no_backward_sync()`:"

src/lightning/fabric/wrappers.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from lightning.fabric.utilities import move_data_to_device
2929
from lightning.fabric.utilities.data import _set_sampler_epoch
3030
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
31+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
3132
from lightning.fabric.utilities.types import Optimizable
3233
from lightning.fabric.utilities.warnings import PossibleUserWarning
3334

@@ -218,15 +219,35 @@ def _unwrap_objects(collection: Any) -> Any:
218219
def _unwrap(
219220
obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader]
220221
) -> Union[nn.Module, Optimizer, DataLoader]:
221-
if isinstance(obj, _FabricModule):
222-
return obj._forward_module
222+
if isinstance(unwrapped := _unwrap_compiled(obj), _FabricModule):
223+
return unwrapped._forward_module
223224
if isinstance(obj, _FabricOptimizer):
224225
return obj.optimizer
225226
if isinstance(obj, _FabricDataLoader):
226227
return obj._dataloader
227228
return obj
228229

229-
return apply_to_collection(collection, dtype=(_FabricModule, _FabricOptimizer, _FabricDataLoader), function=_unwrap)
230+
types = [_FabricModule, _FabricOptimizer, _FabricDataLoader]
231+
if _TORCH_GREATER_EQUAL_2_0:
232+
from torch._dynamo import OptimizedModule
233+
234+
types.append(OptimizedModule)
235+
236+
return apply_to_collection(collection, dtype=tuple(types), function=_unwrap)
237+
238+
239+
def _unwrap_compiled(obj: Any) -> Any:
240+
"""Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped.
241+
242+
Use this function before instance checks against e.g. :class:`_FabricModule`.
243+
"""
244+
if not _TORCH_GREATER_EQUAL_2_0:
245+
return obj
246+
from torch._dynamo import OptimizedModule
247+
248+
if isinstance(obj, OptimizedModule):
249+
return obj._orig_mod
250+
return obj
230251

231252

232253
def is_wrapped(obj: object) -> bool:
@@ -239,4 +260,5 @@ def is_wrapped(obj: object) -> bool:
239260
Args:
240261
obj: The object to test.
241262
"""
263+
obj = _unwrap_compiled(obj)
242264
return isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader))

tests/tests_fabric/test_wrappers.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from lightning.fabric.fabric import Fabric
2424
from lightning.fabric.plugins import Precision
2525
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
26-
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer, is_wrapped
26+
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer, _unwrap_objects, is_wrapped
2727
from tests_fabric.helpers.runif import RunIf
2828

2929

@@ -358,7 +358,8 @@ def zero_grad(self, set_grads_to_None=False):
358358
custom_zero_grad.assert_called_with(set_grads_to_None=False)
359359

360360

361-
def test_is_wrapped():
361+
@pytest.mark.parametrize("compile", [False, pytest.param(True, marks=RunIf(dynamo=True))])
362+
def test_is_wrapped(compile):
362363
"""Test that the `is_wrapped` utility recognizes when an object was wrapped by Fabric."""
363364
assert not is_wrapped(None)
364365

@@ -368,6 +369,15 @@ def test_is_wrapped():
368369
wrapped = _FabricModule(module, Mock())
369370
assert is_wrapped(wrapped)
370371

372+
# _FabricModule inside an OptimizedModule
373+
if compile:
374+
from torch._dynamo import OptimizedModule
375+
376+
module = torch.nn.Linear(2, 2)
377+
wrapped = torch.compile(_FabricModule(module, Mock()))
378+
assert isinstance(wrapped, OptimizedModule)
379+
assert is_wrapped(wrapped)
380+
371381
# _FabricOptimizer
372382
optimizer = torch.optim.Adam(module.parameters())
373383
assert not is_wrapped(optimizer)
@@ -381,6 +391,43 @@ def test_is_wrapped():
381391
assert is_wrapped(wrapped)
382392

383393

394+
@pytest.mark.parametrize("compile", [False, pytest.param(True, marks=RunIf(dynamo=True))])
395+
def test_unwrap_objects(compile):
396+
# empty container
397+
assert _unwrap_objects({}) == {}
398+
399+
# container with pure objects and wrapped objects
400+
module = torch.nn.Linear(1, 1)
401+
wrapped_module = _FabricModule(module, Mock())
402+
if compile:
403+
wrapped_module = torch.compile(wrapped_module)
404+
optimizer = torch.optim.Adam(module.parameters())
405+
wrapped_optimizer = _FabricOptimizer(optimizer, Mock())
406+
dataloader = DataLoader([1, 2, 3])
407+
wrapped_dataloader = _FabricDataLoader(dataloader)
408+
container = {
409+
"int": 1,
410+
"module": module,
411+
"wrapped_module": wrapped_module,
412+
"optimizer": optimizer,
413+
"wrapped_optimizer": wrapped_optimizer,
414+
"dataloader": dataloader,
415+
"wrapped_dataloader": wrapped_dataloader,
416+
"nested": [module, wrapped_module, optimizer, wrapped_optimizer, dataloader, wrapped_dataloader],
417+
}
418+
expected = {
419+
"int": 1,
420+
"module": module,
421+
"wrapped_module": wrapped_module._forward_module,
422+
"optimizer": optimizer,
423+
"wrapped_optimizer": optimizer,
424+
"dataloader": dataloader,
425+
"wrapped_dataloader": dataloader,
426+
"nested": [module, wrapped_module._forward_module, optimizer, optimizer, dataloader, dataloader],
427+
}
428+
assert _unwrap_objects(container) == expected
429+
430+
384431
def test_step_method_redirection():
385432
"""Test that the FabricModule redirects the special `LightningModule.*_step` methods through the forward-
386433
module."""

0 commit comments

Comments
 (0)