Skip to content

Commit 674ab35

Browse files
jonb377alanwaketan
authored andcommitted
Update device count API
1 parent ec6c2d6 commit 674ab35

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/pytorch/language-modeling/run_clm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ def main():
500500
import torch_xla.core.xla_model as xm
501501
import torch_xla.experimental.xla_sharding as xs
502502
import torch_xla.runtime as xr
503-
num_devices = xr.global_device_count()
503+
num_devices = xr.global_runtime_device_count()
504504
device_ids = torch.arange(num_devices)
505505
print('Using dtype', model_args.torch_dtype)
506506
model = model.to(xm.xla_device(), dtype=getattr(torch, model_args.torch_dtype))

src/transformers/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1420,7 +1420,7 @@ def _xla_sharded_dataloader(self, dataloader):
14201420
import torch_xla.experimental.xla_sharding as xs
14211421
import torch_xla.runtime as xr
14221422
import torch_xla.distributed.parallel_loader as pl
1423-
num_devices = xr.global_device_count()
1423+
num_devices = xr.global_runtime_device_count()
14241424
device_ids = np.arange(num_devices)
14251425

14261426
sharding_spec = None

0 commit comments

Comments
 (0)