Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Glm #33823

Merged
merged 75 commits into from
Oct 18, 2024
Merged

add Glm #33823

merged 75 commits into from
Oct 18, 2024

Conversation

Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Sep 30, 2024

GLM model!

@HuggingFaceDocBuilderDev

Hey! 🤗 Thanks for your contribution to the transformers library!

Before merging this pull request, slow tests CI should be triggered. To enable this:

  • Add the run-slow label to the PR
  • When your PR is ready for merge and all reviewers' comments have been addressed, push an empty commit with the command [run-slow] followed by a comma separated list of all the models to be tested, i.e. [run_slow] model_to_test_1, model_to_test_2
    • If the pull request affects a lot of models, put at most 10 models in the commit message
  • A transformers maintainer will then approve the workflow to start the tests

(For maintainers) The documentation for slow tests CI on PRs is here.

@Cyrilvallez Cyrilvallez force-pushed the glm branch 2 times, most recently from 207ec14 to 152569e Compare October 1, 2024 14:08
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very very nice!

@@ -0,0 +1,27 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2020 The HuggingFace Team. All rights reserved.
# Copyright 2024 The HuggingFace Team. All rights reserved.

Comment on lines 12 to 36
STATE_DICT_MAPPING = {
"transformer.output_layer.": "lm_head.",
"transformer.": "model.",
".embedding.word_embeddings.": ".embed_tokens.",
".encoder.final_layernorm.": ".norm.",
".encoder.layers.": ".layers.",
"rotary_pos_embed.": "rotary_emb.",
"self_attention.": "self_attn.",
"query_key_value.": "qkv_proj.",
"dense.": "o_proj.",
"dense_h_to_4h.": "gate_up_proj.",
"dense_4h_to_h.": "down_proj.",
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool! Let's setup good standards however, see MLLAMA, full explicit regex are more informative IMO! 🤗

Comment on lines 73 to 92
vocab_size=original_config.pop("padded_vocab_size"),
hidden_size=original_config.pop("hidden_size"),
intermediate_size=original_config.pop("ffn_hidden_size"),
num_hidden_layers=original_config.pop("num_layers"),
num_attention_heads=num_attention_heads,
num_key_value_heads=(
num_attention_heads
if not original_config.pop("multi_query_attention")
else original_config.pop("multi_query_group_num")
),
attention_dropout=original_config.pop("attention_dropout"),
max_position_embeddings=original_config.pop("seq_length"),
rms_norm_eps=original_config.pop("layernorm_epsilon"),
rope_theta=10000.0 * original_config.pop("rope_ratio", 1),
use_cache=original_config.pop("use_cache"),
head_dim=original_config.pop("kv_channels"),
attention_bias=original_config.pop("add_qkv_bias"),
eos_token_id=original_config.pop("eos_token_id"),
pad_token_id=original_config.pop("pad_token_id"),
tie_word_embeddings=original_config.pop("tie_word_embeddings"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to use ** here for attributes that have the same name

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't to avoid adding unused fields, but I refactored to make that block nicer to read.

pass


class GlmSdpaAttention(GlmAttention, GraniteSdpaAttention):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class GlmSdpaAttention(GlmAttention, GraniteSdpaAttention):
class GlmSdpaAttention(GraniteSdpaAttention):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be enough

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

holy molly so nice!

@require_torch_sdpa
@slow
@is_flaky
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have to overwrite this one?

Copy link
Member Author

@Cyrilvallez Cyrilvallez Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, based on the random inputs there may be some times when one of the cases fail - I overwrote it to add the flaky decorator (which allows the test to consistently pass)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! In general the least we have to overwrite the better!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meaning are there ways to remove some of the tests you added?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately no -- based on the random seed, some are failing from time to time, and they need to be flaky to consistently pass

@Cyrilvallez Cyrilvallez marked this pull request as ready for review October 2, 2024 12:14
@Cyrilvallez
Copy link
Member Author

Ready for last review @ArthurZucker, setup_and_quality fail because of the __all__ issue, but will pass once #33859 is merged.

@zRzRzRzRzRzRzR
Copy link
Contributor

Thank you very much for your help. I also saw this huggingface PR.
I have replied, and some of the code may need to be modified. Perhaps we can work together to improve it and merge this work.

Thank you again for your support!

@ArthurZucker
Copy link
Collaborator

Of course!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM anything missing before we merge?

@Cyrilvallez
Copy link
Member Author

Cyrilvallez commented Oct 18, 2024

LGTM anything missing before we merge?

No, only issue are the docstrings in the configuration, but this will be solved with the auto-docstrings. In the meantime, I just moved the config outside modular to please the CIs.

@Cyrilvallez
Copy link
Member Author

Confimed that slow tests pass for the model. Merging.

@Cyrilvallez Cyrilvallez merged commit 6604764 into main Oct 18, 2024
23 of 27 checks passed
@Cyrilvallez Cyrilvallez deleted the glm branch October 18, 2024 15:41
NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Oct 21, 2024
* Create modular_glm.py

* Update modular_glm.py

* Finalize architecture without all attentions

* Add all attentions modules

* Finalize modular

* Update given last version

* Last update

* Finalize model

* Finalize converter

* Update convert_glm_weights_to_hf.py

* style

* style

* Create __init__.py

* Aff all inits

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Correct the rotary embeddings

* Remove apply_residual_connection_post_layernorm (always false)

* remove use_rms_norm (always true)

* remove past_layer_norm (always true)

* Update __init__.py

* Update config and license

* start adding tests and doc

* Add doc + style

* Update test_modeling_glm.py

* Add dummies

* Apply correct modeling

* Refactor attention to follow llama

* Update __init__.py

* Update convert_glm_weights_to_hf.py

* Correct bias

* remove linear_bias and pdrop (never used)

* apply modular

* Simplify converter

* remove dummies + style

* add model_input_names

* Add pretraining_tp to config for when eager attention is used

* Update modular to remove all pretraining_tp

* Update test_modeling_glm.py

* Update the __all__

* Update __all__

* Update __init__.py

* Update test_modeling_glm.py

* add revisions

* Add the correct repos and revisions

* style

* Update __init__.py

* update exports

* remove import of modular files

* style

* Apply Llama changes + refine converter

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* style

* Use new modular converter

* add pretrainedmodel to init

* style

* Update test_modeling_glm.py

* Move config outside modular to please CI about docstrings

* Add dummies to please CI

* Update glm.md

* Update glm.md
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Oct 21, 2024
* Create modular_glm.py

* Update modular_glm.py

* Finalize architecture without all attentions

* Add all attentions modules

* Finalize modular

* Update given last version

* Last update

* Finalize model

* Finalize converter

* Update convert_glm_weights_to_hf.py

* style

* style

* Create __init__.py

* Aff all inits

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Correct the rotary embeddings

* Remove apply_residual_connection_post_layernorm (always false)

* remove use_rms_norm (always true)

* remove past_layer_norm (always true)

* Update __init__.py

* Update config and license

* start adding tests and doc

* Add doc + style

* Update test_modeling_glm.py

* Add dummies

* Apply correct modeling

* Refactor attention to follow llama

* Update __init__.py

* Update convert_glm_weights_to_hf.py

* Correct bias

* remove linear_bias and pdrop (never used)

* apply modular

* Simplify converter

* remove dummies + style

* add model_input_names

* Add pretraining_tp to config for when eager attention is used

* Update modular to remove all pretraining_tp

* Update test_modeling_glm.py

* Update the __all__

* Update __all__

* Update __init__.py

* Update test_modeling_glm.py

* add revisions

* Add the correct repos and revisions

* style

* Update __init__.py

* update exports

* remove import of modular files

* style

* Apply Llama changes + refine converter

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* style

* Use new modular converter

* add pretrainedmodel to init

* style

* Update test_modeling_glm.py

* Move config outside modular to please CI about docstrings

* Add dummies to please CI

* Update glm.md

* Update glm.md
@liyucheng09
Copy link

@Cyrilvallez Hi Cyril, you PR for the 1M version of the model got an unexpected generation. Please refer to here for more information: https://huggingface.co/THUDM/glm-4-9b-chat-1m/discussions/17.

@Qubitium Qubitium mentioned this pull request Nov 10, 2024
4 tasks
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* Create modular_glm.py

* Update modular_glm.py

* Finalize architecture without all attentions

* Add all attentions modules

* Finalize modular

* Update given last version

* Last update

* Finalize model

* Finalize converter

* Update convert_glm_weights_to_hf.py

* style

* style

* Create __init__.py

* Aff all inits

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Correct the rotary embeddings

* Remove apply_residual_connection_post_layernorm (always false)

* remove use_rms_norm (always true)

* remove past_layer_norm (always true)

* Update __init__.py

* Update config and license

* start adding tests and doc

* Add doc + style

* Update test_modeling_glm.py

* Add dummies

* Apply correct modeling

* Refactor attention to follow llama

* Update __init__.py

* Update convert_glm_weights_to_hf.py

* Correct bias

* remove linear_bias and pdrop (never used)

* apply modular

* Simplify converter

* remove dummies + style

* add model_input_names

* Add pretraining_tp to config for when eager attention is used

* Update modular to remove all pretraining_tp

* Update test_modeling_glm.py

* Update the __all__

* Update __all__

* Update __init__.py

* Update test_modeling_glm.py

* add revisions

* Add the correct repos and revisions

* style

* Update __init__.py

* update exports

* remove import of modular files

* style

* Apply Llama changes + refine converter

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* style

* Use new modular converter

* add pretrainedmodel to init

* style

* Update test_modeling_glm.py

* Move config outside modular to please CI about docstrings

* Add dummies to please CI

* Update glm.md

* Update glm.md
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* Create modular_glm.py

* Update modular_glm.py

* Finalize architecture without all attentions

* Add all attentions modules

* Finalize modular

* Update given last version

* Last update

* Finalize model

* Finalize converter

* Update convert_glm_weights_to_hf.py

* style

* style

* Create __init__.py

* Aff all inits

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Correct the rotary embeddings

* Remove apply_residual_connection_post_layernorm (always false)

* remove use_rms_norm (always true)

* remove past_layer_norm (always true)

* Update __init__.py

* Update config and license

* start adding tests and doc

* Add doc + style

* Update test_modeling_glm.py

* Add dummies

* Apply correct modeling

* Refactor attention to follow llama

* Update __init__.py

* Update convert_glm_weights_to_hf.py

* Correct bias

* remove linear_bias and pdrop (never used)

* apply modular

* Simplify converter

* remove dummies + style

* add model_input_names

* Add pretraining_tp to config for when eager attention is used

* Update modular to remove all pretraining_tp

* Update test_modeling_glm.py

* Update the __all__

* Update __all__

* Update __init__.py

* Update test_modeling_glm.py

* add revisions

* Add the correct repos and revisions

* style

* Update __init__.py

* update exports

* remove import of modular files

* style

* Apply Llama changes + refine converter

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* Update convert_glm_weights_to_hf.py

* style

* Use new modular converter

* add pretrainedmodel to init

* style

* Update test_modeling_glm.py

* Move config outside modular to please CI about docstrings

* Add dummies to please CI

* Update glm.md

* Update glm.md
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants