Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 72 additions & 59 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,69 @@ def maybe_get_vit_flash_attn_backend(
return attn_backend, flash_attn_varlen_func


def _init_kv_cache_quant(
layer: nn.Module,
quant_config: QuantizationConfig | None,
prefix: str,
kv_cache_dtype: str,
calculate_kv_scales: bool,
) -> None:
"""Initializes KV cache scaling factors and quantization method.

This helper function sets up the KV cache quantization attributes that are
shared between Attention and MLAAttention layers. It initializes scale
tensors for query, key, value, and probability, and configures the
quantization method if applicable.

Args:
layer: The attention layer instance to initialize.
quant_config: Optional quantization configuration.
prefix: Layer name prefix for quantization method lookup.
kv_cache_dtype: The KV cache data type string.
calculate_kv_scales: Whether to calculate KV scales dynamically.
"""
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
layer.kv_cache_dtype = kv_cache_dtype
layer.calculate_kv_scales = calculate_kv_scales
layer._k_scale = torch.tensor(1.0, dtype=torch.float32)
layer._v_scale = torch.tensor(1.0, dtype=torch.float32)
layer._q_scale = torch.tensor(1.0, dtype=torch.float32)
layer._prob_scale = torch.tensor(1.0, dtype=torch.float32)

# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
layer._q_scale_float = 1.0
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0

# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
layer._o_scale_float = None

quant_method = (
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
)
if quant_method is not None and not isinstance(
quant_method, UnquantizedLinearMethod
):
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
layer.quant_method = quant_method
layer.quant_method.create_weights(layer)


class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.

Expand Down Expand Up @@ -184,57 +247,17 @@ def __init__(
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
)

# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)

# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0

# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
self._o_scale_float: float | None = None
# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
)

self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None

quant_method = (
quant_config.get_quant_method(self, prefix=prefix) if quant_config else None
)
if quant_method is not None and not isinstance(
quant_method, UnquantizedLinearMethod
):
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError(
"fp8_e5m2 kv-cache is not supported with fp8 checkpoints."
)
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)

# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
Expand Down Expand Up @@ -636,7 +659,11 @@ def __init__(
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
self.kv_cache_dtype = kv_cache_dtype

# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
)

dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
Expand Down Expand Up @@ -685,20 +712,6 @@ def __init__(
)
]

# Align with Attention's scale attributes for MLA backends.

self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)

# Host-side mirrors used by some attention backends
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
self._o_scale_float: float | None = None

self.use_sparse = use_sparse

# Initialize q/k/v range constants.
Expand Down