-
-
Notifications
You must be signed in to change notification settings - Fork 12.1k
Support compressed-tensors W4A8 MoE checkpoints in GptOssModel weight loader for CPU #29315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -495,6 +495,125 @@ def _load_weights_mxfp4( | |||
| loaded_params.add(name) | ||||
| return loaded_params | ||||
|
|
||||
| def load_per_expert_unfused_w4a8( | ||||
| self, | ||||
| nm: str, | ||||
| weight: torch.Tensor, | ||||
| params_dict: dict[str, torch.nn.Parameter], | ||||
| expert_params_mapping: list[tuple[str, str, int, str]], | ||||
| ) -> tuple[bool, str | None]: | ||||
| """Try to map/load per-expert unfused weights/bias for W4A8. | ||||
| Returns (handled, target_param_name).""" | ||||
| if "mlp.experts." not in nm: | ||||
| return (False, None) | ||||
| if not any(x in nm for x in (".gate_proj", ".up_proj", ".down_proj")): | ||||
| return (False, None) | ||||
|
|
||||
| suffix = None | ||||
| for suf in (".weight", ".bias", ".weight_scale", ".input_scale"): | ||||
| if nm.endswith(suf): | ||||
| suffix = suf.lstrip(".") | ||||
| break | ||||
| if suffix is None: | ||||
| return (False, None) | ||||
|
|
||||
| try: | ||||
| layer_pfx, _ = nm.split("mlp.experts.", 1) | ||||
| layer_pfx = layer_pfx + "mlp.experts." | ||||
| except ValueError: | ||||
| return (False, None) | ||||
|
|
||||
| for param_prefix, weight_prefix, expert_id, shard_id in expert_params_mapping: | ||||
| if weight_prefix not in nm: | ||||
| continue | ||||
|
|
||||
| # choose fused target | ||||
| if param_prefix.endswith("w13_"): | ||||
| target_map = { | ||||
| "weight": "w13_weight", | ||||
| "weight_scale": "w13_weight_scale", | ||||
| "bias": "w13_bias", | ||||
| "input_scale": "w13_input_scale", | ||||
| } | ||||
| elif param_prefix.endswith("w2_"): | ||||
| target_map = { | ||||
| "weight": "w2_weight", | ||||
| "weight_scale": "w2_weight_scale", | ||||
| "bias": "w2_bias", | ||||
| "input_scale": "w2_input_scale", | ||||
| } | ||||
| else: | ||||
| continue | ||||
|
|
||||
| tgt_suffix = target_map.get(suffix) | ||||
| if not tgt_suffix: | ||||
| continue | ||||
| target = layer_pfx + tgt_suffix | ||||
| if target not in params_dict: | ||||
| continue | ||||
|
|
||||
| param = params_dict[target] | ||||
| wl = getattr(param, "weight_loader", None) | ||||
|
|
||||
| if suffix == "bias": | ||||
| if callable(wl) and wl is not default_weight_loader: | ||||
| ok = wl( | ||||
| param, | ||||
| weight, | ||||
| nm, | ||||
| shard_id=shard_id, | ||||
| expert_id=expert_id, | ||||
| return_success=True, | ||||
| ) | ||||
| if ok: | ||||
| return (True, target) | ||||
|
|
||||
| inter_size = self.config.intermediate_size | ||||
| src = weight | ||||
| if src.dim() == 2: | ||||
| src = src.squeeze(0) if src.size(0) == 1 else src[expert_id] | ||||
|
|
||||
| if target.endswith("w13_bias"): | ||||
| if "gate_proj" in weight_prefix: | ||||
| col_slice = slice(0, inter_size) | ||||
| elif "up_proj" in weight_prefix: | ||||
| col_slice = slice(inter_size, 2 * inter_size) | ||||
| else: | ||||
| return (False, None) | ||||
|
|
||||
| if param.data.dim() == 2: | ||||
| param.data[expert_id, col_slice].copy_(src) | ||||
| else: | ||||
| param.data[col_slice].copy_(src) | ||||
|
|
||||
| elif target.endswith("w2_bias"): | ||||
| if param.data.dim() == 2: | ||||
| param.data[expert_id, :].copy_(src) | ||||
| else: | ||||
| param.data.copy_(src) | ||||
| else: | ||||
| return (False, None) | ||||
|
|
||||
| return (True, target) | ||||
|
|
||||
| # Weights/scales path | ||||
| if callable(wl) and wl is not default_weight_loader: | ||||
| ok = wl( | ||||
| param, | ||||
| weight, | ||||
| nm, | ||||
| shard_id=shard_id, | ||||
| expert_id=expert_id, | ||||
| return_success=True, | ||||
| ) | ||||
| if ok: | ||||
| return (True, target) | ||||
| else: | ||||
| default_weight_loader(param, weight) | ||||
| return (True, target) | ||||
|
|
||||
| return (False, None) | ||||
|
|
||||
| def _load_weights_other( | ||||
| self, | ||||
| ep_rank_end: int, | ||||
|
|
@@ -525,11 +644,36 @@ def _load_weights_other( | |||
| tp_rank_start = tp_rank * per_rank_intermediate_size | ||||
| tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) | ||||
|
|
||||
| # W4A8 detection (int4 weights, int8 activations) | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. generally speaking, can we apply the logic that this PR applies in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No we cannot do in |
||||
| qc = getattr(self.config, "quantization_config", None) | ||||
| group0 = (qc or {}).get("config_groups", {}).get("group_0", {}) | ||||
| w = group0.get("weights") or {} | ||||
| ia = group0.get("input_activations") or {} | ||||
| is_w4a8 = (w.get("num_bits") == 4) and (ia.get("num_bits") == 8) | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have an api to check this here: vllm/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py Line 382 in 70d5953
Will it be a good idea to re-use this api?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good suggestion , but no we cant use that since it’s private and tied to CompressedTensorsConfig |
||||
|
|
||||
| # Map per-expert unfused (gate|up|down) → fused MoE params via FusedMoE loader | ||||
| expert_params_mapping = FusedMoE.make_expert_params_mapping( | ||||
| ckpt_gate_proj_name="gate_proj", | ||||
| ckpt_down_proj_name="down_proj", | ||||
| ckpt_up_proj_name="up_proj", | ||||
| num_experts=self.config.num_local_experts, | ||||
| ) | ||||
|
|
||||
| for name, weight in weights: | ||||
| # Skip layers on other devices. | ||||
| if is_pp_missing_parameter(name, self): | ||||
| continue | ||||
|
|
||||
| # W4A8 per-expert unfused mapping | ||||
| if is_w4a8: | ||||
| handled, target = self.load_per_expert_unfused_w4a8( | ||||
| name, weight, params_dict, expert_params_mapping | ||||
| ) | ||||
| if handled: | ||||
| if target: | ||||
| loaded_params.add(target) | ||||
| continue | ||||
|
|
||||
| if ".w13_weight" in name: | ||||
| # Handle MLP gate and up projection weights | ||||
| # Extract gate and up projection parts | ||||
|
|
@@ -591,12 +735,15 @@ def _load_weights_other( | |||
| if weight_name not in name: | ||||
| continue | ||||
| name = name.replace(weight_name, param_name) | ||||
| if name not in params_dict: | ||||
| continue | ||||
| param = params_dict[name] | ||||
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | ||||
| if weight_loader == default_weight_loader: | ||||
| weight_loader(param, weight) | ||||
| else: | ||||
| weight_loader(param, weight, shard_id) | ||||
| loaded_params.add(name) | ||||
| break | ||||
| else: | ||||
| # Handle all other weights with potential renaming | ||||
|
|
@@ -605,7 +752,7 @@ def _load_weights_other( | |||
| param = params_dict[name] | ||||
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | ||||
| weight_loader(param, weight) | ||||
| loaded_params.add(name) | ||||
| loaded_params.add(name) | ||||
| return loaded_params | ||||
|
|
||||
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | ||||
|
|
@@ -635,6 +782,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |||
| if hasattr(self.config, "quantization_config") | ||||
| else None | ||||
| ) | ||||
|
|
||||
| if quant_method == "mxfp4": | ||||
| return self._load_weights_mxfp4( | ||||
| ep_rank_end, | ||||
|
|
@@ -657,11 +805,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |||
|
|
||||
| class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): | ||||
| is_3d_moe_weight: bool = True | ||||
| packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} | ||||
| packed_modules_mapping = { | ||||
| "qkv": ["q_proj", "k_proj", "v_proj"], | ||||
| "gate_up_proj": ["gate_proj", "up_proj"], | ||||
| } | ||||
|
|
||||
| hf_to_vllm_mapper = WeightsMapper( | ||||
| orig_to_new_substr={ | ||||
| ".self_attn.": ".attn.", | ||||
| ".qkv.": ".qkv_proj.", | ||||
| ".mlp.experts.experts.": ".mlp.experts.", | ||||
| }, | ||||
| orig_to_new_suffix={ | ||||
| ".embed_tokens.weight": ".embedding.weight", | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice for us not to specialize these functions for specific quantization schemes.
why can't the
w4a8be an argument for this function, instead of baking it into the name/impl?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its good to have different weight specific function and then we can call those specific function from load_weights_other, since all other combinations can have different weight loading scheme, we can do that if we have any in future