Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Gemma #940

Closed
carmocca opened this issue Feb 21, 2024 · 14 comments · Fixed by #941
Closed

Support Gemma #940

carmocca opened this issue Feb 21, 2024 · 14 comments · Fixed by #941

Comments

@carmocca
Copy link
Contributor

carmocca commented Feb 21, 2024

Announcement: https://blog.google/technology/developers/gemma-open-models/
Technical report: https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf
HF Hub weights: https://huggingface.co/google/gemma-7b
HF Transformers PR: huggingface/transformers#29167 with https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py as the model implementation

From a brief skim, I think it just needs to add geglu as the activation

@rasbt
Copy link
Collaborator

rasbt commented Feb 21, 2024

Based on the paper:

  1. We already support multi-query attention (through Llama)
  2. Yes, GeGLU is one of the "novelties"
  3. The other novelty is that they apply RMSNorm before and after attention instead only one or the other as in Post/Pre-Norm

@rasbt
Copy link
Collaborator

rasbt commented Feb 21, 2024

There doesn't seem to be an official GeGLU implementation in PyTorch, yet, but this looks good: https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/deep_table/nn/layers/activation.py#L22

@carmocca
Copy link
Contributor Author

The other novelty is that they apply RMSNorm before and after attention instead only one or the other as in Post/Pre-Norm

Can you be specific? It seems to be exactly what we implement: https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/model.py#L154-L166 and seems to match what's in HF: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L621-L640

@carmocca
Copy link
Contributor Author

BTW this is up-for-grabs if one of you want to add this quickly

@rasbt
Copy link
Collaborator

rasbt commented Feb 21, 2024

Can you be specific? It seems to be exactly what we implement: https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/model.py#L154-L166 and seems to match what's in HF: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L621-L640

Hm, dunno what they mean. Even GPT-2 had a LayerNorm before and after each multihead attention module, so I thought they meant something different since they specifically highlighted that. From the paper:

" We normalize both the input and the output of each transformer sub-layer,
a deviation from the standard practice of solely
normalizing one or the other."

Maybe they mean that they added an additional normalization after the feedforward module.

@Andrei-Aksionov
Copy link
Collaborator

Andrei-Aksionov commented Feb 21, 2024

Maybe something wrong with the HF implementation?
Because they don't even use geglu, only a standard gelu.

I think it's better to check how it's implemented in Keras.

Update:

There doesn't seem to be an official GeGLU implementation in PyTorch, yet, but this looks good: https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/deep_table/nn/layers/activation.py#L22

Oh, I thought that GeGlu is a some smart variant that improves Gelu, but it's just that weird thing as in Olmo.
I expected a lot of math, but not this.
I don't like that the activation function changes size of the input, it's an unexpected behavior IMO, but if it's a thing I guess I just have to get used to it.

Update 2:
In their convert_gemma_weights.py code I haven't noticed anything related to MLP weights.
And they use the regular LlaMAMLP code. Maybe Google provided already prepared weights, where up and gate weights aren't merged 🤷‍♂️?

In theory, the code should work with just updating the config file.

@Andrei-Aksionov
Copy link
Collaborator

We normalize both the input and the output of each transformer sub-layer

When I read it I imagined something like this:

  • norm
  • attention
  • norm
  • MLP
  • norm

@rasbt
Copy link
Collaborator

rasbt commented Feb 21, 2024

Yes, same. It's weird. I agree we should check the Keras code.

@Andrei-Aksionov
Copy link
Collaborator

Ok, here is their TransformerBlock from KerasNLP.
Not sure what is going on, but:

  1. They use only 2 norm layers: before attention and before MLP
  2. They don't use geglu thing either 😆

@carmocca
Copy link
Contributor Author

carmocca commented Feb 21, 2024

https://github.com/keras-team/keras-nlp/blob/dd3ceb7957ab993fc7cc53890a72239e7a566218/keras_nlp/models/gemma/gemma_decoder_block.py#L167

Keras does implement geglu by having two fc layers - as in our LLaMAMLP class - but their size is halved compared to the size in LLaMAMLP

It's also interesting that they use approximate=True which maps to tanh in PyTorch: https://github.com/keras-team/keras/blob/2ad117c44c5346691a646512ded3d25a5e3cb322/keras/backend/torch/nn.py#L88-L89. This is as in our GptNeoxMLP class.

In conclusion, Gemma needs a new MLP class that is a mix of both

@rasbt
Copy link
Collaborator

rasbt commented Feb 21, 2024

Thanks for checking. In that case, let me submit a PR. Almost done.

@carmocca
Copy link
Contributor Author

geglu is gelu but only applied on half of the input. I agree that the HF impl doesn't look equal to that in Keras

@rasbt rasbt mentioned this issue Feb 21, 2024
18 tasks
@Andrei-Aksionov
Copy link
Collaborator

Keras does implement geglu by having two fc layers - as in our LLaMAMLP class - but their size is halved compared to the size in LLaMAMLP

Didn't know that this thing is called geglu. I seriously expected something more math heavy.
I definitely need to read the paper.

What's confusing me is that why do you need to specify an intermediate size, which is used only in DecoderBlock, just to half it in the process. Maybe scaling factor 4x is hard-coded? 🤷‍♂️

In conclusion, Gemma needs a new MLP class that is a mix of both

Yep.

@rasbt
Copy link
Collaborator

rasbt commented Feb 21, 2024

We normalize both the input and the output of each transformer sub-layer

When I read it I imagined something like this:

  • norm
  • attention
  • norm
  • MLP
  • norm

The weird thing though is they don't seem to have a 3rd layernorm weight there in the HF checkpoint. It's were the paper and the implementation seem to differ.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants