-
Notifications
You must be signed in to change notification settings - Fork 615
Fixing counting number of batches for accumulation through epoch #2745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2745
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2745 +/- ##
==========================================
+ Coverage 60.64% 62.64% +1.99%
==========================================
Files 428 430 +2
Lines 26091 26395 +304
==========================================
+ Hits 15823 16534 +711
+ Misses 10268 9861 -407 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this @wesbz !
After some discussion with @pbontrager and @ebsmothers , I think we actually prefer an approach like the one you outlined as option 1, where in we simply drop the extra grads / metrics for any data that has been accumulated before the end of an epoch.
In our opinion, this breaks the least amount of assumptions around which data is trained in a certain epoch and also mirrors the behavior of the PyTorch dataloader which exposes an parameter called drop_last
, which will similarly drop any remaining data that does not divide evenly into the size of your batch.
Thanks for your response, @joecummings |
This reverts commit afb1c0b.
Context
What is the purpose of this PR? Is it to
As I was running DPO and SFT, I noticed two surprising behaviours when running for several epochs:

I noticed that statistics accumulators re-initialization and zero-ing gradients would happen in this condition (e.g.
recipes/full_finetune_distributed.py
):Now the issue is that if your gradient accumulation parameter is set at say 8, but you only have 63 batches to process, it means that you process the last 7 batches, accumulating statistics and gradients, without re-initializing them before starting the new epoch, continuing accumulating stats and grads for the first 8 batches and then do a step.
You end up accumulating 15 batches instead of 8, messing up not only with the statistics report but also with the optimisation.
I see at least three possibilities:
len(self._dataloader) % self._gradient_accumulation_steps == 0
;Changelog
This PR implements the second option.
It should be done for all recipe but I thought discussing the solution first was better.
Test plan
pre-commit install
)pytest tests
pytest tests -m integration_test
UX