-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Support casting to fp32 when word embeddings are tied to lm_head #4446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support casting to fp32 when word embeddings are tied to lm_head #4446
Conversation
|
Profiling with Profiling when LM_HEAD and WORD_EMBEDDINGS share the same object so additional memory for FP32 This is equal to 1.401GB - 1.110GB = 296 MB (approx) 2 is the factor of additional bytes FP32/FP16 Forward Pass 🤔 |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
tests/test_grpo_trainer.py
Outdated
|
|
||
| def test_training_with_cast_lm_head_to_fp32(self): | ||
| @pytest.mark.parametrize( | ||
| "model_name", ["trl-internal-testing/tiny-Qwen3ForCausalLM", "trl-internal-testing/tiny-Gemma2ForCausalLM"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Qwen3 has tied word embedding and Gemma 2 no, correct? If so, I'd just add a small comment so that we remember why we test these two cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the other way around Qwen3 has untied and Gemma 2 has tied.
| return (inputs[0].to(torch.float32),) + inputs[1:] | ||
|
|
||
| original_dtype_local = target_model.lm_head.weight.dtype | ||
| target_model.lm_head = target_model.lm_head.float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the record, float() is inlace, so in theory, you could just have
target_model.lm_head.float()it happens that .float() returns self, so target_model.lm_head = target_model.lm_head.float() works as well. I personally prefer the current way, assignment makes it more explicit.
qgallouedec
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! just a small comment on the test
commit 7a9592b Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Nov 4 14:32:04 2025 -0700 🐍 Drop Python 3.9 (huggingface#4183) commit 7f15a7f Author: Harras Mansoor <98635627+Harras3@users.noreply.github.com> Date: Wed Nov 5 02:06:31 2025 +0500 Removed outdated warning about batch contamination (huggingface#4423) commit 8b0a3ce Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue Nov 4 21:37:39 2025 +0100 Update tokenizer apply_chat_template with return_dict=True default (huggingface#4448) commit d9f9e2b Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Tue Nov 4 19:56:58 2025 +0000 Support casting to fp32 when word embeddings are tied to lm_head (huggingface#4446) commit 4e138ab Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Date: Tue Nov 4 15:15:23 2025 +0100 Upload notebook with T4 selected (huggingface#4449)
commit 4677cf2 Author: Harras Mansoor <98635627+Harras3@users.noreply.github.com> Date: Wed Nov 5 04:06:13 2025 +0500 Removed Sentiment Tuning Examples (#4424) commit 7a9592b Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Nov 4 14:32:04 2025 -0700 🐍 Drop Python 3.9 (#4183) commit 7f15a7f Author: Harras Mansoor <98635627+Harras3@users.noreply.github.com> Date: Wed Nov 5 02:06:31 2025 +0500 Removed outdated warning about batch contamination (#4423) commit 8b0a3ce Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue Nov 4 21:37:39 2025 +0100 Update tokenizer apply_chat_template with return_dict=True default (#4448) commit d9f9e2b Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Tue Nov 4 19:56:58 2025 +0000 Support casting to fp32 when word embeddings are tied to lm_head (#4446) commit 4e138ab Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Date: Tue Nov 4 15:15:23 2025 +0100 Upload notebook with T4 selected (#4449) commit 43253b2 Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Mon Nov 3 21:07:31 2025 +0000 Add On-Policy Distillation from thinking labs to paper index. (#4410) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit 6f41b18 Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Date: Mon Nov 3 10:57:51 2025 -0800 fix: Remove chat template setting from non-SFT trainer scripts (#4437) Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
What does this PR do?
Per slack discussion https://huggingface.slack.com/archives/C089Q56GPMM/p1761949724954879.
TinyGemma has
tie_word_embeddings=TrueQwen3 has
tie_word_embeddings=FalseBefore submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.