diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index f58bf330ce7db3..cca6d548cdf3ac 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -15,7 +15,7 @@ # limitations under the License. import re -from typing import Dict, Optional +from typing import Dict, NamedTuple, Optional import numpy as np from tqdm import tqdm @@ -55,6 +55,200 @@ GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["tensors"].keys()) +class GGUFTensor(NamedTuple): + weights: np.ndarray + name: str + metadata: dict + + +class TensorProcessor: + def __init__(self, config=None): + self.config = config or {} + + def process(self, weights, name, **kwargs): + return GGUFTensor(weights, name, {}) + + +class LlamaTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if ".attn_k." in name or ".attn_q." in name: + num_heads = self.config.get("num_attention_heads") + num_kv_heads = self.config.get("num_key_value_heads") + + if None in (num_heads, num_kv_heads): + return GGUFTensor(weights, name, {}) + if ".attn_q." in name: + weights = self._reverse_permute_weights(weights, num_heads, num_heads) + elif ".attn_k." in name: + weights = self._reverse_permute_weights(weights, num_heads, num_kv_heads) + return GGUFTensor(weights, name, {}) + + def _reverse_permute_weights( + self, weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None + ) -> np.ndarray: + # Original permutation implementation + # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408 + if num_kv_heads is not None and n_head != num_kv_heads: + n_head = num_kv_heads + + dim = weights.shape[0] // n_head // 2 + w = weights.reshape(n_head, dim, 2, *weights.shape[1:]) + return w.swapaxes(2, 1).reshape(weights.shape) + + +class Qwen2MoeTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if "_exp" in name: + tensor_key_mapping = kwargs.get("tensor_key_mapping") + parsed_parameters = kwargs.get("parsed_parameters") + if tensor_key_mapping: + self._split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping) + return GGUFTensor(weights, None, {}) + if "ffn_gate_inp_shexp" in name: + # for compatibility tensor shared_expert_gate must be (1, 2048) dim, + # quantized one is (2048) + weights = np.expand_dims(weights, axis=0) + return GGUFTensor(weights, name, {}) + + def _split_moe_expert_tensor( + self, weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict + ): + # Original merge implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022 + exp_name = "" + if "ffn_gate_exps" in name: + exp_name = "gate_proj" + elif "ffn_down_exps" in name: + exp_name = "down_proj" + elif "ffn_up_exps" in name: + exp_name = "up_proj" + else: + raise ValueError(f"Cannot map expert tensor {name} in Qwen2Moe architecture.") + for tensor_name in tensor_key_mapping: + if tensor_name in name: + name = name.replace(tensor_name, tensor_key_mapping[tensor_name]) + w_counter = self.config.get("num_experts", 60) + for i in range(0, w_counter): + temp_name = name.replace(".weight", f".{i}.{exp_name}.weight") + exp_weight = weights[i] + parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight)) + + +class BloomTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if "attn_qkv" in name: + num_heads = self.config["n_head"] + n_embed = self.config["hidden_size"] + if "weight" in name: + weights = self._reverse_reshape_weights(weights, num_heads, n_embed) + else: + weights = self._reverse_reshape_bias(weights, num_heads, n_embed) + return GGUFTensor(weights, name, {}) + + def _reverse_reshape_weights(self, weights: np.ndarray, n_head: int, n_embed: int): + # Original reshape implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985 + q, k, v = np.array_split(weights, 3, axis=0) + + q = q.reshape(n_head, n_embed // n_head, n_embed) + k = k.reshape(n_head, n_embed // n_head, n_embed) + v = v.reshape(n_head, n_embed // n_head, n_embed) + qkv_weights = np.stack([q, k, v], axis=1) + + return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed) + + def _reverse_reshape_bias(self, weights: np.ndarray, n_head: int, n_embed: int): + # Original reshape implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998 + q_bias, k_bias, v_bias = np.array_split(weights, 3) + + q_bias = q_bias.reshape(n_head, n_embed // n_head) + k_bias = k_bias.reshape(n_head, n_embed // n_head) + v_bias = v_bias.reshape(n_head, n_embed // n_head) + + qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten() + return qkv_bias + + +class T5TensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + bid = None + for chunk in name.split("."): + if chunk.isdigit(): + bid = int(chunk) + break + return GGUFTensor(weights, name, {"bid": bid}) + + +class GPT2TensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + # Original transpose implementation + # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061 + if ( + "attn_qkv.weight" in name + or "ffn_down.weight" in name + or "ffn_up.weight" in name + or "attn_output.weight" in name + ): + weights = weights.T + + # Handle special case for output.weight + if name == "output.weight": + # output.weight has conflicts with attn_output.weight in name checking + # Store the tensor directly and signal to skip further processing + name = "lm_head.weight" + parsed_parameters = kwargs.get("parsed_parameters", {}) + parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights)) + name = None # Signal to skip further processing + return GGUFTensor(weights, name, {}) + + +class MambaTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if "ssm_d" in name and "bias" not in name and "weight" not in name: + # ssm_d has conflicts with ssm_dt in name checking + # we have to explicitly check that name is exactly ssm_d + name = name.replace("ssm_d", "mixer.D") + if "ssm_conv1d.weight" in name: + # for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim, + # quantized one is (5120, 4) + weights = np.expand_dims(weights, axis=1) + if "ssm_a" in name: + # Original exponential implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977 + weights = np.log(-weights) + return GGUFTensor(weights, name, {}) + + +TENSOR_PROCESSORS = { + "llama": LlamaTensorProcessor, + "qwen2moe": Qwen2MoeTensorProcessor, + "bloom": BloomTensorProcessor, + "t5": T5TensorProcessor, + "t5encoder": T5TensorProcessor, + "gpt2": GPT2TensorProcessor, + "mamba": MambaTensorProcessor, +} + + def read_field(reader, field): value = reader.fields[field] return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data] @@ -177,73 +371,28 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): if return_tensors: tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture + model_size] + config = parsed_parameters.get("config", {}) + + ProcessorClass = TENSOR_PROCESSORS.get(architecture, TensorProcessor) + processor = ProcessorClass(config=config) for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."): name = tensor.name - weights = dequantize(tensor.data, tensor.tensor_type) - if architecture == "llama" and (".attn_k." in name or ".attn_q." in name): - num_heads = parsed_parameters["config"]["num_attention_heads"] - num_kv_heads = parsed_parameters["config"]["num_key_value_heads"] - if ".attn_q." in name: - weights = reverse_permute_weights(weights, num_heads, num_heads) - elif ".attn_k." in name: - weights = reverse_permute_weights(weights, num_heads, num_kv_heads) - - if architecture == "qwen2moe": - if "_exp" in name: - split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping) - continue - if "ffn_gate_inp_shexp" in name: - # for compatibility tensor shared_expert_gate must be (1, 2048) dim, - # quantized one is (2048) - weights = np.expand_dims(weights, axis=0) - - if architecture == "bloom" and "attn_qkv" in name: - num_heads = parsed_parameters["config"]["n_head"] - n_embed = parsed_parameters["config"]["hidden_size"] - if "weight" in name: - weights = reverse_reshape_weights(weights, num_heads, n_embed) - else: - weights = reverse_reshape_bias(weights, num_heads, n_embed) - - bid = None - if architecture in ("t5", "t5encoder"): - for chunk in name.split("."): - if chunk.isdigit(): - bid = int(chunk) - break - - if architecture == "gpt2": - if ( - "attn_qkv.weight" in name - or "ffn_down.weight" in name - or "ffn_up.weight" in name - or "attn_output.weight" in name - ): - # Original transpose implementation - # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061 - weights = weights.T - if name == "output.weight": - # output.weight has conflicts with attn_output.weight in name checking - # we have to explicitly check that name is exactly output.weight - name = "lm_head.weight" - parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights)) - continue - if architecture == "mamba": - if "ssm_d" in name and "bias" not in name and "weight" not in name: - # ssm_d has conflicts with ssm_dt in name checking - # we have to explicitly check that name is exactly ssm_d - name = name.replace("ssm_d", "mixer.D") - if "ssm_conv1d.weight" in name: - # for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim, - # quantized one is (5120, 4) - weights = np.expand_dims(weights, axis=1) - if "ssm_a" in name: - # Original exponential implementation - # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977 - weights = np.log(-weights) + result = processor.process( + weights=weights, + name=name, + tensor_key_mapping=tensor_key_mapping, + parsed_parameters=parsed_parameters, + ) + + weights = result.weights + name = result.name + bid = result.metadata.get("bid") + + if name is None: + continue for tensor_name in tensor_key_mapping: if tensor_name.format(bid=bid) in name: @@ -256,64 +405,3 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}") return parsed_parameters - - -def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray: - # Original permutation implementation - # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408 - if num_kv_heads is not None and n_head != num_kv_heads: - n_head = num_kv_heads - - dim = weights.shape[0] // n_head // 2 - w = weights.reshape(n_head, dim, 2, *weights.shape[1:]) - return w.swapaxes(2, 1).reshape(weights.shape) - - -def reverse_reshape_weights(weights: np.ndarray, n_head: int, n_embed: int): - # Original reshape implementation - # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985 - q, k, v = np.array_split(weights, 3, axis=0) - - q = q.reshape(n_head, n_embed // n_head, n_embed) - k = k.reshape(n_head, n_embed // n_head, n_embed) - v = v.reshape(n_head, n_embed // n_head, n_embed) - qkv_weights = np.stack([q, k, v], axis=1) - - return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed) - - -def reverse_reshape_bias(weights: np.ndarray, n_head: int, n_embed: int): - # Original reshape implementation - # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998 - q_bias, k_bias, v_bias = np.array_split(weights, 3) - - q_bias = q_bias.reshape(n_head, n_embed // n_head) - k_bias = k_bias.reshape(n_head, n_embed // n_head) - v_bias = v_bias.reshape(n_head, n_embed // n_head) - - qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten() - return qkv_bias - - -def split_moe_expert_tensor( - weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict -): - # Original merge implementation - # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022 - exp_name = "" - if "ffn_gate_exps" in name: - exp_name = "gate_proj" - elif "ffn_down_exps" in name: - exp_name = "down_proj" - elif "ffn_up_exps" in name: - exp_name = "up_proj" - else: - raise ValueError(f"Cannot map expert tensor {name} in Qwen2Moe architecture.") - for tensor_name in tensor_key_mapping: - if tensor_name in name: - name = name.replace(tensor_name, tensor_key_mapping[tensor_name]) - w_counter = parsed_parameters["config"].get("num_experts", 60) - for i in range(0, w_counter): - temp_name = name.replace(".weight", f".{i}.{exp_name}.weight") - exp_weight = weights[i] - parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))