Skip to content

Commit

Permalink
Fix test_script.py on TPU v2/v3 (#2542)
Browse files Browse the repository at this point in the history
* fix replication

* Set generator on each thread. The test passed.

* remove comments

* fix up

* fix format

* fix comment

* not setting the dataloader.batch_sampler
  • Loading branch information
vanbasten23 authored Mar 13, 2024
1 parent ee163b6 commit 02a8a9a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,11 @@ def prepare_data_loader(
generator=getattr(sampler, "generator", torch.Generator()),
)

if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:
# isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
generator = torch.Generator().manual_seed(42)
dataloader.generator = generator
dataloader.sampler.generator = generator
# No change if no multiprocess
if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
if isinstance(new_dataset, IterableDataset):
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, cpu: bool = False, **kwargs):
elif is_torch_xla_available() and not cpu:
self.distributed_type = DistributedType.XLA
self.device = xm.xla_device()
xm.set_replication(self.device, [self.device])
xm.set_replication(self.device, xm.get_xla_supported_devices())
self.num_processes = xm.xrt_world_size()
self.process_index = xm.get_ordinal()
if is_torch_xla_available(check_is_tpu=True):
Expand Down

0 comments on commit 02a8a9a

Please sign in to comment.