diff --git a/docs/source/fundamentals/sampling.md b/docs/source/fundamentals/sampling.md index 18d2dae8d..f8b022b1f 100644 --- a/docs/source/fundamentals/sampling.md +++ b/docs/source/fundamentals/sampling.md @@ -2,5 +2,5 @@ You can choose how sampling from your dataset(s) occurs between epochs by specifying the `sampling_method` when instantiating `StreamingDataset`. Currently, this can take on one of two values: --`'balanced'`: (default) Samples are chosen at random from dataset(s) during each epoch according to the proportions specified. --`'fixed`: The same samples from the dataset(s) are chosen during every epoch, still according to the proportions specified. +- `'balanced'`: (default) Samples are chosen at random from dataset(s) during each epoch according to the proportions specified. +- `'fixed`: The same samples from the dataset(s) are chosen during every epoch, still according to the proportions specified. diff --git a/docs/source/fundamentals/shuffling.md b/docs/source/fundamentals/shuffling.md index adc5e67a5..4f5cdc57e 100644 --- a/docs/source/fundamentals/shuffling.md +++ b/docs/source/fundamentals/shuffling.md @@ -11,7 +11,7 @@ StreamingDataset takes four arguments to directly control shuffling. | `shuffle` | `bool = False` | turn shuffling on or off | | `shuffle_algo` | `str = 'py1s'` | which shuffling algorithm to use | | `shuffle_seed` | `int = 9176` | all randomness in StreamingDataset is derived from this argument | -| `shuffle_block_size` | `int = 1 << 18` | shuffling unit used by py1b algorithm | +| `shuffle_block_size` | `int = 1 << 18` | shuffling unit used by py1b and py1br algorithms | StreamingDataset also takes two other arguments that shuffling interacts with: @@ -34,11 +34,19 @@ Statistically, this algorithm will result in all nodes downloading all shards, w ### py1b -Globally shuffle shards, divide that sample space over canonical nodes, then shuffle samples in fixed-size blocks (given by `shuffle_block_size`). So named because it shuffles samples in python, once, intra-block. +Globally shuffle shards, divide that sample space over canonical nodes, then shuffle samples in fixed-size blocks (given by `shuffle_block_size`). So named because it shuffles samples in python, once, intra-block. A canonical node, for the purposes of shuffling, is simply a collection of shards. In order to have determinism with a different number of physical nodes, the shuffle ordering is done over the canonical nodes and these are then assigned to physical nodes. Shuffle block size should be set larger or much larger than a single shard. If so, this algorithm is useful for spacing out the contents of shards to mitigate a bad or non-existent pre-shuffle (i.e. if samples from the same shard are related in some way). -This algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard. If you see step-like burstiness in throughput, your workers may not be downloading far enough ahead – try raising predownload (it should be scaled with block size). Step size scales with shuffle block size. +In order to improve shuffle quality, this algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard. If you see step-like burstiness in throughput, your workers may not be downloading far enough ahead – try raising predownload (it should be scaled with block size). Step size scales with shuffle block size. + +### py1br + +Globally shuffle shards, divide that sample space over canonical nodes, then shuffle samples in variable-size blocks (uniformly selected within the range `[0.75*shuffle_block_size, 1.25*shuffle_block_size)`). Shuffle blocks are also staggered -- along with variable shuffle block size, this works to prevent many simultaneous shard downloads. So named because it shuffles samples in python, once, intra-block, and blocks are randomized. + +Shuffle block size should be set larger or much larger than a single shard. If so, this algorithm is useful for spacing out the contents of shards to mitigate a bad or non-existent pre-shuffle (i.e. if samples from the same shard are related in some way). + +In order to improve shuffle quality, this algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard. However, shard downloads with py1br are more balanced than with py1b, and this effect is more apparent when training with a higher number of nodes, resulting in less network bottlenecks. The shuffle quality of py1br and py1b are similar. ### py1s diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index ed7dea40c..b4e04b5d1 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -13,6 +13,7 @@ from threading import Event, Lock from time import sleep, time_ns from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union +from warnings import warn import numpy as np from filelock import FileLock @@ -260,7 +261,8 @@ class StreamingDataset(Array, IterableDataset): ``False``. shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. + shuffle_block_size (int): Unit of shuffle. A canonical node's samples are split into blocks + of this size, and samples within each block are shuffled. Defaults to ``1 << 18``. sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. Defaults to ``balanced``. """ @@ -309,6 +311,13 @@ def __init__(self, f'Invalid sampling method: {sampling_method}. Must be one of `balanced` or `fixed`.' ) + # issue deprecation warning for py1b shuffle algorithm. + if self.shuffle_algo == 'py1b': + warn( + 'The \'py1b\' shuffle algorithm will soon be deprecated. Please use the more performant \'py1br\' algorithm instead.', + DeprecationWarning, + stacklevel=2) + # Check that predownload is at least per device batch size. if self.predownload is not None and self.batch_size is not None and \ self.predownload < self.batch_size: diff --git a/streaming/base/shuffle/__init__.py b/streaming/base/shuffle/__init__.py index 286f34bd2..36c6499f7 100644 --- a/streaming/base/shuffle/__init__.py +++ b/streaming/base/shuffle/__init__.py @@ -8,11 +8,13 @@ from streaming.base.shuffle.naive import get_shuffle_naive from streaming.base.shuffle.py1b import get_shuffle_py1b +from streaming.base.shuffle.py1br import get_shuffle_py1br from streaming.base.shuffle.py1s import get_shuffle_py1s from streaming.base.shuffle.py2s import get_shuffle_py2s algos = { 'py1b': get_shuffle_py1b, + 'py1br': get_shuffle_py1br, 'py1s': get_shuffle_py1s, 'py2s': get_shuffle_py2s, 'naive': get_shuffle_naive, diff --git a/streaming/base/shuffle/py1b.py b/streaming/base/shuffle/py1b.py index a47211e36..606e28548 100644 --- a/streaming/base/shuffle/py1b.py +++ b/streaming/base/shuffle/py1b.py @@ -43,6 +43,8 @@ def get_shuffle_py1b(shard_sizes: NDArray[np.int64], num_samples += shard_size # Generate the initial ordering of shards, which is fixed over an entire training run. + # Because the ordering of shards is fixed the downloaded shards from the first epoch + # can be persisted and used for subsequent epochs in each node as well. run_rng = np.random.default_rng(seed) run_rng.shuffle(spans) @@ -59,12 +61,16 @@ def get_shuffle_py1b(shard_sizes: NDArray[np.int64], # Populate the global sample ID mapping, shuffling within each block within each super-span. ids = np.empty(num_samples, np.int64) offset = 0 + # loop over each canonical node for super_begin, super_end in super_spans: + # super_offset is the offset of the first sample in the canonical node super_offset = offset + # loop over each span contained in the canonical node for begin, end in spans[super_begin:super_end]: span_size = end - begin ids[offset:offset + span_size] = np.arange(begin, end) offset += span_size + # shuffle within each block, but don't shuffle past the canonical node boundary for start in range(super_offset, offset, block_size): stop = min(start + block_size, offset) epoch_rng.shuffle(ids[start:stop]) diff --git a/streaming/base/shuffle/py1br.py b/streaming/base/shuffle/py1br.py new file mode 100644 index 000000000..04404cdd8 --- /dev/null +++ b/streaming/base/shuffle/py1br.py @@ -0,0 +1,91 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Shuffling algorithm that shuffles in fixed-size blocks. + +These units are presumably larger or much larger than single shards, leading to better shuffledness +at the cost of having to download more shards to make progress. +""" + +import numpy as np +from numpy.typing import NDArray + +from streaming.base.shuffle.py1s import divide_spans + + +def get_shuffle_py1br(shard_sizes: NDArray[np.int64], + num_canonical_nodes: int, + seed: int, + epoch: int, + block_size: int = 1 << 18) -> NDArray[np.int64]: + """Get the shuffled global ordering of samples for an epoch. + + The assignment of shards to nodes is fixed across epochs, but each grouping of shards is + processed concurrently in a different order by each node's workers each epoch. + Args: + shard_sizes (NDArray[np.int64]): Number of samples contained in each shard, in order. + num_canonical_nodes (int): Number of canonical nodes. + seed (int): Base random seed, which is held constant over an entire training run. + epoch (int): Current epoch, which is added to the seed to get a different deterministic + shuffle each epoch. + block_size (int): Unit of shuffle. For py1br shuffling method, the block size is chosen + uniformly at random in the range (0.75*block_size, 1.25*block_size). Defaults to ``1 << 18``. + + Returns: + NDArray[np.int64]: 1:1 mapping of sample ID to shuffled sample ID. + """ + # Create each shard's sample ID span (start, stop excl). + spans = [] + num_samples = 0 + for shard_size in shard_sizes: + span = num_samples, num_samples + shard_size + spans.append(span) + num_samples += shard_size + + # Generate the initial ordering of shards, which is fixed over an entire training run. + run_rng = np.random.default_rng(seed) + run_rng.shuffle(spans) + + # Break the shard spans at canonical node boundaries. + spans, node_spans = divide_spans(spans, num_samples, num_canonical_nodes) + + # Shuffle the span ordering within each canonical node uniquely to this epoch. + epoch_rng = np.random.default_rng(seed + epoch) + for node_start_span, node_stop_span in node_spans: + node_span = spans[node_start_span:node_stop_span] + epoch_rng.shuffle(node_span) # pyright: ignore + spans[node_start_span:node_stop_span] = node_span + + # Populate the global sample ID mapping, shuffling within each block within each node. + ids = np.empty(num_samples, np.int64) + node_stop_sample = 0 + stagger = epoch_rng.integers(0, int(0.75 * block_size), (num_canonical_nodes,)) + for node, (node_start_span, node_stop_span) in enumerate(node_spans): + node_start_sample = node_stop_sample + + # Populate sample IDs given the span ordering for this node. + for span_start_sample, span_stop_sample in spans[node_start_span:node_stop_span]: + span_size = span_stop_sample - span_start_sample + ids[node_stop_sample:node_stop_sample + span_size] = \ + np.arange(span_start_sample, span_stop_sample) + node_stop_sample += span_size + + # Get randomized and staggered block ranges for the current node. + block_staggered_ranges = [] + blocks_end = node_start_sample + node_stagger = stagger[node] + while blocks_end < node_stop_sample: + rand_block_size = epoch_rng.integers(int(0.75 * block_size), int(1.25 * block_size)) + # don't want the block to start before the first sample of the node + staggered_block_start = max(blocks_end - node_stagger, node_start_sample) + # don't want the block to stop after the last sample of the node + staggered_block_stop = min(blocks_end + rand_block_size - node_stagger, + node_stop_sample) + block_staggered_ranges.append((staggered_block_start, staggered_block_stop)) + blocks_end += staggered_block_stop - staggered_block_start + + # Shuffle within each staggered, randomized block. + for block_start, block_stop in block_staggered_ranges: + epoch_rng.shuffle(ids[block_start:block_stop]) + + return ids diff --git a/streaming/base/shuffle/py1s.py b/streaming/base/shuffle/py1s.py index 6a99584c7..b491c4a03 100644 --- a/streaming/base/shuffle/py1s.py +++ b/streaming/base/shuffle/py1s.py @@ -38,19 +38,28 @@ def divide_spans(spans: List[Tuple[int, int]], num_samples: int, num_parts: int) super_spans = [] for part in range(num_parts): + # note that the size of a part (canonical node) is num_samples // num_parts. part_end = num_samples * (part + 1) // num_parts + # loop over spans until we've filled up our part (canonical node) completely while True: if span_index == len(spans): break + # input spans are the shard spans. these can be unequally sized and may cross + # part (canonical node) boundaries. span = spans[span_index] + # spans are (begin, end excl) samples_this_span = span[1] - span[0] + # check if the shard span contains more samples than the part (canonical node) can fit if part_end < samples_so_far + samples_this_span: + # if there is space left in the part, split the span if samples_so_far < part_end: split = part_end - samples_so_far + # create a span, filling up with as many samples as possible from shard span new_span = span[0], span[0] + split out_spans.append(new_span) + # modify the old shard span to reflect that it's been split spans[span_index] = span[0] + split, span[1] samples_so_far += split break @@ -59,6 +68,8 @@ def divide_spans(spans: List[Tuple[int, int]], num_samples: int, num_parts: int) span_index += 1 samples_so_far += samples_this_span + # super spans are tell us which new spans belong to each part (canonical node) + # as a tuple of (begin span index, end span index excl) super_span = begin_part, len(out_spans) super_spans.append(super_span) begin_part = len(out_spans) diff --git a/tests/test_shuffle.py b/tests/test_shuffle.py index b9053a351..56e9a041b 100644 --- a/tests/test_shuffle.py +++ b/tests/test_shuffle.py @@ -5,17 +5,19 @@ import numpy as np -from streaming.base.shuffle import get_shuffle_py1s, get_shuffle_py2s +from streaming.base.shuffle import (get_shuffle_py1b, get_shuffle_py1br, get_shuffle_py1s, + get_shuffle_py2s) def check(get_shuffle: Callable) -> None: shard_sizes = 1 + np.arange(100) dataset_size = sum(shard_sizes) + block_size = 300 for num_canonical_nodes in [1, 2, 3]: for seed in [0, 1, 2]: lists = [] for epoch in [0, 1, 2]: - ids = get_shuffle(shard_sizes, num_canonical_nodes, seed, epoch) + ids = get_shuffle(shard_sizes, num_canonical_nodes, seed, epoch, block_size) assert sorted(ids) == list(range(len(ids))) parts = [] for i in range(num_canonical_nodes): @@ -30,6 +32,14 @@ def check(get_shuffle: Callable) -> None: assert parts[0] == parts[i] +def test_shuffle_py1b(): + check(get_shuffle_py1b) + + +def test_shuffle_py1br(): + check(get_shuffle_py1br) + + def test_shuffle_py1s(): check(get_shuffle_py1s)