diff --git a/ignite/distributed/auto.py b/ignite/distributed/auto.py index efe6ec6f4dff..5c4ba1db521c 100644 --- a/ignite/distributed/auto.py +++ b/ignite/distributed/auto.py @@ -238,8 +238,8 @@ def __iter__(self): while len(indices) < self.total_size: indices += list(self.sampler) - if len(indices) != self.total_size: - raise RuntimeError("{} vs {}".format(len(indices), self.total_size)) + if len(indices) > self.total_size: + indices = indices[: self.total_size] # subsample indices = indices[self.rank : self.total_size : self.num_replicas] diff --git a/tests/ignite/distributed/test_auto.py b/tests/ignite/distributed/test_auto.py index 8be766bad8be..1bdb408db240 100644 --- a/tests/ignite/distributed/test_auto.py +++ b/tests/ignite/distributed/test_auto.py @@ -155,18 +155,23 @@ def test_dist_proxy_sampler(): weights = torch.ones(100) weights[:50] += 1 - num_samples = 100 + num_samples = 200 sampler = WeightedRandomSampler(weights, num_samples) - num_replicas = 4 + num_replicas = 8 dist_samplers = [DistributedProxySampler(sampler, num_replicas=num_replicas, rank=i) for i in range(num_replicas)] - torch.manual_seed(0) - true_indices = list(sampler) + for seed in range(100): + torch.manual_seed(seed) + true_indices = list(sampler) - indices_per_rank = [] - for s in dist_samplers: - s.set_epoch(0) - indices_per_rank += list(s) + indices_per_rank = [] + for s in dist_samplers: + s.set_epoch(seed) + indices_per_rank += list(s) - assert set(indices_per_rank) == set(true_indices) + set_indices_per_rank = set(indices_per_rank) + set_true_indices = set(true_indices) + assert set_indices_per_rank == set_true_indices, "{} | {}".format( + set_true_indices - set_indices_per_rank, set_indices_per_rank - set_true_indices + )