Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)

#### Multimodal

Expand Down
101 changes: 99 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,8 +892,8 @@ def get_vocab_base_pre(self, tokenizer) -> str:
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
res = "mellum"
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
res = "llada-moe"
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
res = "bailingmoe2"
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
res = "granite-docling"
Expand Down Expand Up @@ -8063,6 +8063,103 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("BailingMoeV2ForCausalLM")
class BailingMoeV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE2

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if nextn_layers := self.hparams.get("num_nextn_predict_layers", 0):
self.block_count = self.hparams["num_hidden_layers"] + nextn_layers
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

def set_vocab(self):
self._set_vocab_gpt2()

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if (rope_dim := hparams.get("head_dim")) is None:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]

self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
else:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"]))
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_group_count(hparams["n_group"])
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])

if hparams["score_function"] == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif hparams["score_function"] == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")

if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(nextn_layers)

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "mlp.experts" in name:
n_experts = self.hparams["num_experts"]
assert bid is not None

tensors: list[tuple[str, Tensor]] = []

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))

return tensors

if name.endswith(".expert_bias"):
name = name.replace(".expert_bias", ".expert_bias.bias")

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
class GroveMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GROVEMOE
Expand Down
2 changes: 1 addition & 1 deletion convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
]

Expand Down
33 changes: 33 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class LLM:
EXPERT_COUNT = "{arch}.expert_count"
EXPERT_USED_COUNT = "{arch}.expert_used_count"
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
EXPERT_GROUP_COUNT = "{arch}.expert_group_count"
EXPERT_GROUP_USED_COUNT = "{arch}.expert_group_used_count"
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
Expand Down Expand Up @@ -400,6 +402,7 @@ class MODEL_ARCH(IntEnum):
WAVTOKENIZER_DEC = auto()
PLM = auto()
BAILINGMOE = auto()
BAILINGMOE2 = auto()
DOTS1 = auto()
ARCEE = auto()
ERNIE4_5 = auto()
Expand Down Expand Up @@ -744,6 +747,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
MODEL_ARCH.PLM: "plm",
MODEL_ARCH.BAILINGMOE: "bailingmoe",
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
Expand Down Expand Up @@ -2533,6 +2537,35 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.BAILINGMOE2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.DOTS1: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,12 @@ def add_expert_used_count(self, count: int) -> None:
def add_expert_shared_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)

def add_expert_group_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_GROUP_COUNT.format(arch=self.arch), count)

def add_expert_group_used_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_GROUP_USED_COUNT.format(arch=self.arch), count)

def add_expert_weights_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)

Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class TensorNameMap:
"h.{bid}.self_attention.query_key_value", # bloom
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
"model.layers.{bid}.self_attn.query_key_value", # persimmon
"model.layers.{bid}.attention.query_key_value", # bailingmoe2
"h.{bid}.attn.c_attn", # gpt2
"transformer.h.{bid}.mixer.Wqkv", # phi2
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
Expand Down Expand Up @@ -260,6 +261,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"model.layers.{bid}.self_attn.dense", # persimmon
"model.layers.{bid}.attention.dense", # bailingmoe2
"h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
Expand Down Expand Up @@ -373,6 +375,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_EXP_PROBS_B: (
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
),

Expand Down Expand Up @@ -549,6 +552,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
"model.layers.{bid}.attention.query_layernorm", # bailingmoe2
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
"layers.{bid}.self_attn.q_norm", # embeddinggemma
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
Expand All @@ -563,6 +567,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
"model.layers.{bid}.attention.key_layernorm", # bailingmoe2
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
"layers.{bid}.self_attn.k_norm", # embeddinggemma
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
Expand All @@ -584,6 +589,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
"encoder.layer.{bid}.layer_norm_2", # jina-v2-code
"model.layers.{bid}.final_layernorm", # bailingmoe2
),

MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
Expand Down
35 changes: 35 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
Expand Down Expand Up @@ -134,6 +135,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
{ LLM_KV_EXPERT_GROUP_COUNT, "%s.expert_group_count" },
{ LLM_KV_EXPERT_GROUP_USED_COUNT, "%s.expert_group_used_count" },
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
Expand Down Expand Up @@ -1941,6 +1944,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_BAILINGMOE2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
},
},
{
LLM_ARCH_DOTS1,
{
Expand Down
3 changes: 3 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ enum llm_arch {
LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE,
LLM_ARCH_BAILINGMOE2,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5,
Expand Down Expand Up @@ -138,6 +139,8 @@ enum llm_kv {
LLM_KV_EXPERT_COUNT,
LLM_KV_EXPERT_USED_COUNT,
LLM_KV_EXPERT_SHARED_COUNT,
LLM_KV_EXPERT_GROUP_COUNT,
LLM_KV_EXPERT_GROUP_USED_COUNT,
LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC,
Expand Down
Loading
Loading