Skip to content

Commit eebad39

Browse files
authored
[torch.compile] support all attention backends (#10558)
Signed-off-by: youkaichao <youkaichao@gmail.com>
1 parent db100c5 commit eebad39

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+879
-651
lines changed

tests/kernels/test_encoder_decoder_attn.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
1919
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
2020
global_force_attn_backend_context_manager)
21+
from vllm.config import VllmConfig
2122
from vllm.forward_context import set_forward_context
2223
from vllm.platforms import current_platform
24+
from vllm.plugins import set_current_vllm_config
2325

2426
# List of support backends for encoder/decoder models
2527
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
@@ -594,6 +596,7 @@ def _run_encoder_attention_test(
594596
encoder_test_params: PhaseTestParameters,
595597
attn_metadata: AttentionMetadata,
596598
test_pt: TestPoint,
599+
vllm_config: VllmConfig,
597600
) -> torch.Tensor:
598601
'''
599602
Run encoder attention.
@@ -623,7 +626,7 @@ def _run_encoder_attention_test(
623626
attn_type = AttentionType.ENCODER
624627
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
625628
assert packed_qkv is not None
626-
with set_forward_context(attn_metadata):
629+
with set_forward_context(attn_metadata, vllm_config):
627630
# In the test setup the shape of the query is
628631
# [batch_size, seq_len, num_heads, head_size]. However
629632
# the attention backend expect the shape to be
@@ -648,6 +651,7 @@ def _run_decoder_self_attention_test(
648651
decoder_test_params: PhaseTestParameters,
649652
attn_metadata: AttentionMetadata,
650653
test_pt: TestPoint,
654+
vllm_config: VllmConfig,
651655
) -> torch.Tensor:
652656
'''
653657
Run decoder self-attention test.
@@ -677,7 +681,7 @@ def _run_decoder_self_attention_test(
677681
kv_cache = test_rsrcs.kv_cache
678682
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
679683
assert packed_qkv is not None
680-
with set_forward_context(attn_metadata):
684+
with set_forward_context(attn_metadata, vllm_config):
681685
# In the test setup the shape of the query is
682686
# [batch_size, seq_len, num_heads, head_size]. However
683687
# the attention backend expect the shape to be
@@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test(
701705
cross_test_params: Optional[PhaseTestParameters],
702706
attn_metadata: AttentionMetadata,
703707
test_pt: TestPoint,
708+
vllm_config: VllmConfig,
704709
) -> torch.Tensor:
705710
'''
706711
Run encoder/decoder cross-attention test.
@@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test(
748753
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
749754
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
750755
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
751-
with set_forward_context(attn_metadata):
756+
with set_forward_context(attn_metadata, vllm_config):
752757
# In the test setup the shape of the query is
753758
# [batch_size, seq_len, num_heads, head_size]. However
754759
# the attention backend expect the shape to be
@@ -839,7 +844,9 @@ def test_encoder_only(
839844

840845
# Attention scale factor, attention backend instance, attention wrapper
841846
# instance, KV cache init
842-
test_rsrcs = _make_test_resources(test_pt)
847+
vllm_config = VllmConfig()
848+
with set_current_vllm_config(vllm_config):
849+
test_rsrcs = _make_test_resources(test_pt)
843850

844851
# Construct encoder attention test params (only used
845852
# during prefill)
@@ -863,7 +870,8 @@ def test_encoder_only(
863870
test_rsrcs.attn,
864871
enc_test_params,
865872
prephase_attn_metadata,
866-
test_pt=test_pt))
873+
test_pt=test_pt,
874+
vllm_config=vllm_config))
867875

868876
# - Is encoder attention result correct?
869877
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
@@ -960,7 +968,9 @@ def test_e2e_enc_dec_attn(
960968

961969
# Attention scale factor, attention backend instance, attention wrapper
962970
# instance, KV cache init
963-
test_rsrcs = _make_test_resources(test_pt)
971+
vllm_config = VllmConfig()
972+
with set_current_vllm_config(vllm_config):
973+
test_rsrcs = _make_test_resources(test_pt)
964974

965975
# Construct encoder attention test params (only used
966976
# during prefill)
@@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn(
10111021
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
10121022
enc_test_params,
10131023
prephase_attn_metadata,
1014-
test_pt=test_pt)
1024+
test_pt=test_pt,
1025+
vllm_config=vllm_config)
10151026

10161027
# - Is encoder attention result correct?
10171028
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
@@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn(
10231034
test_rsrcs,
10241035
prephase_dec_test_params,
10251036
prephase_attn_metadata,
1026-
test_pt=test_pt)
1037+
test_pt=test_pt,
1038+
vllm_config=vllm_config)
10271039

10281040
# - Is prefill decoder self-attention correct?
10291041
assert_actual_matches_ideal(prephase_dec_test_params,
@@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn(
10371049
prephase_dec_test_params,
10381050
prephase_cross_test_params,
10391051
prephase_attn_metadata,
1040-
test_pt=test_pt)
1052+
test_pt=test_pt,
1053+
vllm_config=vllm_config)
10411054

10421055
# - Is prefill encoder/decoder cross-attention correct?
10431056
assert_actual_matches_ideal(prephase_cross_test_params,
@@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn(
10611074
test_rsrcs,
10621075
decphase_dec_test_params,
10631076
decphase_attn_metadata,
1064-
test_pt=test_pt)
1077+
test_pt=test_pt,
1078+
vllm_config=vllm_config)
10651079

10661080
# - Is decode-phase decoder self-attention correct?
10671081
assert_actual_matches_ideal(decphase_dec_test_params,
@@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn(
10751089
decphase_dec_test_params,
10761090
None,
10771091
decphase_attn_metadata,
1078-
test_pt=test_pt)
1092+
test_pt=test_pt,
1093+
vllm_config=vllm_config)
10791094

10801095
# - Is decode-phase encoder/decoder cross-attention correct?
10811096
assert_actual_matches_ideal(decphase_cross_test_params,

vllm/attention/backends/abstract.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from abc import ABC, abstractmethod
22
from contextlib import contextmanager
33
from dataclasses import dataclass, fields
4-
from enum import Enum, auto
54
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
65
Tuple, Type, TypeVar)
76

@@ -15,13 +14,19 @@
1514
ModelRunnerInputBuilderBase)
1615

1716

18-
class AttentionType(Enum):
19-
DECODER = auto() # Decoder attention between previous layer Q/K/V
20-
ENCODER = auto(
21-
) # Encoder attention between previous layer Q/K/V for encoder-decoder
22-
ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V
23-
ENCODER_DECODER = auto(
24-
) # Attention between dec. Q and enc. K/V for encoder-decoder
17+
class AttentionType:
18+
"""
19+
Attention type.
20+
Use string to be compatible with `torch.compile`.
21+
"""
22+
# Decoder attention between previous layer Q/K/V
23+
DECODER = "decoder"
24+
# Encoder attention between previous layer Q/K/V for encoder-decoder
25+
ENCODER = "encoder"
26+
# Encoder attention between previous layer Q/K/V
27+
ENCODER_ONLY = "encoder_only"
28+
# Attention between dec. Q and enc. K/V for encoder-decoder
29+
ENCODER_DECODER = "encoder_decoder"
2530

2631

2732
class AttentionBackend(ABC):
@@ -241,6 +246,6 @@ def forward(
241246
attn_metadata: T,
242247
k_scale: float = 1.0,
243248
v_scale: float = 1.0,
244-
attn_type: AttentionType = AttentionType.DECODER,
249+
attn_type: str = AttentionType.DECODER,
245250
) -> torch.Tensor:
246251
raise NotImplementedError

vllm/attention/backends/blocksparse_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def forward(
354354
attn_metadata: BlocksparseFlashAttentionMetadata,
355355
k_scale: float = 1.0,
356356
v_scale: float = 1.0,
357-
attn_type: AttentionType = AttentionType.DECODER,
357+
attn_type: str = AttentionType.DECODER,
358358
) -> torch.Tensor:
359359
"""Forward pass with FlashAttention and PagedAttention.
360360

0 commit comments

Comments
 (0)