From 230379d5ed83b1c6a57c8e150bd224855fb5f769 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 21 Aug 2023 14:45:51 -0700 Subject: [PATCH 1/6] py1g in progress --- streaming/base/shuffle/__init__.py | 3 +- streaming/base/shuffle/py1g.py | 125 +++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 streaming/base/shuffle/py1g.py diff --git a/streaming/base/shuffle/__init__.py b/streaming/base/shuffle/__init__.py index 286f34bd2..3a505567b 100644 --- a/streaming/base/shuffle/__init__.py +++ b/streaming/base/shuffle/__init__.py @@ -10,15 +10,16 @@ from streaming.base.shuffle.py1b import get_shuffle_py1b from streaming.base.shuffle.py1s import get_shuffle_py1s from streaming.base.shuffle.py2s import get_shuffle_py2s +from streaming.base.shuffle.py1g import get_shuffle_py1g algos = { 'py1b': get_shuffle_py1b, 'py1s': get_shuffle_py1s, 'py2s': get_shuffle_py2s, + 'py1g': get_shuffle_py1g, 'naive': get_shuffle_naive, } - def get_shuffle(algo: str, shard_sizes: NDArray[np.int64], num_canonical_nodes: int, diff --git a/streaming/base/shuffle/py1g.py b/streaming/base/shuffle/py1g.py new file mode 100644 index 000000000..4c4f98956 --- /dev/null +++ b/streaming/base/shuffle/py1g.py @@ -0,0 +1,125 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Shuffling algorithm that shuffles intra-shard in one place. + +This algorithm is roughly twice as fast as algorithm ``py2s``, and ever so slightly biased. + +Bias in this case merely refers to how we assign samples when we split shards at canonical node +boundaries, which is non-random in this algorithm. In practice, we found this does not matter to +convergence, while making us faster. +""" + +from typing import List, Tuple + +import numpy as np +from numpy.typing import NDArray + +from streaming.base.shuffle.py1s import divide_spans + + +def get_shuffle_py1g(shard_sizes: NDArray[np.int64], + num_canonical_nodes: int, + seed: int, + epoch: int, + block_size: int = 1 << 18) -> NDArray[np.int64]: + """Get the shuffled global ordering of samples for an epoch. + + The assignment of shards to nodes is fixed across epochs, but each grouping of shards is + processed concurrently in a different order by each node's workers each epoch. + + Args: + shard_sizes (NDArray[np.int64]): Number of samples contained in each shard, in order. + num_canonical_nodes (int): Number of canonical nodes. + seed (int): Base random seed, which is held constant over an entire training run. + epoch (int): Current epoch, which is added to the seed to get a different deterministic + shuffle each epoch. + block_size (int): Unit of shuffle, used to set the std and clip length for the gaussian + noise to be added to each shard. Defaults to ``1 << 18``. + + Returns: + NDArray[np.int64]: 1:1 mapping of sample ID to shuffled sample ID. + """ + # Create each shard's sample ID span (begin, end excl). + # also get max shard size to calculate the + spans = [] + num_samples = 0 + max_shard_size = 0 + for shard_size in shard_sizes: + span = num_samples, num_samples + shard_size + spans.append(span) + num_samples += shard_size + if shard_size > max_shard_size: + max_shard_size = shard_size + + # Generate the initial ordering of shards, which is fixed over an entire training run. + run_rng = np.random.default_rng(seed) + run_rng.shuffle(spans) + + # Break the shard spans at canonical node boundaries. + # super_spans are the indices of spans that correspond to each canonical node + spans, super_spans = divide_spans(spans, num_samples, num_canonical_nodes) + + # Shuffle the span ordering within each canonical node uniquely to this epoch. + epoch_rng = np.random.default_rng(seed + epoch) + for begin, end in super_spans: + # retrieving the spans (shard parts) associated with this canonical node + part = spans[begin:end] + epoch_rng.shuffle(part) # pyright: ignore + spans[begin:end] = part + + # Populate the global sample ID mapping, shuffling within each span. + ids = np.empty(num_samples, np.int64) + offset = 0 + # iterate through each canonical node's spans because we don't want samples crossing canonical node boundaries + for cn_begin, cn_end in super_spans: + cn_spans = spans[cn_begin:cn_end] + cn_span_sizes = np.array([end - begin for begin, end in cn_spans]) + num_cn_samples = cn_span_sizes.sum() + # the spans of a canonical node are shuffled, so they have sample ids that are + # not contiguous. need to get the correct sample ids for the current canonical node + cn_samples = np.empty(num_cn_samples, np.int64) + samples_inserted = 0 + for begin, end in cn_spans: + # insert span samples into cn_samples array + cn_span_samples = np.arange(begin, end) + epoch_rng.shuffle(cn_span_samples) + cn_samples[samples_inserted:samples_inserted + (end - begin)] = cn_span_samples + + cn_sample_idxs = np.arange(num_cn_samples) + + # iterate over each span and shift sample indices by gaussian noise + cn_sample_offset = 0 + shifted_samples = cn_sample_idxs.copy().astype(np.float64) + for span_size in cn_span_sizes: + + span_std = (block_size * span_size) / (max_shard_size * 3) + # ~0.3% of samples will be clipped and stay where they are + cutoff = 3*span_std + + # sample shifts from gaussian + shifts = epoch_rng.normal(loc=0, scale=span_std, size=span_size) + # if shift is greater than cutoff (outlier), set shift to 0 + shifts = shifts * (np.absolute(shifts) < cutoff) + + # add shifts to shard samples + shifted_samples[cn_sample_offset:cn_sample_offset+span_size] += shifts + + # update offset for next shard + cn_sample_offset += span_size + + # get incides that would sort the shifted_samples array + sort_indices = np.argsort(shifted_samples) + + # apply the sorting to the canonical node sample indices + cn_sample_idxs = cn_sample_idxs[sort_indices] + + # use the shuffled indices to get the shuffled sample ids for the canonical node + cn_samples = cn_samples[cn_sample_idxs] + + # assign the gaussian "shuffled" samples to the global ids array + ids[offset:offset + num_cn_samples] = cn_samples + + offset += num_cn_samples + + return ids From f09986ac99cf61ddadf42459b5b65cbec7b0bd1d Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 21 Aug 2023 18:01:40 -0700 Subject: [PATCH 2/6] py1e algo implementation --- streaming/base/shuffle/__init__.py | 4 ++-- streaming/base/shuffle/{py1g.py => py1e.py} | 20 +++++++++----------- 2 files changed, 11 insertions(+), 13 deletions(-) rename streaming/base/shuffle/{py1g.py => py1e.py} (89%) diff --git a/streaming/base/shuffle/__init__.py b/streaming/base/shuffle/__init__.py index 3a505567b..48a794dc0 100644 --- a/streaming/base/shuffle/__init__.py +++ b/streaming/base/shuffle/__init__.py @@ -10,13 +10,13 @@ from streaming.base.shuffle.py1b import get_shuffle_py1b from streaming.base.shuffle.py1s import get_shuffle_py1s from streaming.base.shuffle.py2s import get_shuffle_py2s -from streaming.base.shuffle.py1g import get_shuffle_py1g +from streaming.base.shuffle.py1e import get_shuffle_py1e algos = { 'py1b': get_shuffle_py1b, 'py1s': get_shuffle_py1s, 'py2s': get_shuffle_py2s, - 'py1g': get_shuffle_py1g, + 'py1g': get_shuffle_py1e, 'naive': get_shuffle_naive, } diff --git a/streaming/base/shuffle/py1g.py b/streaming/base/shuffle/py1e.py similarity index 89% rename from streaming/base/shuffle/py1g.py rename to streaming/base/shuffle/py1e.py index 4c4f98956..d04997b86 100644 --- a/streaming/base/shuffle/py1g.py +++ b/streaming/base/shuffle/py1e.py @@ -10,15 +10,13 @@ convergence, while making us faster. """ -from typing import List, Tuple - import numpy as np from numpy.typing import NDArray from streaming.base.shuffle.py1s import divide_spans -def get_shuffle_py1g(shard_sizes: NDArray[np.int64], +def get_shuffle_py1e(shard_sizes: NDArray[np.int64], num_canonical_nodes: int, seed: int, epoch: int, @@ -78,13 +76,14 @@ def get_shuffle_py1g(shard_sizes: NDArray[np.int64], num_cn_samples = cn_span_sizes.sum() # the spans of a canonical node are shuffled, so they have sample ids that are # not contiguous. need to get the correct sample ids for the current canonical node - cn_samples = np.empty(num_cn_samples, np.int64) + cn_samples = np.empty(num_cn_samples) samples_inserted = 0 for begin, end in cn_spans: # insert span samples into cn_samples array cn_span_samples = np.arange(begin, end) epoch_rng.shuffle(cn_span_samples) cn_samples[samples_inserted:samples_inserted + (end - begin)] = cn_span_samples + samples_inserted += (end - begin) cn_sample_idxs = np.arange(num_cn_samples) @@ -93,14 +92,13 @@ def get_shuffle_py1g(shard_sizes: NDArray[np.int64], shifted_samples = cn_sample_idxs.copy().astype(np.float64) for span_size in cn_span_sizes: - span_std = (block_size * span_size) / (max_shard_size * 3) - # ~0.3% of samples will be clipped and stay where they are - cutoff = 3*span_std + # cutoff is (block_size - span_size)/2, so the span samples + # are only found in a range of size block_size + cutoff = (block_size - span_size)/2 - # sample shifts from gaussian - shifts = epoch_rng.normal(loc=0, scale=span_std, size=span_size) - # if shift is greater than cutoff (outlier), set shift to 0 - shifts = shifts * (np.absolute(shifts) < cutoff) + # sample shifts from uniform distribution + #shifts = epoch_rng.normal(loc=0, scale=span_std, size=span_size) + shifts = epoch_rng.uniform(low=-cutoff, high=cutoff, size=span_size) # add shifts to shard samples shifted_samples[cn_sample_offset:cn_sample_offset+span_size] += shifts From dc1dac5788c2de867296050023fac382296f2f7a Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 21 Aug 2023 18:03:01 -0700 Subject: [PATCH 3/6] py1e algo implementation --- streaming/base/shuffle/__init__.py | 3 ++- streaming/base/shuffle/py1e.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/streaming/base/shuffle/__init__.py b/streaming/base/shuffle/__init__.py index 48a794dc0..9385219ca 100644 --- a/streaming/base/shuffle/__init__.py +++ b/streaming/base/shuffle/__init__.py @@ -8,9 +8,9 @@ from streaming.base.shuffle.naive import get_shuffle_naive from streaming.base.shuffle.py1b import get_shuffle_py1b +from streaming.base.shuffle.py1e import get_shuffle_py1e from streaming.base.shuffle.py1s import get_shuffle_py1s from streaming.base.shuffle.py2s import get_shuffle_py2s -from streaming.base.shuffle.py1e import get_shuffle_py1e algos = { 'py1b': get_shuffle_py1b, @@ -20,6 +20,7 @@ 'naive': get_shuffle_naive, } + def get_shuffle(algo: str, shard_sizes: NDArray[np.int64], num_canonical_nodes: int, diff --git a/streaming/base/shuffle/py1e.py b/streaming/base/shuffle/py1e.py index d04997b86..d2efdb306 100644 --- a/streaming/base/shuffle/py1e.py +++ b/streaming/base/shuffle/py1e.py @@ -74,7 +74,7 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64], cn_spans = spans[cn_begin:cn_end] cn_span_sizes = np.array([end - begin for begin, end in cn_spans]) num_cn_samples = cn_span_sizes.sum() - # the spans of a canonical node are shuffled, so they have sample ids that are + # the spans of a canonical node are shuffled, so they have sample ids that are # not contiguous. need to get the correct sample ids for the current canonical node cn_samples = np.empty(num_cn_samples) samples_inserted = 0 @@ -94,18 +94,18 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64], # cutoff is (block_size - span_size)/2, so the span samples # are only found in a range of size block_size - cutoff = (block_size - span_size)/2 + cutoff = (block_size - span_size) / 2 # sample shifts from uniform distribution #shifts = epoch_rng.normal(loc=0, scale=span_std, size=span_size) shifts = epoch_rng.uniform(low=-cutoff, high=cutoff, size=span_size) # add shifts to shard samples - shifted_samples[cn_sample_offset:cn_sample_offset+span_size] += shifts + shifted_samples[cn_sample_offset:cn_sample_offset + span_size] += shifts # update offset for next shard cn_sample_offset += span_size - + # get incides that would sort the shifted_samples array sort_indices = np.argsort(shifted_samples) From 2b6f79b59cccc84df47a0da060fe5fc69505204f Mon Sep 17 00:00:00 2001 From: Saaketh Date: Tue, 22 Aug 2023 23:15:31 -0700 Subject: [PATCH 4/6] added py1e algorithm -- extended range --- docs/source/fundamentals/sampling.md | 4 ++-- docs/source/fundamentals/shuffling.md | 8 ++++++++ streaming/base/shuffle/__init__.py | 2 +- streaming/base/shuffle/py1e.py | 20 +++++++++----------- tests/test_shuffle.py | 18 +++++++++++++++++- 5 files changed, 37 insertions(+), 15 deletions(-) diff --git a/docs/source/fundamentals/sampling.md b/docs/source/fundamentals/sampling.md index 18d2dae8d..f8b022b1f 100644 --- a/docs/source/fundamentals/sampling.md +++ b/docs/source/fundamentals/sampling.md @@ -2,5 +2,5 @@ You can choose how sampling from your dataset(s) occurs between epochs by specifying the `sampling_method` when instantiating `StreamingDataset`. Currently, this can take on one of two values: --`'balanced'`: (default) Samples are chosen at random from dataset(s) during each epoch according to the proportions specified. --`'fixed`: The same samples from the dataset(s) are chosen during every epoch, still according to the proportions specified. +- `'balanced'`: (default) Samples are chosen at random from dataset(s) during each epoch according to the proportions specified. +- `'fixed`: The same samples from the dataset(s) are chosen during every epoch, still according to the proportions specified. diff --git a/docs/source/fundamentals/shuffling.md b/docs/source/fundamentals/shuffling.md index adc5e67a5..6a337cb61 100644 --- a/docs/source/fundamentals/shuffling.md +++ b/docs/source/fundamentals/shuffling.md @@ -40,6 +40,14 @@ Shuffle block size should be set larger or much larger than a single shard. If s This algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard. If you see step-like burstiness in throughput, your workers may not be downloading far enough ahead – try raising predownload (it should be scaled with block size). Step size scales with shuffle block size. +### py1e + +Globally shuffle shards, divide that sample space over canonical nodes, shuffle the samples in each shard, then randomly distribute the samples from each shard over a larger range (given by `shuffle_block_size`). So named because it shuffles samples by extending the range of each shard, in python. + +Shuffle block size should be set larger or much larger than the number of samples in a single shard. This algorithm provides guaranteed bounds on the range that samples from a shard can appear, allowing for a lower cache limit without decreasing throughput compared to py1b. + +This algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard, similar to py1b. However, these shards will be downloaded in a more balanced fashion, reducing network bandwidth bottlenecks. + ### py1s Globally shuffle shards, divide that sample space over canonical nodes, then shuffle the samples within each shard or shard part. So named because it shuffles samples in python, once, intra-shard. diff --git a/streaming/base/shuffle/__init__.py b/streaming/base/shuffle/__init__.py index 9385219ca..4515f5071 100644 --- a/streaming/base/shuffle/__init__.py +++ b/streaming/base/shuffle/__init__.py @@ -14,9 +14,9 @@ algos = { 'py1b': get_shuffle_py1b, + 'py1e': get_shuffle_py1e, 'py1s': get_shuffle_py1s, 'py2s': get_shuffle_py2s, - 'py1g': get_shuffle_py1e, 'naive': get_shuffle_naive, } diff --git a/streaming/base/shuffle/py1e.py b/streaming/base/shuffle/py1e.py index d2efdb306..05545c570 100644 --- a/streaming/base/shuffle/py1e.py +++ b/streaming/base/shuffle/py1e.py @@ -85,20 +85,21 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64], cn_samples[samples_inserted:samples_inserted + (end - begin)] = cn_span_samples samples_inserted += (end - begin) - cn_sample_idxs = np.arange(num_cn_samples) - # iterate over each span and shift sample indices by gaussian noise cn_sample_offset = 0 - shifted_samples = cn_sample_idxs.copy().astype(np.float64) + shifted_samples = np.arange(num_cn_samples).astype(np.float64) for span_size in cn_span_sizes: # cutoff is (block_size - span_size)/2, so the span samples - # are only found in a range of size block_size + # are only found in a range of maximum possible size block_size cutoff = (block_size - span_size) / 2 + # make sure the lower bound doesn't cross the start of the canonical node + lower_bound = max(-cutoff, -cn_sample_offset) + # make sure the upper bound doesn't cross the end of the canonical node + upper_bound = min(cutoff, num_cn_samples - cn_sample_offset - span_size) # sample shifts from uniform distribution - #shifts = epoch_rng.normal(loc=0, scale=span_std, size=span_size) - shifts = epoch_rng.uniform(low=-cutoff, high=cutoff, size=span_size) + shifts = epoch_rng.uniform(low=lower_bound, high=upper_bound, size=span_size) # add shifts to shard samples shifted_samples[cn_sample_offset:cn_sample_offset + span_size] += shifts @@ -109,11 +110,8 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64], # get incides that would sort the shifted_samples array sort_indices = np.argsort(shifted_samples) - # apply the sorting to the canonical node sample indices - cn_sample_idxs = cn_sample_idxs[sort_indices] - - # use the shuffled indices to get the shuffled sample ids for the canonical node - cn_samples = cn_samples[cn_sample_idxs] + # apply the sorting to the samples for our canonical node + cn_samples = cn_samples[sort_indices] # assign the gaussian "shuffled" samples to the global ids array ids[offset:offset + num_cn_samples] = cn_samples diff --git a/tests/test_shuffle.py b/tests/test_shuffle.py index b9053a351..cce629260 100644 --- a/tests/test_shuffle.py +++ b/tests/test_shuffle.py @@ -5,7 +5,8 @@ import numpy as np -from streaming.base.shuffle import get_shuffle_py1s, get_shuffle_py2s +from streaming.base.shuffle import (get_shuffle_py1b, get_shuffle_py1e, get_shuffle_py1s, + get_shuffle_py2s) def check(get_shuffle: Callable) -> None: @@ -13,23 +14,38 @@ def check(get_shuffle: Callable) -> None: dataset_size = sum(shard_sizes) for num_canonical_nodes in [1, 2, 3]: for seed in [0, 1, 2]: + # lists is the list of sorted ids seen by every canonical node in every epoch + # for example: [[epoch0_CN_a, epoch0_CN_b], [epoch1_CN_a, epoch1_CN_b], [epoch2_CN_a, epoch2_CN_b]]] lists = [] for epoch in [0, 1, 2]: ids = get_shuffle(shard_sizes, num_canonical_nodes, seed, epoch) assert sorted(ids) == list(range(len(ids))) + # parts is a list of the sorted ids seen by each canonical node in a particular epoch parts = [] for i in range(num_canonical_nodes): begin = dataset_size * i // num_canonical_nodes end = dataset_size * (i + 1) // num_canonical_nodes + # get the section of ids corresponding to this canonical node part = ids[begin:end] parts.append(sorted(part)) lists.append(parts) + # want to make sure the sample ids seen by each canonical node in each epoch is the same lists = list(zip(*lists)) + # each element of lists is now a tuple containing the lists of samples seen by a canonical node over all the epochs for parts in lists: + # make sure all other epochs are the same as epoch 0. for i in range(1, len(parts)): assert parts[0] == parts[i] +def test_shuffle_py1b(): + check(get_shuffle_py1b) + + +def test_shuffle_py1e(): + check(get_shuffle_py1e) + + def test_shuffle_py1s(): check(get_shuffle_py1s) From d580d0fec2e4afeb0edd0c023d5cf4f27df0bb86 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Tue, 29 Aug 2023 11:26:16 -0700 Subject: [PATCH 5/6] merged with main --- tests/test_shuffle.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_shuffle.py b/tests/test_shuffle.py index 954f62378..a3d70220a 100644 --- a/tests/test_shuffle.py +++ b/tests/test_shuffle.py @@ -5,7 +5,8 @@ import numpy as np -from streaming.base.shuffle import get_shuffle_py1b, get_shuffle_py1br, get_shuffle_py1e, get_shuffle_py1s, get_shuffle_py2s +from streaming.base.shuffle import (get_shuffle_py1b, get_shuffle_py1br, get_shuffle_py1e, + get_shuffle_py1s, get_shuffle_py2s) def check(get_shuffle: Callable) -> None: @@ -41,14 +42,18 @@ def check(get_shuffle: Callable) -> None: def test_shuffle_py1b(): check(get_shuffle_py1b) + def test_shuffle_py1br(): check(get_shuffle_py1br) + def test_shuffle_py1e(): check(get_shuffle_py1e) + def test_shuffle_py1s(): check(get_shuffle_py1s) + def test_shuffle_py2s(): check(get_shuffle_py2s) From d4220041a996b2a5e54fd767470c2a4b33abf954 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Tue, 29 Aug 2023 16:49:54 -0700 Subject: [PATCH 6/6] fixed comments, nits, etc --- docs/source/fundamentals/shuffling.md | 2 +- streaming/base/shuffle/py1b.py | 8 ++-- streaming/base/shuffle/py1br.py | 4 +- streaming/base/shuffle/py1e.py | 55 ++++++++++++--------------- 4 files changed, 31 insertions(+), 38 deletions(-) diff --git a/docs/source/fundamentals/shuffling.md b/docs/source/fundamentals/shuffling.md index db70f31a2..f64ce99ff 100644 --- a/docs/source/fundamentals/shuffling.md +++ b/docs/source/fundamentals/shuffling.md @@ -50,7 +50,7 @@ In order to improve shuffle quality, this algorithm requires more shards to be d ### py1e -Globally shuffle shards, divide that sample space over canonical nodes, shuffle the samples in each shard, then randomly distribute the samples from each shard over a larger range (given by `shuffle_block_size`). So named because it shuffles samples by extending the range of each shard, in python. +Globally shuffle shards, divide that sample space over canonical nodes, shuffle the samples in each shard, then randomly distribute the samples from each shard over an expanded range (given by `shuffle_block_size`). So named because it shuffles samples by extending the range of each shard, in python. Shuffle block size should be set larger or much larger than the number of samples in a single shard. This algorithm provides guaranteed bounds on the range that samples from a shard can appear, allowing for a lower cache limit without decreasing throughput compared to py1b. diff --git a/streaming/base/shuffle/py1b.py b/streaming/base/shuffle/py1b.py index 606e28548..bb59f0c73 100644 --- a/streaming/base/shuffle/py1b.py +++ b/streaming/base/shuffle/py1b.py @@ -61,16 +61,16 @@ def get_shuffle_py1b(shard_sizes: NDArray[np.int64], # Populate the global sample ID mapping, shuffling within each block within each super-span. ids = np.empty(num_samples, np.int64) offset = 0 - # loop over each canonical node + # Loop over each canonical node. for super_begin, super_end in super_spans: - # super_offset is the offset of the first sample in the canonical node + # The super_offset is the offset of the first sample in the canonical node. super_offset = offset - # loop over each span contained in the canonical node + # Loop over each span contained in the canonical node. for begin, end in spans[super_begin:super_end]: span_size = end - begin ids[offset:offset + span_size] = np.arange(begin, end) offset += span_size - # shuffle within each block, but don't shuffle past the canonical node boundary + # Shuffle within each block, but don't shuffle past the canonical node boundary. for start in range(super_offset, offset, block_size): stop = min(start + block_size, offset) epoch_rng.shuffle(ids[start:stop]) diff --git a/streaming/base/shuffle/py1br.py b/streaming/base/shuffle/py1br.py index 04404cdd8..c301433a9 100644 --- a/streaming/base/shuffle/py1br.py +++ b/streaming/base/shuffle/py1br.py @@ -76,9 +76,9 @@ def get_shuffle_py1br(shard_sizes: NDArray[np.int64], node_stagger = stagger[node] while blocks_end < node_stop_sample: rand_block_size = epoch_rng.integers(int(0.75 * block_size), int(1.25 * block_size)) - # don't want the block to start before the first sample of the node + # We don't want the block to start before the first sample of the node. staggered_block_start = max(blocks_end - node_stagger, node_start_sample) - # don't want the block to stop after the last sample of the node + # We don't want the block to stop after the last sample of the node. staggered_block_stop = min(blocks_end + rand_block_size - node_stagger, node_stop_sample) block_staggered_ranges.append((staggered_block_start, staggered_block_stop)) diff --git a/streaming/base/shuffle/py1e.py b/streaming/base/shuffle/py1e.py index 05545c570..5f435fdd5 100644 --- a/streaming/base/shuffle/py1e.py +++ b/streaming/base/shuffle/py1e.py @@ -1,13 +1,11 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Shuffling algorithm that shuffles intra-shard in one place. +"""Shuffling algorithm that shuffles by randomly placing shard samples in expanded ranges. -This algorithm is roughly twice as fast as algorithm ``py2s``, and ever so slightly biased. - -Bias in this case merely refers to how we assign samples when we split shards at canonical node -boundaries, which is non-random in this algorithm. In practice, we found this does not matter to -convergence, while making us faster. +This algorithm has more balanced downloading and a lower minimum cache limit than ``py1b`` and +``py1br``, but also slightly lower shuffle quality. The maximum range the samples from each shard +can cover is determined by ``shuffle_block_size``. """ import numpy as np @@ -25,7 +23,6 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64], The assignment of shards to nodes is fixed across epochs, but each grouping of shards is processed concurrently in a different order by each node's workers each epoch. - Args: shard_sizes (NDArray[np.int64]): Number of samples contained in each shard, in order. num_canonical_nodes (int): Number of canonical nodes. @@ -39,29 +36,25 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64], NDArray[np.int64]: 1:1 mapping of sample ID to shuffled sample ID. """ # Create each shard's sample ID span (begin, end excl). - # also get max shard size to calculate the spans = [] num_samples = 0 - max_shard_size = 0 for shard_size in shard_sizes: span = num_samples, num_samples + shard_size spans.append(span) num_samples += shard_size - if shard_size > max_shard_size: - max_shard_size = shard_size # Generate the initial ordering of shards, which is fixed over an entire training run. run_rng = np.random.default_rng(seed) run_rng.shuffle(spans) # Break the shard spans at canonical node boundaries. - # super_spans are the indices of spans that correspond to each canonical node + # The super_spans are the indices of spans that correspond to each canonical node. spans, super_spans = divide_spans(spans, num_samples, num_canonical_nodes) # Shuffle the span ordering within each canonical node uniquely to this epoch. epoch_rng = np.random.default_rng(seed + epoch) for begin, end in super_spans: - # retrieving the spans (shard parts) associated with this canonical node + # Retrieve the spans (shard parts) associated with this canonical node. part = spans[begin:end] epoch_rng.shuffle(part) # pyright: ignore spans[begin:end] = part @@ -69,51 +62,51 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64], # Populate the global sample ID mapping, shuffling within each span. ids = np.empty(num_samples, np.int64) offset = 0 - # iterate through each canonical node's spans because we don't want samples crossing canonical node boundaries + # Iterate through each canonical node's spans because we don't want samples crossing canonical node boundaries for cn_begin, cn_end in super_spans: cn_spans = spans[cn_begin:cn_end] cn_span_sizes = np.array([end - begin for begin, end in cn_spans]) num_cn_samples = cn_span_sizes.sum() - # the spans of a canonical node are shuffled, so they have sample ids that are - # not contiguous. need to get the correct sample ids for the current canonical node + # The spans of a canonical node are shuffled, so they have sample ids that are + # not contiguous. We need to get the correct sample ids for the current canonical node. cn_samples = np.empty(num_cn_samples) samples_inserted = 0 for begin, end in cn_spans: - # insert span samples into cn_samples array + # Inserting span samples into cn_samples array. cn_span_samples = np.arange(begin, end) epoch_rng.shuffle(cn_span_samples) cn_samples[samples_inserted:samples_inserted + (end - begin)] = cn_span_samples samples_inserted += (end - begin) - # iterate over each span and shift sample indices by gaussian noise + # Iterate over each span and shift sample indices by randomly sampled shifts from uniform distribution. cn_sample_offset = 0 - shifted_samples = np.arange(num_cn_samples).astype(np.float64) + sample_positions = np.arange(num_cn_samples).astype(np.float64) for span_size in cn_span_sizes: - # cutoff is (block_size - span_size)/2, so the span samples - # are only found in a range of maximum possible size block_size + # The maximum range on each side of the span is (block_size - span_size) / 2. + # This ensures that the span samples are only found in a range of max possible size block_size. cutoff = (block_size - span_size) / 2 - # make sure the lower bound doesn't cross the start of the canonical node + # Make sure the lower bound of the range doesn't cross the start of the canonical node. lower_bound = max(-cutoff, -cn_sample_offset) - # make sure the upper bound doesn't cross the end of the canonical node + # Make sure the upper bound of the range doesn't cross the end of the canonical node. upper_bound = min(cutoff, num_cn_samples - cn_sample_offset - span_size) - # sample shifts from uniform distribution + # Sample shifts from a uniform distribution with the bounds calculated above. shifts = epoch_rng.uniform(low=lower_bound, high=upper_bound, size=span_size) - # add shifts to shard samples - shifted_samples[cn_sample_offset:cn_sample_offset + span_size] += shifts + # Add shifts to shard sample indices. + sample_positions[cn_sample_offset:cn_sample_offset + span_size] += shifts - # update offset for next shard + # Update sample offset for the next shard. cn_sample_offset += span_size - # get incides that would sort the shifted_samples array - sort_indices = np.argsort(shifted_samples) + # Get incides that would sort the sample_positions array. + sort_indices = np.argsort(sample_positions) - # apply the sorting to the samples for our canonical node + # Apply the sorting to the samples for our canonical node. cn_samples = cn_samples[sort_indices] - # assign the gaussian "shuffled" samples to the global ids array + # Assign the newly shuffled samples to the global ids array. ids[offset:offset + num_cn_samples] = cn_samples offset += num_cn_samples