Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,25 @@ def _is_col_major(x: torch.Tensor) -> bool:
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m


def _wrap_parameter_or_copy(layer: torch.nn.Module, name: str,
Copy link
Contributor

@kylesayrs kylesayrs Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of torch28 (and torch27 using this pr) torch compile supports parameter subclasses. Given this, all that should be required is that the newly (maybe padded) weight is updated, a new Parameter class need not be created.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. In fact, the second branch in this code never gets triggered. All that is needed is to clean up fp8.py from param = Parameter(...) statement that drop the weight loaders.

weight: torch.Tensor):
layer_weight = getattr(layer, name)
if isinstance(layer_weight, Parameter):
# If it is already a Parameter, we assume it is the right shape
# directly copy it from weight to keep pointer unchanged in CUDA Graph
layer_weight.copy_(weight)
else:
# torch.compile() cannot use Parameter subclasses.
# but these weights are already Parameter
# so this can be compatible with torch.compile
param = Parameter(weight, requires_grad=False)
if hasattr(layer_weight, "weight_loader"):
# keep the weight_loader attribute to make sure
# the weight can be loaded correctly in weight update
param.weight_loader = layer_weight.weight_loader
setattr(layer, name, param)


class Fp8Config(QuantizationConfig):
"""Config class for FP8."""

Expand Down Expand Up @@ -387,10 +406,9 @@ def process_weights_after_loading(self, layer: Module) -> None:

weight = self._maybe_pad_weight(weight)

# Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv,
requires_grad=False)
_wrap_parameter_or_copy(layer, "weight", weight)
_wrap_parameter_or_copy(layer, "weight_scale_inv",
weight_scale_inv)

# If checkpoint not serialized fp8, quantize the weights.
elif not self.quant_config.is_checkpoint_fp8_serialized:
Expand Down Expand Up @@ -740,13 +758,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
w2_weight = layer.w2_weight
w2_weight_scale_inv = layer.w2_weight_scale_inv

# torch.compile() cannot use Parameter subclasses.
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
requires_grad=False)
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
requires_grad=False)
_wrap_parameter_or_copy(layer, "w13_weight", w13_weight)
_wrap_parameter_or_copy(layer, "w13_weight_scale_inv",
w13_weight_scale_inv)
_wrap_parameter_or_copy(layer, "w2_weight", w2_weight)
_wrap_parameter_or_copy(layer, "w2_weight_scale_inv",
w2_weight_scale_inv)
Comment on lines +764 to +766
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In the else branch of the conditional starting at line 733, the variables w2_weight and w2_weight_scale_inv are assigned torch.nn.Parameter objects on lines 755-756, instead of their underlying tensor data. Consequently, these calls to _wrap_parameter_or_copy become no-ops due to self-copying, which is likely not the intended behavior and can lead to incorrect weight updates.

This is inconsistent with how w13_weight is handled in the same block, which correctly uses .data. To fix this, you should modify lines 755-756 to extract the tensor data, like so:

# In vllm/model_executor/layers/quantization/fp8.py, lines 755-756
w2_weight = layer.w2_weight.data
w2_weight_scale_inv = layer.w2_weight_scale_inv.data

Since the fix is outside the diff, I'm placing this comment here to highlight this critical issue.


if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
Expand Down
7 changes: 7 additions & 0 deletions vllm/model_executor/layers/quantization/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor:
f"{self.__class__.__name__}.apply should not be called.")

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# update weights may miss these attributes, we create it if not present
if not hasattr(layer, "q_scale"):
assert not hasattr(layer, "k_scale")
assert not hasattr(layer, "v_scale")
assert not hasattr(layer, "prob_scale")
Comment on lines +53 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These assertions could make the code brittle. If another part of the codebase modifies these attributes partially (e.g., removes q_scale but not k_scale), these assertions will fail. The main goal here is to ensure all weights are present if any are missing. Simply checking for q_scale and then creating all weights is sufficient and more robust against unforeseen state changes.

self.create_weights(layer)

# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to
Expand Down