From 596a494c7f53af5307d7f3f576b7ae962312df68 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 24 Oct 2025 17:21:08 -0400 Subject: [PATCH 1/4] fix scales calculation Signed-off-by: Matthew Bonanni --- vllm/attention/layer.py | 49 +++++++++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index bd38f3679ece..74ac5e99527e 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -637,6 +637,41 @@ def __init__( block_size = 16 calculate_kv_scales = False self.kv_cache_dtype = kv_cache_dtype + self.calculate_kv_scales = calculate_kv_scales + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep q/k/v_scale on host (cpu) memory for attention + # backends that require the scales to be on host instead of on device. + # e.g. Flashinfer + self._q_scale_float = 1.0 + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + # The output scale on host memory. This should be the input scale of + # the quant op after this attention layer. + self._o_scale_float: float | None = None + quant_method = ( + quant_config.get_quant_method(self, prefix=prefix) if quant_config else None + ) + if quant_method is not None and not isinstance( + quant_method, UnquantizedLinearMethod + ): + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError( + "fp8_e5m2 kv-cache is not supported with fp8 checkpoints." + ) + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) dtype = torch.get_default_dtype() self.attn_backend = get_attn_backend( @@ -685,20 +720,6 @@ def __init__( ) ] - # Align with Attention's scale attributes for MLA backends. - - self.calculate_kv_scales = calculate_kv_scales - self._k_scale = torch.tensor(1.0, dtype=torch.float32) - self._v_scale = torch.tensor(1.0, dtype=torch.float32) - self._q_scale = torch.tensor(1.0, dtype=torch.float32) - self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - - # Host-side mirrors used by some attention backends - self._q_scale_float = 1.0 - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 - self._o_scale_float: float | None = None - self.use_sparse = use_sparse # Initialize q/k/v range constants. From d4e2794e41736a25e31396f7ec53a79223c9b29e Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 24 Oct 2025 17:27:06 -0400 Subject: [PATCH 2/4] copy comments over Signed-off-by: Matthew Bonanni --- vllm/attention/layer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 74ac5e99527e..f521d250d72e 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -636,10 +636,18 @@ def __init__( kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False + + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. self.kv_cache_dtype = kv_cache_dtype self.calculate_kv_scales = calculate_kv_scales self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32) + # FlashAttn doesn't support quantizing the kv-cache only + # but requires q to be quantized as well. self._q_scale = torch.tensor(1.0, dtype=torch.float32) self._prob_scale = torch.tensor(1.0, dtype=torch.float32) From 59d47568c62ef541c50c4b58355d6f1a65d71c7c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 24 Oct 2025 20:20:25 -0400 Subject: [PATCH 3/4] refactor to helper Signed-off-by: Matthew Bonanni --- vllm/attention/layer.py | 158 ++++++++++++++++++---------------------- 1 file changed, 72 insertions(+), 86 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index f521d250d72e..6bfe65b6fc36 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -123,6 +123,71 @@ def maybe_get_vit_flash_attn_backend( return attn_backend, flash_attn_varlen_func +def _init_kv_cache_quant( + layer: nn.Module, + quant_config: QuantizationConfig | None, + prefix: str, + kv_cache_dtype: str, + calculate_kv_scales: bool, +) -> None: + """Initializes KV cache scaling factors and quantization method. + + This helper function sets up the KV cache quantization attributes that are + shared between Attention and MLAAttention layers. It initializes scale + tensors for query, key, value, and probability, and configures the + quantization method if applicable. + + Args: + layer: The attention layer instance to initialize. + quant_config: Optional quantization configuration. + prefix: Layer name prefix for quantization method lookup. + kv_cache_dtype: The KV cache data type string. + calculate_kv_scales: Whether to calculate KV scales dynamically. + """ + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. + layer.kv_cache_dtype = kv_cache_dtype + layer.calculate_kv_scales = calculate_kv_scales + layer._k_scale = torch.tensor(1.0, dtype=torch.float32) + layer._v_scale = torch.tensor(1.0, dtype=torch.float32) + # FlashAttn doesn't support quantizing the kv-cache only + # but requires q to be quantized as well. + layer._q_scale = torch.tensor(1.0, dtype=torch.float32) + layer._prob_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep q/k/v_scale on host (cpu) memory for attention + # backends that require the scales to be on host instead of on device. + # e.g. Flashinfer + layer._q_scale_float = 1.0 + layer._k_scale_float = 1.0 + layer._v_scale_float = 1.0 + + # The output scale on host memory. This should be the input scale of + # the quant op after this attention layer. + layer._o_scale_float = None + + quant_method = ( + quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None + ) + if quant_method is not None and not isinstance( + quant_method, UnquantizedLinearMethod + ): + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + layer.quant_method = quant_method + layer.quant_method.create_weights(layer) + + class Attention(nn.Module, AttentionLayerBase): """Attention layer. @@ -184,30 +249,10 @@ def __init__( f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})" ) - # The default k/v_scale is set to 1.0. This is ignored - # when kv-cache is not fp8, and should be used with - # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we - # expect the pre-quantized k/v_scale to be loaded along - # with the model weights. - self.kv_cache_dtype = kv_cache_dtype - self.calculate_kv_scales = calculate_kv_scales - self._k_scale = torch.tensor(1.0, dtype=torch.float32) - self._v_scale = torch.tensor(1.0, dtype=torch.float32) - # FlashAttn doesn't support quantizing the kv-cache only - # but requires q to be quantized as well. - self._q_scale = torch.tensor(1.0, dtype=torch.float32) - self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - - # We also keep q/k/v_scale on host (cpu) memory for attention - # backends that require the scales to be on host instead of on device. - # e.g. Flashinfer - self._q_scale_float = 1.0 - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 - - # The output scale on host memory. This should be the input scale of - # the quant op after this attention layer. - self._o_scale_float: float | None = None + # Initialize KV cache quantization attributes + _init_kv_cache_quant( + self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales + ) self.num_heads = num_heads self.head_size = head_size @@ -215,26 +260,6 @@ def __init__( self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None - quant_method = ( - quant_config.get_quant_method(self, prefix=prefix) if quant_config else None - ) - if quant_method is not None and not isinstance( - quant_method, UnquantizedLinearMethod - ): - assert isinstance(quant_method, BaseKVCacheMethod) - # TODO (mgoin): kv cache dtype should be specified in the FP8 - # checkpoint config and become the "auto" behavior - if self.kv_cache_dtype == "fp8_e5m2": - raise ValueError( - "fp8_e5m2 kv-cache is not supported with fp8 checkpoints." - ) - # If quantization is enabled, we make "k_scale" and "v_scale" - # parameters so that it can be loaded from the model checkpoint. - # The k/v_scale will then be converted back to native float32 - # values after weight loading. - self.quant_method = quant_method - self.quant_method.create_weights(self) - # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() @@ -637,49 +662,10 @@ def __init__( block_size = 16 calculate_kv_scales = False - # The default k/v_scale is set to 1.0. This is ignored - # when kv-cache is not fp8, and should be used with - # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we - # expect the pre-quantized k/v_scale to be loaded along - # with the model weights. - self.kv_cache_dtype = kv_cache_dtype - self.calculate_kv_scales = calculate_kv_scales - self._k_scale = torch.tensor(1.0, dtype=torch.float32) - self._v_scale = torch.tensor(1.0, dtype=torch.float32) - # FlashAttn doesn't support quantizing the kv-cache only - # but requires q to be quantized as well. - self._q_scale = torch.tensor(1.0, dtype=torch.float32) - self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - - # We also keep q/k/v_scale on host (cpu) memory for attention - # backends that require the scales to be on host instead of on device. - # e.g. Flashinfer - self._q_scale_float = 1.0 - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 - - # The output scale on host memory. This should be the input scale of - # the quant op after this attention layer. - self._o_scale_float: float | None = None - quant_method = ( - quant_config.get_quant_method(self, prefix=prefix) if quant_config else None + # Initialize KV cache quantization attributes + _init_kv_cache_quant( + self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales ) - if quant_method is not None and not isinstance( - quant_method, UnquantizedLinearMethod - ): - assert isinstance(quant_method, BaseKVCacheMethod) - # TODO (mgoin): kv cache dtype should be specified in the FP8 - # checkpoint config and become the "auto" behavior - if self.kv_cache_dtype == "fp8_e5m2": - raise ValueError( - "fp8_e5m2 kv-cache is not supported with fp8 checkpoints." - ) - # If quantization is enabled, we make "k_scale" and "v_scale" - # parameters so that it can be loaded from the model checkpoint. - # The k/v_scale will then be converted back to native float32 - # values after weight loading. - self.quant_method = quant_method - self.quant_method.create_weights(self) dtype = torch.get_default_dtype() self.attn_backend = get_attn_backend( From 801b65ca15861f06138ea2226ea85b01a5d6e500 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 24 Oct 2025 20:20:55 -0400 Subject: [PATCH 4/4] remove comment Signed-off-by: Matthew Bonanni --- vllm/attention/layer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 6bfe65b6fc36..7544daa3aff7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -153,8 +153,6 @@ def _init_kv_cache_quant( layer.calculate_kv_scales = calculate_kv_scales layer._k_scale = torch.tensor(1.0, dtype=torch.float32) layer._v_scale = torch.tensor(1.0, dtype=torch.float32) - # FlashAttn doesn't support quantizing the kv-cache only - # but requires q to be quantized as well. layer._q_scale = torch.tensor(1.0, dtype=torch.float32) layer._prob_scale = torch.tensor(1.0, dtype=torch.float32)