33import pytest
44import torch
55
6- from pytorch_lightning .overrides .data_parallel import LightningDistributedModule
6+ from pytorch_lightning .overrides .data_parallel import LightningDistributedModule , LightningParallelModule
77
88
9- def test_lightning_distributed_module_methods ():
9+ @pytest .mark .parametrize (['wrapper_class' ], [
10+ LightningParallelModule ,
11+ LightningDistributedModule ,
12+ ])
13+ def test_lightning_distributed_module_methods (wrapper_class ):
1014 """ Test that the LightningDistributedModule redirects .forward() to the LightningModule methods. """
1115 pl_module = MagicMock ()
12- dist_module = LightningDistributedModule (pl_module )
16+ wrapped_module = wrapper_class (pl_module )
1317
1418 batch = torch .rand (5 )
1519 batch_idx = 3
1620
1721 pl_module .training = True
1822 pl_module .testing = False
19- dist_module (batch , batch_idx )
23+ wrapped_module (batch , batch_idx )
2024 pl_module .training_step .assert_called_with (batch , batch_idx )
2125
2226 pl_module .training = False
2327 pl_module .testing = True
24- dist_module (batch , batch_idx )
28+ wrapped_module (batch , batch_idx )
2529 pl_module .test_step .assert_called_with (batch , batch_idx )
2630
2731 pl_module .training = False
2832 pl_module .testing = False
29- dist_module (batch , batch_idx )
33+ wrapped_module (batch , batch_idx )
3034 pl_module .validation_step .assert_called_with (batch , batch_idx )
3135
3236
33- def test_lightning_distributed_module_warn_none_output ():
37+ @pytest .mark .parametrize (['wrapper_class' ], [
38+ LightningParallelModule ,
39+ LightningDistributedModule ,
40+ ])
41+ def test_lightning_distributed_module_warn_none_output (wrapper_class ):
3442 """ Test that the LightningDistributedModule warns about forgotten return statement. """
3543 pl_module = MagicMock ()
36- dist_module = LightningDistributedModule (pl_module )
44+ wrapped_module = wrapper_class (pl_module )
3745
3846 pl_module .training_step .return_value = None
3947 pl_module .validation_step .return_value = None
@@ -42,14 +50,14 @@ def test_lightning_distributed_module_warn_none_output():
4250 with pytest .warns (UserWarning , match = "Your training_step returned None" ):
4351 pl_module .training = True
4452 pl_module .testing = False
45- dist_module ()
53+ wrapped_module ()
4654
4755 with pytest .warns (UserWarning , match = "Your test_step returned None" ):
4856 pl_module .training = False
4957 pl_module .testing = True
50- dist_module ()
58+ wrapped_module ()
5159
5260 with pytest .warns (UserWarning , match = "Your validation_step returned None" ):
5361 pl_module .training = False
5462 pl_module .testing = False
55- dist_module ()
63+ wrapped_module ()
0 commit comments