Skip to content

Commit 237b4d8

Browse files
committed
generalize test
1 parent e96b218 commit 237b4d8

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

tests/overrides/test_data_parallel.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,45 @@
33
import pytest
44
import torch
55

6-
from pytorch_lightning.overrides.data_parallel import LightningDistributedModule
6+
from pytorch_lightning.overrides.data_parallel import LightningDistributedModule, LightningParallelModule
77

88

9-
def test_lightning_distributed_module_methods():
9+
@pytest.mark.parametrize(['wrapper_class'], [
10+
LightningParallelModule,
11+
LightningDistributedModule,
12+
])
13+
def test_lightning_distributed_module_methods(wrapper_class):
1014
""" Test that the LightningDistributedModule redirects .forward() to the LightningModule methods. """
1115
pl_module = MagicMock()
12-
dist_module = LightningDistributedModule(pl_module)
16+
wrapped_module = wrapper_class(pl_module)
1317

1418
batch = torch.rand(5)
1519
batch_idx = 3
1620

1721
pl_module.training = True
1822
pl_module.testing = False
19-
dist_module(batch, batch_idx)
23+
wrapped_module(batch, batch_idx)
2024
pl_module.training_step.assert_called_with(batch, batch_idx)
2125

2226
pl_module.training = False
2327
pl_module.testing = True
24-
dist_module(batch, batch_idx)
28+
wrapped_module(batch, batch_idx)
2529
pl_module.test_step.assert_called_with(batch, batch_idx)
2630

2731
pl_module.training = False
2832
pl_module.testing = False
29-
dist_module(batch, batch_idx)
33+
wrapped_module(batch, batch_idx)
3034
pl_module.validation_step.assert_called_with(batch, batch_idx)
3135

3236

33-
def test_lightning_distributed_module_warn_none_output():
37+
@pytest.mark.parametrize(['wrapper_class'], [
38+
LightningParallelModule,
39+
LightningDistributedModule,
40+
])
41+
def test_lightning_distributed_module_warn_none_output(wrapper_class):
3442
""" Test that the LightningDistributedModule warns about forgotten return statement. """
3543
pl_module = MagicMock()
36-
dist_module = LightningDistributedModule(pl_module)
44+
wrapped_module = wrapper_class(pl_module)
3745

3846
pl_module.training_step.return_value = None
3947
pl_module.validation_step.return_value = None
@@ -42,14 +50,14 @@ def test_lightning_distributed_module_warn_none_output():
4250
with pytest.warns(UserWarning, match="Your training_step returned None"):
4351
pl_module.training = True
4452
pl_module.testing = False
45-
dist_module()
53+
wrapped_module()
4654

4755
with pytest.warns(UserWarning, match="Your test_step returned None"):
4856
pl_module.training = False
4957
pl_module.testing = True
50-
dist_module()
58+
wrapped_module()
5159

5260
with pytest.warns(UserWarning, match="Your validation_step returned None"):
5361
pl_module.training = False
5462
pl_module.testing = False
55-
dist_module()
63+
wrapped_module()

0 commit comments

Comments
 (0)