-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax] Fix incomplete batches in example scripts #17863
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this look good to me! The pad_shard_unpad
trick is neat!
@@ -847,7 +846,7 @@ def generate_step(params, batch): | |||
# train | |||
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): | |||
batch = next(train_loader) | |||
state, train_metric = p_train_step(state, batch) | |||
state, train_metric = pad_shard_unpad(p_train_step)(state, batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really nit!
|
||
for idx in batch_idx: | ||
batch = dataset[idx] | ||
batch = {k: jnp.array(v) for k, v in batch.items()} | ||
|
||
batch = shard(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does pad_shard_unpad
also takes care of sharding the batch ? (seems like it, but just to confirm)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing @sanchit-gandhi !
Just waiting to double check that the slow tests pass from |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@patil-suraj @sanchit-gandhi can we merge this one? |
Just verifying the slow tests from |
8238406
to
66220e4
Compare
* [Flax] Fix incomplete batches in example scripts * fix dataloader batching * convert jnp batch idxs to np array * add missing `pad_shard_unpad` to final prediction generate step * only `pad_shard_unpad` at inference time * merge conflicts * remove incomplete batch step from eval * fix run_qa.py * add `pad_shard_unpad` to run_flax_ner.py * add `pad_shard_unpad` to run_flax_glue.py * add `pad_shard_unpad` to run_image_classification.py * make style * fix mlm flax eval batches * remove redundant imports
What does this PR do?
Currently in our Flax examples scripts, we drop the last incomplete batch during training and inference:
transformers/examples/flax/summarization/run_summarization_flax.py
Line 350 in 0917870
We do this for two reasons:
pmap
'd function .During training, dropping the last batch isn't an issue: since we shuffle the data and train for multiple epochs, all of the training data is eventually used and the effects of dropping the last batch amortised.
However, during evaluation and prediction, dropping the last batch leads to incorrect results: since we don't account for the examples in the last batch, we do not evaluate over the whole dataset, and thus have partial results.
This PR corrects for the incomplete batches in the relevant Flax training examples.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.