Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Add option for per-epoch preprocessor #31739

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions doc/source/ray-air/check-ingest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,25 @@ Shuffling or data randomization is important for training high-quality models. B
* you suspect high-quality shuffles may significantly improve model quality; and
* absolute ingest performance is less of a concern

.. _air-per-epoch-preprocessing:

Applying randomized preprocessing (experimental)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The standard preprocessor passed to the ``Trainer`` is only applied once to the initial dataset when using :ref:`bulk ingest <air-streaming-ingest>`.
However, in some cases you may want to reapply a preprocessor on each epoch, for example to augment your training dataset with a randomized transform.

To support this use case, AIR offers an additional *per-epoch preprocessor* that gets reapplied on each epoch, after all other preprocessors and right before dataset consumption (e.g., using :meth:`~ray.data.DatasetIterator.iter_batches()`).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the executing in parallel part only true for the pipelined enabled version?

Copy link
Contributor Author

@stephanie-wang stephanie-wang Jan 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using DatasetPipeline under the hood so actually it is always true. I figure this is OK since the implementation detail is now hidden under DatasetIterator and the feature is experimental anyway. Long-term, I imagine we want to switch to the fully pipelined Datasets backend or we cache the preprocessed dataset and run the per-epoch preprocessing on the pipelined backend.

Per-epoch preprocessing also executes in parallel with dataset consumption to reduce pauses in dataset consumption.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also fit() the per-epoch preprocessor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually wasn't sure about this part because I don't really understand how fit() works...

  • when do we need to call fit()?
  • if the standard preprocessor is defined, do we need to fit() on the preprocessed dataset or the input dataset?

Copy link
Contributor

@ericl ericl Jan 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, we call fit on start (actually fit_transform() I believe to create the original preprocessed dataset).. A fittable preprocessor isn't usable for transformation until it is fitted.

if the standard preprocessor is defined, do we need to fit() on the preprocessed dataset or the input dataset?

Hmm I'd think you would fit the preprocessed dataset, since this preprocessor is logically consuming the output of the previous one-time preprocessor.

Perhaps we should just raise ValueError if the per-epoch preprocessor requires fitting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on not allowing fittable per-epoch preprocessors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep +1 on not allowing fittable per-epoch preprocessors


This example shows how to use this feature to apply a randomized preprocessor on top of the standard preprocessor.

.. literalinclude:: doc_code/air_ingest.py
:language: python
:start-after: __config_6__
:end-before: __config_6_end__


.. _air-splitting-aux-datasets:

Splitting Auxiliary Datasets
Expand Down
50 changes: 50 additions & 0 deletions doc/source/ray-air/doc_code/air_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,56 @@ def train_loop_per_worker():
my_trainer.fit()
# __config_5_end__

# __config_6__
import random

import ray
from ray.air import session
from ray.data import DatasetIterator
from ray.data.preprocessors import BatchMapper
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, DatasetConfig

# A simple preprocessor that just scales all values by 2.0.
preprocessor = BatchMapper(lambda df: df * 2, batch_format="pandas")

# A randomized preprocessor that adds a random float to all values, to be
# reapplied on each epoch after `preprocessor`. Each epoch will therefore add a
# different random float to the scaled dataset.
add_noise = BatchMapper(lambda df: df + random.random(), batch_format="pandas")


def train_loop_per_worker():
# Get a handle to the worker's assigned DatasetIterator shard.
data_shard: DatasetIterator = session.get_dataset_shard("train")

# Manually iterate over the data 10 times (10 epochs).
for _ in range(10):
for batch in data_shard.iter_batches():
print("Do some training on batch", batch)

# Print the stats for performance debugging.
print(data_shard.stats())


my_trainer = TorchTrainer(
train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=1),
datasets={
"train": ray.data.range_tensor(100),
},
dataset_config={
"train": DatasetConfig(
# Don't randomize order, just to make it easier to read the results.
randomize_block_order=False,
per_epoch_preprocessor=add_noise,
),
},
preprocessor=preprocessor,
)
my_trainer.fit()
# __config_6_end__

# __global_shuffling_start__
import ray
from ray.air import session
Expand Down
25 changes: 25 additions & 0 deletions python/ray/air/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ray.air.constants import WILDCARD_KEY
from ray.util.annotations import PublicAPI
from ray.widgets import Template, make_table_html_repr
from ray.data.preprocessor import Preprocessor

if TYPE_CHECKING:
from ray.data import Dataset
Expand Down Expand Up @@ -326,6 +327,11 @@ class DatasetConfig:
The main purpose of this is to prevent data fetching hotspots in the
cluster when running many parallel workers / trials on the same data.
We recommend enabling it always. True by default.
per_epoch_preprocessor [Experimental]: A preprocessor to re-apply on
each pass of the dataset. The main use case for this is to apply a
random transform on a training dataset on each epoch. The
per-epoch preprocessor will be applied *after* all other
preprocessors and in parallel with the dataset consumer.
use_stream_api: Deprecated. Use max_object_store_memory_fraction instead.
stream_window_size: Deprecated. Use max_object_store_memory_fraction instead.
"""
Expand All @@ -340,6 +346,7 @@ class DatasetConfig:
max_object_store_memory_fraction: Optional[float] = None
global_shuffle: Optional[bool] = None
randomize_block_order: Optional[bool] = None
per_epoch_preprocessor: Optional["Preprocessor"] = None
# Deprecated.
use_stream_api: Optional[int] = None
stream_window_size: Optional[int] = None
Expand Down Expand Up @@ -377,6 +384,7 @@ def fill_defaults(self) -> "DatasetConfig":
randomize_block_order=self.randomize_block_order
if self.randomize_block_order is not None
else True,
per_epoch_preprocessor=self.per_epoch_preprocessor,
)

@staticmethod
Expand Down Expand Up @@ -444,6 +452,20 @@ def validated(
"must be None or a float with value -1 or >=0, but got "
f"{v.max_object_store_memory_fraction}."
)
if v.per_epoch_preprocessor is not None:
if not isinstance(v.per_epoch_preprocessor, Preprocessor):
raise ValueError(
"`per_epoch_preprocessor` must be a ray.data.Preprocessor "
f"but got {v.per_epoch_preprocessor}."
)
if (
v.per_epoch_preprocessor.fit_status()
!= Preprocessor.FitStatus.NOT_FITTABLE
):
raise ValueError(
"`per_epoch_preprocessor` currently does not support "
"fittable ray.data.Preprocessors."
)

if len(fittable) > 1:
raise ValueError(
Expand Down Expand Up @@ -474,6 +496,9 @@ def _merge(self, other: "DatasetConfig") -> "DatasetConfig":
randomize_block_order=self.randomize_block_order
if other.randomize_block_order is None
else other.randomize_block_order,
per_epoch_preprocessor=self.per_epoch_preprocessor
if other.per_epoch_preprocessor is None
else other.per_epoch_preprocessor,
)
return new_config

Expand Down
106 changes: 106 additions & 0 deletions python/ray/air/tests/test_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray.air.config import DatasetConfig, ScalingConfig
from ray.air.util.check_ingest import make_local_dataset_iterator
from ray.data import DatasetIterator
from ray.data.preprocessor import Preprocessor
from ray.data.preprocessors import BatchMapper
from ray.train.data_parallel_trainer import DataParallelTrainer

Expand Down Expand Up @@ -398,6 +399,111 @@ def check_error(shard, results):
test.fit()


def test_deterministic_per_epoch_preprocessor(ray_start_4_cpus):
ds = ray.data.range_table(5)

def multiply(x):
return x * 2

for max_object_store_memory_fraction in [None, 1, 0.3]:
Copy link
Contributor

@amogkam amogkam Jan 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use @pytest.mark.parametrize for this so that it will be easier to identify which case fails?

it = make_local_dataset_iterator(
ds,
preprocessor=BatchMapper(
lambda x: x * int(10 * random.random()), batch_format="pandas"
),
dataset_config=DatasetConfig(
randomize_block_order=False,
max_object_store_memory_fraction=max_object_store_memory_fraction,
per_epoch_preprocessor=BatchMapper(multiply, batch_format="pandas"),
),
)

def checker(shard, results):
assert len(results[0]) == 5, (max_object_store_memory_fraction, results)
if max_object_store_memory_fraction is None:
assert results[0] == results[1], (
max_object_store_memory_fraction,
results,
)
else:
# Windowed pipelined ingest also reapplies the base
# preprocessor on every epoch, so we get a random dataset each
# time.
assert results[0] != results[1], (
max_object_store_memory_fraction,
results,
)
assert all(x % 2 == 0 for x in results[0]), (
max_object_store_memory_fraction,
results,
)

TestStream.train_loop_per_worker(it, checker)


def test_nondeterministic_per_epoch_preprocessor(ray_start_4_cpus):
ds = ray.data.range_table(5)

# Use randomized per-epoch preprocessor to check that it gets applied once
# per epoch.
def rand(x):
return x * random.random()

for max_object_store_memory_fraction in [None, 1, 0.3]:
it = make_local_dataset_iterator(
ds,
preprocessor=None,
dataset_config=DatasetConfig(
randomize_block_order=False,
max_object_store_memory_fraction=max_object_store_memory_fraction,
per_epoch_preprocessor=BatchMapper(rand, batch_format="pandas"),
),
)

def checker(shard, results):
assert len(results[0]) == 5, (max_object_store_memory_fraction, results)
# Per-epoch preprocessor is randomized, so we should get a random
# dataset on each epoch.
assert results[0] != results[1], (max_object_store_memory_fraction, results)

TestStream.train_loop_per_worker(it, checker)


def test_validate_per_epoch_preprocessor(ray_start_4_cpus):
ds = ray.data.range_table(5)

def multiply(x):
return x * 2

dataset_config = DatasetConfig(
per_epoch_preprocessor=BatchMapper(multiply, batch_format="pandas")
)
DatasetConfig.validated(
{
"train": dataset_config,
},
{"train": ds},
)

with pytest.raises(ValueError):
dataset_config = DatasetConfig(per_epoch_preprocessor=multiply)
DatasetConfig.validated(
{
"train": dataset_config,
},
{"train": ds},
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: could you add a comment here and maybe on line 487 describing why DatasetConfig raises a ValueError? It wasn't obvious to me why Preprocessor() is invalid from the reading the test.

with pytest.raises(ValueError):
dataset_config = DatasetConfig(per_epoch_preprocessor=Preprocessor())
DatasetConfig.validated(
{
"train": dataset_config,
},
{"train": ds},
)


if __name__ == "__main__":
import sys

Expand Down
12 changes: 6 additions & 6 deletions python/ray/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from typing import TYPE_CHECKING, Optional, Union, Dict, Any

from ray.air.util.data_batch_conversion import BatchFormat, BlockFormat
from ray.data import Dataset
from ray.util.annotations import DeveloperAPI, PublicAPI

if TYPE_CHECKING:
from ray.data import Dataset
import pandas as pd
import numpy as np
from ray.air.data_batch_type import DataBatchType
Expand Down Expand Up @@ -73,7 +73,7 @@ def transform_stats(self) -> Optional[str]:
return None
return self._transform_stats

def fit(self, dataset: Dataset) -> "Preprocessor":
def fit(self, dataset: "Dataset") -> "Preprocessor":
"""Fit this Preprocessor to the Dataset.

Fitted state attributes will be directly set in the Preprocessor.
Expand Down Expand Up @@ -104,7 +104,7 @@ def fit(self, dataset: Dataset) -> "Preprocessor":

return self._fit(dataset)

def fit_transform(self, dataset: Dataset) -> Dataset:
def fit_transform(self, dataset: "Dataset") -> "Dataset":
"""Fit this Preprocessor to the Dataset and then transform the Dataset.

Calling it more than once will overwrite all previously fitted state:
Expand All @@ -120,7 +120,7 @@ def fit_transform(self, dataset: Dataset) -> Dataset:
self.fit(dataset)
return self.transform(dataset)

def transform(self, dataset: Dataset) -> Dataset:
def transform(self, dataset: "Dataset") -> "Dataset":
"""Transform the given dataset.

Args:
Expand Down Expand Up @@ -180,7 +180,7 @@ def _check_is_fitted(self) -> bool:
return bool(fitted_vars)

@DeveloperAPI
def _fit(self, dataset: Dataset) -> "Preprocessor":
def _fit(self, dataset: "Dataset") -> "Preprocessor":
"""Sub-classes should override this instead of fit()."""
raise NotImplementedError()

Expand Down Expand Up @@ -233,7 +233,7 @@ def _determine_transform_to_use(self, data_format: BlockFormat) -> BatchFormat:

return transform_type

def _transform(self, dataset: Dataset) -> Dataset:
def _transform(self, dataset: "Dataset") -> "Dataset":
# TODO(matt): Expose `batch_size` or similar configurability.
# The default may be too small for some datasets and too large for others.

Expand Down
22 changes: 20 additions & 2 deletions python/ray/train/_internal/dataset_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ray.air.config import DatasetConfig

from ray.data import Dataset, DatasetPipeline
from ray.data.preprocessors import Chain
from ray.air._internal.util import _estimate_avail_object_store_memory

if TYPE_CHECKING:
Expand Down Expand Up @@ -188,17 +189,34 @@ def get_dataset_shards(
)
dataset = dataset.window(bytes_per_window=stream_window_size).repeat()
# In windowed mode, we re-apply the preprocessor on each iteration.
if self.preprocessor:
if self.preprocessor or config.per_epoch_preprocessor:
if self.preprocessor is not None:
preprocessor = self.preprocessor
if config.per_epoch_preprocessor is not None:
preprocessor = Chain(
preprocessor, config.per_epoch_preprocessor
)
else:
preprocessor = config.per_epoch_preprocessor

# TODO: Replace with self.preprocessor.transform when possible.
prep = self.preprocessor.transform_batch
prep = preprocessor.transform_batch
dataset = dataset.map_batches(prep, batch_format="pandas")

# Always re-randomize each window; this doesn't help with reducing
# cluster hot-spots since we already randomized the based blocks, but
# can help with improving randomness in combination with local shuffle.
if config.randomize_block_order and not config.global_shuffle:
# TODO(swang): Should randomize block order across the
# original dataset, not the window.
dataset = dataset.randomize_block_order_each_window()
if config.per_epoch_preprocessor is not None:
# Reapply the per epoch preprocessor on each epoch.
if isinstance(dataset, Dataset):
dataset = dataset.repeat()
# TODO: Replace with preprocessor.transform when possible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use preprocessor._transform_pipeline here now

per_epoch_prep = config.per_epoch_preprocessor.transform_batch
dataset = dataset.map_batches(per_epoch_prep)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I'm just being dumb here, but doesn't this apply the per-epoch preprocessor twice on the first epoch? Like, if config.per_epoch_preprocessor isn't None, then we apply it on both line 204 and line 219?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good point...I think lines 218-219 need to be moved inside the if statement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks you're right, it should be under an elif. I'll add a test to make sure we're only applying it once.


if config.global_shuffle:
# If global shuffle is requested, then we should try to overlap
Expand Down