Skip to content

Commit a9082a4

Browse files
authored
[Bugfix] Fix Qwen3 MoE GPTQ inference (#23490)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent e0329ed commit a9082a4

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

vllm/model_executor/models/qwen3_moe.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
RowParallelLinear)
4646
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4747
from vllm.model_executor.layers.quantization import QuantizationConfig
48+
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
49+
from vllm.model_executor.layers.quantization.gptq_marlin import (
50+
GPTQMarlinConfig)
4851
from vllm.model_executor.layers.rotary_embedding import get_rope
4952
from vllm.model_executor.layers.vocab_parallel_embedding import (
5053
ParallelLMHead, VocabParallelEmbedding)
@@ -146,11 +149,20 @@ def __init__(
146149
enable_eplb=self.enable_eplb,
147150
num_redundant_experts=self.n_redundant_experts)
148151

149-
self.gate = ReplicatedLinear(config.hidden_size,
150-
config.num_experts,
151-
bias=False,
152-
quant_config=quant_config,
153-
prefix=f"{prefix}.gate")
152+
self.gate = ReplicatedLinear(
153+
config.hidden_size,
154+
config.num_experts,
155+
bias=False,
156+
quant_config=self._maybe_ignore_quant_config(quant_config),
157+
prefix=f"{prefix}.gate")
158+
159+
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
160+
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
161+
# seems to avoid gate quantization.
162+
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
163+
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
164+
return None
165+
return quant_config
154166

155167
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
156168
# NOTE: hidden_states can have either 1D or 2D shape.
@@ -682,4 +694,4 @@ def load_weights(self, weights: Iterable[tuple[str,
682694
return loader.load_weights(weights)
683695

684696
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
685-
return self.model.get_expert_mapping()
697+
return self.model.get_expert_mapping()

0 commit comments

Comments
 (0)