Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Granite language models #31502

Merged
merged 87 commits into from
Aug 27, 2024
Merged

Granite language models #31502

merged 87 commits into from
Aug 27, 2024

Conversation

mayank31398
Copy link
Contributor

@mayank31398 mayank31398 commented Jun 19, 2024

What does this PR do?

This PR adds support for IBM's upcoming LLMs 3B and 8B.

@amyeroberts
Copy link
Collaborator

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot !
isn't granite support already added in #30031 ? If not we could leverage diff tool that has been recently added - see for example #31211 for reference. I'll let @ArthurZucker comment on this

@mayank31398
Copy link
Contributor Author

hey @younesbelkada
This is for our upcoming open model releases.
3B and 8B language models (lots of tokens :D)

lets just leave this PR for now.
I will get back to this in a few days.

@mayank31398 mayank31398 marked this pull request as ready for review July 1, 2024 22:59
@mayank31398
Copy link
Contributor Author

@ArthurZucker this is ready for merge

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, 2 small nits and let's merge

docs/source/en/model_doc/granite.md Outdated Show resolved Hide resolved
docs/source/en/model_doc/granite.md Outdated Show resolved Hide resolved
docs/source/en/model_doc/granite.md Outdated Show resolved Hide resolved
Comment on lines 642 to 645
@slow
@require_torch_gpu
@require_read_token
def test_compile_static_cache(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this not work (or why was it removed!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it does not work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the models use mup and the error is way to high to compare generated outputs

mayank31398 and others added 3 commits August 27, 2024 11:33
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@mayank31398
Copy link
Contributor Author

I have addressed the changes

@ArthurZucker
Copy link
Collaborator

Thanks for bearing with me 🤗

@mayank31398
Copy link
Contributor Author

passed docs 🥳

@ArthurZucker ArthurZucker merged commit c35d2cc into huggingface:main Aug 27, 2024
25 checks passed
@mayank31398
Copy link
Contributor Author

thanks Arthur :)

@mayank31398 mayank31398 deleted the granite branch August 27, 2024 19:42
@ArthurZucker
Copy link
Collaborator

Thank you as well! 🤗

@Jintao-Huang
Copy link
Contributor

hello!

module 'torch.nn' has no attribute 'RMSNorm'

The version of torch < 2.4.0 will report an error.

@@ -22,7 +22,7 @@
from .utils import is_torch_xla_available, logging


ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
ALL_LAYERNORM_LAYERS = [nn.LayerNorm, nn.RMSNorm]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Encountered the same issue, opened #33177 to fix it

@NielsRogge NielsRogge mentioned this pull request Aug 28, 2024
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Aug 30, 2024
* first commit

* drop tokenizer

* drop tokenizer

* drop tokenizer

* drop convert

* granite

* drop tokenization test

* mup

* fix

* reformat

* reformat

* reformat

* fix docs

* stop checking for checkpoint

* update support

* attention multiplier

* update model

* tiny drop

* saibo drop

* skip test

* fix test

* fix test

* drop

* drop useless imports

* update docs

* drop flash function

* copied from

* drop pretraining tp

* drop pretraining tp

* drop pretraining tp

* drop unused import

* drop code path

* change name

* softmax scale

* head dim

* drop legacy cache

* rename params

* cleanup

* fix copies

* comments

* add back legacy cache

* multipliers

* multipliers

* multipliers

* text fix

* fix copies

* merge

* multipliers

* attention multiplier

* drop unused imports

* fix

* fix

* fix

* move rope?

* Update src/transformers/models/granite/configuration_granite.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* Update src/transformers/models/granite/modeling_granite.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* fix

* fix

* fix

* fix-copies

* torch rmsnorm

* add authors

* change model path

* fix

* test

* drop static cache test

* uupdate readme

* drop non-causal

* readme

* drop useless imports

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Aug 30, 2024
* first commit

* drop tokenizer

* drop tokenizer

* drop tokenizer

* drop convert

* granite

* drop tokenization test

* mup

* fix

* reformat

* reformat

* reformat

* fix docs

* stop checking for checkpoint

* update support

* attention multiplier

* update model

* tiny drop

* saibo drop

* skip test

* fix test

* fix test

* drop

* drop useless imports

* update docs

* drop flash function

* copied from

* drop pretraining tp

* drop pretraining tp

* drop pretraining tp

* drop unused import

* drop code path

* change name

* softmax scale

* head dim

* drop legacy cache

* rename params

* cleanup

* fix copies

* comments

* add back legacy cache

* multipliers

* multipliers

* multipliers

* text fix

* fix copies

* merge

* multipliers

* attention multiplier

* drop unused imports

* fix

* fix

* fix

* move rope?

* Update src/transformers/models/granite/configuration_granite.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* Update src/transformers/models/granite/modeling_granite.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* fix

* fix

* fix

* fix-copies

* torch rmsnorm

* add authors

* change model path

* fix

* test

* drop static cache test

* uupdate readme

* drop non-causal

* readme

* drop useless imports

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
* first commit

* drop tokenizer

* drop tokenizer

* drop tokenizer

* drop convert

* granite

* drop tokenization test

* mup

* fix

* reformat

* reformat

* reformat

* fix docs

* stop checking for checkpoint

* update support

* attention multiplier

* update model

* tiny drop

* saibo drop

* skip test

* fix test

* fix test

* drop

* drop useless imports

* update docs

* drop flash function

* copied from

* drop pretraining tp

* drop pretraining tp

* drop pretraining tp

* drop unused import

* drop code path

* change name

* softmax scale

* head dim

* drop legacy cache

* rename params

* cleanup

* fix copies

* comments

* add back legacy cache

* multipliers

* multipliers

* multipliers

* text fix

* fix copies

* merge

* multipliers

* attention multiplier

* drop unused imports

* fix

* fix

* fix

* move rope?

* Update src/transformers/models/granite/configuration_granite.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* Update src/transformers/models/granite/modeling_granite.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* fix

* fix

* fix

* fix-copies

* torch rmsnorm

* add authors

* change model path

* fix

* test

* drop static cache test

* uupdate readme

* drop non-causal

* readme

* drop useless imports

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* first commit

* drop tokenizer

* drop tokenizer

* drop tokenizer

* drop convert

* granite

* drop tokenization test

* mup

* fix

* reformat

* reformat

* reformat

* fix docs

* stop checking for checkpoint

* update support

* attention multiplier

* update model

* tiny drop

* saibo drop

* skip test

* fix test

* fix test

* drop

* drop useless imports

* update docs

* drop flash function

* copied from

* drop pretraining tp

* drop pretraining tp

* drop pretraining tp

* drop unused import

* drop code path

* change name

* softmax scale

* head dim

* drop legacy cache

* rename params

* cleanup

* fix copies

* comments

* add back legacy cache

* multipliers

* multipliers

* multipliers

* text fix

* fix copies

* merge

* multipliers

* attention multiplier

* drop unused imports

* fix

* fix

* fix

* move rope?

* Update src/transformers/models/granite/configuration_granite.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* Update src/transformers/models/granite/modeling_granite.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* fix

* fix

* fix

* fix-copies

* torch rmsnorm

* add authors

* change model path

* fix

* test

* drop static cache test

* uupdate readme

* drop non-causal

* readme

* drop useless imports

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* first commit

* drop tokenizer

* drop tokenizer

* drop tokenizer

* drop convert

* granite

* drop tokenization test

* mup

* fix

* reformat

* reformat

* reformat

* fix docs

* stop checking for checkpoint

* update support

* attention multiplier

* update model

* tiny drop

* saibo drop

* skip test

* fix test

* fix test

* drop

* drop useless imports

* update docs

* drop flash function

* copied from

* drop pretraining tp

* drop pretraining tp

* drop pretraining tp

* drop unused import

* drop code path

* change name

* softmax scale

* head dim

* drop legacy cache

* rename params

* cleanup

* fix copies

* comments

* add back legacy cache

* multipliers

* multipliers

* multipliers

* text fix

* fix copies

* merge

* multipliers

* attention multiplier

* drop unused imports

* fix

* fix

* fix

* move rope?

* Update src/transformers/models/granite/configuration_granite.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* Update src/transformers/models/granite/modeling_granite.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* fix

* fix

* fix

* fix-copies

* torch rmsnorm

* add authors

* change model path

* fix

* test

* drop static cache test

* uupdate readme

* drop non-causal

* readme

* drop useless imports

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants