-
Notifications
You must be signed in to change notification settings - Fork 6.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] Add option for per-epoch preprocessor #31739
[AIR] Add option for per-epoch preprocessor #31739
Conversation
Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
However, in some cases you may want to reapply a preprocessor on each epoch, for example to augment your training dataset with a randomized transform. | ||
|
||
To support this use case, AIR offers an additional *per-epoch preprocessor* that gets reapplied on each epoch, after all other preprocessors and right before dataset consumption (e.g., using :meth:`~ray.data.DatasetIterator.iter_batches()`). | ||
Per-epoch preprocessing also executes in parallel with dataset consumption to reduce pauses in dataset consumption. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also fit() the per-epoch preprocessor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually wasn't sure about this part because I don't really understand how fit() works...
- when do we need to call fit()?
- if the standard preprocessor is defined, do we need to fit() on the preprocessed dataset or the input dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now, we call fit on start (actually fit_transform() I believe to create the original preprocessed dataset).. A fittable preprocessor isn't usable for transformation until it is fitted.
if the standard preprocessor is defined, do we need to fit() on the preprocessed dataset or the input dataset?
Hmm I'd think you would fit the preprocessed dataset, since this preprocessor is logically consuming the output of the previous one-time preprocessor.
Perhaps we should just raise ValueError if the per-epoch preprocessor requires fitting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good idea, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on not allowing fittable per-epoch preprocessors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep +1 on not allowing fittable per-epoch preprocessors
# A randomized preprocessor that adds a random float to all values, to be | ||
# reapplied on each epoch after `preprocessor`. Each epoch will therefore add a | ||
# different random float to the scaled dataset. | ||
rand_preprocessor = BatchMapper(lambda df: df + random.random(), batch_format="pandas") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we call this add_noise or something to avoid overloading the term "random" too many times in this example?
The standard preprocessor passed to the ``Trainer`` is only applied once to the initial dataset when using :ref:`bulk ingest <air-streaming-ingest>`. | ||
However, in some cases you may want to reapply a preprocessor on each epoch, for example to augment your training dataset with a randomized transform. | ||
|
||
To support this use case, AIR offers an additional *per-epoch preprocessor* that gets reapplied on each epoch, after all other preprocessors and right before dataset consumption (e.g., using :meth:`~ray.data.DatasetIterator.iter_batches()`). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the executing in parallel part only true for the pipelined enabled version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm using DatasetPipeline under the hood so actually it is always true. I figure this is OK since the implementation detail is now hidden under DatasetIterator and the feature is experimental anyway. Long-term, I imagine we want to switch to the fully pipelined Datasets backend or we cache the preprocessed dataset and run the per-epoch preprocessing on the pipelined backend.
def multiply(x): | ||
return x * 2 | ||
|
||
for max_object_store_memory_fraction in [None, 1, 0.3]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we use @pytest.mark.parametrize
for this so that it will be easier to identify which case fails?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @swang! LGTM-- just left 1 comment on using pytest for parametrization rather than doing it manually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! This'll be a super useful addition
dataset = dataset.repeat() | ||
# TODO: Replace with preprocessor.transform when possible. | ||
per_epoch_prep = config.per_epoch_preprocessor.transform_batch | ||
dataset = dataset.map_batches(per_epoch_prep) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if I'm just being dumb here, but doesn't this apply the per-epoch preprocessor twice on the first epoch? Like, if config.per_epoch_preprocessor
isn't None
, then we apply it on both line 204 and line 219?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a good point...I think lines 218-219 need to be moved inside the if statement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah thanks you're right, it should be under an elif
. I'll add a test to make sure we're only applying it once.
}, | ||
{"train": ds}, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: could you add a comment here and maybe on line 487 describing why DatasetConfig
raises a ValueError
? It wasn't obvious to me why Preprocessor()
is invalid from the reading the test.
# Reapply the per epoch preprocessor on each epoch. | ||
if isinstance(dataset, Dataset): | ||
dataset = dataset.repeat() | ||
# TODO: Replace with preprocessor.transform when possible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can use preprocessor._transform_pipeline
here now
This adds an option to the AIR DatasetConfig for a preprocessor that gets reapplied on each epoch. Currently the implementation uses DatasetPipeline to ensure that the extra preprocessing step is overlapped with training. Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu> Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
Why are these changes needed?
This adds an option to the AIR DatasetConfig for a preprocessor that gets reapplied on each epoch. Currently the implementation uses DatasetPipeline to ensure that the extra preprocessing step is overlapped with training.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.