Skip to content

Commit

Permalink
refactor tensor processing
Browse files Browse the repository at this point in the history
  • Loading branch information
VladOS95-cyber committed Nov 6, 2024
1 parent b47e483 commit ddcea0f
Showing 1 changed file with 59 additions and 40 deletions.
99 changes: 59 additions & 40 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import re
from typing import Dict, Optional
from typing import Dict, Optional, NamedTuple

import numpy as np
from tqdm import tqdm
Expand Down Expand Up @@ -55,25 +55,36 @@
GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["tensors"].keys())


class GGUFTensor(NamedTuple):
weights: np.ndarray
name: str
metadata: dict


class TensorProcessor:
def __init__(self, config=None):
self.config = config or {}

def process(self, weights, name, **kwargs):
return {"weights": weights, "name": name, "metadata": {}}
return GGUFTensor(weights, name, {})


class LlamaTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name, **kwargs):
if ".attn_k." in name or ".attn_q." in name:
config = kwargs.get("config", {})
num_heads = config.get("num_attention_heads")
num_kv_heads = config.get("num_key_value_heads")
num_heads = self.config.get("num_attention_heads")
num_kv_heads = self.config.get("num_key_value_heads")

if None in (num_heads, num_kv_heads):
return weights, name
return GGUFTensor(weights, name, {})
if ".attn_q." in name:
weights = self._reverse_permute_weights(weights, num_heads, num_heads)
elif ".attn_k." in name:
weights = self._reverse_permute_weights(weights, num_heads, num_kv_heads)
return {"weights": weights, "name": name, "metadata": {}}
return GGUFTensor(weights, name, {})

def _reverse_permute_weights(
self, weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None
Expand All @@ -89,22 +100,21 @@ def _reverse_permute_weights(


class Qwen2MoeTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name, **kwargs):
if "_exp" in name:
tensor_key_mapping = kwargs.get("tensor_key_mapping")
config = kwargs.get("config", {})
parsed_parameters = kwargs.get("parsed_parameters")
if tensor_key_mapping:
self._split_moe_expert_tensor(weights, config, name, tensor_key_mapping)
return {
"weights": weights,
"name": None, # Signal to skip further processing
"metadata": {},
}
self._split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping)
return GGUFTensor(weights, None, {})
if "ffn_gate_inp_shexp" in name:
# for compatibility tensor shared_expert_gate must be (1, 2048) dim,
# quantized one is (2048)
weights = np.expand_dims(weights, axis=0)
return {"weights": weights, "name": name, "metadata": {}}
return GGUFTensor(weights, name, {})

def _split_moe_expert_tensor(
self, weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict
Expand All @@ -123,24 +133,26 @@ def _split_moe_expert_tensor(
for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
w_counter = parsed_parameters["config"].get("num_experts", 60)
w_counter = self.config.get("num_experts", 60)
for i in range(0, w_counter):
temp_name = name.replace(".weight", f".{i}.{exp_name}.weight")
exp_weight = weights[i]
parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))


class BloomTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name, **kwargs):
if "attn_qkv" in name:
config = kwargs.get("config", {})
num_heads = config["n_head"]
n_embed = config["hidden_size"]
num_heads = self.config["n_head"]
n_embed = self.config["hidden_size"]
if "weight" in name:
weights = self._reverse_reshape_weights(weights, num_heads, n_embed)
else:
weights = self._reverse_reshape_bias(weights, num_heads, n_embed)
return {"weights": weights, "name": name, "metadata": {}}
return GGUFTensor(weights, name, {})

def _reverse_reshape_weights(self, weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
Expand Down Expand Up @@ -168,16 +180,22 @@ def _reverse_reshape_bias(self, weights: np.ndarray, n_head: int, n_embed: int):


class T5TensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name, **kwargs):
bid = None
for chunk in name.split("."):
if chunk.isdigit():
bid = int(chunk)
break
return {"weights": weights, "name": name, "metadata": {"bid": bid}}
return GGUFTensor(weights, name, {"bid": bid})


class GPT2TensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name, **kwargs):
# Original transpose implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061
Expand All @@ -197,14 +215,13 @@ def process(self, weights, name, **kwargs):
parsed_parameters = kwargs.get("parsed_parameters", {})
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
name = None # Signal to skip further processing
return {
"weights": weights,
"name": name,
"metadata": {},
}
return GGUFTensor(weights, name, {})


class MambaTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name, **kwargs):
if "ssm_d" in name and "bias" not in name and "weight" not in name:
# ssm_d has conflicts with ssm_dt in name checking
Expand All @@ -218,17 +235,17 @@ def process(self, weights, name, **kwargs):
# Original exponential implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977
weights = np.log(-weights)
return {"weights": weights, "name": name, "metadata": {}}
return GGUFTensor(weights, name, {})


TENSOR_PROCESSORS = {
"llama": LlamaTensorProcessor(),
"qwen2moe": Qwen2MoeTensorProcessor(),
"bloom": BloomTensorProcessor(),
"t5": T5TensorProcessor(),
"t5encoder": T5TensorProcessor(),
"gpt2": GPT2TensorProcessor(),
"mamba": MambaTensorProcessor(),
"llama": LlamaTensorProcessor,
"qwen2moe": Qwen2MoeTensorProcessor,
"bloom": BloomTensorProcessor,
"t5": T5TensorProcessor,
"t5encoder": T5TensorProcessor,
"gpt2": GPT2TensorProcessor,
"mamba": MambaTensorProcessor,
}


Expand Down Expand Up @@ -354,7 +371,10 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):

if return_tensors:
tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture + model_size]
processor = TENSOR_PROCESSORS.get(architecture, TensorProcessor())
config = parsed_parameters.get("config", {})

ProcessorClass = TENSOR_PROCESSORS.get(architecture, TensorProcessor)
processor = ProcessorClass(config=config)

for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
name = tensor.name
Expand All @@ -363,16 +383,15 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
result = processor.process(
weights=weights,
name=name,
config=parsed_parameters.get("config", {}),
tensor_key_mapping=tensor_key_mapping,
parsed_parameters=parsed_parameters,
)

weights = result["weights"]
name = result["name"]
bid = result["metadata"].get("bid")
weights = result.weights
name = result.name
bid = result.metadata.get("bid")

if result["name"] is None:
if name is None:
continue

for tensor_name in tensor_key_mapping:
Expand Down

0 comments on commit ddcea0f

Please sign in to comment.