Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expanded range shuffle #394

Merged
merged 8 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/fundamentals/shuffling.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ Shuffle block size should be set larger or much larger than a single shard. If s

In order to improve shuffle quality, this algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard. However, shard downloads with py1br are more balanced than with py1b, and this effect is more apparent when training with a higher number of nodes, resulting in less network bottlenecks. The shuffle quality of py1br and py1b are similar.

### py1e

Globally shuffle shards, divide that sample space over canonical nodes, shuffle the samples in each shard, then randomly distribute the samples from each shard over an expanded range (given by `shuffle_block_size`). So named because it shuffles samples by extending the range of each shard, in python.

Shuffle block size should be set larger or much larger than the number of samples in a single shard. This algorithm provides guaranteed bounds on the range that samples from a shard can appear, allowing for a lower cache limit without decreasing throughput compared to py1b.

This algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard, similar to py1b. However, these shards will be downloaded in a more balanced fashion, reducing network bandwidth bottlenecks.

### py1s

Globally shuffle shards, divide that sample space over canonical nodes, then shuffle the samples within each shard or shard part. So named because it shuffles samples in python, once, intra-shard.
Expand Down
2 changes: 2 additions & 0 deletions streaming/base/shuffle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from streaming.base.shuffle.naive import get_shuffle_naive
from streaming.base.shuffle.py1b import get_shuffle_py1b
from streaming.base.shuffle.py1br import get_shuffle_py1br
from streaming.base.shuffle.py1e import get_shuffle_py1e
from streaming.base.shuffle.py1s import get_shuffle_py1s
from streaming.base.shuffle.py2s import get_shuffle_py2s

algos = {
'py1b': get_shuffle_py1b,
'py1br': get_shuffle_py1br,
'py1e': get_shuffle_py1e,
'py1s': get_shuffle_py1s,
'py2s': get_shuffle_py2s,
'naive': get_shuffle_naive,
Expand Down
8 changes: 4 additions & 4 deletions streaming/base/shuffle/py1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ def get_shuffle_py1b(shard_sizes: NDArray[np.int64],
# Populate the global sample ID mapping, shuffling within each block within each super-span.
ids = np.empty(num_samples, np.int64)
offset = 0
# loop over each canonical node
# Loop over each canonical node.
for super_begin, super_end in super_spans:
# super_offset is the offset of the first sample in the canonical node
# The super_offset is the offset of the first sample in the canonical node.
super_offset = offset
# loop over each span contained in the canonical node
# Loop over each span contained in the canonical node.
for begin, end in spans[super_begin:super_end]:
span_size = end - begin
ids[offset:offset + span_size] = np.arange(begin, end)
offset += span_size
# shuffle within each block, but don't shuffle past the canonical node boundary
# Shuffle within each block, but don't shuffle past the canonical node boundary.
for start in range(super_offset, offset, block_size):
stop = min(start + block_size, offset)
epoch_rng.shuffle(ids[start:stop])
Expand Down
4 changes: 2 additions & 2 deletions streaming/base/shuffle/py1br.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def get_shuffle_py1br(shard_sizes: NDArray[np.int64],
node_stagger = stagger[node]
while blocks_end < node_stop_sample:
rand_block_size = epoch_rng.integers(int(0.75 * block_size), int(1.25 * block_size))
# don't want the block to start before the first sample of the node
# We don't want the block to start before the first sample of the node.
staggered_block_start = max(blocks_end - node_stagger, node_start_sample)
# don't want the block to stop after the last sample of the node
# We don't want the block to stop after the last sample of the node.
staggered_block_stop = min(blocks_end + rand_block_size - node_stagger,
node_stop_sample)
block_staggered_ranges.append((staggered_block_start, staggered_block_stop))
Expand Down
114 changes: 114 additions & 0 deletions streaming/base/shuffle/py1e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Shuffling algorithm that shuffles by randomly placing shard samples in expanded ranges.

This algorithm has more balanced downloading and a lower minimum cache limit than ``py1b`` and
``py1br``, but also slightly lower shuffle quality. The maximum range the samples from each shard
can cover is determined by ``shuffle_block_size``.
"""

import numpy as np
from numpy.typing import NDArray

from streaming.base.shuffle.py1s import divide_spans
karan6181 marked this conversation as resolved.
Show resolved Hide resolved


def get_shuffle_py1e(shard_sizes: NDArray[np.int64],
num_canonical_nodes: int,
seed: int,
epoch: int,
block_size: int = 1 << 18) -> NDArray[np.int64]:
"""Get the shuffled global ordering of samples for an epoch.

The assignment of shards to nodes is fixed across epochs, but each grouping of shards is
processed concurrently in a different order by each node's workers each epoch.
Args:
shard_sizes (NDArray[np.int64]): Number of samples contained in each shard, in order.
num_canonical_nodes (int): Number of canonical nodes.
seed (int): Base random seed, which is held constant over an entire training run.
epoch (int): Current epoch, which is added to the seed to get a different deterministic
shuffle each epoch.
block_size (int): Unit of shuffle, used to set the std and clip length for the gaussian
noise to be added to each shard. Defaults to ``1 << 18``.

Returns:
NDArray[np.int64]: 1:1 mapping of sample ID to shuffled sample ID.
"""
# Create each shard's sample ID span (begin, end excl).
spans = []
num_samples = 0
for shard_size in shard_sizes:
span = num_samples, num_samples + shard_size
spans.append(span)
num_samples += shard_size

# Generate the initial ordering of shards, which is fixed over an entire training run.
run_rng = np.random.default_rng(seed)
run_rng.shuffle(spans)

# Break the shard spans at canonical node boundaries.
# The super_spans are the indices of spans that correspond to each canonical node.
spans, super_spans = divide_spans(spans, num_samples, num_canonical_nodes)

# Shuffle the span ordering within each canonical node uniquely to this epoch.
epoch_rng = np.random.default_rng(seed + epoch)
for begin, end in super_spans:
# Retrieve the spans (shard parts) associated with this canonical node.
part = spans[begin:end]
epoch_rng.shuffle(part) # pyright: ignore
spans[begin:end] = part

# Populate the global sample ID mapping, shuffling within each span.
ids = np.empty(num_samples, np.int64)
offset = 0
# Iterate through each canonical node's spans because we don't want samples crossing canonical node boundaries
for cn_begin, cn_end in super_spans:
cn_spans = spans[cn_begin:cn_end]
cn_span_sizes = np.array([end - begin for begin, end in cn_spans])
num_cn_samples = cn_span_sizes.sum()
# The spans of a canonical node are shuffled, so they have sample ids that are
# not contiguous. We need to get the correct sample ids for the current canonical node.
cn_samples = np.empty(num_cn_samples)
samples_inserted = 0
for begin, end in cn_spans:
# Inserting span samples into cn_samples array.
cn_span_samples = np.arange(begin, end)
epoch_rng.shuffle(cn_span_samples)
cn_samples[samples_inserted:samples_inserted + (end - begin)] = cn_span_samples
samples_inserted += (end - begin)

# Iterate over each span and shift sample indices by randomly sampled shifts from uniform distribution.
cn_sample_offset = 0
sample_positions = np.arange(num_cn_samples).astype(np.float64)
for span_size in cn_span_sizes:

# The maximum range on each side of the span is (block_size - span_size) / 2.
# This ensures that the span samples are only found in a range of max possible size block_size.
cutoff = (block_size - span_size) / 2

# Make sure the lower bound of the range doesn't cross the start of the canonical node.
lower_bound = max(-cutoff, -cn_sample_offset)
# Make sure the upper bound of the range doesn't cross the end of the canonical node.
upper_bound = min(cutoff, num_cn_samples - cn_sample_offset - span_size)
# Sample shifts from a uniform distribution with the bounds calculated above.
shifts = epoch_rng.uniform(low=lower_bound, high=upper_bound, size=span_size)

# Add shifts to shard sample indices.
sample_positions[cn_sample_offset:cn_sample_offset + span_size] += shifts

# Update sample offset for the next shard.
cn_sample_offset += span_size

# Get incides that would sort the sample_positions array.
sort_indices = np.argsort(sample_positions)

# Apply the sorting to the samples for our canonical node.
cn_samples = cn_samples[sort_indices]

# Assign the newly shuffled samples to the global ids array.
ids[offset:offset + num_cn_samples] = cn_samples

offset += num_cn_samples

return ids
15 changes: 13 additions & 2 deletions tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import numpy as np

from streaming.base.shuffle import (get_shuffle_py1b, get_shuffle_py1br, get_shuffle_py1s,
get_shuffle_py2s)
from streaming.base.shuffle import (get_shuffle_py1b, get_shuffle_py1br, get_shuffle_py1e,
get_shuffle_py1s, get_shuffle_py2s)


def check(get_shuffle: Callable) -> None:
Expand All @@ -15,19 +15,26 @@ def check(get_shuffle: Callable) -> None:
block_size = 300
for num_canonical_nodes in [1, 2, 3]:
for seed in [0, 1, 2]:
# lists is the list of sorted ids seen by every canonical node in every epoch
# for example: [[epoch0_CN_a, epoch0_CN_b], [epoch1_CN_a, epoch1_CN_b], [epoch2_CN_a, epoch2_CN_b]]]
lists = []
for epoch in [0, 1, 2]:
ids = get_shuffle(shard_sizes, num_canonical_nodes, seed, epoch, block_size)
assert sorted(ids) == list(range(len(ids)))
# parts is a list of the sorted ids seen by each canonical node in a particular epoch
parts = []
for i in range(num_canonical_nodes):
begin = dataset_size * i // num_canonical_nodes
end = dataset_size * (i + 1) // num_canonical_nodes
# get the section of ids corresponding to this canonical node
part = ids[begin:end]
parts.append(sorted(part))
lists.append(parts)
# want to make sure the sample ids seen by each canonical node in each epoch is the same
lists = list(zip(*lists))
# each element of lists is now a tuple containing the lists of samples seen by a canonical node over all the epochs
for parts in lists:
# make sure all other epochs are the same as epoch 0.
for i in range(1, len(parts)):
assert parts[0] == parts[i]

Expand All @@ -40,6 +47,10 @@ def test_shuffle_py1br():
check(get_shuffle_py1br)


def test_shuffle_py1e():
check(get_shuffle_py1e)


def test_shuffle_py1s():
check(get_shuffle_py1s)

Expand Down