From 24ef0b5e731ce536daef33bb67a748dad9ff6911 Mon Sep 17 00:00:00 2001 From: Simon Lang Date: Wed, 20 Nov 2024 11:37:25 +0000 Subject: [PATCH 1/2] full shuffle of the dataset --- src/anemoi/training/data/dataset.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 9e368f9c..5d71ffe2 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -153,6 +153,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: low = shard_start + worker_id * self.n_samples_per_worker high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end) + self.chunk_index_range = np.arange(low, high, dtype=np.uint32) LOGGER.debug( "Worker %d (pid %d, global_rank %d, model comm group %d) has low/high range %d / %d", @@ -164,27 +165,17 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: high, ) - self.chunk_index_range = self.valid_date_indices[np.arange(low, high, dtype=np.uint32)] - - # each worker must have a different seed for its random number generator, - # otherwise all the workers will output exactly the same data - # should we check lightning env variable "PL_SEED_WORKERS" here? - # but we alwyas want to seed these anyways ... - base_seed = get_base_seed() - seed = ( - base_seed * (self.model_comm_group_id + 1) - worker_id - ) # note that test, validation etc. datasets get same seed - torch.manual_seed(seed) - random.seed(seed) - self.rng = np.random.default_rng(seed=seed) + torch.manual_seed(base_seed) + random.seed(base_seed) + self.rng = np.random.default_rng(seed=base_seed) sanity_rnd = self.rng.random(1) LOGGER.debug( ( "Worker %d (%s, pid %d, glob. rank %d, model comm group %d, " - "group_rank %d, base_seed %d) using seed %d, sanity rnd %f" + "group_rank %d, base_seed %d), sanity rnd %f" ), worker_id, self.label, @@ -193,7 +184,6 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: self.model_comm_group_id, self.model_comm_group_rank, base_seed, - seed, sanity_rnd, ) @@ -208,12 +198,12 @@ def __iter__(self) -> torch.Tensor: """ if self.shuffle: shuffled_chunk_indices = self.rng.choice( - self.chunk_index_range, - size=self.n_samples_per_worker, + self.valid_date_indices, + size=len(self.valid_date_indices), replace=False, - ) + )[self.chunk_index_range] else: - shuffled_chunk_indices = self.chunk_index_range + shuffled_chunk_indices = self.valid_date_indices[self.chunk_index_range] LOGGER.debug( ( From 2aeb594f3411a91b1bc1928e4676098c4f39198c Mon Sep 17 00:00:00 2001 From: Simon Lang Date: Wed, 20 Nov 2024 11:47:37 +0000 Subject: [PATCH 2/2] added changelog entry --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc9d40f0..432041f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.2.2...HEAD) +### Changed +- Perform full shuffle of training dataset [#153](https://github.com/ecmwf/anemoi-training/pull/153) + ### Fixed - Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) - Updated docs [#115](https://github.com/ecmwf/anemoi-training/pull/115)