-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemma bug fixes - Approx GELU, Layernorms, Sqrt(hd) #29402
Conversation
Actually I just noticed this requires updating all of Gemma's models on the HF Model Hub: https://huggingface.co/google/gemma-7b/blob/main/config.json for eg |
if hidden_act != "gelu_pytorch_tanh": | ||
logger.warning_once( | ||
"Gemma's activation function should be approximate GeLU and not exact GeLU.\n"\ | ||
"Please edit your model config to use `gelu_pytorch_tanh` and not `gelu`.\n"\ | ||
"For now, we shall use `gelu_pytorch_tanh` temporarily.\n"\ | ||
"See https://github.com/huggingface/transformers/pull/29402 for more details." | ||
) | ||
hidden_act = "gelu_pytorch_tanh" | ||
self.act_fn = ACT2FN[hidden_act] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mind automatically switching, but it's best if the users still have a way to use the legacy gelu! Either a big warning or use another config name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we need a self.hidden_activation
set to None by default and if None
warn that we will use the new approx else use what was give
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh ok good point! Sorry didn't work on this in the meantime - I found a few more issues, and will push them here tomorrow :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this @danielhanchen ! 🤩
I left one comment about backward compatibility, what do you think?
@@ -170,7 +173,16 @@ def __init__(self, config): | |||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |||
self.act_fn = ACT2FN[config.hidden_act] | |||
hidden_act = config.hidden_act | |||
if hidden_act != "gelu_pytorch_tanh": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This way even if there is a model with the old gelu on the config we're force-setting hidden_act to "gelu_pytorch_tanh"
right?
I think we should either use a new config name or create a new attribute in the config force_use_exact_gelu
, that is iniailizaed to False so that users can have the flexibility to switch to the old act function in case they fine-tuned it with old GeLU, what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I like that approach ie getattr(config, "force_use_exact_gelu", False)
so if force_use_exact_gelu = True
then True. If force_use_exact_gelu = False
then also False, and False otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! I think we can add that directly into GemmaConfig
class and default it to False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added force_use_exact_gelu
!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Let's fix the CI and merge!
hidden_states = hidden_states * (self.config.hidden_size**0.5) | ||
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 | ||
# See https://github.com/huggingface/transformers/pull/29402 | ||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype = hidden_states.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should not make a difference but it should be cleaner to cache this one as hidden_states_scale
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Closing this since we merged this! Thanks everyone! |
) This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is used in the Gemma model. References huggingface/transformers#29402 and huggingface/transformers#29729 Setting `force_downcast_after=True` will perform the `hidden_states * weight` multiplication in f32 and then downcast to half. This differs slightly from the current implementation which first casts the `hidden_states` to a half and then multiples.
@danielhanchen Many thanks for the fixes. Do you observe any performance difference before and after this fix? thanks. |
fix precision issue mentioned in huggingface/transformers#29402 this diff: * fixed 1> Approx Gelu and 3> sqrt(hidden_dim) with dtype * fixed the head_dim for gemma_7b_* models
…658) This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is used in the Gemma model. References huggingface/transformers#29402 and huggingface/transformers#29729 Setting `force_downcast_after=True` will perform the `hidden_states * weight` multiplication in f32 and then downcast to half. This differs slightly from the current implementation which first casts the `hidden_states` to a half and then multiples.
…ggingface#1658) This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is used in the Gemma model. References huggingface/transformers#29402 and huggingface/transformers#29729 Setting `force_downcast_after=True` will perform the `hidden_states * weight` multiplication in f32 and then downcast to half. This differs slightly from the current implementation which first casts the `hidden_states` to a half and then multiples.
…658) This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is used in the Gemma model. References huggingface/transformers#29402 and huggingface/transformers#29729 Setting `force_downcast_after=True` will perform the `hidden_states * weight` multiplication in f32 and then downcast to half. This differs slightly from the current implementation which first casts the `hidden_states` to a half and then multiples.
Just a few more Gemma fixes :) Currently checking for more as well! All fixes are derived from Unsloth's 2.5x faster and 70% less VRAM Gemma finetuning script :) https://github.com/unslothai/unsloth
Related PR: #29285, which showed RoPE must be done in float32 and not float16, causing positional encodings to lose accuracy. @ArthurZucker @younesbelkada
1. Approx Gelu and not Exact
Activation function according to https://twitter.com/danielhanchen/status/1763613620909580505 (waiting for confirmation) should be approximate gelu and not exact gelu. Ie
gelu
should actually begelu_pytorch_tanh
which callsPytorchGELUTanh
and that callsnn.functional.gelu(input, approximate="tanh")
.gelu
callsnn.functional.gelu(input, approximate="none")
2. Layernorm (w+1) should be done in float32
The layernorms must be upcasted to float32 and not bfloat16 or float16 halfway. It must be done unlike Llama’s RMS Layernorm which downcasted before multiplying by the weights. We must downcast at the end.
3. sqrt(3072)=55.4256 but bfloat16 is 55.5
Interestingly, Gemma multiplies the embeddings by sqrt(hidden_dim). However, there is a precision problem! Gemma uses jnp.sqrt(self.embed_dim) .astype(x.dtype) which means sqrt(3072) = 55.4256, but casting it to bfloat16 rounds it to 55.5. For Gemma 2b, sqrt(2048) = 45.2548, but casting it to bfloat16 makes it 45.25.