Skip to content

Commit

Permalink
avoid calling xm.xla_device() on multi-tpu
Browse files Browse the repository at this point in the history
  • Loading branch information
DuYicong515 committed Mar 25, 2022
1 parent 23a2f5e commit 9e26696
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
10 changes: 7 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
SimpleProfiler,
XLAProfiler,
)
from pytorch_lightning.strategies import ParallelStrategy, Strategy
from pytorch_lightning.strategies import ParallelStrategy, SingleDeviceStrategy, Strategy
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
Expand Down Expand Up @@ -2066,7 +2066,11 @@ def num_nodes(self) -> int:
@property
def device_ids(self) -> List[int]:
"""List of device indexes per node."""
devices = getattr(self.strategy, "parallel_devices", [self.strategy.root_device])
devices = (
[self.strategy.root_device]
if isinstance(self.strategy, SingleDeviceStrategy)
else self.strategy.parallel_devices
)
device_ids = []
for idx, device in enumerate(devices):
if isinstance(device, torch.device):
Expand All @@ -2078,7 +2082,7 @@ def device_ids(self) -> List[int]:
@property
def num_devices(self) -> int:
"""Number of devices the trainer uses per node."""
return len(self.device_ids)
return 1 if isinstance(self.strategy, SingleDeviceStrategy) else len(self.device_ids)

@property
def num_processes(self) -> int:
Expand Down
5 changes: 1 addition & 4 deletions tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def test_accelerator_cpu_with_tpu_cores_flag():


@RunIf(tpu=True)
@pl_multi_process_test
@pytest.mark.parametrize(["accelerator", "devices"], [("auto", 8), ("auto", "auto"), ("tpu", None)])
def test_accelerator_tpu(accelerator, devices):
assert TPUAccelerator.is_available()
Expand All @@ -104,13 +103,12 @@ def test_accelerator_tpu(accelerator, devices):
assert isinstance(trainer.strategy, TPUSpawnStrategy)
assert trainer.num_devices == 8
with pytest.deprecated_call(
match= "`Trainer.tpu_cores` is deprecated in v1.6 and will be removed in v1.8. "
match="`Trainer.tpu_cores` is deprecated in v1.6 and will be removed in v1.8. "
"Please use `Trainer.devices` instead."
):
trainer.tpu_cores == 8



@RunIf(tpu=True)
def test_accelerator_tpu_with_tpu_cores_priority():
"""Test for checking `tpu_cores` flag takes priority over `devices`."""
Expand All @@ -124,7 +122,6 @@ def test_accelerator_tpu_with_tpu_cores_priority():


@RunIf(tpu=True)
@pl_multi_process_test
def test_set_devices_if_none_tpu():
trainer = Trainer(accelerator="tpu", tpu_cores=8)
assert trainer.num_devices == 8
Expand Down

0 comments on commit 9e26696

Please sign in to comment.