Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma bug fixes - Approx GELU, Layernorms, Sqrt(hd) #29402

Closed
wants to merge 13 commits into from

Conversation

danielhanchen
Copy link
Contributor

@danielhanchen danielhanchen commented Mar 2, 2024

Just a few more Gemma fixes :) Currently checking for more as well! All fixes are derived from Unsloth's 2.5x faster and 70% less VRAM Gemma finetuning script :) https://github.com/unslothai/unsloth

Related PR: #29285, which showed RoPE must be done in float32 and not float16, causing positional encodings to lose accuracy. @ArthurZucker @younesbelkada

1. Approx Gelu and not Exact

Activation function according to https://twitter.com/danielhanchen/status/1763613620909580505 (waiting for confirmation) should be approximate gelu and not exact gelu. Ie gelu should actually be gelu_pytorch_tanh which calls PytorchGELUTanh and that calls nn.functional.gelu(input, approximate="tanh"). gelu calls nn.functional.gelu(input, approximate="none")
Approx_Gelu_MqgVNPUOIJJL0PByVo10d

2. Layernorm (w+1) should be done in float32

The layernorms must be upcasted to float32 and not bfloat16 or float16 halfway. It must be done unlike Llama’s RMS Layernorm which downcasted before multiplying by the weights. We must downcast at the end.
RMS_Layernorm_IQDKDDYqOynaWPbbJT71e

3. sqrt(3072)=55.4256 but bfloat16 is 55.5

Interestingly, Gemma multiplies the embeddings by sqrt(hidden_dim). However, there is a precision problem! Gemma uses jnp.sqrt(self.embed_dim) .astype(x.dtype) which means sqrt(3072) = 55.4256, but casting it to bfloat16 rounds it to 55.5. For Gemma 2b, sqrt(2048) = 45.2548, but casting it to bfloat16 makes it 45.25.
Embedding_i3Hbs2CnyVrNF-otch_tL

@danielhanchen
Copy link
Contributor Author

Actually I just noticed this requires updating all of Gemma's models on the HF Model Hub: https://huggingface.co/google/gemma-7b/blob/main/config.json for eg gelu should be changed to gelu_pytorch_tanh.

Comment on lines 173 to 181
if hidden_act != "gelu_pytorch_tanh":
logger.warning_once(
"Gemma's activation function should be approximate GeLU and not exact GeLU.\n"\
"Please edit your model config to use `gelu_pytorch_tanh` and not `gelu`.\n"\
"For now, we shall use `gelu_pytorch_tanh` temporarily.\n"\
"See https://github.com/huggingface/transformers/pull/29402 for more details."
)
hidden_act = "gelu_pytorch_tanh"
self.act_fn = ACT2FN[hidden_act]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind automatically switching, but it's best if the users still have a way to use the legacy gelu! Either a big warning or use another config name

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we need a self.hidden_activation set to None by default and if None warn that we will use the new approx else use what was give

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok good point! Sorry didn't work on this in the meantime - I found a few more issues, and will push them here tomorrow :)

@danielhanchen danielhanchen changed the title Gemma fixes - gelu Gemma bug fixes - Approx GELU, Layernorms, Sqrt(hd) Mar 9, 2024
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this @danielhanchen ! 🤩
I left one comment about backward compatibility, what do you think?

@@ -170,7 +173,16 @@ def __init__(self, config):
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
hidden_act = config.hidden_act
if hidden_act != "gelu_pytorch_tanh":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way even if there is a model with the old gelu on the config we're force-setting hidden_act to "gelu_pytorch_tanh" right?
I think we should either use a new config name or create a new attribute in the config force_use_exact_gelu, that is iniailizaed to False so that users can have the flexibility to switch to the old act function in case they fine-tuned it with old GeLU, what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I like that approach ie getattr(config, "force_use_exact_gelu", False) so if force_use_exact_gelu = True then True. If force_use_exact_gelu = False then also False, and False otherwise.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! I think we can add that directly into GemmaConfig class and default it to False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added force_use_exact_gelu!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Let's fix the CI and merge!

src/transformers/models/gemma/modeling_gemma.py Outdated Show resolved Hide resolved
hidden_states = hidden_states * (self.config.hidden_size**0.5)
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype = hidden_states.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not make a difference but it should be cleaner to cache this one as hidden_states_scale ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@danielhanchen
Copy link
Contributor Author

Closing this since we merged this! Thanks everyone!

Narsil pushed a commit to huggingface/text-generation-inference that referenced this pull request Mar 21, 2024
)

This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is
used in the Gemma model. References
huggingface/transformers#29402 and
huggingface/transformers#29729

Setting `force_downcast_after=True` will perform the `hidden_states *
weight` multiplication in f32 and then downcast to half. This differs
slightly from the current implementation which first casts the
`hidden_states` to a half and then multiples.
@nxphi47
Copy link
Contributor

nxphi47 commented Mar 22, 2024

@danielhanchen Many thanks for the fixes. Do you observe any performance difference before and after this fix? thanks.

guocuimi added a commit to vectorch-ai/ScaleLLM that referenced this pull request Apr 15, 2024
fix precision issue mentioned in
huggingface/transformers#29402
this diff:
* fixed 1> Approx Gelu and 3> sqrt(hidden_dim) with dtype
* fixed the head_dim for gemma_7b_* models
@cuichenx cuichenx mentioned this pull request Apr 17, 2024
8 tasks
cr313 added a commit to cr313/text-generation-inference-load-test that referenced this pull request Apr 19, 2024
…658)

This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is
used in the Gemma model. References
huggingface/transformers#29402 and
huggingface/transformers#29729

Setting `force_downcast_after=True` will perform the `hidden_states *
weight` multiplication in f32 and then downcast to half. This differs
slightly from the current implementation which first casts the
`hidden_states` to a half and then multiples.
kdamaszk pushed a commit to kdamaszk/tgi-gaudi that referenced this pull request Apr 29, 2024
…ggingface#1658)

This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is
used in the Gemma model. References
huggingface/transformers#29402 and
huggingface/transformers#29729

Setting `force_downcast_after=True` will perform the `hidden_states *
weight` multiplication in f32 and then downcast to half. This differs
slightly from the current implementation which first casts the
`hidden_states` to a half and then multiples.
alfredgui2 pushed a commit to mlsys-io/kv.run that referenced this pull request Jul 6, 2024
…658)

This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is
used in the Gemma model. References
huggingface/transformers#29402 and
huggingface/transformers#29729

Setting `force_downcast_after=True` will perform the `hidden_states *
weight` multiplication in f32 and then downcast to half. This differs
slightly from the current implementation which first casts the
`hidden_states` to a half and then multiples.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants