diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index bd38f3679ece..7544daa3aff7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -123,6 +123,69 @@ def maybe_get_vit_flash_attn_backend( return attn_backend, flash_attn_varlen_func +def _init_kv_cache_quant( + layer: nn.Module, + quant_config: QuantizationConfig | None, + prefix: str, + kv_cache_dtype: str, + calculate_kv_scales: bool, +) -> None: + """Initializes KV cache scaling factors and quantization method. + + This helper function sets up the KV cache quantization attributes that are + shared between Attention and MLAAttention layers. It initializes scale + tensors for query, key, value, and probability, and configures the + quantization method if applicable. + + Args: + layer: The attention layer instance to initialize. + quant_config: Optional quantization configuration. + prefix: Layer name prefix for quantization method lookup. + kv_cache_dtype: The KV cache data type string. + calculate_kv_scales: Whether to calculate KV scales dynamically. + """ + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. + layer.kv_cache_dtype = kv_cache_dtype + layer.calculate_kv_scales = calculate_kv_scales + layer._k_scale = torch.tensor(1.0, dtype=torch.float32) + layer._v_scale = torch.tensor(1.0, dtype=torch.float32) + layer._q_scale = torch.tensor(1.0, dtype=torch.float32) + layer._prob_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep q/k/v_scale on host (cpu) memory for attention + # backends that require the scales to be on host instead of on device. + # e.g. Flashinfer + layer._q_scale_float = 1.0 + layer._k_scale_float = 1.0 + layer._v_scale_float = 1.0 + + # The output scale on host memory. This should be the input scale of + # the quant op after this attention layer. + layer._o_scale_float = None + + quant_method = ( + quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None + ) + if quant_method is not None and not isinstance( + quant_method, UnquantizedLinearMethod + ): + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + layer.quant_method = quant_method + layer.quant_method.create_weights(layer) + + class Attention(nn.Module, AttentionLayerBase): """Attention layer. @@ -184,30 +247,10 @@ def __init__( f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})" ) - # The default k/v_scale is set to 1.0. This is ignored - # when kv-cache is not fp8, and should be used with - # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we - # expect the pre-quantized k/v_scale to be loaded along - # with the model weights. - self.kv_cache_dtype = kv_cache_dtype - self.calculate_kv_scales = calculate_kv_scales - self._k_scale = torch.tensor(1.0, dtype=torch.float32) - self._v_scale = torch.tensor(1.0, dtype=torch.float32) - # FlashAttn doesn't support quantizing the kv-cache only - # but requires q to be quantized as well. - self._q_scale = torch.tensor(1.0, dtype=torch.float32) - self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - - # We also keep q/k/v_scale on host (cpu) memory for attention - # backends that require the scales to be on host instead of on device. - # e.g. Flashinfer - self._q_scale_float = 1.0 - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 - - # The output scale on host memory. This should be the input scale of - # the quant op after this attention layer. - self._o_scale_float: float | None = None + # Initialize KV cache quantization attributes + _init_kv_cache_quant( + self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales + ) self.num_heads = num_heads self.head_size = head_size @@ -215,26 +258,6 @@ def __init__( self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None - quant_method = ( - quant_config.get_quant_method(self, prefix=prefix) if quant_config else None - ) - if quant_method is not None and not isinstance( - quant_method, UnquantizedLinearMethod - ): - assert isinstance(quant_method, BaseKVCacheMethod) - # TODO (mgoin): kv cache dtype should be specified in the FP8 - # checkpoint config and become the "auto" behavior - if self.kv_cache_dtype == "fp8_e5m2": - raise ValueError( - "fp8_e5m2 kv-cache is not supported with fp8 checkpoints." - ) - # If quantization is enabled, we make "k_scale" and "v_scale" - # parameters so that it can be loaded from the model checkpoint. - # The k/v_scale will then be converted back to native float32 - # values after weight loading. - self.quant_method = quant_method - self.quant_method.create_weights(self) - # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() @@ -636,7 +659,11 @@ def __init__( kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False - self.kv_cache_dtype = kv_cache_dtype + + # Initialize KV cache quantization attributes + _init_kv_cache_quant( + self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales + ) dtype = torch.get_default_dtype() self.attn_backend = get_attn_backend( @@ -685,20 +712,6 @@ def __init__( ) ] - # Align with Attention's scale attributes for MLA backends. - - self.calculate_kv_scales = calculate_kv_scales - self._k_scale = torch.tensor(1.0, dtype=torch.float32) - self._v_scale = torch.tensor(1.0, dtype=torch.float32) - self._q_scale = torch.tensor(1.0, dtype=torch.float32) - self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - - # Host-side mirrors used by some attention backends - self._q_scale_float = 1.0 - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 - self._o_scale_float: float | None = None - self.use_sparse = use_sparse # Initialize q/k/v range constants.