Skip to content

Commit 0cfc53d

Browse files
authored
Fix regression on default value for find_unused_parameters (#14095)
1 parent 82d2d1d commit 0cfc53d

File tree

5 files changed

+57
-2
lines changed

5 files changed

+57
-2
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7070
- Fixed dtype inference during gradient norm computation ([#14051](https://github.com/Lightning-AI/lightning/pull/14051))
7171

7272

73+
- Fixed a bug that caused `ddp_find_unused_parameters` to be set `False`, whereas the intended default is `True` ([#14095](https://github.com/Lightning-AI/lightning/pull/14095))
74+
75+
7376
## [1.7.0] - 2022-08-02
7477

7578
### Added

src/pytorch_lightning/strategies/ddp_spawn.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,10 +315,20 @@ def post_training_step(self) -> None:
315315
def register_strategies(cls, strategy_registry: Dict) -> None:
316316
entries = (
317317
("ddp_spawn", "spawn"),
318-
("ddp_spawn_find_unused_parameters_false", "spawn"),
319318
("ddp_fork", "fork"),
320-
("ddp_fork_find_unused_parameters_false", "fork"),
321319
("ddp_notebook", "fork"),
320+
)
321+
for name, start_method in entries:
322+
strategy_registry.register(
323+
name,
324+
cls,
325+
description=f"DDP strategy with `start_method` '{start_method}'",
326+
start_method=start_method,
327+
)
328+
329+
entries = (
330+
("ddp_spawn_find_unused_parameters_false", "spawn"),
331+
("ddp_fork_find_unused_parameters_false", "fork"),
322332
("ddp_notebook_find_unused_parameters_false", "fork"),
323333
)
324334
for name, start_method in entries:

tests/tests_pytorch/strategies/test_ddp.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,15 @@ def root_device(self):
194194
assert strategy._get_process_group_backend() == expected_process_group_backend
195195
else:
196196
assert strategy._get_process_group_backend() == expected_process_group_backend
197+
198+
199+
@pytest.mark.parametrize(
200+
"strategy_name,expected_ddp_kwargs",
201+
[
202+
("ddp", {}),
203+
("ddp_find_unused_parameters_false", {"find_unused_parameters": False}),
204+
],
205+
)
206+
def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs):
207+
trainer = Trainer(strategy=strategy_name)
208+
assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs

tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,19 @@ def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
178178
mock_init_process_group.assert_called_with(
179179
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
180180
)
181+
182+
183+
@pytest.mark.parametrize(
184+
"strategy_name,expected_ddp_kwargs",
185+
[
186+
("ddp_spawn", {}),
187+
("ddp_fork", {}),
188+
("ddp_notebook", {}),
189+
("ddp_spawn_find_unused_parameters_false", {"find_unused_parameters": False}),
190+
("ddp_fork_find_unused_parameters_false", {"find_unused_parameters": False}),
191+
("ddp_notebook_find_unused_parameters_false", {"find_unused_parameters": False}),
192+
],
193+
)
194+
def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs):
195+
trainer = Trainer(strategy=strategy_name)
196+
assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs

tests/tests_pytorch/strategies/test_sharded_strategy.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,17 @@ def test_block_backward_sync():
300300
with strategy.block_backward_sync():
301301
pass
302302
model.no_sync.assert_called_once()
303+
304+
305+
@pytest.mark.parametrize(
306+
"strategy_name,expected_ddp_kwargs",
307+
[
308+
("ddp_sharded", {}),
309+
("ddp_sharded_find_unused_parameters_false", {"find_unused_parameters": False}),
310+
("ddp_sharded_spawn", {}),
311+
("ddp_sharded_spawn_find_unused_parameters_false", {"find_unused_parameters": False}),
312+
],
313+
)
314+
def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs):
315+
trainer = Trainer(strategy=strategy_name)
316+
assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs

0 commit comments

Comments
 (0)