1- from unittest .mock import patch
1+ from unittest .mock import Mock , patch
22
33import pytest
44import torch
55
66from tests .kernels .utils import override_backend_env_variable
7- from vllm .attention .selector import which_attn_to_use
7+ from vllm .attention .selector import _cached_get_attn_backend , get_attn_backend
88from vllm .platforms .cpu import CpuPlatform
99from vllm .platforms .cuda import CudaPlatform
1010from vllm .platforms .openvino import OpenVinoPlatform
1111from vllm .platforms .rocm import RocmPlatform
1212from vllm .utils import STR_FLASH_ATTN_VAL , STR_INVALID_VAL
1313
1414
15+ @pytest .fixture (autouse = True )
16+ def clear_cache ():
17+ """Clear lru cache to ensure each test case runs without caching.
18+ """
19+ _cached_get_attn_backend .cache_clear ()
20+
21+
1522@pytest .mark .parametrize (
1623 "name" , ["TORCH_SDPA" , "ROCM_FLASH" , "XFORMERS" , "FLASHINFER" , "OPENVINO" ])
1724@pytest .mark .parametrize ("device" , ["cpu" , "openvino" , "hip" , "cuda" ])
@@ -24,67 +31,70 @@ def test_env(name: str, device: str, monkeypatch):
2431
2532 if device == "cpu" :
2633 with patch ("vllm.attention.selector.current_platform" , CpuPlatform ()):
27- backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 ,
28- False )
29- assert backend .name == "TORCH_SDPA"
34+ backend = get_attn_backend (16 , torch .float16 , torch .float16 , 16 ,
35+ False )
36+ assert backend .get_name () == "TORCH_SDPA"
3037 elif device == "hip" :
3138 with patch ("vllm.attention.selector.current_platform" , RocmPlatform ()):
32- backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 ,
33- False )
34- assert backend .name == "ROCM_FLASH"
39+ backend = get_attn_backend (16 , torch .float16 , torch .float16 , 16 ,
40+ False )
41+ assert backend .get_name () == "ROCM_FLASH"
3542 elif device == "openvino" :
3643 with patch ("vllm.attention.selector.current_platform" ,
37- OpenVinoPlatform ()):
38- backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 ,
39- False )
40- assert backend .name == "OPENVINO"
44+ OpenVinoPlatform ()), patch .dict ('sys.modules' ,
45+ {'openvino' : Mock ()}):
46+ backend = get_attn_backend (16 , torch .float16 , torch .float16 , 16 ,
47+ False )
48+ assert backend .get_name () == "OPENVINO"
4149 else :
42- with patch ("vllm.attention.selector.current_platform" , CudaPlatform ()):
43- backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 ,
44- False )
45- assert backend .name == name
50+ if name in ["XFORMERS" , "FLASHINFER" ]:
51+ with patch ("vllm.attention.selector.current_platform" ,
52+ CudaPlatform ()):
53+ backend = get_attn_backend (16 , torch .float16 , torch .float16 ,
54+ 16 , False )
55+ assert backend .get_name () == name
4656
4757
4858def test_flash_attn (monkeypatch ):
4959 """Test FlashAttn validation."""
5060 # TODO: When testing for v1, pipe in `use_v1` as an argument to
51- # which_attn_to_use
61+ # get_attn_backend
5262
5363 override_backend_env_variable (monkeypatch , STR_FLASH_ATTN_VAL )
5464
5565 # Unsupported CUDA arch
5666 with patch ("torch.cuda.get_device_capability" , return_value = (7 , 5 )):
57- backend = which_attn_to_use (16 , torch .float16 , None , 16 , False )
58- assert backend .name != STR_FLASH_ATTN_VAL
67+ backend = get_attn_backend (16 , torch .float16 , None , 16 , False )
68+ assert backend .get_name () != STR_FLASH_ATTN_VAL
5969
6070 # Unsupported data type
61- backend = which_attn_to_use (16 , torch .float8_e4m3fn , None , 16 , False )
62- assert backend .name != STR_FLASH_ATTN_VAL
71+ backend = get_attn_backend (16 , torch .float8_e4m3fn , None , 16 , False )
72+ assert backend .get_name () != STR_FLASH_ATTN_VAL
6373
6474 # Unsupported kv cache data type
65- backend = which_attn_to_use (16 , torch .float16 , "fp8" , 16 , False )
66- assert backend .name != STR_FLASH_ATTN_VAL
75+ backend = get_attn_backend (16 , torch .float16 , "fp8" , 16 , False )
76+ assert backend .get_name () != STR_FLASH_ATTN_VAL
6777
6878 # Unsupported block size
69- backend = which_attn_to_use (16 , torch .float16 , None , 8 , False )
70- assert backend .name != STR_FLASH_ATTN_VAL
79+ backend = get_attn_backend (16 , torch .float16 , None , 8 , False )
80+ assert backend .get_name () != STR_FLASH_ATTN_VAL
7181
7282 # flash-attn is not installed
7383 with patch .dict ('sys.modules' , {'vllm_flash_attn' : None }):
74- backend = which_attn_to_use (16 , torch .float16 , None , 16 , False )
75- assert backend .name != STR_FLASH_ATTN_VAL
84+ backend = get_attn_backend (16 , torch .float16 , None , 16 , False )
85+ assert backend .get_name () != STR_FLASH_ATTN_VAL
7686
7787 # Unsupported head size
78- backend = which_attn_to_use (17 , torch .float16 , None , 16 , False )
79- assert backend .name != STR_FLASH_ATTN_VAL
88+ backend = get_attn_backend (17 , torch .float16 , None , 16 , False )
89+ assert backend .get_name () != STR_FLASH_ATTN_VAL
8090
8191 # Attention-free models should bypass env and use PlaceholderAttention
82- backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 , True )
83- assert backend .name != STR_FLASH_ATTN_VAL
92+ backend = get_attn_backend (16 , torch .float16 , torch .float16 , 16 , True )
93+ assert backend .get_name () != STR_FLASH_ATTN_VAL
8494
8595
8696def test_invalid_env (monkeypatch ):
8797 """Throw an exception if the backend name is invalid."""
8898 override_backend_env_variable (monkeypatch , STR_INVALID_VAL )
8999 with pytest .raises (ValueError ):
90- which_attn_to_use (16 , torch .float16 , None , 16 , False )
100+ get_attn_backend (16 , torch .float16 , None , 16 , False )
0 commit comments