Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve shard efficiency of sampling for fractional stream repeats. #391

Merged
merged 26 commits into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e8d3bf9
Improve shard-efficiency of sampling fractional stream repeats.
knighton Aug 21, 2023
831ce45
Merge branch 'main' into james/sparse-sampling
knighton Aug 28, 2023
48cd9f0
Fix (np.expand_dims).
knighton Aug 28, 2023
436f0d3
Merge branch 'james/sparse-sampling' of github.com:mosaicml/streaming…
knighton Aug 28, 2023
071ae24
Merge branch 'main' into james/sparse-sampling
knighton Aug 29, 2023
c4e5623
Merge branch 'main' into james/sparse-sampling
knighton Aug 29, 2023
1a165a3
Merge branch 'main' into james/sparse-sampling
knighton Aug 29, 2023
94b3841
Merge branch 'main' into james/sparse-sampling
knighton Sep 5, 2023
f4e5ae5
Merge branch 'james/sparse-sampling' of github.com:mosaicml/streaming…
Sep 5, 2023
1f73e50
Redesign sampling for sparsity.
Sep 5, 2023
8ec04c6
Hook up args.
knighton Sep 5, 2023
7eae135
Rename get_shard_sampling -> get_sampling.
knighton Sep 6, 2023
23392bf
Fix bug, add some tests.
knighton Sep 6, 2023
3c6343b
Test sampling balance.
knighton Sep 6, 2023
a40302b
Merge branch 'main' into james/sparse-sampling
knighton Sep 7, 2023
57579b2
Merge branch 'main' into james/sparse-sampling
knighton Sep 8, 2023
d4d92cf
Misc.
knighton Sep 8, 2023
189b3f2
Fix statistics test.
knighton Sep 8, 2023
1b5c5ff
Elaborate on docstrings.
knighton Sep 8, 2023
b460e50
Rewrite algorithm to be more performant.
knighton Sep 8, 2023
c391bcf
Merge branch 'main' into james/sparse-sampling
knighton Sep 8, 2023
bcd5ecf
Merge branch 'main' into james/sparse-sampling
karan6181 Sep 8, 2023
9e4ea89
Pyright.
knighton Sep 8, 2023
7d27c4a
Merge branch 'james/sparse-sampling' of github.com:mosaicml/streaming…
knighton Sep 8, 2023
09562c3
fstring the ValueErrors.
knighton Sep 8, 2023
794aa3f
Merge branch 'main' into james/sparse-sampling
karan6181 Sep 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions scripts/samples/bench_and_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,11 @@ def bench(args: Namespace, bench_name: str, desc: str, generate: Callable,
plt.grid(which='minor', ls=':', c='#ddd', lw=0.5)
ax = plt.gca()
ax.xaxis.set_major_formatter(ScalarFormatter())
ax.xaxis.get_major_formatter().set_scientific(False)
ax.xaxis.get_major_formatter().set_useOffset(False)
ax.xaxis.get_major_formatter().set_scientific(False) # pyright: ignore
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
ax.xaxis.get_major_formatter().set_useOffset(False) # pyright: ignore
ax.xaxis.set_minor_formatter(ScalarFormatter())
ax.xaxis.get_minor_formatter().set_scientific(False)
ax.xaxis.get_minor_formatter().set_useOffset(False)
ax.xaxis.get_minor_formatter().set_scientific(False) # pyright: ignore
ax.xaxis.get_minor_formatter().set_useOffset(False) # pyright: ignore
ax.xaxis.set_tick_params(which='minor', pad=5)
print(' Stats')
for (format_name, writer_class, color), seq, rand in zip(format_infos, seqs, rands):
Expand Down
25 changes: 14 additions & 11 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from streaming.base.distributed import maybe_init_dist
from streaming.base.format import get_index_basename
from streaming.base.partition import get_partitions
from streaming.base.sampling import get_sampling
from streaming.base.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar,
_get_path, get_shm_prefix)
from streaming.base.shuffle import get_shuffle
Expand Down Expand Up @@ -203,6 +204,7 @@ class StreamingDataset(Array, IterableDataset):
* Sampling:

* ``sampling_method``
* ``sampling_granularity``


Args:
Expand Down Expand Up @@ -265,6 +267,9 @@ class StreamingDataset(Array, IterableDataset):
of this size, and samples within each block are shuffled. Defaults to ``1 << 18``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
sampling_granularity (int): When picking samples for a stream's final partial repeat,
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
how many samples to pick from the same shard at a time (``1`` for evenly balanced
across shards). Defaults to ``1``.
"""

def __init__(self,
Expand All @@ -287,7 +292,8 @@ def __init__(self,
shuffle_algo: str = 'py1s',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
sampling_method: str = 'balanced') -> None:
sampling_method: str = 'balanced',
sampling_granularity: int = 1) -> None:
# Global arguments (which do not live in Streams).
self.predownload = predownload
self.cache_limit = cache_limit
Expand All @@ -299,6 +305,7 @@ def __init__(self,
self.shuffle_seed = shuffle_seed
self.shuffle_block_size = shuffle_block_size
self.sampling_method = sampling_method.lower().strip()
self.sampling_granularity = sampling_granularity

# Check streams vs remote/local.
if bool(streams) == (bool(remote) or bool(local)):
Expand Down Expand Up @@ -716,17 +723,13 @@ def _resample_streams(self, epoch: int) -> Tuple[NDArray[np.int64], NDArray[np.i

# Calculate choose per stream shard.
samples_per_stream_shard = self.samples_per_shard[stream_shard_ids]
stream_samples = sum(samples_per_stream_shard)
# the number of items to choose from each stream (calculated during dataset initialization)
# the number of items to choose from each stream (calculated during dataset
# initialization)
stream_choose = self.streams[stream_id].choose
if stream_choose == stream_samples:
choose_per_stream_shard = samples_per_stream_shard
else:
choose_per_stream_shard = \
samples_per_stream_shard * stream_choose // stream_samples
shortfall = stream_choose - choose_per_stream_shard.sum()
indices = rng.choice(num_stream_shards, shortfall, False)
choose_per_stream_shard[indices] += 1
use_epoch = self.sampling_method == 'balanced'
choose_per_stream_shard = get_sampling(samples_per_stream_shard, stream_choose,
self.sampling_granularity, self.shuffle_seed,
epoch, use_epoch)

# Iterate over each shard of this stream.
for shard_id, shard_samples, shard_choose in zip(stream_shard_ids,
Expand Down
75 changes: 75 additions & 0 deletions streaming/base/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Functionality relating to sampling."""

import numpy as np
from numpy.typing import NDArray


def get_sampling(samples_per_shard: NDArray[np.int64], choose: int, granularity: int, seed: int,
epoch: int, use_epoch: bool) -> NDArray[np.int64]:
"""Get how many samples to draw from each shard of the given stream.

Args:
samples_per_shard (NDArray[np.int64]): Array of underlying shard sizes.
choose (int): How many samples to draw in total over all shards.
granularity (int): How many samples to draw at a time from the same shard.
seed (int): Seed for shuffling sampling granules.
epoch (int): Which epoch we are sampling for.
use_epoch (bool): Whether to factor epoch into the base seed, or use the same seed across
epochs.

Returns:
NDArray[np.int64]: Array of ephemeral samples chosen per shard.
"""
if choose < 0:
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('Choose must be a non-negative integer.')
karan6181 marked this conversation as resolved.
Show resolved Hide resolved

if granularity <= 0:
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('Granularity must be a positive integer.')

if seed < 0:
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('Seed must be a non-negative integer.')

if epoch < 0:
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('Epoch must be a non-negative integer.')
karan6181 marked this conversation as resolved.
Show resolved Hide resolved

# Handle whole integer repeat case.
num_samples = sum(samples_per_shard)
if not choose % num_samples:
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
return samples_per_shard * choose // num_samples

# Fractional repeat case.

# Get the ordering by which we will exhaust the shards.
pairs = [] # List of (shard ID, samples to draw).
for shard_id, shard_samples in enumerate(samples_per_shard):
num_granules = (shard_samples + granularity - 1) // granularity
shard_ids = np.full(num_granules, shard_id)
counts = np.full(num_granules, granularity)
if shard_samples % granularity:
counts[-1] = shard_samples % granularity
pair = shard_ids, counts
pairs.append(pair)
shard_ids, counts = zip(*pairs)
shard_ids = np.concatenate(shard_ids)
counts = np.concatenate(counts)
num_granules = len(shard_ids)
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
epoch_seed = seed + epoch if use_epoch else seed
rng = np.random.default_rng(epoch_seed)
ordering = rng.permutation(num_granules)

# Collect choose per shard.
choose_per_shard = samples_per_shard * (choose // num_samples)
choose %= num_samples
for index in ordering:
shard_id = shard_ids[index]
count = counts[index]
count = min(choose, int(count))
choose_per_shard[shard_id] += count
choose -= count
if not choose:
break

return choose_per_shard
58 changes: 58 additions & 0 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

import numpy as np

from streaming.base.sampling import get_sampling


def test_choose_per_shard_adds_up():
for granularity in range(1, 100):
for _ in range(10):
samples_per_shard = np.random.choice(100, 10)
samples = sum(samples_per_shard)
choose = np.random.choice(samples)
seed = np.random.choice(31337)
epoch = np.random.choice(42)
use_epoch = bool(np.random.choice(2))
choose_per_shard = get_sampling(samples_per_shard, choose, granularity, seed, epoch,
use_epoch)
assert (0 <= choose_per_shard).all()
assert (choose_per_shard <= samples_per_shard).all()
assert sum(choose_per_shard) == choose


def test_is_deterministic():
for granularity in range(1, 100):
for _iter in range(3):
samples_per_shard = np.random.choice(100, 10)
samples = sum(samples_per_shard)
choose = np.random.choice(samples)
seed = np.random.choice(31337)
epoch = np.random.choice(42)
use_epoch = bool(np.random.choice(2))
last = None
for _repeat in range(2):
choose_per_shard = get_sampling(samples_per_shard, choose, granularity, seed,
epoch, use_epoch)
if last is not None:
assert (last == choose_per_shard).all()
last = choose_per_shard


def test_balance():
samples_per_shard = np.random.choice(1_000, 10)
samples = sum(samples_per_shard)
choose = np.random.choice(samples)
choose_per_shard = np.zeros(len(samples_per_shard))
for granularity in range(1, 100):
for _ in range(10):
seed = np.random.choice(31337)
epoch = np.random.choice(42)
use_epoch = bool(np.random.choice(2))
choose_per_shard += get_sampling(samples_per_shard, choose, granularity, seed, epoch,
use_epoch)
choose_per_shard /= 99 * 10
rates = choose_per_shard / samples_per_shard
imbalance = rates.std() / rates.mean()
assert imbalance < 0.05