Skip to content

Commit

Permalink
Improve gguf tensor processing (#34515)
Browse files Browse the repository at this point in the history
* add tensor processing system to separate logic for models

* format refactoring

* small fix

* make some methods private

* move custom methods to processors

* refactor tensor processing

* format fix
  • Loading branch information
VladOS95-cyber authored Nov 21, 2024
1 parent c57eafd commit ae5cbf8
Showing 1 changed file with 212 additions and 124 deletions.
336 changes: 212 additions & 124 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import re
from typing import Dict, Optional
from typing import Dict, NamedTuple, Optional

import numpy as np
from tqdm import tqdm
Expand Down Expand Up @@ -55,6 +55,200 @@
GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["tensors"].keys())


class GGUFTensor(NamedTuple):
weights: np.ndarray
name: str
metadata: dict


class TensorProcessor:
def __init__(self, config=None):
self.config = config or {}

def process(self, weights, name, **kwargs):
return GGUFTensor(weights, name, {})


class LlamaTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name, **kwargs):
if ".attn_k." in name or ".attn_q." in name:
num_heads = self.config.get("num_attention_heads")
num_kv_heads = self.config.get("num_key_value_heads")

if None in (num_heads, num_kv_heads):
return GGUFTensor(weights, name, {})
if ".attn_q." in name:
weights = self._reverse_permute_weights(weights, num_heads, num_heads)
elif ".attn_k." in name:
weights = self._reverse_permute_weights(weights, num_heads, num_kv_heads)
return GGUFTensor(weights, name, {})

def _reverse_permute_weights(
self, weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None
) -> np.ndarray:
# Original permutation implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
if num_kv_heads is not None and n_head != num_kv_heads:
n_head = num_kv_heads

dim = weights.shape[0] // n_head // 2
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
return w.swapaxes(2, 1).reshape(weights.shape)


class Qwen2MoeTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name, **kwargs):
if "_exp" in name:
tensor_key_mapping = kwargs.get("tensor_key_mapping")
parsed_parameters = kwargs.get("parsed_parameters")
if tensor_key_mapping:
self._split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping)
return GGUFTensor(weights, None, {})
if "ffn_gate_inp_shexp" in name:
# for compatibility tensor shared_expert_gate must be (1, 2048) dim,
# quantized one is (2048)
weights = np.expand_dims(weights, axis=0)
return GGUFTensor(weights, name, {})

def _split_moe_expert_tensor(
self, weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict
):
# Original merge implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
exp_name = ""
if "ffn_gate_exps" in name:
exp_name = "gate_proj"
elif "ffn_down_exps" in name:
exp_name = "down_proj"
elif "ffn_up_exps" in name:
exp_name = "up_proj"
else:
raise ValueError(f"Cannot map expert tensor {name} in Qwen2Moe architecture.")
for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
w_counter = self.config.get("num_experts", 60)
for i in range(0, w_counter):
temp_name = name.replace(".weight", f".{i}.{exp_name}.weight")
exp_weight = weights[i]
parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))


class BloomTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name, **kwargs):
if "attn_qkv" in name:
num_heads = self.config["n_head"]
n_embed = self.config["hidden_size"]
if "weight" in name:
weights = self._reverse_reshape_weights(weights, num_heads, n_embed)
else:
weights = self._reverse_reshape_bias(weights, num_heads, n_embed)
return GGUFTensor(weights, name, {})

def _reverse_reshape_weights(self, weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
q, k, v = np.array_split(weights, 3, axis=0)

q = q.reshape(n_head, n_embed // n_head, n_embed)
k = k.reshape(n_head, n_embed // n_head, n_embed)
v = v.reshape(n_head, n_embed // n_head, n_embed)
qkv_weights = np.stack([q, k, v], axis=1)

return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)

def _reverse_reshape_bias(self, weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
q_bias, k_bias, v_bias = np.array_split(weights, 3)

q_bias = q_bias.reshape(n_head, n_embed // n_head)
k_bias = k_bias.reshape(n_head, n_embed // n_head)
v_bias = v_bias.reshape(n_head, n_embed // n_head)

qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
return qkv_bias


class T5TensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name, **kwargs):
bid = None
for chunk in name.split("."):
if chunk.isdigit():
bid = int(chunk)
break
return GGUFTensor(weights, name, {"bid": bid})


class GPT2TensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name, **kwargs):
# Original transpose implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061
if (
"attn_qkv.weight" in name
or "ffn_down.weight" in name
or "ffn_up.weight" in name
or "attn_output.weight" in name
):
weights = weights.T

# Handle special case for output.weight
if name == "output.weight":
# output.weight has conflicts with attn_output.weight in name checking
# Store the tensor directly and signal to skip further processing
name = "lm_head.weight"
parsed_parameters = kwargs.get("parsed_parameters", {})
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
name = None # Signal to skip further processing
return GGUFTensor(weights, name, {})


class MambaTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name, **kwargs):
if "ssm_d" in name and "bias" not in name and "weight" not in name:
# ssm_d has conflicts with ssm_dt in name checking
# we have to explicitly check that name is exactly ssm_d
name = name.replace("ssm_d", "mixer.D")
if "ssm_conv1d.weight" in name:
# for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim,
# quantized one is (5120, 4)
weights = np.expand_dims(weights, axis=1)
if "ssm_a" in name:
# Original exponential implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977
weights = np.log(-weights)
return GGUFTensor(weights, name, {})


TENSOR_PROCESSORS = {
"llama": LlamaTensorProcessor,
"qwen2moe": Qwen2MoeTensorProcessor,
"bloom": BloomTensorProcessor,
"t5": T5TensorProcessor,
"t5encoder": T5TensorProcessor,
"gpt2": GPT2TensorProcessor,
"mamba": MambaTensorProcessor,
}


def read_field(reader, field):
value = reader.fields[field]
return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data]
Expand Down Expand Up @@ -177,73 +371,28 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):

if return_tensors:
tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture + model_size]
config = parsed_parameters.get("config", {})

ProcessorClass = TENSOR_PROCESSORS.get(architecture, TensorProcessor)
processor = ProcessorClass(config=config)

for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
name = tensor.name

weights = dequantize(tensor.data, tensor.tensor_type)

if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
num_heads = parsed_parameters["config"]["num_attention_heads"]
num_kv_heads = parsed_parameters["config"]["num_key_value_heads"]
if ".attn_q." in name:
weights = reverse_permute_weights(weights, num_heads, num_heads)
elif ".attn_k." in name:
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)

if architecture == "qwen2moe":
if "_exp" in name:
split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping)
continue
if "ffn_gate_inp_shexp" in name:
# for compatibility tensor shared_expert_gate must be (1, 2048) dim,
# quantized one is (2048)
weights = np.expand_dims(weights, axis=0)

if architecture == "bloom" and "attn_qkv" in name:
num_heads = parsed_parameters["config"]["n_head"]
n_embed = parsed_parameters["config"]["hidden_size"]
if "weight" in name:
weights = reverse_reshape_weights(weights, num_heads, n_embed)
else:
weights = reverse_reshape_bias(weights, num_heads, n_embed)

bid = None
if architecture in ("t5", "t5encoder"):
for chunk in name.split("."):
if chunk.isdigit():
bid = int(chunk)
break

if architecture == "gpt2":
if (
"attn_qkv.weight" in name
or "ffn_down.weight" in name
or "ffn_up.weight" in name
or "attn_output.weight" in name
):
# Original transpose implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061
weights = weights.T
if name == "output.weight":
# output.weight has conflicts with attn_output.weight in name checking
# we have to explicitly check that name is exactly output.weight
name = "lm_head.weight"
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
continue
if architecture == "mamba":
if "ssm_d" in name and "bias" not in name and "weight" not in name:
# ssm_d has conflicts with ssm_dt in name checking
# we have to explicitly check that name is exactly ssm_d
name = name.replace("ssm_d", "mixer.D")
if "ssm_conv1d.weight" in name:
# for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim,
# quantized one is (5120, 4)
weights = np.expand_dims(weights, axis=1)
if "ssm_a" in name:
# Original exponential implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977
weights = np.log(-weights)
result = processor.process(
weights=weights,
name=name,
tensor_key_mapping=tensor_key_mapping,
parsed_parameters=parsed_parameters,
)

weights = result.weights
name = result.name
bid = result.metadata.get("bid")

if name is None:
continue

for tensor_name in tensor_key_mapping:
if tensor_name.format(bid=bid) in name:
Expand All @@ -256,64 +405,3 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")

return parsed_parameters


def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray:
# Original permutation implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
if num_kv_heads is not None and n_head != num_kv_heads:
n_head = num_kv_heads

dim = weights.shape[0] // n_head // 2
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
return w.swapaxes(2, 1).reshape(weights.shape)


def reverse_reshape_weights(weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
q, k, v = np.array_split(weights, 3, axis=0)

q = q.reshape(n_head, n_embed // n_head, n_embed)
k = k.reshape(n_head, n_embed // n_head, n_embed)
v = v.reshape(n_head, n_embed // n_head, n_embed)
qkv_weights = np.stack([q, k, v], axis=1)

return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)


def reverse_reshape_bias(weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
q_bias, k_bias, v_bias = np.array_split(weights, 3)

q_bias = q_bias.reshape(n_head, n_embed // n_head)
k_bias = k_bias.reshape(n_head, n_embed // n_head)
v_bias = v_bias.reshape(n_head, n_embed // n_head)

qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
return qkv_bias


def split_moe_expert_tensor(
weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict
):
# Original merge implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
exp_name = ""
if "ffn_gate_exps" in name:
exp_name = "gate_proj"
elif "ffn_down_exps" in name:
exp_name = "down_proj"
elif "ffn_up_exps" in name:
exp_name = "up_proj"
else:
raise ValueError(f"Cannot map expert tensor {name} in Qwen2Moe architecture.")
for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
w_counter = parsed_parameters["config"].get("num_experts", 60)
for i in range(0, w_counter):
temp_name = name.replace(".weight", f".{i}.{exp_name}.weight")
exp_weight = weights[i]
parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))

0 comments on commit ae5cbf8

Please sign in to comment.