From cb707077b0df8bb02adadc7eda63c97a8c6d6c76 Mon Sep 17 00:00:00 2001 From: weixiao-huang Date: Sun, 7 Sep 2025 14:56:43 +0800 Subject: [PATCH] [BugFix] use _wrap_parameter_or_copy instead of using Parameter and add missing scale attributes Signed-off-by: huangweixiao --- .../model_executor/layers/quantization/fp8.py | 40 ++++++++++++++----- .../layers/quantization/kv_cache.py | 7 ++++ 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 65e0b7062153..46f1885ccd27 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -65,6 +65,25 @@ def _is_col_major(x: torch.Tensor) -> bool: return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m +def _wrap_parameter_or_copy(layer: torch.nn.Module, name: str, + weight: torch.Tensor): + layer_weight = getattr(layer, name) + if isinstance(layer_weight, Parameter): + # If it is already a Parameter, we assume it is the right shape + # directly copy it from weight to keep pointer unchanged in CUDA Graph + layer_weight.copy_(weight) + else: + # torch.compile() cannot use Parameter subclasses. + # but these weights are already Parameter + # so this can be compatible with torch.compile + param = Parameter(weight, requires_grad=False) + if hasattr(layer_weight, "weight_loader"): + # keep the weight_loader attribute to make sure + # the weight can be loaded correctly in weight update + param.weight_loader = layer_weight.weight_loader + setattr(layer, name, param) + + class Fp8Config(QuantizationConfig): """Config class for FP8.""" @@ -387,10 +406,9 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = self._maybe_pad_weight(weight) - # Torch.compile cannot use Parameter subclasses. - layer.weight = Parameter(weight, requires_grad=False) - layer.weight_scale_inv = Parameter(weight_scale_inv, - requires_grad=False) + _wrap_parameter_or_copy(layer, "weight", weight) + _wrap_parameter_or_copy(layer, "weight_scale_inv", + weight_scale_inv) # If checkpoint not serialized fp8, quantize the weights. elif not self.quant_config.is_checkpoint_fp8_serialized: @@ -740,13 +758,13 @@ def process_weights_after_loading(self, layer: Module) -> None: w2_weight = layer.w2_weight w2_weight_scale_inv = layer.w2_weight_scale_inv - # torch.compile() cannot use Parameter subclasses. - layer.w13_weight = Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv, - requires_grad=False) - layer.w2_weight = Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, - requires_grad=False) + _wrap_parameter_or_copy(layer, "w13_weight", w13_weight) + _wrap_parameter_or_copy(layer, "w13_weight_scale_inv", + w13_weight_scale_inv) + _wrap_parameter_or_copy(layer, "w2_weight", w2_weight) + _wrap_parameter_or_copy(layer, "w2_weight_scale_inv", + w2_weight_scale_inv) + if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index e5604670fb4c..0483e9b324ff 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -48,6 +48,13 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor: f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # update weights may miss these attributes, we create it if not present + if not hasattr(layer, "q_scale"): + assert not hasattr(layer, "k_scale") + assert not hasattr(layer, "v_scale") + assert not hasattr(layer, "prob_scale") + self.create_weights(layer) + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. # No need to process kv scales after loading if we are going to