Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient accumulation yields worse results than the equivalent batch size #2175

Closed
benjamin-marie opened this issue Oct 4, 2024 · 27 comments · Fixed by huggingface/transformers#34198
Labels
⏳ needs more info Additional information or clarification is required to proceed ❓ question Seeking clarification or more information

Comments

@benjamin-marie
Copy link

I expected a training configuration with per_device_train_batch_size=1 and gradient_accumulation_steps=32 to yield the same (or similar) result to per_device_train_batch_size=32 and gradient_accumulation_steps=1 but that's not the case, the former is much worse.
I ran several experiments with SmolLM-135M and Llama 3.2 1B, using always the same seed, and the results are consistent with this observation.

image
image

Maybe I misunderstand something here?

My training code is in this Colab notebook. I ran this notebook to draw the learning curves above, restarting the notebook between each training to avoid OOM.
Note that I have the same observations with Qwen2.

@mayank31398
Copy link

Hey, this is expected behaviour.
FSDP-1 only allows accumulation in 16-bit precision.
This is not the case for FSDP-2 which allows accumulation in both 16-bit and 32-bit.

@mayank31398
Copy link

documentation for FSDP-1:
Screenshot 2024-10-06 at 2 34 34 AM
documentation for FSDP-2:
Screenshot 2024-10-06 at 2 35 10 AM

@benjamin-marie
Copy link
Author

Interesting, I didn't know this. But I don't think it matters, I would be surprised that TRL uses FSDP's reduce-scatter for single GPU training.

@qgallouedec
Copy link
Member

Hi, thanks for reporting this.
Can you share your system info and the code you use for training?

@qgallouedec qgallouedec added ❓ question Seeking clarification or more information ⏳ needs more info Additional information or clarification is required to proceed labels Oct 7, 2024
@benjamin-marie
Copy link
Author

Sure, it's all in the notebook I linked to in my first post. I ran this notebook on Colab with the A100.

@teknium1
Copy link

Someone tried in fp32 and it didnt help so it doesnt seem to be the reason

https://x.com/bnjmn_marie/status/1842464802636980564

@vigneshwaran
Copy link

Have you tried full/mixed precision AdamW optimiser?

@benjamin-marie
Copy link
Author

Yes:

image

This configuration uses fp32 and adamw_torch.

@fzyzcjy
Copy link
Contributor

fzyzcjy commented Oct 15, 2024

Hi, is there any updates? Thanks!

@danielhanchen
Copy link
Contributor

I'm writing up a report about this - I think I managed to fix it :)
(Yes it is in fact a subtle bug!) - will tweet and post about it in like 8 - 10 hours!

@shimmyshimmer
Copy link

We have fixed the issue guys!

Tweet: https://twitter.com/UnslothAI/status/1846231235749990699
Blogpost: https://unsloth.ai/blog/gradient

@geronimi73
Copy link

We have fixed the issue guys!

nice! feel like fixing it in TRL too?

@shimmyshimmer
Copy link

We have fixed the issue guys!

nice! feel like fixing it in TRL too?

The Hugging Face team is already on it! :)

@muellerzr
Copy link
Contributor

muellerzr commented Oct 15, 2024

(Somewhat, currently trying to reverse engineer a few ways you did it, you guys would be much faster at it I imagine if you want to beat us to it ;) As this is more than TRL, it's ground up transformers/Trainer tbh I think)

@danielhanchen
Copy link
Contributor

:)
Wrote a detailed tweet about it: https://x.com/danielhanchen/status/1846235913443262891
Also Reddit post: https://www.reddit.com/r/LocalLLaMA/comments/1g4ego7/llm_training_bug_fixes_gradient_accumulation_was/
Blog post: https://unsloth.ai/blog/gradient
Also @shimmyshimmer is my brother!! :)

@muellerzr
Copy link
Contributor

Just as a fair warning, this will not be an immediate nor quick fix, since essentially this means every single model's calculation is off when doing output.loss, and every single model will need a custom variation of CrossEntropy (and other valid loss funcs) if you do not calculate the loss by hand.

We are working on figuring out the best solution.

@nahidalam
Copy link

@danielhanchen from the blog The 2nd theory was there is in fact a bug in the loss calculation, which we find to be the case. this bug is specifically for CrossEntropy loss calculation in HF trl? This will not be an issue if someone is using say torch.nn.CrossEntropyLoss ?

@huseinzol05
Copy link

@muellerzr , i believe this only make sense padding based batch, for packing, there is no 0 / pad token in the batch, so avg cross entropy is consistent

@danielhanchen
Copy link
Contributor

@nahidalam Unfortunately this is not a HF native issue. The way gradient accumulation has been originally done in many packages even those that use Pytorch directly accidentally missed considering ignored tokens. Using CE Loss directly does not solve the issue since mean reduction does not work, and sum will cause the loss to be scaled incorrectly.

@huseinzol05 Packing is also affected albeit less so since some people also do training on completions so it'll also make the loss incorrect.

@muellerzr If you guys need any help on anything, ping me!

@wongjingping
Copy link

wongjingping commented Oct 16, 2024

Kudos @danielhanchen on the fix! Neat write-up as well! Back to the OP, I think the issue isn't with the trl library, but with the transformers library instead, because of how SFTTrainer extends Trainer, how the loss is calculated in Trainer's compute_loss, and how it is naively scaled by the number of steps here. I don't have a ton of context, but I imagine the more principled solution would be to fix it within the Trainer.compute_loss function, vs say having SFTTrainer override the compute_loss method. Happy to assist with the transformers fix if anyone from HF would like to take me up on it 😄

@qingjianbuyi
Copy link

Does DDP have the same issue? @danielhanchen

@muellerzr
Copy link
Contributor

muellerzr commented Oct 28, 2024

Yes, ddp does we already have documented this + a fix is being put in (I also have an article talking about this more, tl;dr you can choose a slower option of gathering all of the inputs/counts, which causes a communication which generally isn't recommended so it's False by default)

@burtenshaw
Copy link
Contributor

Should this be closed since it's fixed in transformers?

cc @qgallouedec @lewtun

@qgallouedec
Copy link
Member

qgallouedec commented Nov 25, 2024

Right @burtenshaw.
Closed by huggingface/transformers#34198

@pminervini
Copy link

Screenshot 2024-11-27 at 07 53 31

Time to hit that "Close Issue" button @qgallouedec @burtenshaw! :) I thought the issue was open because of that!

@qgallouedec
Copy link
Member

Oops

@surprisedPikachu007
Copy link

@huseinzol05 Packing is also affected albeit less so since some people also do training on completions so it'll also make the loss incorrect.

For language modeling task, will this be a problem even if all samples in a batch have the exact same sequence length?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⏳ needs more info Additional information or clarification is required to proceed ❓ question Seeking clarification or more information
Projects
None yet