Skip to content

Commit

Permalink
fix: improve various aspect of stream shuffling
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Feb 19, 2025
1 parent 18b03a3 commit 2b366fc
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
- Support packaging with poetry 2.0
- Solve pickling issues with multiprocessing when pytorch is installed
- Allow deep attributes like `a.b.c` for `span_attributes` in Standoff and OMOP doc2dict converters
- Fixed various aspects of stream shuffling:

- Ensure the Parquet reader shuffles the data when `shuffle=True`
- Ensure we don't overwrite the RNG of the data reader when calling `stream.shuffle()` with no seed
- Raise an error if the batch size in `stream.shuffle(batch_size=...)` is not compatible with the stream

# v0.15.0 (2024-12-13)

Expand Down
26 changes: 17 additions & 9 deletions edsnlp/core/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,18 +785,19 @@ def shuffle(
if batch_by is None and batch_size is None:
batch_by = "dataset"
if shuffle_reader is None or shuffle_reader is True:
shuffle_reader = (
possible_shuffle_reader = (
batch_by
if batch_by in self.reader.emitted_sentinels and not self.reader.shuffle
else False
)
if not possible_shuffle_reader and shuffle_reader:
# Maybe should we be more explicit about why we cannot shuffle ?
raise ValueError(
"You cannot shuffle the reader given the current stream and the "
f"batching mode {batch_by!r}."
)
shuffle_reader = possible_shuffle_reader
stream = self
# Ensure that we have a "deterministic" random seed, meaning
# if the user sets a global seed before in the program and execute the
# same program twice, the shuffling should be the same in both cases.
# This is not garanteed by just creating random.Random() which does not
# account
seed = seed if seed is not None else random.getrandbits(32)
if shuffle_reader:
if shuffle_reader not in self.reader.emitted_sentinels:
raise ValueError(f"Cannot shuffle by {shuffle_reader}")
Expand All @@ -807,8 +808,15 @@ def shuffle(
config=stream.config,
)
stream.reader.shuffle = shuffle_reader
stream.reader.rng = random.Random(seed)
if any(not op.elementwise for op in self.ops) or not shuffle_reader:
# Ensure that we have a "deterministic" random seed, meaning
# if the user sets a global seed before in the program and execute the
# same program twice, the shuffling should be the same in both cases.
# This is not garanteed by just creating random.Random() which does not
# account for the global seed.
if seed is not None:
stream.reader.rng = random.Random(seed)
# Else, if seed is None, then the reader rng stays the same
if any(not op.elementwise for op in self.ops) or shuffle_reader != batch_by:
stream = stream.map_batches(
pipe=shuffle,
batch_size=batch_size,
Expand Down
4 changes: 3 additions & 1 deletion edsnlp/data/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ def extract_task(self, item):
def read_records(self) -> Iterable[Any]:
while True:
files = self.fragments
if self.shuffle:
files = shuffle(files, self.rng)
if self.shuffle == "fragment":
for file in shuffle(files, self.rng):
for file in files:
if self.work_unit == "fragment":
yield file
else:
Expand Down

0 comments on commit 2b366fc

Please sign in to comment.