Skip to content

Commit f6f06d4

Browse files
authored
Set default strategy to ddp_fork in interactive environments (#13746)
1 parent 9f51c07 commit f6f06d4

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

src/pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
_HOROVOD_AVAILABLE,
8888
_HPU_AVAILABLE,
8989
_IPU_AVAILABLE,
90+
_IS_INTERACTIVE,
9091
_TORCH_GREATER_EQUAL_1_11,
9192
_TPU_AVAILABLE,
9293
)
@@ -588,7 +589,9 @@ def _choose_strategy(self) -> Union[Strategy, str]:
588589
# TODO: lazy initialized device, then here could be self._strategy_flag = "single_device"
589590
return SingleDeviceStrategy(device=device) # type: ignore
590591
if len(self._parallel_devices) > 1:
591-
return DDPSpawnStrategy.strategy_name
592+
if _IS_INTERACTIVE:
593+
return "ddp_fork"
594+
return "ddp_spawn"
592595

593596
return DDPStrategy.strategy_name
594597

tests/tests_pytorch/accelerators/test_accelerator_connector.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,18 @@ def test_strategy_choice_ddp_spawn_cpu():
424424
assert trainer.strategy.launcher._start_method == "spawn"
425425

426426

427+
@RunIf(skip_windows=True)
428+
@mock.patch("pytorch_lightning.trainer.connectors.accelerator_connector._IS_INTERACTIVE", True)
429+
def test_strategy_choice_ddp_fork_in_interactive():
430+
"""Test that when accelerator and strategy are unspecified, the connector chooses DDP Fork in interactive
431+
environments by default."""
432+
trainer = Trainer(devices=2)
433+
assert isinstance(trainer.accelerator, CPUAccelerator)
434+
assert isinstance(trainer.strategy, DDPSpawnStrategy)
435+
assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment)
436+
assert trainer.strategy.launcher._start_method == "fork"
437+
438+
427439
@RunIf(skip_windows=True)
428440
def test_strategy_choice_ddp_fork_cpu():
429441
trainer = Trainer(strategy="ddp_fork", accelerator="cpu", devices=2)

0 commit comments

Comments
 (0)