Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add is_distributed to Strategy API and remove accl_conn is_distributed property #11968

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -

def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
return (
self._accelerator_connector.is_distributed
self._strategy.is_distributed
and not isinstance(dataloader.sampler, DistributedSampler)
and not has_iterable_dataset(dataloader)
)
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,6 @@ def __init__(
self._sync_dir: Optional[str] = None
self._rank_0_will_call_children_scripts: bool = False

@property
def is_distributed(self) -> bool:
return True

@property
def root_device(self) -> torch.device:
return self.parallel_devices[self.local_rank]
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/strategies/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def node_rank(self) -> int:
def world_size(self) -> int:
return 1

@property
def is_distributed(self) -> bool:
return False

def setup(self, trainer: "pl.Trainer") -> None:
# model needs to be moved to the device before it is wrapped
self.model_to_device()
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def parallel_devices(self):
def parallel_devices(self, parallel_devices):
self._parallel_devices = parallel_devices

@property
def is_distributed(self) -> bool:
return True

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank)
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/strategies/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def setup(self, trainer: pl.Trainer) -> None:
self.model_to_device()
super().setup(trainer)

@property
def is_distributed(self) -> bool:
return False

@property
def is_global_zero(self) -> bool:
return True
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ def model_to_device(self) -> None:
def is_global_zero(self) -> bool:
"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""

@property
@abstractmethod
def is_distributed(self) -> bool:
"""Whether the strategy is distributed strategy."""

@abstractmethod
def reduce(
self,
Expand Down
23 changes: 0 additions & 23 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
)
from pytorch_lightning.strategies import (
DataParallelStrategy,
DDP2Strategy,
DDPFullyShardedStrategy,
DDPShardedStrategy,
DDPSpawnShardedStrategy,
Expand Down Expand Up @@ -818,28 +817,6 @@ def gpus(self) -> Optional[Union[List[int], str, int]]:
def parallel_device_ids(self) -> List[int]:
return [i for i in range(len(self.parallel_devices))] if isinstance(self.accelerator, GPUAccelerator) else []

@property
def is_distributed(self) -> bool:
# Used for custom plugins.
# Custom plugins should implement is_distributed property.
if hasattr(self.strategy, "is_distributed") and not isinstance(self.accelerator, TPUAccelerator):
return self.strategy.is_distributed
distributed_strategy = (
DDP2Strategy,
DDPStrategy,
DDPSpawnShardedStrategy,
DDPShardedStrategy,
DDPFullyShardedStrategy,
DDPSpawnStrategy,
DeepSpeedStrategy,
TPUSpawnStrategy,
HorovodStrategy,
)
is_distributed = isinstance(self.strategy, distributed_strategy)
if isinstance(self.accelerator, TPUAccelerator):
is_distributed |= self.strategy.is_distributed
return is_distributed

@property
def has_ipu(self) -> bool:
return isinstance(self.accelerator, IPUAccelerator) and isinstance(self.strategy, IPUStrategy)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
def _requires_distributed_sampler(self, dataloader) -> bool:
return (
self.trainer._accelerator_connector.replace_sampler_ddp
and self.trainer._accelerator_connector.is_distributed
and self.trainer.strategy.is_distributed
and not isinstance(dataloader.sampler, DistributedSampler)
and not has_iterable_dataset(dataloader)
)
Expand Down
4 changes: 2 additions & 2 deletions tests/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy):

# explicitly asking to replace when a custom sampler is already configured raises an exception
lite = EmptyLite(accelerator="cpu", strategy=strategy, devices=2)
if lite._accelerator_connector.is_distributed:
if lite._strategy.is_distributed:
with pytest.raises(MisconfigurationException, match="You seem to have configured a sampler in your DataLoader"):
lite.setup_dataloaders(dataloader, replace_sampler=True)

Expand All @@ -292,7 +292,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy):
def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy):
"""Test that Lite replaces the default samplers with DistributedSampler automatically."""
lite = EmptyLite(accelerator="cpu", strategy=strategy, devices=2)
is_distributed = lite._accelerator_connector.is_distributed
is_distributed = lite._strategy.is_distributed
lite_dataloader = lite.setup_dataloaders(DataLoader(range(3), shuffle=shuffle))
assert not is_distributed or isinstance(lite_dataloader.sampler, DistributedSampler)

Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,15 +1393,15 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, *arg
self.write_on_batch_end_called = True

def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
expected = 1 if trainer._accelerator_connector.is_distributed else 2
expected = 1 if trainer.strategy.is_distributed else 2
assert len(predictions) == 2
assert len(predictions[0]) == expected
assert len(batch_indices) == 2
assert len(batch_indices[0]) == expected
self.write_on_epoch_end_called = True

def on_predict_epoch_end(self, trainer, pl_module, outputs):
if trainer._accelerator_connector.is_distributed:
if trainer.strategy.is_distributed:
for idx in range(2):
assert isinstance(trainer.predict_dataloaders[idx].batch_sampler.sampler, UnrepeatedDistributedSampler)
assert isinstance(trainer.predict_dataloaders[idx].batch_sampler, IndexBatchSamplerWrapper)
Expand Down