diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index a0160834652..0f596bb29f8 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -882,6 +882,11 @@ def prepare_data_loader( generator=getattr(sampler, "generator", torch.Generator()), ) + if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA: + # isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled. + generator = torch.Generator().manual_seed(42) + dataloader.generator = generator + dataloader.sampler.generator = generator # No change if no multiprocess if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches: if isinstance(new_dataset, IterableDataset): diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 21d18931bee..682e59654af 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -160,7 +160,7 @@ def __init__(self, cpu: bool = False, **kwargs): elif is_torch_xla_available() and not cpu: self.distributed_type = DistributedType.XLA self.device = xm.xla_device() - xm.set_replication(self.device, [self.device]) + xm.set_replication(self.device, xm.get_xla_supported_devices()) self.num_processes = xm.xrt_world_size() self.process_index = xm.get_ordinal() if is_torch_xla_available(check_is_tpu=True):