-
Notifications
You must be signed in to change notification settings - Fork 6.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] Add option for per-epoch preprocessor #31739
Changes from 3 commits
53e9cf6
fd26d9a
35d06aa
e31805e
52ce5e9
8be8eda
c055ca4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -178,6 +178,25 @@ Shuffling or data randomization is important for training high-quality models. B | |
* you suspect high-quality shuffles may significantly improve model quality; and | ||
* absolute ingest performance is less of a concern | ||
|
||
.. _air-per-epoch-preprocessing: | ||
|
||
Applying randomized preprocessing (experimental) | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
The standard preprocessor passed to the ``Trainer`` is only applied once to the initial dataset when using :ref:`bulk ingest <air-streaming-ingest>`. | ||
However, in some cases you may want to reapply a preprocessor on each epoch, for example to augment your training dataset with a randomized transform. | ||
|
||
To support this use case, AIR offers an additional *per-epoch preprocessor* that gets reapplied on each epoch, after all other preprocessors and right before dataset consumption (e.g., using :meth:`~ray.data.DatasetIterator.iter_batches()`). | ||
Per-epoch preprocessing also executes in parallel with dataset consumption to reduce pauses in dataset consumption. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we also fit() the per-epoch preprocessor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually wasn't sure about this part because I don't really understand how fit() works...
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now, we call fit on start (actually fit_transform() I believe to create the original preprocessed dataset).. A fittable preprocessor isn't usable for transformation until it is fitted.
Hmm I'd think you would fit the preprocessed dataset, since this preprocessor is logically consuming the output of the previous one-time preprocessor. Perhaps we should just raise ValueError if the per-epoch preprocessor requires fitting? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good idea, thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 on not allowing fittable per-epoch preprocessors There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep +1 on not allowing fittable per-epoch preprocessors |
||
|
||
This example shows how to use this feature to apply a randomized preprocessor on top of the standard preprocessor. | ||
|
||
.. literalinclude:: doc_code/air_ingest.py | ||
:language: python | ||
:start-after: __config_6__ | ||
:end-before: __config_6_end__ | ||
|
||
|
||
.. _air-splitting-aux-datasets: | ||
|
||
Splitting Auxiliary Datasets | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from ray.air.config import DatasetConfig, ScalingConfig | ||
from ray.air.util.check_ingest import make_local_dataset_iterator | ||
from ray.data import DatasetIterator | ||
from ray.data.preprocessor import Preprocessor | ||
from ray.data.preprocessors import BatchMapper | ||
from ray.train.data_parallel_trainer import DataParallelTrainer | ||
|
||
|
@@ -398,6 +399,111 @@ def check_error(shard, results): | |
test.fit() | ||
|
||
|
||
def test_deterministic_per_epoch_preprocessor(ray_start_4_cpus): | ||
ds = ray.data.range_table(5) | ||
|
||
def multiply(x): | ||
return x * 2 | ||
|
||
for max_object_store_memory_fraction in [None, 1, 0.3]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we use |
||
it = make_local_dataset_iterator( | ||
ds, | ||
preprocessor=BatchMapper( | ||
lambda x: x * int(10 * random.random()), batch_format="pandas" | ||
), | ||
dataset_config=DatasetConfig( | ||
randomize_block_order=False, | ||
max_object_store_memory_fraction=max_object_store_memory_fraction, | ||
per_epoch_preprocessor=BatchMapper(multiply, batch_format="pandas"), | ||
), | ||
) | ||
|
||
def checker(shard, results): | ||
assert len(results[0]) == 5, (max_object_store_memory_fraction, results) | ||
if max_object_store_memory_fraction is None: | ||
assert results[0] == results[1], ( | ||
max_object_store_memory_fraction, | ||
results, | ||
) | ||
else: | ||
# Windowed pipelined ingest also reapplies the base | ||
# preprocessor on every epoch, so we get a random dataset each | ||
# time. | ||
assert results[0] != results[1], ( | ||
max_object_store_memory_fraction, | ||
results, | ||
) | ||
assert all(x % 2 == 0 for x in results[0]), ( | ||
max_object_store_memory_fraction, | ||
results, | ||
) | ||
|
||
TestStream.train_loop_per_worker(it, checker) | ||
|
||
|
||
def test_nondeterministic_per_epoch_preprocessor(ray_start_4_cpus): | ||
ds = ray.data.range_table(5) | ||
|
||
# Use randomized per-epoch preprocessor to check that it gets applied once | ||
# per epoch. | ||
def rand(x): | ||
return x * random.random() | ||
|
||
for max_object_store_memory_fraction in [None, 1, 0.3]: | ||
it = make_local_dataset_iterator( | ||
ds, | ||
preprocessor=None, | ||
dataset_config=DatasetConfig( | ||
randomize_block_order=False, | ||
max_object_store_memory_fraction=max_object_store_memory_fraction, | ||
per_epoch_preprocessor=BatchMapper(rand, batch_format="pandas"), | ||
), | ||
) | ||
|
||
def checker(shard, results): | ||
assert len(results[0]) == 5, (max_object_store_memory_fraction, results) | ||
# Per-epoch preprocessor is randomized, so we should get a random | ||
# dataset on each epoch. | ||
assert results[0] != results[1], (max_object_store_memory_fraction, results) | ||
|
||
TestStream.train_loop_per_worker(it, checker) | ||
|
||
|
||
def test_validate_per_epoch_preprocessor(ray_start_4_cpus): | ||
ds = ray.data.range_table(5) | ||
|
||
def multiply(x): | ||
return x * 2 | ||
|
||
dataset_config = DatasetConfig( | ||
per_epoch_preprocessor=BatchMapper(multiply, batch_format="pandas") | ||
) | ||
DatasetConfig.validated( | ||
{ | ||
"train": dataset_config, | ||
}, | ||
{"train": ds}, | ||
) | ||
|
||
with pytest.raises(ValueError): | ||
dataset_config = DatasetConfig(per_epoch_preprocessor=multiply) | ||
DatasetConfig.validated( | ||
{ | ||
"train": dataset_config, | ||
}, | ||
{"train": ds}, | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: could you add a comment here and maybe on line 487 describing why |
||
with pytest.raises(ValueError): | ||
dataset_config = DatasetConfig(per_epoch_preprocessor=Preprocessor()) | ||
DatasetConfig.validated( | ||
{ | ||
"train": dataset_config, | ||
}, | ||
{"train": ds}, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
from ray.air.config import DatasetConfig | ||
|
||
from ray.data import Dataset, DatasetPipeline | ||
from ray.data.preprocessors import Chain | ||
from ray.air._internal.util import _estimate_avail_object_store_memory | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -188,17 +189,34 @@ def get_dataset_shards( | |
) | ||
dataset = dataset.window(bytes_per_window=stream_window_size).repeat() | ||
# In windowed mode, we re-apply the preprocessor on each iteration. | ||
if self.preprocessor: | ||
if self.preprocessor or config.per_epoch_preprocessor: | ||
if self.preprocessor is not None: | ||
preprocessor = self.preprocessor | ||
if config.per_epoch_preprocessor is not None: | ||
preprocessor = Chain( | ||
preprocessor, config.per_epoch_preprocessor | ||
) | ||
else: | ||
preprocessor = config.per_epoch_preprocessor | ||
|
||
# TODO: Replace with self.preprocessor.transform when possible. | ||
prep = self.preprocessor.transform_batch | ||
prep = preprocessor.transform_batch | ||
dataset = dataset.map_batches(prep, batch_format="pandas") | ||
|
||
# Always re-randomize each window; this doesn't help with reducing | ||
# cluster hot-spots since we already randomized the based blocks, but | ||
# can help with improving randomness in combination with local shuffle. | ||
if config.randomize_block_order and not config.global_shuffle: | ||
# TODO(swang): Should randomize block order across the | ||
# original dataset, not the window. | ||
dataset = dataset.randomize_block_order_each_window() | ||
if config.per_epoch_preprocessor is not None: | ||
# Reapply the per epoch preprocessor on each epoch. | ||
if isinstance(dataset, Dataset): | ||
dataset = dataset.repeat() | ||
# TODO: Replace with preprocessor.transform when possible. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can use |
||
per_epoch_prep = config.per_epoch_preprocessor.transform_batch | ||
dataset = dataset.map_batches(per_epoch_prep) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if I'm just being dumb here, but doesn't this apply the per-epoch preprocessor twice on the first epoch? Like, if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a good point...I think lines 218-219 need to be moved inside the if statement. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah thanks you're right, it should be under an |
||
|
||
if config.global_shuffle: | ||
# If global shuffle is requested, then we should try to overlap | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the executing in parallel part only true for the pipelined enabled version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm using DatasetPipeline under the hood so actually it is always true. I figure this is OK since the implementation detail is now hidden under DatasetIterator and the feature is experimental anyway. Long-term, I imagine we want to switch to the fully pipelined Datasets backend or we cache the preprocessed dataset and run the per-epoch preprocessing on the pipelined backend.