Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
full shuffle of the dataset (#153)
Browse files Browse the repository at this point in the history
* full shuffle of the dataset

* added changelog entry

---------

Co-authored-by: Ana Prieto Nemesio <91897203+anaprietonem@users.noreply.github.com>
  • Loading branch information
ssmmnn11 and anaprietonem authored Nov 29, 2024
1 parent 6b45f8d commit 7e363fa
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Keep it human-readable, your future self will thank you!

## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28

### Changed
- Perform full shuffle of training dataset [#153](https://github.com/ecmwf/anemoi-training/pull/153)

### Fixed

- Update `n_pixel` used by datashader to better adapt across resolutions #152
Expand Down
28 changes: 9 additions & 19 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None:

low = shard_start + worker_id * self.n_samples_per_worker
high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end)
self.chunk_index_range = np.arange(low, high, dtype=np.uint32)

LOGGER.debug(
"Worker %d (pid %d, global_rank %d, model comm group %d) has low/high range %d / %d",
Expand All @@ -212,27 +213,17 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None:
high,
)

self.chunk_index_range = self.valid_date_indices[np.arange(low, high, dtype=np.uint32)]

# each worker must have a different seed for its random number generator,
# otherwise all the workers will output exactly the same data
# should we check lightning env variable "PL_SEED_WORKERS" here?
# but we alwyas want to seed these anyways ...

base_seed = get_base_seed()

seed = (
base_seed * (self.model_comm_group_id + 1) - worker_id
) # note that test, validation etc. datasets get same seed
torch.manual_seed(seed)
random.seed(seed)
self.rng = np.random.default_rng(seed=seed)
torch.manual_seed(base_seed)
random.seed(base_seed)
self.rng = np.random.default_rng(seed=base_seed)
sanity_rnd = self.rng.random(1)

LOGGER.debug(
(
"Worker %d (%s, pid %d, glob. rank %d, model comm group %d, "
"group_rank %d, base_seed %d) using seed %d, sanity rnd %f"
"group_rank %d, base_seed %d), sanity rnd %f"
),
worker_id,
self.label,
Expand All @@ -241,7 +232,6 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None:
self.model_comm_group_id,
self.model_comm_group_rank,
base_seed,
seed,
sanity_rnd,
)

Expand All @@ -256,12 +246,12 @@ def __iter__(self) -> torch.Tensor:
"""
if self.shuffle:
shuffled_chunk_indices = self.rng.choice(
self.chunk_index_range,
size=self.n_samples_per_worker,
self.valid_date_indices,
size=len(self.valid_date_indices),
replace=False,
)
)[self.chunk_index_range]
else:
shuffled_chunk_indices = self.chunk_index_range
shuffled_chunk_indices = self.valid_date_indices[self.chunk_index_range]

LOGGER.debug(
(
Expand Down

0 comments on commit 7e363fa

Please sign in to comment.