Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into moshi-integration
Browse files Browse the repository at this point in the history
  • Loading branch information
ylacombe authored Oct 2, 2024
2 parents 9e650fd + 614c79a commit 0a147ff
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 66 deletions.
51 changes: 0 additions & 51 deletions .github/workflows/slow_ci_remainder.yml

This file was deleted.

1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ For now the supported model architectures are the architectures that have been v
- Qwen2Moe
- Phi3
- Bloom
- Falcon

## Example usage

Expand Down
1 change: 0 additions & 1 deletion src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,6 @@ def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]]
)
)

add_prefix_space = False
add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
Expand Down
36 changes: 36 additions & 0 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,29 @@
"output.weight": "lm_head.weight",
"output_norm": "transformer.ln_f",
},
"falcon7b": {
"token_embd": "word_embeddings",
"blk": "h",
"ffn_up": "mlp.dense_h_to_4h",
"ffn_down": "mlp.dense_4h_to_h",
"attn_norm": "input_layernorm",
"attn_qkv": "self_attention.query_key_value",
"attn_output": "self_attention.dense",
".output.": ".lm_head.",
"output_norm": "ln_f",
},
"falcon40b": {
"token_embd": "word_embeddings",
"blk": "h",
"ffn_up": "mlp.dense_h_to_4h",
"ffn_down": "mlp.dense_4h_to_h",
".attn_norm.": ".ln_mlp.",
"attn_norm_2": "ln_attn",
"attn_qkv": "self_attention.query_key_value",
"attn_output": "self_attention.dense",
".output.": ".lm_head.",
"output_norm": "ln_f",
},
}


Expand Down Expand Up @@ -178,6 +201,18 @@
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"falcon": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"tokenizer": {
"ggml.bos_token_id": "bos_token_id",
"ggml.eos_token_id": "eos_token_id",
Expand Down Expand Up @@ -530,6 +565,7 @@ def converted(self) -> Tokenizer:
"qwen2_moe": GGUFQwen2Converter,
"phi3": GGUFPhi3Converter,
"bloom": GGUFBloomConverter,
"falcon": GGUFBloomConverter,
}


Expand Down
27 changes: 16 additions & 11 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -99,8 +100,20 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
if "qwen2moe" in architecture:
updated_architecture = "qwen2_moe"

if architecture not in GGUF_SUPPORTED_ARCHITECTURES:
raise ValueError(f"Architecture {architecture} not supported")
model_size = ""
# extract the number of params from file name as architectures can differ ;
# eg. for falcon : `...falcon-7b-...`
if "falcon" in architecture:
gguf_file_name = gguf_checkpoint_path.split("/")[-1].lower()
m = re.search(r"-\d+b-", gguf_file_name) # regex to catch `-7b-`
if m is None:
raise ValueError(
f"From file name, cannot determine the number of parameters for {architecture} architecture"
)
model_size = m.group().strip("-") # only keeps `7b`

if architecture + model_size not in GGUF_SUPPORTED_ARCHITECTURES:
raise ValueError(f"Architecture {architecture + model_size} not supported")

# List all key-value pairs in a columnized format
for gguf_key, field in reader.fields.items():
Expand Down Expand Up @@ -146,17 +159,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
)

if return_tensors:
tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture]
tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture + model_size]

for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
renamed_tensor_name = tensor.name

for tensor_name_mapping in GGUF_TO_TRANSFORMERS_MAPPING["tensors"]:
if tensor_name_mapping in renamed_tensor_name:
renamed_tensor_name = renamed_tensor_name.replace(
tensor_name_mapping, GGUF_TO_TRANSFORMERS_MAPPING["tensors"][tensor_name_mapping]
)

name = tensor.name

weights = dequantize(tensor.data, tensor.tensor_type)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
quantization_config_dict = model_config.quantization_config
quantization_config = cls.from_dict(quantization_config_dict)
# Update with potential kwargs that are passed through from_pretrained.
quantization_config.update(kwargs)
quantization_config.update(**kwargs)
return quantization_config


Expand Down
3 changes: 2 additions & 1 deletion src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
# `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None
_quanto_available = _is_package_available("quanto")
_compressed_tensors_available = _is_package_available("compressed_tensors")
# For compressed_tensors, only check spec to allow compressed_tensors-nightly package
_compressed_tensors_available = importlib.util.find_spec("compressed_tensors") is not None
_pandas_available = _is_package_available("pandas")
_peft_available = _is_package_available("peft")
_phonemizer_available = _is_package_available("phonemizer")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ def __init__(
self.sparsity_config = None

# parse from dict to load nested QuantizationScheme objects
if config_groups:
if config_groups or kv_cache_scheme:
self.quantization_config = QuantizationConfig.parse_obj(
{
"config_groups": config_groups,
Expand Down
58 changes: 58 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class GgufIntegrationTests(unittest.TestCase):
phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf"
bloom_model_id = "afrideva/bloom-560m-GGUF"
original_bloom_model_id = "bigscience/bloom-560m"
falcon7b_model_id = "xaviviro/falcon-7b-quantized-gguf"
falcon40b_model_id = "maddes8cht/tiiuae-falcon-40b-gguf"
original_flacon7b_model_id = "tiiuae/falcon-7b"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand Down Expand Up @@ -74,6 +77,9 @@ class GgufIntegrationTests(unittest.TestCase):
fp16_bloom_model_id = "bloom-560m.fp16.gguf"
q8_bloom_model_id = "bloom-560m.q8_0.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
q2_k_falcon7b_model_id = "falcon-7b-q2_k.gguf"
fp16_falcon7b_model_id = "falcon-7b-fp16.gguf"
q2_k_falcon40b_model_id = "tiiuae-falcon-40b-Q2_K.gguf"

example_text = "Hello"

Expand Down Expand Up @@ -445,6 +451,58 @@ def test_bloom_weights_conversion_fp16(self):
self.assertTrue(quantized_param.shape == original_param.shape)
torch.testing.assert_close(quantized_param, original_param)

@unittest.skip(reason="Heavy memory")
def test_falcon40b_q2_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.falcon40b_model_id, gguf_file=self.q2_k_falcon40b_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.falcon40b_model_id,
gguf_file=self.q2_k_falcon40b_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello All,\nI am new to this forum."
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_falcon7b_q2_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.falcon7b_model_id, gguf_file=self.q2_k_falcon7b_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.falcon7b_model_id,
gguf_file=self.q2_k_falcon7b_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello All,\nI am new to this forum."
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_falcon7b_weights_conversion_fp16(self):
quantized_model = AutoModelForCausalLM.from_pretrained(
self.falcon7b_model_id,
gguf_file=self.fp16_falcon7b_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
original_model = AutoModelForCausalLM.from_pretrained(
self.original_flacon7b_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

quantized_state_dict = quantized_model.state_dict()
original_state_dict = original_model.state_dict()

for layer_name, original_params in original_state_dict.items():
if layer_name in quantized_state_dict:
self.assertTrue(original_params.shape == quantized_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, quantized_state_dict[layer_name])

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down

0 comments on commit 0a147ff

Please sign in to comment.