From 3fdc9570ca59671c4e2bc3bbf52bf32a3f323d9f Mon Sep 17 00:00:00 2001 From: JartX Date: Sat, 30 Aug 2025 17:14:19 +0200 Subject: [PATCH 1/9] Fix load AutoGPTQ and Autoround_GPTQ Models using fallback Signed-off-by: JartX --- vllm/model_executor/models/qwen3_moe.py | 126 +++++++++++------------- 1 file changed, 56 insertions(+), 70 deletions(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 94e6a66bea5c..79672cd26b4d 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. @@ -22,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" +import os import typing from collections.abc import Callable, Iterable from itertools import islice @@ -112,28 +112,23 @@ def __init__( ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() - self.ep_group = get_ep_group().device_group self.ep_rank = self.ep_group.rank() self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts - if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}.") - # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb - self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * self.n_local_physical_experts) self.physical_expert_end = (self.physical_expert_start + @@ -149,33 +144,28 @@ def __init__( prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) - self.gate = ReplicatedLinear( config.hidden_size, config.num_experts, bias=False, - quant_config=self._maybe_ignore_quant_config(quant_config), + quant_config=self.maybe_not_quantization(quant_config), prefix=f"{prefix}.gate") - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - # GPTQ configs do not have a list of ignored modules, however AutoGPTQ - # seems to avoid gate quantization. - # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4 + def maybe_not_quantization(self, + quant_config: Optional[QuantizationConfig]): + if os.environ.get("VLLM_QUANTIZATION_FROM_AUTOROUND_GPTQ") == "1": + return quant_config if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): return None return quant_config def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - - # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - return final_hidden_states.view(orig_shape) @@ -228,7 +218,6 @@ def __init__( bias=qkv_bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj") - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, hidden_size, bias=False, @@ -267,16 +256,16 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # Add qk-norm + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) @@ -302,6 +291,7 @@ def __init__( dual_chunk_attention_config = getattr(config, "dual_chunk_attention_config", None) + self.self_attn = Qwen3MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -318,13 +308,12 @@ def __init__( dual_chunk_attention_config=dual_chunk_attention_config, ) - # `mlp_only_layers` in the config. layer_idx = extract_layer_index(prefix) mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers) if (layer_idx not in mlp_only_layers) and ( - config.num_experts > 0 and - (layer_idx + 1) % config.decoder_sparse_step == 0): + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = Qwen3MoeSparseMoeBlock(config=config, quant_config=quant_config, prefix=f"{prefix}.mlp", @@ -346,19 +335,17 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) + hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) - - # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) @@ -370,11 +357,11 @@ class Qwen3MoeModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config + enable_eplb = parallel_config.enable_eplb eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts @@ -382,6 +369,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.config = config + self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -397,6 +385,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) @@ -421,19 +410,20 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -444,7 +434,6 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ - # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), @@ -452,38 +441,30 @@ def load_weights(self, weights: Iterable[tuple[str, ("gate_up_proj", "up_proj", 1), ] - # Skip loading extra parameters for GPTQ/modelopt models. ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", ".v_scale", "_v_scale", ".weight_scale", "_weight_scale", ".input_scale", "_input_scale") params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. if "mlp.experts" in name: continue name = name.replace(weight_name, param_name) - # Skip loading extra parameters for GPTQ/modelopt models. if name.endswith(ignore_suffixes) and name not in params_dict: continue - # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + if name.endswith("scale"): - # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue @@ -504,28 +485,16 @@ def load_weights(self, weights: Iterable[tuple[str, param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - - # Anyway, this is an expert weight and should not be - # attempted to load as other weights later is_expert_weight = True - - # Do not modify `name` since the loop may continue here - # Instead, create a new variable name_mapped = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name_mapped, self): continue - - # Skip loading extra parameters for GPTQ/modelopt models. if name_mapped.endswith( ignore_suffixes ) and name_mapped not in params_dict: continue param = params_dict[name_mapped] - # We should ask the weight loader to return success or not - # here since otherwise we may skip experts with other - # available replicas. weight_loader = typing.cast(Callable[..., bool], param.weight_loader) success = weight_loader(param, @@ -539,19 +508,14 @@ def load_weights(self, weights: Iterable[tuple[str, break else: if is_expert_weight: - # We've checked that this is an expert weight - # However it's not mapped locally to this rank - # So we simply skip it continue - - # Skip loading extra parameters for GPTQ/modelopt models. if name.endswith( ignore_suffixes) and name not in params_dict: continue - # Skip layers on other devices. + if is_pp_missing_parameter(name, self): continue - # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( ".kv_scale", ".attn.kv_scale") @@ -568,6 +532,7 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) return loaded_params @@ -585,7 +550,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, "up_proj", ], } - fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -604,24 +568,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - - # Set MoE hyperparameters self.expert_weights = [] - self.moe_layers: list[FusedMoE] = [] example_layer = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue - assert isinstance(layer, Qwen3MoeDecoderLayer) if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): example_layer = layer.mlp self.moe_layers.append(layer.mlp.experts) - if example_layer is None: - raise RuntimeError("No Qwen3MoE layer found in the model.layers.") - + raise RuntimeError("No Qwen3Moe layer found in the model.layers.") self.num_moe_layers = len(self.moe_layers) self.num_expert_groups = 1 self.num_shared_experts = 0 @@ -638,7 +596,6 @@ def set_eplb_state( logical_replica_count: torch.Tensor, ) -> None: for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. self.expert_weights.append(layer.get_expert_weights()) layer.set_eplb_state( moe_layer_idx=layer_idx, @@ -690,8 +647,37 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + weights_list = list(weights) + try: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights_list) + except Exception: + logger.warning("Detected quantized MoE gate layers. " + "Proceeding with automatic" + "gate layer adjustment" + "for compatibility.") + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + if hasattr(layer, "mlp") and isinstance( + layer.mlp, Qwen3MoeSparseMoeBlock): + moe_block = layer.mlp + original_gate = moe_block.gate + + new_gate = ReplicatedLinear( + self.config.hidden_size, + self.config.num_experts, + bias=False, + quant_config=self.quant_config, + ).to(device=original_gate.weight.device, + dtype=original_gate.weight.dtype) + + moe_block.gate = new_gate + + logger.info("MoE gate layers adjusted successfully." + " Continuing with weight loading.") + loader = AutoWeightsLoader(self) + return loader.load_weights(weights_list) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() From 0d5658da9f555e24657f3bf906ab364f69f579b4 Mon Sep 17 00:00:00 2001 From: JartX Date: Sat, 30 Aug 2025 17:14:48 +0200 Subject: [PATCH 2/9] Fix load AutoGPTQ and Autoround_GPTQ Models using fallback Signed-off-by: JartX --- vllm/model_executor/models/qwen3_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 79672cd26b4d..267748dd7949 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -312,8 +312,8 @@ def __init__( mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers) if (layer_idx not in mlp_only_layers) and ( - config.num_experts > 0 and - (layer_idx + 1) % config.decoder_sparse_step == 0): + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = Qwen3MoeSparseMoeBlock(config=config, quant_config=quant_config, prefix=f"{prefix}.mlp", From 06c041028fc08fe9ad5524e96828c968a70c317d Mon Sep 17 00:00:00 2001 From: JartX Date: Sat, 30 Aug 2025 17:27:18 +0200 Subject: [PATCH 3/9] complete warning message to informe to the user of the use of the env variable Signed-off-by: JartX --- vllm/model_executor/models/qwen3_moe.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 267748dd7949..830ce2054f3d 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -653,9 +653,13 @@ def load_weights(self, weights: Iterable[tuple[str, return loader.load_weights(weights_list) except Exception: logger.warning("Detected quantized MoE gate layers. " - "Proceeding with automatic" - "gate layer adjustment" - "for compatibility.") + "Proceeding with automatic " + "gate layer adjustment " + "for compatibility. " + "Please use the env variable: " + "VLLM_QUANTIZATION_FROM_AUTOROUND_GPTQ=1 " + "to avoid the adjustment and the WARNING: " + "Current vLLM config is not set.") for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue From 4e83a7ec365ed7ca80529ff4a54a33b02efb4c6c Mon Sep 17 00:00:00 2001 From: JartX Date: Sat, 30 Aug 2025 17:36:34 +0200 Subject: [PATCH 4/9] restore comments Signed-off-by: JartX --- vllm/model_executor/models/qwen3_moe.py | 40 +++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 830ce2054f3d..8c34c8a4865b 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. @@ -121,6 +122,7 @@ def __init__( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}.") + # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb @@ -160,9 +162,11 @@ def maybe_not_quantization(self, return quant_config def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) @@ -257,6 +261,7 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # Add qk-norm q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head) @@ -308,6 +313,7 @@ def __init__( dual_chunk_attention_config=dual_chunk_attention_config, ) + # `mlp_only_layers` in the config. layer_idx = extract_layer_index(prefix) mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers) @@ -335,6 +341,7 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -346,6 +353,7 @@ def forward( positions=positions, hidden_states=hidden_states, ) + # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) @@ -424,6 +432,8 @@ def forward( return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -434,6 +444,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ + # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), @@ -441,6 +452,7 @@ def load_weights(self, weights: Iterable[tuple[str, ("gate_up_proj", "up_proj", 1), ] + # Skip loading extra parameters for GPTQ/modelopt models. ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", ".v_scale", "_v_scale", ".weight_scale", "_weight_scale", ".input_scale", "_input_scale") @@ -452,19 +464,29 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. if "mlp.experts" in name: continue name = name.replace(weight_name, param_name) + # Skip loading extra parameters for GPTQ/modelopt models. if name.endswith(ignore_suffixes) and name not in params_dict: continue + # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue @@ -485,16 +507,24 @@ def load_weights(self, weights: Iterable[tuple[str, param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later is_expert_weight = True + # Do not modify `name` since the loop may continue here + # Instead, create a new variable name_mapped = name.replace(weight_name, param_name) if is_pp_missing_parameter(name_mapped, self): continue + # Skip loading extra parameters for GPTQ/modelopt models. if name_mapped.endswith( ignore_suffixes ) and name_mapped not in params_dict: continue param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. weight_loader = typing.cast(Callable[..., bool], param.weight_loader) success = weight_loader(param, @@ -508,14 +538,20 @@ def load_weights(self, weights: Iterable[tuple[str, break else: if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it continue + # Skip loading extra parameters for GPTQ/modelopt models. if name.endswith( ignore_suffixes) and name not in params_dict: continue + # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( ".kv_scale", ".attn.kv_scale") @@ -550,6 +586,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, "up_proj", ], } + fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -568,6 +605,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + + # Set MoE hyperparameters self.expert_weights = [] self.moe_layers: list[FusedMoE] = [] example_layer = None @@ -596,6 +635,7 @@ def set_eplb_state( logical_replica_count: torch.Tensor, ) -> None: for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. self.expert_weights.append(layer.get_expert_weights()) layer.set_eplb_state( moe_layer_idx=layer_idx, From 52539f503971e34aa062a6350b8ed0f80ea7f91f Mon Sep 17 00:00:00 2001 From: JartX Date: Sat, 30 Aug 2025 17:49:48 +0200 Subject: [PATCH 5/9] restore _maybe_not_quantization name Signed-off-by: JartX --- vllm/model_executor/models/qwen3_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 8c34c8a4865b..e4e86f4d47d1 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -150,11 +150,11 @@ def __init__( config.hidden_size, config.num_experts, bias=False, - quant_config=self.maybe_not_quantization(quant_config), + quant_config=self._maybe_not_quantization(quant_config), prefix=f"{prefix}.gate") - def maybe_not_quantization(self, - quant_config: Optional[QuantizationConfig]): + def _maybe_not_quantization(self, + quant_config: Optional[QuantizationConfig]): if os.environ.get("VLLM_QUANTIZATION_FROM_AUTOROUND_GPTQ") == "1": return quant_config if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): From 749fdde468476d8c53ede50c5639abe624965173 Mon Sep 17 00:00:00 2001 From: JartX Date: Sat, 30 Aug 2025 18:06:42 +0200 Subject: [PATCH 6/9] changed Exception to KeyError Signed-off-by: JartX --- vllm/model_executor/models/qwen3_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index e4e86f4d47d1..ef41f04eccd3 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -691,7 +691,7 @@ def load_weights(self, weights: Iterable[tuple[str, try: loader = AutoWeightsLoader(self) return loader.load_weights(weights_list) - except Exception: + except KeyError: logger.warning("Detected quantized MoE gate layers. " "Proceeding with automatic " "gate layer adjustment " From 146bdc72a4386c1dc63ad92aba5ed91320fc7ae4 Mon Sep 17 00:00:00 2001 From: JartX Date: Sun, 31 Aug 2025 01:00:15 +0200 Subject: [PATCH 7/9] extract autoround key at init Signed-off-by: JartX --- vllm/model_executor/models/qwen3_moe.py | 69 ++++++------------------- 1 file changed, 17 insertions(+), 52 deletions(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index ef41f04eccd3..3d2a7a3292b0 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -22,7 +22,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" -import os import typing from collections.abc import Callable, Iterable from itertools import islice @@ -47,9 +46,6 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -122,6 +118,16 @@ def __init__( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}.") + from_autoround_gptq = False + if hasattr(config, "quantization_config"): + q_config = config.quantization_config + if (isinstance(q_config, dict) + and q_config.get("quant_method") == "gptq" + and "autoround_version" in q_config): + from_autoround_gptq = True + + gate_quant_config = quant_config if from_autoround_gptq else None + # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config @@ -146,20 +152,11 @@ def __init__( prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) - self.gate = ReplicatedLinear( - config.hidden_size, - config.num_experts, - bias=False, - quant_config=self._maybe_not_quantization(quant_config), - prefix=f"{prefix}.gate") - - def _maybe_not_quantization(self, - quant_config: Optional[QuantizationConfig]): - if os.environ.get("VLLM_QUANTIZATION_FROM_AUTOROUND_GPTQ") == "1": - return quant_config - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): - return None - return quant_config + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=gate_quant_config, + prefix=f"{prefix}.gate") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -688,40 +685,8 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights_list = list(weights) - try: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights_list) - except KeyError: - logger.warning("Detected quantized MoE gate layers. " - "Proceeding with automatic " - "gate layer adjustment " - "for compatibility. " - "Please use the env variable: " - "VLLM_QUANTIZATION_FROM_AUTOROUND_GPTQ=1 " - "to avoid the adjustment and the WARNING: " - "Current vLLM config is not set.") - for layer in self.model.layers: - if isinstance(layer, PPMissingLayer): - continue - if hasattr(layer, "mlp") and isinstance( - layer.mlp, Qwen3MoeSparseMoeBlock): - moe_block = layer.mlp - original_gate = moe_block.gate - - new_gate = ReplicatedLinear( - self.config.hidden_size, - self.config.num_experts, - bias=False, - quant_config=self.quant_config, - ).to(device=original_gate.weight.device, - dtype=original_gate.weight.dtype) - - moe_block.gate = new_gate - - logger.info("MoE gate layers adjusted successfully." - " Continuing with weight loading.") - loader = AutoWeightsLoader(self) - return loader.load_weights(weights_list) + loader = AutoWeightsLoader(self) + return loader.load_weights(weights_list) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() From 7e1eddf3abd67129d23581eef65f6174095537db Mon Sep 17 00:00:00 2001 From: JartX Date: Sun, 31 Aug 2025 01:06:40 +0200 Subject: [PATCH 8/9] restore load_weights to original Signed-off-by: JartX --- vllm/model_executor/models/qwen3_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 3d2a7a3292b0..42d2ced04874 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -684,9 +684,8 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - weights_list = list(weights) loader = AutoWeightsLoader(self) - return loader.load_weights(weights_list) + return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() From 66d0ab5b2406afea396e5e944d4cfed628a798d7 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sun, 31 Aug 2025 20:11:21 +0800 Subject: [PATCH 9/9] update quant config Signed-off-by: Isotr0py --- .../layers/quantization/gptq.py | 8 +- .../layers/quantization/gptq_marlin.py | 3 + vllm/model_executor/models/qwen3_moe.py | 76 +++++++++++-------- 3 files changed, 53 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index f18c936bac60..2272709f9309 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -37,6 +37,7 @@ def __init__( desc_act: bool, lm_head_quantized: bool, dynamic: dict[str, dict[str, Union[int, bool]]], + autoround_version: str = "", ) -> None: # GPTQModel use `dynamic` config property to allow per module # quantization config so each module can be individually optimized. @@ -74,6 +75,9 @@ def __init__( "Currently, only 2/3/4/8-bit weight quantization is " f"supported for GPTQ, but got {self.weight_bits} bits.") + # used to identify GPTQ model quantized by autoround + self.autoround_version = autoround_version + def __repr__(self) -> str: return (f"GPTQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " @@ -108,8 +112,10 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQConfig": desc_act = cls.get_from_keys(config, ["desc_act"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + autoround_version = cls.get_from_keys_or(config, ["autoround_version"], + default="") return cls(weight_bits, group_size, desc_act, lm_head_quantized, - dynamic) + dynamic, autoround_version) def get_quant_method( self, layer: torch.nn.Module, prefix: str diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 350975966668..3644d91f64e3 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -119,6 +119,9 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + # used to identify GPTQ model quantized by autoround + self.autoround_version = full_config.get("autoround_version", "") + def __repr__(self) -> str: return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 42d2ced04874..a7e0a00350e6 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# + # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. @@ -46,6 +46,9 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -109,34 +112,28 @@ def __init__( ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() + self.ep_group = get_ep_group().device_group self.ep_rank = self.ep_group.rank() self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts + if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}.") - from_autoround_gptq = False - if hasattr(config, "quantization_config"): - q_config = config.quantization_config - if (isinstance(q_config, dict) - and q_config.get("quant_method") == "gptq" - and "autoround_version" in q_config): - from_autoround_gptq = True - - gate_quant_config = quant_config if from_autoround_gptq else None - # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb + self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.physical_expert_start = (self.ep_rank * self.n_local_physical_experts) self.physical_expert_end = (self.physical_expert_start + @@ -152,21 +149,37 @@ def __init__( prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=gate_quant_config, - prefix=f"{prefix}.gate") + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=f"{prefix}.gate") + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid gate quantization while AutoRound does. + # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4, + # and https://huggingface.co/jart25/Qwen3-Coder-30B-A3B-Instruct-Int4-gptq + if isinstance( + quant_config, + (GPTQConfig, + GPTQMarlinConfig)) and not quant_config.autoround_version: + return None + return quant_config def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) + return final_hidden_states.view(orig_shape) @@ -219,6 +232,7 @@ def __init__( bias=qkv_bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj") + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, hidden_size, bias=False, @@ -257,17 +271,16 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # Add qk-norm q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) - q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) @@ -293,7 +306,6 @@ def __init__( dual_chunk_attention_config = getattr(config, "dual_chunk_attention_config", None) - self.self_attn = Qwen3MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -345,11 +357,11 @@ def forward( else: hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) + # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) @@ -362,11 +374,11 @@ class Qwen3MoeModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config - enable_eplb = parallel_config.enable_eplb eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts @@ -374,7 +386,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.config = config - self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -390,7 +401,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) @@ -415,16 +425,13 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) - if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) - hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -456,9 +463,7 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() - expert_params_mapping = self.get_expert_mapping() - for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -481,7 +486,6 @@ def load_weights(self, weights: Iterable[tuple[str, # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - if name.endswith("scale"): # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) @@ -504,14 +508,18 @@ def load_weights(self, weights: Iterable[tuple[str, param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue + # Anyway, this is an expert weight and should not be # attempted to load as other weights later is_expert_weight = True + # Do not modify `name` since the loop may continue here # Instead, create a new variable name_mapped = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name_mapped, self): continue + # Skip loading extra parameters for GPTQ/modelopt models. if name_mapped.endswith( ignore_suffixes @@ -539,15 +547,14 @@ def load_weights(self, weights: Iterable[tuple[str, # However it's not mapped locally to this rank # So we simply skip it continue + # Skip loading extra parameters for GPTQ/modelopt models. if name.endswith( ignore_suffixes) and name not in params_dict: continue - # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( @@ -565,7 +572,6 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) return loaded_params @@ -605,17 +611,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Set MoE hyperparameters self.expert_weights = [] + self.moe_layers: list[FusedMoE] = [] example_layer = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue + assert isinstance(layer, Qwen3MoeDecoderLayer) if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): example_layer = layer.mlp self.moe_layers.append(layer.mlp.experts) + if example_layer is None: - raise RuntimeError("No Qwen3Moe layer found in the model.layers.") + raise RuntimeError("No Qwen3MoE layer found in the model.layers.") + self.num_moe_layers = len(self.moe_layers) self.num_expert_groups = 1 self.num_shared_experts = 0