diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 6a92cf153321..acd30ee5f20d 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -495,6 +495,125 @@ def _load_weights_mxfp4( loaded_params.add(name) return loaded_params + def load_per_expert_unfused_w4a8( + self, + nm: str, + weight: torch.Tensor, + params_dict: dict[str, torch.nn.Parameter], + expert_params_mapping: list[tuple[str, str, int, str]], + ) -> tuple[bool, str | None]: + """Try to map/load per-expert unfused weights/bias for W4A8. + Returns (handled, target_param_name).""" + if "mlp.experts." not in nm: + return (False, None) + if not any(x in nm for x in (".gate_proj", ".up_proj", ".down_proj")): + return (False, None) + + suffix = None + for suf in (".weight", ".bias", ".weight_scale", ".input_scale"): + if nm.endswith(suf): + suffix = suf.lstrip(".") + break + if suffix is None: + return (False, None) + + try: + layer_pfx, _ = nm.split("mlp.experts.", 1) + layer_pfx = layer_pfx + "mlp.experts." + except ValueError: + return (False, None) + + for param_prefix, weight_prefix, expert_id, shard_id in expert_params_mapping: + if weight_prefix not in nm: + continue + + # choose fused target + if param_prefix.endswith("w13_"): + target_map = { + "weight": "w13_weight", + "weight_scale": "w13_weight_scale", + "bias": "w13_bias", + "input_scale": "w13_input_scale", + } + elif param_prefix.endswith("w2_"): + target_map = { + "weight": "w2_weight", + "weight_scale": "w2_weight_scale", + "bias": "w2_bias", + "input_scale": "w2_input_scale", + } + else: + continue + + tgt_suffix = target_map.get(suffix) + if not tgt_suffix: + continue + target = layer_pfx + tgt_suffix + if target not in params_dict: + continue + + param = params_dict[target] + wl = getattr(param, "weight_loader", None) + + if suffix == "bias": + if callable(wl) and wl is not default_weight_loader: + ok = wl( + param, + weight, + nm, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if ok: + return (True, target) + + inter_size = self.config.intermediate_size + src = weight + if src.dim() == 2: + src = src.squeeze(0) if src.size(0) == 1 else src[expert_id] + + if target.endswith("w13_bias"): + if "gate_proj" in weight_prefix: + col_slice = slice(0, inter_size) + elif "up_proj" in weight_prefix: + col_slice = slice(inter_size, 2 * inter_size) + else: + return (False, None) + + if param.data.dim() == 2: + param.data[expert_id, col_slice].copy_(src) + else: + param.data[col_slice].copy_(src) + + elif target.endswith("w2_bias"): + if param.data.dim() == 2: + param.data[expert_id, :].copy_(src) + else: + param.data.copy_(src) + else: + return (False, None) + + return (True, target) + + # Weights/scales path + if callable(wl) and wl is not default_weight_loader: + ok = wl( + param, + weight, + nm, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if ok: + return (True, target) + else: + default_weight_loader(param, weight) + return (True, target) + + return (False, None) + def _load_weights_other( self, ep_rank_end: int, @@ -525,11 +644,36 @@ def _load_weights_other( tp_rank_start = tp_rank * per_rank_intermediate_size tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) + # W4A8 detection (int4 weights, int8 activations) + qc = getattr(self.config, "quantization_config", None) + group0 = (qc or {}).get("config_groups", {}).get("group_0", {}) + w = group0.get("weights") or {} + ia = group0.get("input_activations") or {} + is_w4a8 = (w.get("num_bits") == 4) and (ia.get("num_bits") == 8) + + # Map per-expert unfused (gate|up|down) → fused MoE params via FusedMoE loader + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_local_experts, + ) + for name, weight in weights: # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + # W4A8 per-expert unfused mapping + if is_w4a8: + handled, target = self.load_per_expert_unfused_w4a8( + name, weight, params_dict, expert_params_mapping + ) + if handled: + if target: + loaded_params.add(target) + continue + if ".w13_weight" in name: # Handle MLP gate and up projection weights # Extract gate and up projection parts @@ -591,12 +735,15 @@ def _load_weights_other( if weight_name not in name: continue name = name.replace(weight_name, param_name) + if name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, weight) else: weight_loader(param, weight, shard_id) + loaded_params.add(name) break else: # Handle all other weights with potential renaming @@ -605,7 +752,7 @@ def _load_weights_other( param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) - loaded_params.add(name) + loaded_params.add(name) return loaded_params def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -635,6 +782,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if hasattr(self.config, "quantization_config") else None ) + if quant_method == "mxfp4": return self._load_weights_mxfp4( ep_rank_end, @@ -657,11 +805,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): is_3d_moe_weight: bool = True - packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + packed_modules_mapping = { + "qkv": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ ".self_attn.": ".attn.", + ".qkv.": ".qkv_proj.", + ".mlp.experts.experts.": ".mlp.experts.", }, orig_to_new_suffix={ ".embed_tokens.weight": ".embedding.weight",