forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Granite language models (huggingface#31502)
* first commit * drop tokenizer * drop tokenizer * drop tokenizer * drop convert * granite * drop tokenization test * mup * fix * reformat * reformat * reformat * fix docs * stop checking for checkpoint * update support * attention multiplier * update model * tiny drop * saibo drop * skip test * fix test * fix test * drop * drop useless imports * update docs * drop flash function * copied from * drop pretraining tp * drop pretraining tp * drop pretraining tp * drop unused import * drop code path * change name * softmax scale * head dim * drop legacy cache * rename params * cleanup * fix copies * comments * add back legacy cache * multipliers * multipliers * multipliers * text fix * fix copies * merge * multipliers * attention multiplier * drop unused imports * fix * fix * fix * move rope? * Update src/transformers/models/granite/configuration_granite.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix * Update src/transformers/models/granite/modeling_granite.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix * fix * fix * fix * fix-copies * torch rmsnorm * add authors * change model path * fix * test * drop static cache test * uupdate readme * drop non-causal * readme * drop useless imports * Update docs/source/en/model_doc/granite.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/granite.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/granite.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
- Loading branch information
Showing
16 changed files
with
2,154 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
|
||
# Granite | ||
|
||
## Overview | ||
|
||
The Granite model was proposed in [Power Scheduler: A Batch Size and Token Number Agnostic Learning Rate Scheduler](https://arxiv.org/abs/2408.13359) by Yikang Shen, Matthew Stallone, Mayank Mishra, Gaoyuan Zhang, Shawn Tan, Aditya Prasad, Adriana Meza Soria, David D. Cox and Rameswar Panda. | ||
|
||
PowerLM-3B is a 3B state-of-the-art small language model trained with the Power learning rate scheduler. It is trained on a wide range of open-source and synthetic datasets with permissive licenses. PowerLM-3B has shown promising results compared to other models in the size categories across various benchmarks, including natural language multi-choices, code generation, and math reasoning. | ||
|
||
The abstract from the paper is the following: | ||
|
||
*Finding the optimal learning rate for language model pretraining is a challenging task. | ||
This is not only because there is a complicated correlation between learning rate, batch size, number of training tokens, model size, and other hyperparameters but also because it is prohibitively expensive to perform a hyperparameter search for large language models with Billions or Trillions of parameters. Recent studies propose using small proxy models and small corpus to perform hyperparameter searches and transposing the optimal parameters to large models and large corpus. While the zero-shot transferability is theoretically and empirically proven for model size related hyperparameters, like depth and width, the zero-shot transfer from small corpus to large corpus is underexplored. | ||
In this paper, we study the correlation between optimal learning rate, batch size, and number of training tokens for the recently proposed WSD scheduler. After thousands of small experiments, we found a power-law relationship between variables and demonstrated its transferability across model sizes. Based on the observation, we propose a new learning rate scheduler, Power scheduler, that is agnostic about the number of training tokens and batch size. The experiment shows that combining the Power scheduler with Maximum Update Parameterization (\mup) can consistently achieve impressive performance with one set of hyperparameters regardless of the number of training tokens, batch size, model size, and even model architecture. Our 3B dense and MoE models trained with the Power scheduler achieve comparable performance as state-of-the-art small language models. | ||
We [open source](https://huggingface.co/collections/ibm/power-lm-66be64ae647ddf11b9808000) these pretrained models.* | ||
|
||
Tips: | ||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
model_path = "ibm/PowerLM-3b" | ||
tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
|
||
# drop device_map if running on CPU | ||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") | ||
model.eval() | ||
|
||
# change input text as desired | ||
prompt = "Write a code to find the maximum value in a list of numbers." | ||
|
||
# tokenize the text | ||
input_tokens = tokenizer(prompt, return_tensors="pt") | ||
# generate output tokens | ||
output = model.generate(**input_tokens, max_new_tokens=100) | ||
# decode output tokens into text | ||
output = tokenizer.batch_decode(output) | ||
# loop over the batch to print, in this example the batch size is 1 | ||
for i in output: | ||
print(i) | ||
``` | ||
|
||
This model was contributed by [mayank-mishra](https://huggingface.co/mayank-mishra). | ||
|
||
|
||
## GraniteConfig | ||
|
||
[[autodoc]] GraniteConfig | ||
|
||
## GraniteModel | ||
|
||
[[autodoc]] GraniteModel | ||
- forward | ||
|
||
## GraniteForCausalLM | ||
|
||
[[autodoc]] GraniteForCausalLM | ||
- forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,6 +105,7 @@ | |
gpt_neox_japanese, | ||
gpt_sw3, | ||
gptj, | ||
granite, | ||
grounding_dino, | ||
groupvit, | ||
herbert, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import ( | ||
OptionalDependencyNotAvailable, | ||
_LazyModule, | ||
is_torch_available, | ||
) | ||
|
||
|
||
_import_structure = { | ||
"configuration_granite": ["GraniteConfig"], | ||
} | ||
|
||
try: | ||
if not is_torch_available(): | ||
raise OptionalDependencyNotAvailable() | ||
except OptionalDependencyNotAvailable: | ||
pass | ||
else: | ||
_import_structure["modeling_granite"] = [ | ||
"GraniteForCausalLM", | ||
"GraniteModel", | ||
"GranitePreTrainedModel", | ||
] | ||
|
||
if TYPE_CHECKING: | ||
from .configuration_granite import GraniteConfig | ||
|
||
try: | ||
if not is_torch_available(): | ||
raise OptionalDependencyNotAvailable() | ||
except OptionalDependencyNotAvailable: | ||
pass | ||
else: | ||
from .modeling_granite import ( | ||
GraniteForCausalLM, | ||
GraniteModel, | ||
GranitePreTrainedModel, | ||
) | ||
|
||
else: | ||
import sys | ||
|
||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) |
Oops, something went wrong.