Skip to content

Commit

Permalink
[AIR] Add option for per-epoch preprocessor (#31739)
Browse files Browse the repository at this point in the history
This adds an option to the AIR DatasetConfig for a preprocessor that gets reapplied on each epoch. Currently the implementation uses DatasetPipeline to ensure that the extra preprocessing step is overlapped with training.

Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
  • Loading branch information
stephanie-wang authored Jan 31, 2023
1 parent e3001e9 commit ae167f0
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 10 deletions.
19 changes: 19 additions & 0 deletions doc/source/ray-air/check-ingest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,25 @@ Shuffling or data randomization is important for training high-quality models. B
* you suspect high-quality shuffles may significantly improve model quality; and
* absolute ingest performance is less of a concern

.. _air-per-epoch-preprocessing:

Applying randomized preprocessing (experimental)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The standard preprocessor passed to the ``Trainer`` is only applied once to the initial dataset when using :ref:`bulk ingest <air-streaming-ingest>`.
However, in some cases you may want to reapply a preprocessor on each epoch, for example to augment your training dataset with a randomized transform.

To support this use case, AIR offers an additional *per-epoch preprocessor* that gets reapplied on each epoch, after all other preprocessors and right before dataset consumption (e.g., using :meth:`~ray.data.DatasetIterator.iter_batches()`).
Per-epoch preprocessing also executes in parallel with dataset consumption to reduce pauses in dataset consumption.

This example shows how to use this feature to apply a randomized preprocessor on top of the standard preprocessor.

.. literalinclude:: doc_code/air_ingest.py
:language: python
:start-after: __config_6__
:end-before: __config_6_end__


.. _air-splitting-aux-datasets:

Splitting Auxiliary Datasets
Expand Down
50 changes: 50 additions & 0 deletions doc/source/ray-air/doc_code/air_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,56 @@ def train_loop_per_worker():
my_trainer.fit()
# __config_5_end__

# __config_6__
import random

import ray
from ray.air import session
from ray.data import DatasetIterator
from ray.data.preprocessors import BatchMapper
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, DatasetConfig

# A simple preprocessor that just scales all values by 2.0.
preprocessor = BatchMapper(lambda df: df * 2, batch_format="pandas")

# A randomized preprocessor that adds a random float to all values, to be
# reapplied on each epoch after `preprocessor`. Each epoch will therefore add a
# different random float to the scaled dataset.
add_noise = BatchMapper(lambda df: df + random.random(), batch_format="pandas")


def train_loop_per_worker():
# Get a handle to the worker's assigned DatasetIterator shard.
data_shard: DatasetIterator = session.get_dataset_shard("train")

# Manually iterate over the data 10 times (10 epochs).
for _ in range(10):
for batch in data_shard.iter_batches():
print("Do some training on batch", batch)

# Print the stats for performance debugging.
print(data_shard.stats())


my_trainer = TorchTrainer(
train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=1),
datasets={
"train": ray.data.range_tensor(100),
},
dataset_config={
"train": DatasetConfig(
# Don't randomize order, just to make it easier to read the results.
randomize_block_order=False,
per_epoch_preprocessor=add_noise,
),
},
preprocessor=preprocessor,
)
my_trainer.fit()
# __config_6_end__

# __global_shuffling_start__
import ray
from ray.air import session
Expand Down
25 changes: 25 additions & 0 deletions python/ray/air/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ray.air.constants import WILDCARD_KEY
from ray.util.annotations import PublicAPI
from ray.widgets import Template, make_table_html_repr
from ray.data.preprocessor import Preprocessor

if TYPE_CHECKING:
from ray.data import Dataset
Expand Down Expand Up @@ -326,6 +327,11 @@ class DatasetConfig:
The main purpose of this is to prevent data fetching hotspots in the
cluster when running many parallel workers / trials on the same data.
We recommend enabling it always. True by default.
per_epoch_preprocessor [Experimental]: A preprocessor to re-apply on
each pass of the dataset. The main use case for this is to apply a
random transform on a training dataset on each epoch. The
per-epoch preprocessor will be applied *after* all other
preprocessors and in parallel with the dataset consumer.
use_stream_api: Deprecated. Use max_object_store_memory_fraction instead.
stream_window_size: Deprecated. Use max_object_store_memory_fraction instead.
"""
Expand All @@ -340,6 +346,7 @@ class DatasetConfig:
max_object_store_memory_fraction: Optional[float] = None
global_shuffle: Optional[bool] = None
randomize_block_order: Optional[bool] = None
per_epoch_preprocessor: Optional["Preprocessor"] = None
# Deprecated.
use_stream_api: Optional[int] = None
stream_window_size: Optional[int] = None
Expand Down Expand Up @@ -377,6 +384,7 @@ def fill_defaults(self) -> "DatasetConfig":
randomize_block_order=self.randomize_block_order
if self.randomize_block_order is not None
else True,
per_epoch_preprocessor=self.per_epoch_preprocessor,
)

@staticmethod
Expand Down Expand Up @@ -444,6 +452,20 @@ def validated(
"must be None or a float with value -1 or >=0, but got "
f"{v.max_object_store_memory_fraction}."
)
if v.per_epoch_preprocessor is not None:
if not isinstance(v.per_epoch_preprocessor, Preprocessor):
raise ValueError(
"`per_epoch_preprocessor` must be a ray.data.Preprocessor "
f"but got {v.per_epoch_preprocessor}."
)
if (
v.per_epoch_preprocessor.fit_status()
!= Preprocessor.FitStatus.NOT_FITTABLE
):
raise ValueError(
"`per_epoch_preprocessor` currently does not support "
"fittable ray.data.Preprocessors."
)

if len(fittable) > 1:
raise ValueError(
Expand Down Expand Up @@ -474,6 +496,9 @@ def _merge(self, other: "DatasetConfig") -> "DatasetConfig":
randomize_block_order=self.randomize_block_order
if other.randomize_block_order is None
else other.randomize_block_order,
per_epoch_preprocessor=self.per_epoch_preprocessor
if other.per_epoch_preprocessor is None
else other.per_epoch_preprocessor,
)
return new_config

Expand Down
119 changes: 119 additions & 0 deletions python/ray/air/tests/test_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ray.air.config import DatasetConfig, ScalingConfig
from ray.air.util.check_ingest import make_local_dataset_iterator
from ray.data import DatasetIterator
from ray.data.preprocessor import Preprocessor
from ray.data.preprocessors import BatchMapper
from ray.train.data_parallel_trainer import DataParallelTrainer

Expand Down Expand Up @@ -421,6 +422,124 @@ def check_error(shard, results):
test.fit()


@pytest.mark.parametrize("max_object_store_memory_fraction", [None, 1, 0.3])
def test_deterministic_per_epoch_preprocessor(
ray_start_4_cpus, max_object_store_memory_fraction
):
ds = ray.data.range_table(5)

def multiply(x):
return x * 2

it = make_local_dataset_iterator(
ds,
# Add some random noise to each integer.
preprocessor=BatchMapper(
lambda x: x + 0.1 * random.random(), batch_format="pandas"
),
dataset_config=DatasetConfig(
randomize_block_order=False,
max_object_store_memory_fraction=max_object_store_memory_fraction,
per_epoch_preprocessor=BatchMapper(multiply, batch_format="pandas"),
),
)

def checker(shard, results):
assert len(results[0]) == 5, (max_object_store_memory_fraction, results)
if max_object_store_memory_fraction is None:
assert results[0] == results[1], (
max_object_store_memory_fraction,
results,
)
else:
# Windowed pipelined ingest also reapplies the base
# preprocessor on every epoch, so we get a random dataset each
# time.
assert results[0] != results[1], (
max_object_store_memory_fraction,
results,
)
# Per-epoch preprocessor was applied at least once.
assert all(int(x) % 2 == 0 for x in results[0]), (
max_object_store_memory_fraction,
results,
)
# Per-epoch preprocessor was applied no more than once.
assert any(int(x) % 4 != 0 for x in results[0]), (
max_object_store_memory_fraction,
results,
)

TestStream.train_loop_per_worker(it, checker)


@pytest.mark.parametrize("max_object_store_memory_fraction", [None, 1, 0.3])
def test_nondeterministic_per_epoch_preprocessor(
ray_start_4_cpus, max_object_store_memory_fraction
):
ds = ray.data.range_table(5)

# Use randomized per-epoch preprocessor to check that it gets applied once
# per epoch.
def rand(x):
return x * random.random()

it = make_local_dataset_iterator(
ds,
preprocessor=None,
dataset_config=DatasetConfig(
randomize_block_order=False,
max_object_store_memory_fraction=max_object_store_memory_fraction,
per_epoch_preprocessor=BatchMapper(rand, batch_format="pandas"),
),
)

def checker(shard, results):
assert len(results[0]) == 5, (max_object_store_memory_fraction, results)
# Per-epoch preprocessor is randomized, so we should get a random
# dataset on each epoch.
assert results[0] != results[1], (max_object_store_memory_fraction, results)

TestStream.train_loop_per_worker(it, checker)


def test_validate_per_epoch_preprocessor(ray_start_4_cpus):
ds = ray.data.range_table(5)

def multiply(x):
return x * 2

dataset_config = DatasetConfig(
per_epoch_preprocessor=BatchMapper(multiply, batch_format="pandas")
)
DatasetConfig.validated(
{
"train": dataset_config,
},
{"train": ds},
)

with pytest.raises(ValueError):
# Must specify a ray.data.Preprocessor.
dataset_config = DatasetConfig(per_epoch_preprocessor=multiply)
DatasetConfig.validated(
{
"train": dataset_config,
},
{"train": ds},
)

with pytest.raises(ValueError):
# Must specify a non-fittable ray.data.Preprocessor.
dataset_config = DatasetConfig(per_epoch_preprocessor=Preprocessor())
DatasetConfig.validated(
{
"train": dataset_config,
},
{"train": ds},
)


if __name__ == "__main__":
import sys

Expand Down
16 changes: 8 additions & 8 deletions python/ray/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from typing import TYPE_CHECKING, Optional, Union, Dict, Any

from ray.air.util.data_batch_conversion import BatchFormat, BlockFormat
from ray.data import Dataset, DatasetPipeline
from ray.util.annotations import DeveloperAPI, PublicAPI

if TYPE_CHECKING:
from ray.data import Dataset, DatasetPipeline
import pandas as pd
import numpy as np
from ray.air.data_batch_type import DataBatchType
Expand Down Expand Up @@ -73,7 +73,7 @@ def transform_stats(self) -> Optional[str]:
return None
return self._transform_stats

def fit(self, dataset: Dataset) -> "Preprocessor":
def fit(self, dataset: "Dataset") -> "Preprocessor":
"""Fit this Preprocessor to the Dataset.
Fitted state attributes will be directly set in the Preprocessor.
Expand Down Expand Up @@ -104,7 +104,7 @@ def fit(self, dataset: Dataset) -> "Preprocessor":

return self._fit(dataset)

def fit_transform(self, dataset: Dataset) -> Dataset:
def fit_transform(self, dataset: "Dataset") -> "Dataset":
"""Fit this Preprocessor to the Dataset and then transform the Dataset.
Calling it more than once will overwrite all previously fitted state:
Expand All @@ -120,7 +120,7 @@ def fit_transform(self, dataset: Dataset) -> Dataset:
self.fit(dataset)
return self.transform(dataset)

def transform(self, dataset: Dataset) -> Dataset:
def transform(self, dataset: "Dataset") -> "Dataset":
"""Transform the given dataset.
Args:
Expand Down Expand Up @@ -170,7 +170,7 @@ def transform_batch(self, data: "DataBatchType") -> "DataBatchType":
)
return self._transform_batch(data)

def _transform_pipeline(self, pipeline: DatasetPipeline) -> DatasetPipeline:
def _transform_pipeline(self, pipeline: "DatasetPipeline") -> "DatasetPipeline":
"""Transform the given DatasetPipeline.
Args:
Expand Down Expand Up @@ -205,7 +205,7 @@ def _check_is_fitted(self) -> bool:
return bool(fitted_vars)

@DeveloperAPI
def _fit(self, dataset: Dataset) -> "Preprocessor":
def _fit(self, dataset: "Dataset") -> "Preprocessor":
"""Sub-classes should override this instead of fit()."""
raise NotImplementedError()

Expand Down Expand Up @@ -259,8 +259,8 @@ def _determine_transform_to_use(self, data_format: BlockFormat) -> BatchFormat:
return transform_type

def _transform(
self, dataset: Union[Dataset, DatasetPipeline]
) -> Union[Dataset, DatasetPipeline]:
self, dataset: Union["Dataset", "DatasetPipeline"]
) -> Union["Dataset", "DatasetPipeline"]:
# TODO(matt): Expose `batch_size` or similar configurability.
# The default may be too small for some datasets and too large for others.

Expand Down
4 changes: 4 additions & 0 deletions python/ray/data/preprocessors/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import hashlib
from typing import List

from ray.util.annotations import DeveloperAPI


@DeveloperAPI
def simple_split_tokenizer(value: str) -> List[str]:
"""Tokenize a string using a split on spaces."""
return value.split(" ")


@DeveloperAPI
def simple_hash(value: object, num_features: int) -> int:
"""Deterministically hash a value into the integer space."""
encoded_value = str(value).encode()
Expand Down
Loading

0 comments on commit ae167f0

Please sign in to comment.