Skip to content

Conversation

@pramodith
Copy link
Collaborator

What does this PR do?

Per slack discussion https://huggingface.slack.com/archives/C089Q56GPMM/p1761949724954879.

TinyGemma has tie_word_embeddings=True
Qwen3 has tie_word_embeddings=False

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@pramodith
Copy link
Collaborator Author

pramodith commented Nov 3, 2025

Profiling with Qwen3-0.6B (tied_weights=True) original model in fp16 with inputs of shape [4, 128]

============================================================
GPU MEMORY USAGE SUMMARY
============================================================
Initial memory:              1.110 GB
After forward pass:          2.789 GB
Peak during forward:         2.789 GB
After backward pass:         2.444 GB
Peak during backward:        3.088 GB
After optimizer step:        4.672 GB
Peak total (overall):        5.782 GB
============================================================

Profiling when word_embeddings and lm_head cast to fp32

============================================================
GPU MEMORY USAGE SUMMARY
============================================================
Initial memory:              1.401 GB
After forward pass:          3.374 GB
Peak during forward:         3.374 GB
After backward pass:         3.172 GB
Peak during backward:        4.332 GB
After optimizer step:        5.978 GB
Peak total (overall):        7.378 GB
============================================================

LM_HEAD and WORD_EMBEDDINGS share the same object so additional memory for FP32
weights at init should be
2 * vocab_size * hidden_dim => 2 * (151936 * 1024)/(1024**2) = 296 MB

This is equal to 1.401GB - 1.110GB = 296 MB (approx)

2 is the factor of additional bytes FP32/FP16

Forward Pass
Additional Memory for FP32 logits should be
vocab_size * seq_len * batch_size => 2 * (151936 * 128 * 4)/(1024 **2) = 148MB

3.374 GB - 2.789 GB should be = (Additional Init Mem + Additional Forward Pass Mem) =  148 MB + 296 MB
0.585 GB !=  444 MB

🤔
The difference is approx 140 MB, I'm wondering if the forward pass stores both the results after the embeddings (fp32) and the post embedding layer hook separately accounting for the extra 140 MB.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.


def test_training_with_cast_lm_head_to_fp32(self):
@pytest.mark.parametrize(
"model_name", ["trl-internal-testing/tiny-Qwen3ForCausalLM", "trl-internal-testing/tiny-Gemma2ForCausalLM"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3 has tied word embedding and Gemma 2 no, correct? If so, I'd just add a small comment so that we remember why we test these two cases

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the other way around Qwen3 has untied and Gemma 2 has tied.

return (inputs[0].to(torch.float32),) + inputs[1:]

original_dtype_local = target_model.lm_head.weight.dtype
target_model.lm_head = target_model.lm_head.float()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the record, float() is inlace, so in theory, you could just have

target_model.lm_head.float()

it happens that .float() returns self, so target_model.lm_head = target_model.lm_head.float() works as well. I personally prefer the current way, assignment makes it more explicit.

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! just a small comment on the test

@pramodith pramodith merged commit d9f9e2b into huggingface:main Nov 4, 2025
9 of 10 checks passed
qgallouedec added a commit to Harras3/trl that referenced this pull request Nov 4, 2025
commit 7a9592b
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Nov 4 14:32:04 2025 -0700

    🐍 Drop Python 3.9 (huggingface#4183)

commit 7f15a7f
Author: Harras Mansoor <98635627+Harras3@users.noreply.github.com>
Date:   Wed Nov 5 02:06:31 2025 +0500

    Removed outdated warning about batch contamination (huggingface#4423)

commit 8b0a3ce
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Tue Nov 4 21:37:39 2025 +0100

    Update tokenizer apply_chat_template with return_dict=True default (huggingface#4448)

commit d9f9e2b
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Tue Nov 4 19:56:58 2025 +0000

    Support casting to fp32 when word embeddings are tied to lm_head (huggingface#4446)

commit 4e138ab
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Tue Nov 4 15:15:23 2025 +0100

    Upload notebook with T4 selected (huggingface#4449)
@pramodith pramodith deleted the pramodith/support_tied_word_embeddings branch November 4, 2025 23:11
qgallouedec added a commit that referenced this pull request Nov 4, 2025
commit 4677cf2
Author: Harras Mansoor <98635627+Harras3@users.noreply.github.com>
Date:   Wed Nov 5 04:06:13 2025 +0500

    Removed Sentiment Tuning Examples (#4424)

commit 7a9592b
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Nov 4 14:32:04 2025 -0700

    🐍 Drop Python 3.9 (#4183)

commit 7f15a7f
Author: Harras Mansoor <98635627+Harras3@users.noreply.github.com>
Date:   Wed Nov 5 02:06:31 2025 +0500

    Removed outdated warning about batch contamination (#4423)

commit 8b0a3ce
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Tue Nov 4 21:37:39 2025 +0100

    Update tokenizer apply_chat_template with return_dict=True default (#4448)

commit d9f9e2b
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Tue Nov 4 19:56:58 2025 +0000

    Support casting to fp32 when word embeddings are tied to lm_head (#4446)

commit 4e138ab
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Tue Nov 4 15:15:23 2025 +0100

    Upload notebook with T4 selected (#4449)

commit 43253b2
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Mon Nov 3 21:07:31 2025 +0000

    Add On-Policy Distillation from thinking labs to paper index. (#4410)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 6f41b18
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Mon Nov 3 10:57:51 2025 -0800

    fix: Remove chat template setting from non-SFT trainer scripts (#4437)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants