1818from vllm .attention .backends .utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
1919from vllm .attention .selector import (_Backend , _cached_get_attn_backend ,
2020 global_force_attn_backend_context_manager )
21+ from vllm .config import VllmConfig
2122from vllm .forward_context import set_forward_context
2223from vllm .platforms import current_platform
24+ from vllm .plugins import set_current_vllm_config
2325
2426# List of support backends for encoder/decoder models
2527LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend .XFORMERS , _Backend .FLASH_ATTN ]
@@ -594,6 +596,7 @@ def _run_encoder_attention_test(
594596 encoder_test_params : PhaseTestParameters ,
595597 attn_metadata : AttentionMetadata ,
596598 test_pt : TestPoint ,
599+ vllm_config : VllmConfig ,
597600) -> torch .Tensor :
598601 '''
599602 Run encoder attention.
@@ -623,7 +626,7 @@ def _run_encoder_attention_test(
623626 attn_type = AttentionType .ENCODER
624627 packed_qkv = encoder_test_params .packed_qkvo .packed_qkv
625628 assert packed_qkv is not None
626- with set_forward_context (attn_metadata ):
629+ with set_forward_context (attn_metadata , vllm_config ):
627630 # In the test setup the shape of the query is
628631 # [batch_size, seq_len, num_heads, head_size]. However
629632 # the attention backend expect the shape to be
@@ -648,6 +651,7 @@ def _run_decoder_self_attention_test(
648651 decoder_test_params : PhaseTestParameters ,
649652 attn_metadata : AttentionMetadata ,
650653 test_pt : TestPoint ,
654+ vllm_config : VllmConfig ,
651655) -> torch .Tensor :
652656 '''
653657 Run decoder self-attention test.
@@ -677,7 +681,7 @@ def _run_decoder_self_attention_test(
677681 kv_cache = test_rsrcs .kv_cache
678682 packed_qkv = decoder_test_params .packed_qkvo .packed_qkv
679683 assert packed_qkv is not None
680- with set_forward_context (attn_metadata ):
684+ with set_forward_context (attn_metadata , vllm_config ):
681685 # In the test setup the shape of the query is
682686 # [batch_size, seq_len, num_heads, head_size]. However
683687 # the attention backend expect the shape to be
@@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test(
701705 cross_test_params : Optional [PhaseTestParameters ],
702706 attn_metadata : AttentionMetadata ,
703707 test_pt : TestPoint ,
708+ vllm_config : VllmConfig ,
704709) -> torch .Tensor :
705710 '''
706711 Run encoder/decoder cross-attention test.
@@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test(
748753 cross_pckd_qkv = cross_test_params .packed_qkvo .packed_qkv
749754 key = (None if cross_pckd_qkv is None else cross_pckd_qkv .key )
750755 value = (None if cross_pckd_qkv is None else cross_pckd_qkv .value )
751- with set_forward_context (attn_metadata ):
756+ with set_forward_context (attn_metadata , vllm_config ):
752757 # In the test setup the shape of the query is
753758 # [batch_size, seq_len, num_heads, head_size]. However
754759 # the attention backend expect the shape to be
@@ -839,7 +844,9 @@ def test_encoder_only(
839844
840845 # Attention scale factor, attention backend instance, attention wrapper
841846 # instance, KV cache init
842- test_rsrcs = _make_test_resources (test_pt )
847+ vllm_config = VllmConfig ()
848+ with set_current_vllm_config (vllm_config ):
849+ test_rsrcs = _make_test_resources (test_pt )
843850
844851 # Construct encoder attention test params (only used
845852 # during prefill)
@@ -863,7 +870,8 @@ def test_encoder_only(
863870 test_rsrcs .attn ,
864871 enc_test_params ,
865872 prephase_attn_metadata ,
866- test_pt = test_pt ))
873+ test_pt = test_pt ,
874+ vllm_config = vllm_config ))
867875
868876 # - Is encoder attention result correct?
869877 assert_actual_matches_ideal (enc_test_params , enc_pckd_act_out ,
@@ -960,7 +968,9 @@ def test_e2e_enc_dec_attn(
960968
961969 # Attention scale factor, attention backend instance, attention wrapper
962970 # instance, KV cache init
963- test_rsrcs = _make_test_resources (test_pt )
971+ vllm_config = VllmConfig ()
972+ with set_current_vllm_config (vllm_config ):
973+ test_rsrcs = _make_test_resources (test_pt )
964974
965975 # Construct encoder attention test params (only used
966976 # during prefill)
@@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn(
10111021 enc_pckd_act_out = _run_encoder_attention_test (test_rsrcs .attn ,
10121022 enc_test_params ,
10131023 prephase_attn_metadata ,
1014- test_pt = test_pt )
1024+ test_pt = test_pt ,
1025+ vllm_config = vllm_config )
10151026
10161027 # - Is encoder attention result correct?
10171028 assert_actual_matches_ideal (enc_test_params , enc_pckd_act_out ,
@@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn(
10231034 test_rsrcs ,
10241035 prephase_dec_test_params ,
10251036 prephase_attn_metadata ,
1026- test_pt = test_pt )
1037+ test_pt = test_pt ,
1038+ vllm_config = vllm_config )
10271039
10281040 # - Is prefill decoder self-attention correct?
10291041 assert_actual_matches_ideal (prephase_dec_test_params ,
@@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn(
10371049 prephase_dec_test_params ,
10381050 prephase_cross_test_params ,
10391051 prephase_attn_metadata ,
1040- test_pt = test_pt )
1052+ test_pt = test_pt ,
1053+ vllm_config = vllm_config )
10411054
10421055 # - Is prefill encoder/decoder cross-attention correct?
10431056 assert_actual_matches_ideal (prephase_cross_test_params ,
@@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn(
10611074 test_rsrcs ,
10621075 decphase_dec_test_params ,
10631076 decphase_attn_metadata ,
1064- test_pt = test_pt )
1077+ test_pt = test_pt ,
1078+ vllm_config = vllm_config )
10651079
10661080 # - Is decode-phase decoder self-attention correct?
10671081 assert_actual_matches_ideal (decphase_dec_test_params ,
@@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn(
10751089 decphase_dec_test_params ,
10761090 None ,
10771091 decphase_attn_metadata ,
1078- test_pt = test_pt )
1092+ test_pt = test_pt ,
1093+ vllm_config = vllm_config )
10791094
10801095 # - Is decode-phase encoder/decoder cross-attention correct?
10811096 assert_actual_matches_ideal (decphase_cross_test_params ,
0 commit comments