Skip to content

Commit

Permalink
Per Stream Batching (#407)
Browse files Browse the repository at this point in the history
* in progress

* uniform batch sampling

* uniform batch sampling

* more changes

* added padding to end of partition

* why it working doe

* removing extraneous variables

* added under batching_method arg

* added under batching_method arg

* added under batching_method arg

* changed to stratified

* addressed PR comments

* addressed PR comments

* small fix
  • Loading branch information
snarayan21 authored Sep 7, 2023
1 parent 7851bd6 commit 9eeccae
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 54 deletions.
7 changes: 7 additions & 0 deletions docs/source/fundamentals/batching.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Batching

You can choose how batches are constructed by specifying the `batching_method` argument when instantiating `StreamingDataset`. Currently, this can take on one of three values:

- `'random'`: (default) Samples for each batch are chosen at random from input streams. While stream proportions hold in aggregate over the course of training, this batching method does not guarantee that stream proportions hold for each batch.
- `'stratified'`: Every single batch is divided up between streams in the same proportions. Unlike in the default case, stream proportions hold for every batch, unlike in the default case, where they hold only in aggregate.
- `'per_stream'`: Each batch has samples from just one stream. In aggregate over all batches, stream proportions still hold.
2 changes: 1 addition & 1 deletion docs/source/fundamentals/sampling.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
You can choose how sampling from your dataset(s) occurs between epochs by specifying the `sampling_method` when instantiating `StreamingDataset`. Currently, this can take on one of two values:

- `'balanced'`: (default) Samples are chosen at random from dataset(s) during each epoch according to the proportions specified.
- `'fixed`: The same samples from the dataset(s) are chosen during every epoch, still according to the proportions specified.
- `'fixed'`: The same samples from the dataset(s) are chosen during every epoch, still according to the proportions specified.
1 change: 1 addition & 0 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ If you have any questions, please feel free to reach out to us on [Twitter](htt
fundamentals/environments.md
fundamentals/shuffling.md
fundamentals/sampling.md
fundamentals/batching.md
.. toctree::
:hidden:
Expand Down
41 changes: 41 additions & 0 deletions streaming/base/batching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Apportion shards/samples to nodes/ranks/workers for elastically deterministic sample order."""
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from numpy.typing import NDArray

from streaming.base.batching.per_stream import generate_work_per_stream_batching
from streaming.base.batching.random import generate_work_random_batching
from streaming.base.world import World

if TYPE_CHECKING:
from streaming.base.dataset import StreamingDataset

batching_methods = {
'random': generate_work_random_batching,
'per_stream': generate_work_per_stream_batching,
}


def generate_work(batching_method: str, dataset: StreamingDataset, world: World, epoch: int,
sample_in_epoch: int) -> NDArray[np.int64]:
"""Apportion shards/samples to nodes/ranks/workers for elastically deterministic sample order.
Args:
batching_method (str): The batching method to use.
dataset (StreamingDataset): Dataset to generate the partition for.
world (World): World state.
epoch (int): Which epoch it is.
sample_in_epoch (int): Where we are in the epoch.
Returns:
NDArray[np.int64]: The epoch (num physical nodes, ranks per node, workers per rank,
batches per worker, batch size).
"""
get = batching_methods[batching_method]
return get(dataset, world, epoch, sample_in_epoch)
143 changes: 143 additions & 0 deletions streaming/base/batching/per_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Apportion shards/samples such that batches have samples only from a single stream."""
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import numpy as np
from numpy.typing import NDArray

from streaming.base.partition import get_partitions
from streaming.base.shuffle import get_shuffle
from streaming.base.world import World

if TYPE_CHECKING:
from streaming.base.dataset import StreamingDataset

logger = logging.getLogger(__name__)


def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, epoch: int,
sample_in_epoch: int) -> NDArray[np.int64]:
"""Generate this epoch's arrangement of samples for ``per_stream`` batching.
This is only called in local rank zero. When ``batching_method`` is set to ``per_stream``,
each batch consists of samples from only one stream.
Args:
dataset (StreamingDataset): Dataset to generate the partition for.
world (World): World state.
epoch (int): Which epoch it is.
sample_in_epoch (int): Where we are in the epoch.
Returns:
NDArray[np.int64]: The epoch (num physical nodes, ranks per node, workers per rank,
batches per worker, batch size).
"""
# Ensure that num_canonical_nodes has been set.
if dataset.num_canonical_nodes is None:
raise RuntimeError(f'`num_canonical_nodes` can never be None. ' +
f'Provide a positive integer.')

# First, for each stream, sample each shard of the stream according to proportions/repeats/samples.
# We obtain the resampled size of each shard in the stream and a mapping from the training "big" sample ID
# to the underlying shard "small" sample ID.
# Then, we also partition each stream's samples over nodes/devices/workers.
# We handle sample_in_epoch (for resumption) at the end.
partition_per_stream = []

batch_size = dataset.batch_size or 1

for stream_id, stream in enumerate(dataset.streams):
shuffle_units, small_per_big = dataset.resample_streams(epoch, stream_id)
samples_in_stream = len(small_per_big)
stream_partition = get_partitions(dataset.partition_algo, samples_in_stream,
dataset.num_canonical_nodes, world.num_nodes,
world.ranks_per_node, world.workers_per_rank, batch_size,
0)
if dataset.shuffle:
# Ratio of stream's shuffle block size to overall shuffle block size should be the
# same as the ratio of the stream's samples to overall samples.
# This ensures that the overall training shuffle block size is still approximately
# equal to what is set by the user, and allows for reasoning about cache_limit as well.
shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion)
stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units,
dataset.num_canonical_nodes, dataset.shuffle_seed, epoch,
shuffle_block_portion)
stream_partition = np.where(stream_partition != -1, stream_shuffle[stream_partition],
-1)
# The small_per_big array already corresponds to indices of samples per shard of each stream.
# So each sample ID in the stream's partition already corresponds to the sample ID in the right shard.
partition_per_stream.append(
np.where(stream_partition != -1, small_per_big[stream_partition], -1))

# We now merge the partitions from each stream to get our final partition over all streams, where
# each global batch has samples only from a single stream.
# Partitions are arranged (physical nodes, ranks per node, workers per rank, batches per worker, batch size).
batches_per_stream = []
batches_from_partitions = []
for stream_idx, partition in enumerate(partition_per_stream):
# reshape the partition to be global batches in order of traversal, and count only batches without -1 in them.
global_batches_inorder = partition.transpose(3, 2, 0, 1, 4).reshape(
-1, batch_size * world.ranks_per_node * world.num_nodes)
num_full_batches = np.count_nonzero(np.min(global_batches_inorder, axis=1) >= 0)
batches_per_stream.append(num_full_batches)
if num_full_batches != global_batches_inorder.shape[0]:
logger.warning(
'Because of the `per_stream` batching method, some batches with an inadequate number of samples '
+ 'from stream with index ' + str(stream_idx) + ' are being dropped.')
batches_from_partitions.append(global_batches_inorder[:num_full_batches])

# Combine all global batches from all streams into one array.
all_partition_batches = np.concatenate(batches_from_partitions)

# Shuffle seed changes with every epoch so that the order of streams in our batches changes as well.
epoch_rng = np.random.default_rng(dataset.shuffle_seed + epoch)

# stream_origins is an array that tells us which stream each batch is using.
stream_origins = np.concatenate(
[np.full(n_batch, i) for i, n_batch in enumerate(batches_per_stream)])
epoch_rng.shuffle(stream_origins)

# Now, we want the batch_indices array to correctly index into the all_partition_batches
# array according to stream_origins in order to get our final batch order.
# For each stream, we want to traverse its batches in the same order as given in its partition.
batch_indices = np.zeros(stream_origins.shape[0]).astype(np.int64)
batch_offset = 0
for i, n_batch in enumerate(batches_per_stream):
# Update batch_indices for the one stream at a time.
batch_indices[stream_origins == i] += batch_offset + np.arange(n_batch)
batch_offset += n_batch

# Rearrange all_partition_batches by the batch_indices we have obtained.
all_partition_batches = all_partition_batches[batch_indices]

# If applicable we resume right after the most recently used full global batch.
global_batch_size = batch_size * world.num_nodes * world.ranks_per_node
if sample_in_epoch % global_batch_size != 0:
logger.warning(
'Because of the `per_stream` batching method, resumption may only occur on a sample that '
+ 'is a multiple of the current global batch size of ' + str(global_batch_size) +
'. Resuming training ' + 'after the most recently finished global batch.')

# Discard previous batches that may have already finished
resumption_batch = sample_in_epoch // global_batch_size
all_partition_batches = all_partition_batches[resumption_batch:]

# Add padding batches if necessary to ensure that we have an even number of batches per worker/rank/node
current_samples = all_partition_batches.size
divisibility_requirement = world.num_nodes * world.ranks_per_node * world.workers_per_rank * batch_size
if current_samples % divisibility_requirement != 0:
samples_needed = divisibility_requirement - (current_samples % divisibility_requirement)
padding_batches_needed = samples_needed // global_batch_size
all_partition_batches = np.concatenate(
(all_partition_batches, np.full((padding_batches_needed, global_batch_size), -1)))

# Reverse the transposition and reshape from earlier.
# Final result is (physical nodes, ranks per node, workers per rank, batches per worker, batch size), as desired.
return all_partition_batches.reshape(-1, world.workers_per_rank, world.num_nodes,
world.ranks_per_node,
batch_size).transpose(2, 3, 1, 0, 4)
66 changes: 66 additions & 0 deletions streaming/base/batching/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Apportion shards/samples such that batches have samples randomly selected from streams."""
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import numpy as np
from numpy.typing import NDArray

from streaming.base.partition import get_partitions
from streaming.base.shuffle import get_shuffle
from streaming.base.world import World

if TYPE_CHECKING:
from streaming.base.dataset import StreamingDataset

logger = logging.getLogger(__name__)


def generate_work_random_batching(dataset: StreamingDataset, world: World, epoch: int,
sample_in_epoch: int) -> NDArray[np.int64]:
"""Generate this epoch's arrangement of samples for ``random`` batching.
This is only called in local rank zero. When ``batching_method`` is set to ``per_stream``,
which is the default case, each batch consists of samples selected at random from across
all streams.
Args:
dataset (StreamingDataset): Dataset to generate the partition for.
world (World): World state.
epoch (int): Which epoch it is.
sample_in_epoch (int): Where we are in the epoch.
Returns:
NDArray[np.int64]: The epoch (num physical nodes, ranks per node, workers per rank,
batches per worker, batch size).
"""
# Ensure that num_canonical_nodes has been set.
if dataset.num_canonical_nodes is None:
raise RuntimeError(f'`num_canonical_nodes` can never be None. ' +
f'Provide a positive integer.')

# Sample each shard of each stream according to their proportions/repeats/samples. This
# gives us the resampled size of each underlying shard, and a mapping from each fake "big"
# sample ID to its underlying "small" sample ID.
shuffle_units, small_per_big = dataset.resample_streams(epoch)

# Partition the global sample space (of resampled "big" sample IDs) into a tensor of shape
# (num physical nodes, ranks per node, workers per rank, batches per worker, samples per
# batch) such that we have an elastically deterministic sample order.
big_ids = get_partitions(dataset.partition_algo, dataset.epoch_size,
dataset.num_canonical_nodes, world.num_nodes, world.ranks_per_node,
world.workers_per_rank, dataset.batch_size, sample_in_epoch)

# If we need to shuffle, shuffle in a node-aware and *underlying* shard-aware way.
if dataset.shuffle:
shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes,
dataset.shuffle_seed, epoch, dataset.shuffle_block_size)
big_ids = np.where(big_ids != -1, shuffle[big_ids], -1)

# Now that we have partitioning and shuffled with hallucinated "big" sample IDs, we don't
# need them anymore, and can convert back to underlying "small" sample IDs.
return np.where(big_ids != -1, small_per_big[big_ids], -1)
Loading

0 comments on commit 9eeccae

Please sign in to comment.