Skip to content

Commit 1b9c7d8

Browse files
jonb377yeounoh
authored andcommitted
Update device count API
1 parent d140075 commit 1b9c7d8

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/pytorch/language-modeling/run_clm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def main():
499499
import torch_xla.core.xla_model as xm
500500
import torch_xla.experimental.xla_sharding as xs
501501
import torch_xla.runtime as xr
502-
num_devices = xr.global_device_count()
502+
num_devices = xr.global_runtime_device_count()
503503
device_ids = torch.arange(num_devices)
504504
print('Using dtype', model_args.torch_dtype)
505505
model = model.to(xm.xla_device(), dtype=getattr(torch, model_args.torch_dtype))

src/transformers/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1541,7 +1541,7 @@ def _xla_sharded_dataloader(self, dataloader):
15411541
import torch_xla.experimental.xla_sharding as xs
15421542
import torch_xla.runtime as xr
15431543
import torch_xla.distributed.parallel_loader as pl
1544-
num_devices = xr.global_device_count()
1544+
num_devices = xr.global_runtime_device_count()
15451545
device_ids = np.arange(num_devices)
15461546

15471547
sharding_spec = None

0 commit comments

Comments
 (0)