Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Py1br algorithm implementation #373

Merged
merged 8 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/fundamentals/sampling.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

You can choose how sampling from your dataset(s) occurs between epochs by specifying the `sampling_method` when instantiating `StreamingDataset`. Currently, this can take on one of two values:

-`'balanced'`: (default) Samples are chosen at random from dataset(s) during each epoch according to the proportions specified.
-`'fixed`: The same samples from the dataset(s) are chosen during every epoch, still according to the proportions specified.
- `'balanced'`: (default) Samples are chosen at random from dataset(s) during each epoch according to the proportions specified.
- `'fixed`: The same samples from the dataset(s) are chosen during every epoch, still according to the proportions specified.
14 changes: 11 additions & 3 deletions docs/source/fundamentals/shuffling.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ StreamingDataset takes four arguments to directly control shuffling.
| `shuffle` | `bool = False` | turn shuffling on or off |
| `shuffle_algo` | `str = 'py1s'` | which shuffling algorithm to use |
| `shuffle_seed` | `int = 9176` | all randomness in StreamingDataset is derived from this argument |
| `shuffle_block_size` | `int = 1 << 18` | shuffling unit used by py1b algorithm |
| `shuffle_block_size` | `int = 1 << 18` | shuffling unit used by py1b and py1br algorithms |

StreamingDataset also takes two other arguments that shuffling interacts with:

Expand All @@ -34,11 +34,19 @@ Statistically, this algorithm will result in all nodes downloading all shards, w

### py1b

Globally shuffle shards, divide that sample space over canonical nodes, then shuffle samples in fixed-size blocks (given by `shuffle_block_size`). So named because it shuffles samples in python, once, intra-block.
Globally shuffle shards, divide that sample space over canonical nodes, then shuffle samples in fixed-size blocks (given by `shuffle_block_size`). So named because it shuffles samples in python, once, intra-block. A canonical node, for the purposes of shuffling, is simply a collection of shards. In order to have determinism with a different number of physical nodes, the shuffle ordering is done over the canonical nodes and these are then assigned to physical nodes.

Shuffle block size should be set larger or much larger than a single shard. If so, this algorithm is useful for spacing out the contents of shards to mitigate a bad or non-existent pre-shuffle (i.e. if samples from the same shard are related in some way).

This algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard. If you see step-like burstiness in throughput, your workers may not be downloading far enough ahead – try raising predownload (it should be scaled with block size). Step size scales with shuffle block size.
In order to improve shuffle quality, this algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard. If you see step-like burstiness in throughput, your workers may not be downloading far enough ahead – try raising predownload (it should be scaled with block size). Step size scales with shuffle block size.

### py1br

Globally shuffle shards, divide that sample space over canonical nodes, then shuffle samples in variable-size blocks (uniformly selected within the range `[0.75*shuffle_block_size, 1.25*shuffle_block_size)`). Shuffle blocks are also staggered -- along with variable shuffle block size, this works to prevent many simultaneous shard downloads. So named because it shuffles samples in python, once, intra-block, and blocks are randomized.

Shuffle block size should be set larger or much larger than a single shard. If so, this algorithm is useful for spacing out the contents of shards to mitigate a bad or non-existent pre-shuffle (i.e. if samples from the same shard are related in some way).

In order to improve shuffle quality, this algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard. However, shard downloads with py1br are more balanced than with py1b, and this effect is more apparent when training with a higher number of nodes, resulting in less network bottlenecks. The shuffle quality of py1br and py1b are similar.

### py1s

Expand Down
11 changes: 10 additions & 1 deletion streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from threading import Event, Lock
from time import sleep, time_ns
from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union
from warnings import warn

import numpy as np
from filelock import FileLock
Expand Down Expand Up @@ -260,7 +261,8 @@ class StreamingDataset(Array, IterableDataset):
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
shuffle_block_size (int): Unit of shuffle. A canonical node's samples are split into blocks
of this size, and samples within each block are shuffled. Defaults to ``1 << 18``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
"""
Expand Down Expand Up @@ -309,6 +311,13 @@ def __init__(self,
f'Invalid sampling method: {sampling_method}. Must be one of `balanced` or `fixed`.'
)

# issue deprecation warning for py1b shuffle algorithm.
if self.shuffle_algo == 'py1b':
warn(
'The \'py1b\' shuffle algorithm will soon be deprecated. Please use the more performant \'py1br\' algorithm instead.',
DeprecationWarning,
stacklevel=2)

# Check that predownload is at least per device batch size.
if self.predownload is not None and self.batch_size is not None and \
self.predownload < self.batch_size:
Expand Down
2 changes: 2 additions & 0 deletions streaming/base/shuffle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

from streaming.base.shuffle.naive import get_shuffle_naive
from streaming.base.shuffle.py1b import get_shuffle_py1b
from streaming.base.shuffle.py1br import get_shuffle_py1br
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
from streaming.base.shuffle.py1s import get_shuffle_py1s
from streaming.base.shuffle.py2s import get_shuffle_py2s

algos = {
'py1b': get_shuffle_py1b,
'py1br': get_shuffle_py1br,
'py1s': get_shuffle_py1s,
'py2s': get_shuffle_py2s,
'naive': get_shuffle_naive,
Expand Down
6 changes: 6 additions & 0 deletions streaming/base/shuffle/py1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def get_shuffle_py1b(shard_sizes: NDArray[np.int64],
num_samples += shard_size

# Generate the initial ordering of shards, which is fixed over an entire training run.
# Because the ordering of shards is fixed the downloaded shards from the first epoch
# can be persisted and used for subsequent epochs in each node as well.
run_rng = np.random.default_rng(seed)
run_rng.shuffle(spans)

Expand All @@ -59,12 +61,16 @@ def get_shuffle_py1b(shard_sizes: NDArray[np.int64],
# Populate the global sample ID mapping, shuffling within each block within each super-span.
ids = np.empty(num_samples, np.int64)
offset = 0
# loop over each canonical node
for super_begin, super_end in super_spans:
# super_offset is the offset of the first sample in the canonical node
super_offset = offset
# loop over each span contained in the canonical node
for begin, end in spans[super_begin:super_end]:
span_size = end - begin
ids[offset:offset + span_size] = np.arange(begin, end)
offset += span_size
# shuffle within each block, but don't shuffle past the canonical node boundary
for start in range(super_offset, offset, block_size):
stop = min(start + block_size, offset)
epoch_rng.shuffle(ids[start:stop])
Expand Down
91 changes: 91 additions & 0 deletions streaming/base/shuffle/py1br.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Shuffling algorithm that shuffles in fixed-size blocks.

These units are presumably larger or much larger than single shards, leading to better shuffledness
at the cost of having to download more shards to make progress.
"""

import numpy as np
from numpy.typing import NDArray

from streaming.base.shuffle.py1s import divide_spans


def get_shuffle_py1br(shard_sizes: NDArray[np.int64],
num_canonical_nodes: int,
seed: int,
epoch: int,
block_size: int = 1 << 18) -> NDArray[np.int64]:
"""Get the shuffled global ordering of samples for an epoch.

The assignment of shards to nodes is fixed across epochs, but each grouping of shards is
processed concurrently in a different order by each node's workers each epoch.
Args:
shard_sizes (NDArray[np.int64]): Number of samples contained in each shard, in order.
num_canonical_nodes (int): Number of canonical nodes.
seed (int): Base random seed, which is held constant over an entire training run.
epoch (int): Current epoch, which is added to the seed to get a different deterministic
shuffle each epoch.
block_size (int): Unit of shuffle. For py1br shuffling method, the block size is chosen
uniformly at random in the range (0.75*block_size, 1.25*block_size). Defaults to ``1 << 18``.

Returns:
NDArray[np.int64]: 1:1 mapping of sample ID to shuffled sample ID.
"""
# Create each shard's sample ID span (start, stop excl).
spans = []
num_samples = 0
for shard_size in shard_sizes:
span = num_samples, num_samples + shard_size
spans.append(span)
num_samples += shard_size

# Generate the initial ordering of shards, which is fixed over an entire training run.
run_rng = np.random.default_rng(seed)
run_rng.shuffle(spans)

# Break the shard spans at canonical node boundaries.
spans, node_spans = divide_spans(spans, num_samples, num_canonical_nodes)

# Shuffle the span ordering within each canonical node uniquely to this epoch.
epoch_rng = np.random.default_rng(seed + epoch)
for node_start_span, node_stop_span in node_spans:
node_span = spans[node_start_span:node_stop_span]
epoch_rng.shuffle(node_span) # pyright: ignore
spans[node_start_span:node_stop_span] = node_span

# Populate the global sample ID mapping, shuffling within each block within each node.
ids = np.empty(num_samples, np.int64)
node_stop_sample = 0
stagger = epoch_rng.integers(0, int(0.75 * block_size), (num_canonical_nodes,))
for node, (node_start_span, node_stop_span) in enumerate(node_spans):
node_start_sample = node_stop_sample

# Populate sample IDs given the span ordering for this node.
for span_start_sample, span_stop_sample in spans[node_start_span:node_stop_span]:
span_size = span_stop_sample - span_start_sample
ids[node_stop_sample:node_stop_sample + span_size] = \
np.arange(span_start_sample, span_stop_sample)
node_stop_sample += span_size

# Get randomized and staggered block ranges for the current node.
block_staggered_ranges = []
blocks_end = node_start_sample
node_stagger = stagger[node]
while blocks_end < node_stop_sample:
rand_block_size = epoch_rng.integers(int(0.75 * block_size), int(1.25 * block_size))
# don't want the block to start before the first sample of the node
staggered_block_start = max(blocks_end - node_stagger, node_start_sample)
# don't want the block to stop after the last sample of the node
staggered_block_stop = min(blocks_end + rand_block_size - node_stagger,
node_stop_sample)
block_staggered_ranges.append((staggered_block_start, staggered_block_stop))
blocks_end += staggered_block_stop - staggered_block_start

# Shuffle within each staggered, randomized block.
for block_start, block_stop in block_staggered_ranges:
epoch_rng.shuffle(ids[block_start:block_stop])

return ids
11 changes: 11 additions & 0 deletions streaming/base/shuffle/py1s.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,28 @@ def divide_spans(spans: List[Tuple[int, int]], num_samples: int, num_parts: int)
super_spans = []

for part in range(num_parts):
# note that the size of a part (canonical node) is num_samples // num_parts.
part_end = num_samples * (part + 1) // num_parts

# loop over spans until we've filled up our part (canonical node) completely
while True:
if span_index == len(spans):
break

# input spans are the shard spans. these can be unequally sized and may cross
# part (canonical node) boundaries.
span = spans[span_index]
# spans are (begin, end excl)
samples_this_span = span[1] - span[0]
# check if the shard span contains more samples than the part (canonical node) can fit
if part_end < samples_so_far + samples_this_span:
# if there is space left in the part, split the span
if samples_so_far < part_end:
split = part_end - samples_so_far
# create a span, filling up with as many samples as possible from shard span
new_span = span[0], span[0] + split
out_spans.append(new_span)
# modify the old shard span to reflect that it's been split
spans[span_index] = span[0] + split, span[1]
samples_so_far += split
break
Expand All @@ -59,6 +68,8 @@ def divide_spans(spans: List[Tuple[int, int]], num_samples: int, num_parts: int)
span_index += 1
samples_so_far += samples_this_span

# super spans are tell us which new spans belong to each part (canonical node)
# as a tuple of (begin span index, end span index excl)
super_span = begin_part, len(out_spans)
super_spans.append(super_span)
begin_part = len(out_spans)
Expand Down
14 changes: 12 additions & 2 deletions tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@

import numpy as np

from streaming.base.shuffle import get_shuffle_py1s, get_shuffle_py2s
from streaming.base.shuffle import (get_shuffle_py1b, get_shuffle_py1br, get_shuffle_py1s,
get_shuffle_py2s)


def check(get_shuffle: Callable) -> None:
shard_sizes = 1 + np.arange(100)
dataset_size = sum(shard_sizes)
block_size = 300
for num_canonical_nodes in [1, 2, 3]:
for seed in [0, 1, 2]:
lists = []
for epoch in [0, 1, 2]:
ids = get_shuffle(shard_sizes, num_canonical_nodes, seed, epoch)
ids = get_shuffle(shard_sizes, num_canonical_nodes, seed, epoch, block_size)
assert sorted(ids) == list(range(len(ids)))
parts = []
for i in range(num_canonical_nodes):
Expand All @@ -30,6 +32,14 @@ def check(get_shuffle: Callable) -> None:
assert parts[0] == parts[i]


def test_shuffle_py1b():
check(get_shuffle_py1b)


def test_shuffle_py1br():
check(get_shuffle_py1br)


def test_shuffle_py1s():
check(get_shuffle_py1s)

Expand Down