Skip to content

Commit c312320

Browse files
hl475yeqcharlotte
andauthored
[CI/Build] tests(v1): feed Triton attention the (num_blocks, 2, …) KV cache layout in backend-correctness tests (#26663)
Signed-off-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
1 parent c981f0e commit c312320

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/v1/attention/test_attention_backends.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,13 +423,14 @@ def _test_backend_correctness(
423423
for backend_name in backend_to_test:
424424
# FlashAttentionm + FlexAttention:
425425
# [2, num_blocks, block_size, num_kv_heads, head_size]
426-
# FlashInfer:
426+
# FlashInfer + Triton:
427427
# [num_blocks, 2, block_size, num_kv_heads, head_size]
428428
# Select the appropriate KV cache format for each backend
429429
kv_cache_for_backend = kv_cache
430-
if backend_name == _Backend.FLASHINFER:
430+
if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN):
431431
kv_cache_for_backend = kv_cache.transpose(0, 1)
432432

433+
if backend_name == _Backend.FLASHINFER:
433434
# For FlashInfer default to HND layout and
434435
kv_cache_for_backend = (
435436
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)

0 commit comments

Comments
 (0)