Skip to content

Commit a99564a

Browse files
[Attention] Add missing kv cache scale setup (#27490)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
1 parent 4c5f632 commit a99564a

File tree

1 file changed

+72
-59
lines changed

1 file changed

+72
-59
lines changed

vllm/attention/layer.py

Lines changed: 72 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,69 @@ def maybe_get_vit_flash_attn_backend(
123123
return attn_backend, flash_attn_varlen_func
124124

125125

126+
def _init_kv_cache_quant(
127+
layer: nn.Module,
128+
quant_config: QuantizationConfig | None,
129+
prefix: str,
130+
kv_cache_dtype: str,
131+
calculate_kv_scales: bool,
132+
) -> None:
133+
"""Initializes KV cache scaling factors and quantization method.
134+
135+
This helper function sets up the KV cache quantization attributes that are
136+
shared between Attention and MLAAttention layers. It initializes scale
137+
tensors for query, key, value, and probability, and configures the
138+
quantization method if applicable.
139+
140+
Args:
141+
layer: The attention layer instance to initialize.
142+
quant_config: Optional quantization configuration.
143+
prefix: Layer name prefix for quantization method lookup.
144+
kv_cache_dtype: The KV cache data type string.
145+
calculate_kv_scales: Whether to calculate KV scales dynamically.
146+
"""
147+
# The default k/v_scale is set to 1.0. This is ignored
148+
# when kv-cache is not fp8, and should be used with
149+
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
150+
# expect the pre-quantized k/v_scale to be loaded along
151+
# with the model weights.
152+
layer.kv_cache_dtype = kv_cache_dtype
153+
layer.calculate_kv_scales = calculate_kv_scales
154+
layer._k_scale = torch.tensor(1.0, dtype=torch.float32)
155+
layer._v_scale = torch.tensor(1.0, dtype=torch.float32)
156+
layer._q_scale = torch.tensor(1.0, dtype=torch.float32)
157+
layer._prob_scale = torch.tensor(1.0, dtype=torch.float32)
158+
159+
# We also keep q/k/v_scale on host (cpu) memory for attention
160+
# backends that require the scales to be on host instead of on device.
161+
# e.g. Flashinfer
162+
layer._q_scale_float = 1.0
163+
layer._k_scale_float = 1.0
164+
layer._v_scale_float = 1.0
165+
166+
# The output scale on host memory. This should be the input scale of
167+
# the quant op after this attention layer.
168+
layer._o_scale_float = None
169+
170+
quant_method = (
171+
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
172+
)
173+
if quant_method is not None and not isinstance(
174+
quant_method, UnquantizedLinearMethod
175+
):
176+
assert isinstance(quant_method, BaseKVCacheMethod)
177+
# TODO (mgoin): kv cache dtype should be specified in the FP8
178+
# checkpoint config and become the "auto" behavior
179+
if kv_cache_dtype == "fp8_e5m2":
180+
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
181+
# If quantization is enabled, we make "k_scale" and "v_scale"
182+
# parameters so that it can be loaded from the model checkpoint.
183+
# The k/v_scale will then be converted back to native float32
184+
# values after weight loading.
185+
layer.quant_method = quant_method
186+
layer.quant_method.create_weights(layer)
187+
188+
126189
class Attention(nn.Module, AttentionLayerBase):
127190
"""Attention layer.
128191
@@ -184,57 +247,17 @@ def __init__(
184247
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
185248
)
186249

187-
# The default k/v_scale is set to 1.0. This is ignored
188-
# when kv-cache is not fp8, and should be used with
189-
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
190-
# expect the pre-quantized k/v_scale to be loaded along
191-
# with the model weights.
192-
self.kv_cache_dtype = kv_cache_dtype
193-
self.calculate_kv_scales = calculate_kv_scales
194-
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
195-
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
196-
# FlashAttn doesn't support quantizing the kv-cache only
197-
# but requires q to be quantized as well.
198-
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
199-
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
200-
201-
# We also keep q/k/v_scale on host (cpu) memory for attention
202-
# backends that require the scales to be on host instead of on device.
203-
# e.g. Flashinfer
204-
self._q_scale_float = 1.0
205-
self._k_scale_float = 1.0
206-
self._v_scale_float = 1.0
207-
208-
# The output scale on host memory. This should be the input scale of
209-
# the quant op after this attention layer.
210-
self._o_scale_float: float | None = None
250+
# Initialize KV cache quantization attributes
251+
_init_kv_cache_quant(
252+
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
253+
)
211254

212255
self.num_heads = num_heads
213256
self.head_size = head_size
214257
self.num_kv_heads = num_kv_heads
215258
self.sliding_window = sliding_window
216259
self.has_sink = extra_impl_args.get("sinks") is not None
217260

218-
quant_method = (
219-
quant_config.get_quant_method(self, prefix=prefix) if quant_config else None
220-
)
221-
if quant_method is not None and not isinstance(
222-
quant_method, UnquantizedLinearMethod
223-
):
224-
assert isinstance(quant_method, BaseKVCacheMethod)
225-
# TODO (mgoin): kv cache dtype should be specified in the FP8
226-
# checkpoint config and become the "auto" behavior
227-
if self.kv_cache_dtype == "fp8_e5m2":
228-
raise ValueError(
229-
"fp8_e5m2 kv-cache is not supported with fp8 checkpoints."
230-
)
231-
# If quantization is enabled, we make "k_scale" and "v_scale"
232-
# parameters so that it can be loaded from the model checkpoint.
233-
# The k/v_scale will then be converted back to native float32
234-
# values after weight loading.
235-
self.quant_method = quant_method
236-
self.quant_method.create_weights(self)
237-
238261
# During model initialization, the default dtype is set as the model
239262
# weight and activation dtype.
240263
dtype = torch.get_default_dtype()
@@ -636,7 +659,11 @@ def __init__(
636659
kv_cache_dtype = "auto"
637660
block_size = 16
638661
calculate_kv_scales = False
639-
self.kv_cache_dtype = kv_cache_dtype
662+
663+
# Initialize KV cache quantization attributes
664+
_init_kv_cache_quant(
665+
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
666+
)
640667

641668
dtype = torch.get_default_dtype()
642669
self.attn_backend = get_attn_backend(
@@ -685,20 +712,6 @@ def __init__(
685712
)
686713
]
687714

688-
# Align with Attention's scale attributes for MLA backends.
689-
690-
self.calculate_kv_scales = calculate_kv_scales
691-
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
692-
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
693-
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
694-
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
695-
696-
# Host-side mirrors used by some attention backends
697-
self._q_scale_float = 1.0
698-
self._k_scale_float = 1.0
699-
self._v_scale_float = 1.0
700-
self._o_scale_float: float | None = None
701-
702715
self.use_sparse = use_sparse
703716

704717
# Initialize q/k/v range constants.

0 commit comments

Comments
 (0)