diff --git a/CHANGELOG.md b/CHANGELOG.md index a949c1d5..dc2e9f14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ Keep it human-readable, your future self will thank you! ## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28 +### Changed +- Perform full shuffle of training dataset [#153](https://github.com/ecmwf/anemoi-training/pull/153) + ### Fixed - Update `n_pixel` used by datashader to better adapt across resolutions #152 diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 40065e06..062d2d4d 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -201,6 +201,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: low = shard_start + worker_id * self.n_samples_per_worker high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end) + self.chunk_index_range = np.arange(low, high, dtype=np.uint32) LOGGER.debug( "Worker %d (pid %d, global_rank %d, model comm group %d) has low/high range %d / %d", @@ -212,27 +213,17 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: high, ) - self.chunk_index_range = self.valid_date_indices[np.arange(low, high, dtype=np.uint32)] - - # each worker must have a different seed for its random number generator, - # otherwise all the workers will output exactly the same data - # should we check lightning env variable "PL_SEED_WORKERS" here? - # but we alwyas want to seed these anyways ... - base_seed = get_base_seed() - seed = ( - base_seed * (self.model_comm_group_id + 1) - worker_id - ) # note that test, validation etc. datasets get same seed - torch.manual_seed(seed) - random.seed(seed) - self.rng = np.random.default_rng(seed=seed) + torch.manual_seed(base_seed) + random.seed(base_seed) + self.rng = np.random.default_rng(seed=base_seed) sanity_rnd = self.rng.random(1) LOGGER.debug( ( "Worker %d (%s, pid %d, glob. rank %d, model comm group %d, " - "group_rank %d, base_seed %d) using seed %d, sanity rnd %f" + "group_rank %d, base_seed %d), sanity rnd %f" ), worker_id, self.label, @@ -241,7 +232,6 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: self.model_comm_group_id, self.model_comm_group_rank, base_seed, - seed, sanity_rnd, ) @@ -256,12 +246,12 @@ def __iter__(self) -> torch.Tensor: """ if self.shuffle: shuffled_chunk_indices = self.rng.choice( - self.chunk_index_range, - size=self.n_samples_per_worker, + self.valid_date_indices, + size=len(self.valid_date_indices), replace=False, - ) + )[self.chunk_index_range] else: - shuffled_chunk_indices = self.chunk_index_range + shuffled_chunk_indices = self.valid_date_indices[self.chunk_index_range] LOGGER.debug( (