2323from lightning .fabric .fabric import Fabric
2424from lightning .fabric .plugins import Precision
2525from lightning .fabric .utilities .device_dtype_mixin import _DeviceDtypeModuleMixin
26- from lightning .fabric .wrappers import _FabricDataLoader , _FabricModule , _FabricOptimizer , is_wrapped
26+ from lightning .fabric .wrappers import _FabricDataLoader , _FabricModule , _FabricOptimizer , _unwrap_objects , is_wrapped
2727from tests_fabric .helpers .runif import RunIf
2828
2929
@@ -358,7 +358,8 @@ def zero_grad(self, set_grads_to_None=False):
358358 custom_zero_grad .assert_called_with (set_grads_to_None = False )
359359
360360
361- def test_is_wrapped ():
361+ @pytest .mark .parametrize ("compile" , [False , pytest .param (True , marks = RunIf (dynamo = True ))])
362+ def test_is_wrapped (compile ):
362363 """Test that the `is_wrapped` utility recognizes when an object was wrapped by Fabric."""
363364 assert not is_wrapped (None )
364365
@@ -368,6 +369,15 @@ def test_is_wrapped():
368369 wrapped = _FabricModule (module , Mock ())
369370 assert is_wrapped (wrapped )
370371
372+ # _FabricModule inside an OptimizedModule
373+ if compile :
374+ from torch ._dynamo import OptimizedModule
375+
376+ module = torch .nn .Linear (2 , 2 )
377+ wrapped = torch .compile (_FabricModule (module , Mock ()))
378+ assert isinstance (wrapped , OptimizedModule )
379+ assert is_wrapped (wrapped )
380+
371381 # _FabricOptimizer
372382 optimizer = torch .optim .Adam (module .parameters ())
373383 assert not is_wrapped (optimizer )
@@ -381,6 +391,43 @@ def test_is_wrapped():
381391 assert is_wrapped (wrapped )
382392
383393
394+ @pytest .mark .parametrize ("compile" , [False , pytest .param (True , marks = RunIf (dynamo = True ))])
395+ def test_unwrap_objects (compile ):
396+ # empty container
397+ assert _unwrap_objects ({}) == {}
398+
399+ # container with pure objects and wrapped objects
400+ module = torch .nn .Linear (1 , 1 )
401+ wrapped_module = _FabricModule (module , Mock ())
402+ if compile :
403+ wrapped_module = torch .compile (wrapped_module )
404+ optimizer = torch .optim .Adam (module .parameters ())
405+ wrapped_optimizer = _FabricOptimizer (optimizer , Mock ())
406+ dataloader = DataLoader ([1 , 2 , 3 ])
407+ wrapped_dataloader = _FabricDataLoader (dataloader )
408+ container = {
409+ "int" : 1 ,
410+ "module" : module ,
411+ "wrapped_module" : wrapped_module ,
412+ "optimizer" : optimizer ,
413+ "wrapped_optimizer" : wrapped_optimizer ,
414+ "dataloader" : dataloader ,
415+ "wrapped_dataloader" : wrapped_dataloader ,
416+ "nested" : [module , wrapped_module , optimizer , wrapped_optimizer , dataloader , wrapped_dataloader ],
417+ }
418+ expected = {
419+ "int" : 1 ,
420+ "module" : module ,
421+ "wrapped_module" : wrapped_module ._forward_module ,
422+ "optimizer" : optimizer ,
423+ "wrapped_optimizer" : optimizer ,
424+ "dataloader" : dataloader ,
425+ "wrapped_dataloader" : dataloader ,
426+ "nested" : [module , wrapped_module ._forward_module , optimizer , optimizer , dataloader , dataloader ],
427+ }
428+ assert _unwrap_objects (container ) == expected
429+
430+
384431def test_step_method_redirection ():
385432 """Test that the FabricModule redirects the special `LightningModule.*_step` methods through the forward-
386433 module."""
0 commit comments