From 02a8a9a3a778b18e9ec00ccfdb553504bf23ddb6 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Wed, 13 Mar 2024 10:20:16 -0700 Subject: [PATCH] Fix test_script.py on TPU v2/v3 (#2542) * fix replication * Set generator on each thread. The test passed. * remove comments * fix up * fix format * fix comment * not setting the dataloader.batch_sampler --- src/accelerate/data_loader.py | 5 +++++ src/accelerate/state.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index a0160834652..0f596bb29f8 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -882,6 +882,11 @@ def prepare_data_loader( generator=getattr(sampler, "generator", torch.Generator()), ) + if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA: + # isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled. + generator = torch.Generator().manual_seed(42) + dataloader.generator = generator + dataloader.sampler.generator = generator # No change if no multiprocess if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches: if isinstance(new_dataset, IterableDataset): diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 21d18931bee..682e59654af 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -160,7 +160,7 @@ def __init__(self, cpu: bool = False, **kwargs): elif is_torch_xla_available() and not cpu: self.distributed_type = DistributedType.XLA self.device = xm.xla_device() - xm.set_replication(self.device, [self.device]) + xm.set_replication(self.device, xm.get_xla_supported_devices()) self.num_processes = xm.xrt_world_size() self.process_index = xm.get_ordinal() if is_torch_xla_available(check_is_tpu=True):