Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 155 additions & 2 deletions vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,125 @@ def _load_weights_mxfp4(
loaded_params.add(name)
return loaded_params

def load_per_expert_unfused_w4a8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice for us not to specialize these functions for specific quantization schemes.
why can't the w4a8 be an argument for this function, instead of baking it into the name/impl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice for us not to specialize these functions for specific quantization schemes. why can't the w4a8 be an argument for this function, instead of baking it into the name/impl?

I think its good to have different weight specific function and then we can call those specific function from load_weights_other, since all other combinations can have different weight loading scheme, we can do that if we have any in future

self,
nm: str,
weight: torch.Tensor,
params_dict: dict[str, torch.nn.Parameter],
expert_params_mapping: list[tuple[str, str, int, str]],
) -> tuple[bool, str | None]:
"""Try to map/load per-expert unfused weights/bias for W4A8.
Returns (handled, target_param_name)."""
if "mlp.experts." not in nm:
return (False, None)
if not any(x in nm for x in (".gate_proj", ".up_proj", ".down_proj")):
return (False, None)

suffix = None
for suf in (".weight", ".bias", ".weight_scale", ".input_scale"):
if nm.endswith(suf):
suffix = suf.lstrip(".")
break
if suffix is None:
return (False, None)

try:
layer_pfx, _ = nm.split("mlp.experts.", 1)
layer_pfx = layer_pfx + "mlp.experts."
except ValueError:
return (False, None)

for param_prefix, weight_prefix, expert_id, shard_id in expert_params_mapping:
if weight_prefix not in nm:
continue

# choose fused target
if param_prefix.endswith("w13_"):
target_map = {
"weight": "w13_weight",
"weight_scale": "w13_weight_scale",
"bias": "w13_bias",
"input_scale": "w13_input_scale",
}
elif param_prefix.endswith("w2_"):
target_map = {
"weight": "w2_weight",
"weight_scale": "w2_weight_scale",
"bias": "w2_bias",
"input_scale": "w2_input_scale",
}
else:
continue

tgt_suffix = target_map.get(suffix)
if not tgt_suffix:
continue
target = layer_pfx + tgt_suffix
if target not in params_dict:
continue

param = params_dict[target]
wl = getattr(param, "weight_loader", None)

if suffix == "bias":
if callable(wl) and wl is not default_weight_loader:
ok = wl(
param,
weight,
nm,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if ok:
return (True, target)

inter_size = self.config.intermediate_size
src = weight
if src.dim() == 2:
src = src.squeeze(0) if src.size(0) == 1 else src[expert_id]

if target.endswith("w13_bias"):
if "gate_proj" in weight_prefix:
col_slice = slice(0, inter_size)
elif "up_proj" in weight_prefix:
col_slice = slice(inter_size, 2 * inter_size)
else:
return (False, None)

if param.data.dim() == 2:
param.data[expert_id, col_slice].copy_(src)
else:
param.data[col_slice].copy_(src)

elif target.endswith("w2_bias"):
if param.data.dim() == 2:
param.data[expert_id, :].copy_(src)
else:
param.data.copy_(src)
else:
return (False, None)

return (True, target)

# Weights/scales path
if callable(wl) and wl is not default_weight_loader:
ok = wl(
param,
weight,
nm,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if ok:
return (True, target)
else:
default_weight_loader(param, weight)
return (True, target)

return (False, None)

def _load_weights_other(
self,
ep_rank_end: int,
Expand Down Expand Up @@ -525,11 +644,36 @@ def _load_weights_other(
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)

# W4A8 detection (int4 weights, int8 activations)
Copy link
Contributor

@fadara01 fadara01 Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally speaking, can we apply the logic that this PR applies in process_weights_after_loading instead of needing to change the modeling file?, if not, then why not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally speaking, can we apply the logic that this PR applies in process_weights_after_loading instead of needing to change the modeling file?, if not, then why not?

No we cannot do in process_weights_after_loading because we need to first load the weights into w13_weights tensors and then it can be processed later, here we are getting all gate, up and down tensors as separate and then we are fusing them and loading the correct weights.

qc = getattr(self.config, "quantization_config", None)
group0 = (qc or {}).get("config_groups", {}).get("group_0", {})
w = group0.get("weights") or {}
ia = group0.get("input_activations") or {}
is_w4a8 = (w.get("num_bits") == 4) and (ia.get("num_bits") == 8)
Copy link
Contributor

@nikhil-arm nikhil-arm Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have an api to check this here:

Will it be a good idea to re-use this api?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion , but no we cant use that since it’s private and tied to CompressedTensorsConfig
Here we are just reading the flags from the config file and making the decision, to keep it simpler


# Map per-expert unfused (gate|up|down) → fused MoE params via FusedMoE loader
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_local_experts,
)

for name, weight in weights:
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue

# W4A8 per-expert unfused mapping
if is_w4a8:
handled, target = self.load_per_expert_unfused_w4a8(
name, weight, params_dict, expert_params_mapping
)
if handled:
if target:
loaded_params.add(target)
continue

if ".w13_weight" in name:
# Handle MLP gate and up projection weights
# Extract gate and up projection parts
Expand Down Expand Up @@ -591,12 +735,15 @@ def _load_weights_other(
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, weight)
else:
weight_loader(param, weight, shard_id)
loaded_params.add(name)
break
else:
# Handle all other weights with potential renaming
Expand All @@ -605,7 +752,7 @@ def _load_weights_other(
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, weight)
loaded_params.add(name)
loaded_params.add(name)
return loaded_params

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Expand Down Expand Up @@ -635,6 +782,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if hasattr(self.config, "quantization_config")
else None
)

if quant_method == "mxfp4":
return self._load_weights_mxfp4(
ep_rank_end,
Expand All @@ -657,11 +805,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
is_3d_moe_weight: bool = True
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
packed_modules_mapping = {
"qkv": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
".self_attn.": ".attn.",
".qkv.": ".qkv_proj.",
".mlp.experts.experts.": ".mlp.experts.",
},
orig_to_new_suffix={
".embed_tokens.weight": ".embedding.weight",
Expand Down