Skip to content

Commit 085e023

Browse files
ctcanbolMekkCyberjusjinuk
authored
Fix Qwen3 MoE GGUF architecture mismatch (#39976)
* fix qwen3moe gguf architecture * Fix Qwen3Moe GGUF loading --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Jinuk Kim <jusjinuk@snu.ac.kr>
1 parent 2ce0dae commit 085e023

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

src/transformers/integrations/ggml.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,14 @@
102102
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
103103
"vocab_size": "vocab_size",
104104
},
105-
"qwen3moe": {
105+
"qwen3_moe": {
106106
"context_length": "max_position_embeddings",
107107
"block_count": "num_hidden_layers",
108108
"feed_forward_length": "intermediate_size",
109109
"embedding_length": "hidden_size",
110110
"rope.dimension_count": None,
111111
"rope.freq_base": "rope_theta",
112+
"attention.key_length": "head_dim",
112113
"attention.head_count": "num_attention_heads",
113114
"attention.head_count_kv": "num_key_value_heads",
114115
"attention.layer_norm_rms_epsilon": "rms_norm_eps",

src/transformers/modeling_gguf_pytorch_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def process(self, weights, name, **kwargs):
246246
TENSOR_PROCESSORS = {
247247
"llama": LlamaTensorProcessor,
248248
"qwen2moe": Qwen2MoeTensorProcessor,
249+
"qwen3moe": Qwen2MoeTensorProcessor,
249250
"bloom": BloomTensorProcessor,
250251
"t5": T5TensorProcessor,
251252
"t5encoder": T5TensorProcessor,
@@ -295,6 +296,8 @@ def get_gguf_hf_weights_map(
295296
model_type = "command-r"
296297
elif model_type == "qwen2_moe":
297298
model_type = "qwen2moe"
299+
elif model_type == "qwen3_moe":
300+
model_type = "qwen3moe"
298301
elif model_type == "gemma3_text":
299302
model_type = "gemma3"
300303
arch = None
@@ -316,8 +319,8 @@ def get_gguf_hf_weights_map(
316319
gguf_to_hf_name_map = {}
317320
state_dict = hf_model.state_dict()
318321
for hf_name in state_dict:
319-
# An exception for qwen2moe model, where the expert layers are packed
320-
if model_type == "qwen2moe" and "mlp.experts." in hf_name:
322+
# An exception for qwen2moe/qwen3moe model, where the expert layers are packed
323+
if model_type in ("qwen2moe", "qwen3moe") and "mlp.experts." in hf_name:
321324
hf_name = re.sub(r"mlp.experts.\d+.", "mlp.experts.", hf_name)
322325

323326
name, suffix = hf_name, ""
@@ -391,6 +394,8 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo
391394

392395
if "qwen2moe" in architecture:
393396
updated_architecture = "qwen2_moe"
397+
elif "qwen3moe" in architecture:
398+
updated_architecture = "qwen3_moe"
394399

395400
# For stablelm architecture, we need to set qkv_bias and use_parallel_residual from tensors
396401
# If `qkv_bias=True`, qkv_proj with bias will be present in the tensors

0 commit comments

Comments
 (0)