Skip to content

Commit 98de2f9

Browse files
committed
Add output_scale to new attn backends
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent f2b0d01 commit 98de2f9

File tree

3 files changed

+22
-0
lines changed

3 files changed

+22
-0
lines changed

vllm/attention/backends/dual_chunk_flash_attn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ def forward( # type: ignore
370370
value: torch.Tensor,
371371
kv_cache: torch.Tensor,
372372
attn_metadata: DualChunkFlashAttentionMetadata,
373+
output: Optional[torch.Tensor] = None,
374+
output_scale: Optional[torch.Tensor] = None,
373375
) -> torch.Tensor:
374376
"""Forward pass with DualChunkFlashAttention.
375377
Args:
@@ -383,6 +385,13 @@ def forward( # type: ignore
383385
Returns:
384386
shape = [num_tokens, num_heads * head_size]
385387
"""
388+
assert output is None, "Output tensor not supported for DualChunk"
389+
390+
if output_scale is not None:
391+
raise NotImplementedError(
392+
"fused output quantization is not yet supported"
393+
" for FlashAttentionImpl")
394+
386395
(
387396
query,
388397
query_succ,

vllm/v1/attention/backends/flashinfer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ def forward(
547547
kv_cache: torch.Tensor,
548548
attn_metadata: FlashInferMetadata,
549549
output: Optional[torch.Tensor] = None,
550+
output_scale: Optional[torch.Tensor] = None,
550551
) -> torch.Tensor:
551552
"""Forward pass with FlashInfer.
552553
@@ -561,6 +562,11 @@ def forward(
561562
"""
562563
assert output is not None, "Output tensor must be provided."
563564

565+
if output_scale is not None:
566+
raise NotImplementedError(
567+
"fused output quantization is not yet supported"
568+
" for FlashInferImpl")
569+
564570
if attn_metadata is None:
565571
# Profiling run.
566572
return output

vllm/v1/attention/backends/flex_attention.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ def forward(
414414
kv_cache: torch.Tensor,
415415
attn_metadata: FlexAttentionMetadata,
416416
output: Optional[torch.Tensor] = None,
417+
output_scale: Optional[torch.Tensor] = None,
417418
) -> torch.Tensor:
418419
"""Forward pass with FLexAttention.
419420
@@ -427,6 +428,12 @@ def forward(
427428
shape = [num_tokens, num_heads * head_size]
428429
"""
429430
assert output is not None, "Output tensor must be provided."
431+
432+
if output_scale is not None:
433+
raise NotImplementedError(
434+
"fused output quantization is not yet supported"
435+
" for FlexAttentionImpl")
436+
430437
enable_gqa = self.num_kv_heads != self.num_heads
431438

432439
if attn_metadata is None:

0 commit comments

Comments
 (0)