Skip to content

Fixing counting number of batches for accumulation through epoch #2745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

wesbz
Copy link
Contributor

@wesbz wesbz commented May 17, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

As I was running DPO and SFT, I noticed two surprising behaviours when running for several epochs:
image

  1. The loss would drop at the beginning of a new epoch;
  2. The accuracy (for DPO) would be >100% at the beginning of a new epoch.
    I noticed that statistics accumulators re-initialization and zero-ing gradients would happen in this condition (e.g. recipes/full_finetune_distributed.py):
for idx, batch in enumerate(self._dataloader):
    ...
    # Optimizer step (if not fused in backward call)
    if (idx+1) % self._gradient_accumulation_steps == 0:
        ...

Now the issue is that if your gradient accumulation parameter is set at say 8, but you only have 63 batches to process, it means that you process the last 7 batches, accumulating statistics and gradients, without re-initializing them before starting the new epoch, continuing accumulating stats and grads for the first 8 batches and then do a step.
You end up accumulating 15 batches instead of 8, messing up not only with the statistics report but also with the optimisation.
I see at least three possibilities:

  1. dropping the last few batches so that len(self._dataloader) % self._gradient_accumulation_steps == 0;
  2. counting the number of batches to accumulate not in terms of batch index but absolute number of processed batches.
  3. making a step with the last few batches by correctly scaling

Changelog

This PR implements the second option.
It should be done for all recipe but I thought discussing the solution first was better.

Test plan

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented May 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2745

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 17, 2025
@codecov-commenter
Copy link

codecov-commenter commented May 17, 2025

Codecov Report

Attention: Patch coverage is 0% with 6 lines in your changes missing coverage. Please review.

Project coverage is 62.64%. Comparing base (c8e670b) to head (72c1eea).
Report is 9 commits behind head on main.

Files with missing lines Patch % Lines
recipes/full_dpo_distributed.py 0.00% 3 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2745      +/-   ##
==========================================
+ Coverage   60.64%   62.64%   +1.99%     
==========================================
  Files         428      430       +2     
  Lines       26091    26395     +304     
==========================================
+ Hits        15823    16534     +711     
+ Misses      10268     9861     -407     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@joecummings joecummings added triage review This issue should be discussed in weekly review and removed triage review This issue should be discussed in weekly review labels May 19, 2025
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this @wesbz !

After some discussion with @pbontrager and @ebsmothers , I think we actually prefer an approach like the one you outlined as option 1, where in we simply drop the extra grads / metrics for any data that has been accumulated before the end of an epoch.

In our opinion, this breaks the least amount of assumptions around which data is trained in a certain epoch and also mirrors the behavior of the PyTorch dataloader which exposes an parameter called drop_last, which will similarly drop any remaining data that does not divide evenly into the size of your batch.

@wesbz
Copy link
Contributor Author

wesbz commented May 19, 2025

Thanks for your response, @joecummings
Ok sure, I can suggest one way of doing it that induce minimal change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants