diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a020b0d276be..685b0055f51f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -883,11 +883,16 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/attention/backends/mla/cutlass_mla.py + - vllm/v1/attention/backends/mla/flashinfer_mla.py + - vllm/platforms/cuda.py + - vllm/attention/selector.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py # Attention # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 + - pytest -v -s tests/kernels/attention/test_attention_selector.py - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index fecb1e2e918f..ea61c94953a7 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -10,7 +10,7 @@ from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes @@ -104,7 +104,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: # TODO(luka) use get_kv_cache_stride_order # Create dummy KV cache for the selected backend - if backend == _Backend.ROCM_ATTN: + if backend == AttentionBackendEnum.ROCM_ATTN: # k/v as 1st dimention # HND: [num_blocks, num_kv_heads, block_size, head_size] kv_cache = torch.zeros( @@ -116,7 +116,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: dtype=self.kv_cache_dtype, device=self.device, ) - elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: # k/v as 1st dimention # NHD: [num_blocks, block_size, num_kv_heads, head_size] kv_cache = torch.zeros( @@ -128,7 +128,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: dtype=self.kv_cache_dtype, device=self.device, ) - elif backend == _Backend.TRITON_ATTN: + elif backend == AttentionBackendEnum.TRITON_ATTN: # k/v as 2nd dimention # NHD: [num_blocks, block_size, num_kv_heads, head_size] kv_cache = torch.zeros( @@ -140,7 +140,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: dtype=self.kv_cache_dtype, device=self.device, ) - elif backend == _Backend.FLASHINFER: + elif backend == AttentionBackendEnum.FLASHINFER: kv_cache = torch.zeros( num_blocks, 2, @@ -244,8 +244,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): MODELS_FP4: list[tuple[str, type]] = [] HEADS: list[tuple[int, int]] = [] SPLIT_ATTENTION: list[bool] = [] -BACKENDS_FP8: list[_Backend] = [] -BACKENDS_FP4: list[_Backend] = [] +BACKENDS_FP8: list[AttentionBackendEnum] = [] +BACKENDS_FP4: list[AttentionBackendEnum] = [] if current_platform.is_cuda(): HEADS = [(64, 8), (40, 8)] @@ -261,8 +261,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): TestAttentionNvfp4QuantPatternModel, ) ] - BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER] - BACKENDS_FP4 = [_Backend.FLASHINFER] + BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER] + BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER] elif current_platform.is_rocm(): HEADS = [(32, 8), (40, 8)] @@ -270,9 +270,9 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ] BACKENDS = [ - _Backend.ROCM_AITER_UNIFIED_ATTN, - _Backend.ROCM_ATTN, - _Backend.TRITON_ATTN, + AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, + AttentionBackendEnum.ROCM_ATTN, + AttentionBackendEnum.TRITON_ATTN, ] @@ -302,11 +302,11 @@ def test_attention_quant_pattern( custom_ops: str, model_name: str, model_class: type[AttentionQuantPatternModel], - backend: _Backend, + backend: AttentionBackendEnum, dist_init, ): """Test AttentionStaticQuantPattern fusion pass""" - if backend == _Backend.FLASHINFER and ( + if backend == AttentionBackendEnum.FLASHINFER and ( not current_platform.is_device_capability((10, 0)) or not has_flashinfer() ): pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") @@ -314,6 +314,7 @@ def test_attention_quant_pattern( custom_ops_list = custom_ops.split(",") if custom_ops else [] device = torch.device("cuda:0") + torch.set_default_dtype(dtype) torch.manual_seed(42) vllm_config = VllmConfig( @@ -402,7 +403,7 @@ def test_attention_quant_pattern( result_fused_1 = model_compiled(q, k, v) - if backend == _Backend.FLASHINFER: + if backend == AttentionBackendEnum.FLASHINFER: # With the Flashinfer backend after the 1st round of the forward # pass, output quant scale should be loaded into the attn layer's # _o_scale_float, the 2nd round should reuse the loaded diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index d66c60ccb5b2..bbc545348c67 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -11,7 +11,7 @@ import pytest import regex as re -from tests.v1.attention.utils import _Backend +from tests.v1.attention.utils import AttentionBackendEnum from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform @@ -24,7 +24,7 @@ class ModelBackendTestCase(NamedTuple): model_name: str model_kwargs: dict[str, Any] - backend: _Backend + backend: AttentionBackendEnum attention_fusions: int allreduce_fusions: int | None = None @@ -39,14 +39,14 @@ class ModelBackendTestCase(NamedTuple): # Use smaller model for L40s in CI model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=32, allreduce_fusions=65, ), ModelBackendTestCase( model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), - backend=_Backend.FLASHINFER, + backend=AttentionBackendEnum.FLASHINFER, attention_fusions=48, allreduce_fusions=96, ), @@ -56,7 +56,7 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), - backend=_Backend.FLASHINFER, + backend=AttentionBackendEnum.FLASHINFER, attention_fusions=48, allreduce_fusions=96, ), @@ -67,7 +67,7 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="meta-llama/Llama-3.1-8B-Instruct", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=0, allreduce_fusions=65, ), @@ -78,19 +78,19 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=32, ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.ROCM_ATTN, + backend=AttentionBackendEnum.ROCM_ATTN, attention_fusions=32, ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.ROCM_AITER_UNIFIED_ATTN, + backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, attention_fusions=32, ), ] @@ -111,7 +111,7 @@ class ModelBackendTestCase(NamedTuple): def test_attn_quant( model_name: str, model_kwargs: dict[str, Any], - backend: _Backend, + backend: AttentionBackendEnum, attention_fusions: int, allreduce_fusions: int, custom_ops: str, @@ -119,7 +119,7 @@ def test_attn_quant( caplog_mp_spawn, monkeypatch, ): - if backend == _Backend.FLASHINFER and ( + if backend == AttentionBackendEnum.FLASHINFER and ( not current_platform.is_device_capability((10, 0)) or not has_flashinfer() ): pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") @@ -203,7 +203,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: def test_tp2_attn_quant_allreduce_rmsnorm( model_name: str, model_kwargs: dict, - backend: _Backend, + backend: AttentionBackendEnum, attention_fusions: int, allreduce_fusions: int, custom_ops: str, diff --git a/tests/config/test_multimodal_config.py b/tests/config/test_multimodal_config.py index b1a09d88ed9d..3d02893e52f1 100644 --- a/tests/config/test_multimodal_config.py +++ b/tests/config/test_multimodal_config.py @@ -3,13 +3,13 @@ import pytest -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config.multimodal import MultiModalConfig def test_mm_encoder_attn_backend_str_conversion(): config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN") - assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN + assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN def test_mm_encoder_attn_backend_invalid(): @@ -20,6 +20,6 @@ def test_mm_encoder_attn_backend_invalid(): def test_mm_encoder_attn_backend_hash_updates(): base_hash = MultiModalConfig().compute_hash() overridden_hash = MultiModalConfig( - mm_encoder_attn_backend=_Backend.FLASH_ATTN + mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN ).compute_hash() assert base_hash != overridden_hash diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 48a42ce6ffab..1e7873b7f145 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -127,12 +127,13 @@ def test_env( elif device == "cuda": with patch("vllm.platforms.current_platform", CudaPlatform()): + capability = torch.cuda.get_device_capability() if use_mla: # CUDA MLA backend logic: # - CUTLASS_MLA: only supported with block_size == 128 - # and Blackwell GPUs (SM 10.0), V1 only + # and Blackwell GPUs (SM 10.x), V1 only # - FLASHINFER_MLA: only supported on Blackwell GPUs - # (SM 10.0+), V1 only + # (SM 10.x), V1 only # - FLASHMLA: only supported with block_size == 64 # - FLASH_ATTN_MLA: V1 only # - TRITON_MLA: fallback for other cases @@ -141,58 +142,72 @@ def test_env( if block_size != 128: # CUTLASS_MLA only supports block_size == 128 pytest.skip("CUTLASS_MLA only supports block_size 128") - else: - backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla - ) - expected = "CUTLASS_MLA" - assert backend.get_name() == expected + if capability[0] != 10: + pytest.skip("CUTLASS MLA is not supported on this platform") + backend = get_attn_backend( + 576, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "CUTLASS_MLA" + assert backend.get_name() == expected elif name == "FLASHINFER_MLA": + if capability[0] != 10: + pytest.skip( + "FlashInfer MLA is not supported on this platform" + ) if block_size not in [32, 64]: # FlashInfer MLA only supports block_size 32 or 64 pytest.skip( "FlashInfer MLA only supports block_size 32 or 64" ) - else: - backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla - ) - expected = "FLASHINFER_MLA" - assert backend.get_name() == expected + backend = get_attn_backend( + 576, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "FLASHINFER_MLA" + assert backend.get_name() == expected elif name == "FLASHMLA": if block_size != 64: # FlashMLA only supports block_size == 64 pytest.skip("FlashMLA only supports block_size 64") - else: - from vllm.v1.attention.backends.mla.flashmla import ( - is_flashmla_dense_supported, - ) + from vllm.v1.attention.backends.mla.flashmla import ( + is_flashmla_dense_supported, + ) - is_supported, _ = is_flashmla_dense_supported() - if not is_supported: - pytest.skip("FlashMLA not supported on this platform") - else: - backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla - ) - expected = name - assert backend.get_name() == expected + is_supported, _ = is_flashmla_dense_supported() + if not is_supported: + pytest.skip("FlashMLA not supported on this platform") + backend = get_attn_backend( + 576, + torch.float16, + None, + block_size, + use_mla=use_mla, + ) + expected = name + assert backend.get_name() == expected elif name == "FLASH_ATTN_MLA": + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_mla, + ) + + if not flash_attn_supports_mla(): + pytest.skip( + "FlashAttention MLA not supported on this platform" + ) backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, block_size, use_mla=use_mla ) expected = "FLASH_ATTN_MLA" assert backend.get_name() == expected else: # TRITON_MLA or other fallback backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, block_size, use_mla=use_mla ) expected = "TRITON_MLA" assert backend.get_name() == expected elif name == "FLASHINFER": backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 64, torch.float16, None, block_size, use_mla=use_mla ) expected = "FLASHINFER" assert backend.get_name() == expected diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 14d1618bca3c..183bbf3bf4e0 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -11,7 +11,7 @@ import pytest import torch -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import MultiHeadAttention from vllm.attention.selector import _cached_get_attn_backend from vllm.platforms import current_platform @@ -43,14 +43,14 @@ def test_mha_attn_platform(device: str): patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.TORCH_SDPA + assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA elif device == "hip": with ( patch("vllm.attention.layer.current_platform", RocmPlatform()), patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.TORCH_SDPA + assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA else: # Test CUDA with head_size=64 (divisible by 32) # - should use vLLM's FlashAttention @@ -59,7 +59,7 @@ def test_mha_attn_platform(device: str): patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.FLASH_ATTN + assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN # Test CUDA with head_size=72 (not divisible by 32) # - with upstream FA not available @@ -73,7 +73,7 @@ def test_mha_attn_platform(device: str): ), ): attn = MultiHeadAttention(16, 72, scale=1) - assert attn.attn_backend == _Backend.XFORMERS + assert attn.attn_backend == AttentionBackendEnum.XFORMERS # Test CUDA with head_size=72 (not divisible by 32) # - with upstream FA available @@ -96,7 +96,7 @@ def test_mha_attn_platform(device: str): ), ): attn = MultiHeadAttention(16, 72, scale=1) - assert attn.attn_backend == _Backend.FLASH_ATTN + assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN def ref_attention( diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index eb00bc72b4b0..1befd644202d 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -15,7 +15,7 @@ from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils import ( @@ -878,7 +878,7 @@ def make_block_tables_slot_mapping( def make_test_metadata( - attn_backend: _Backend, + attn_backend: AttentionBackendEnum, is_prompt: bool, seq_lens: list[int] | None, decoder_test_params: PhaseTestParameters | None, diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 6659b3eb1e98..b83105428289 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -15,7 +15,7 @@ create_vllm_config, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ModelConfig from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv @@ -27,11 +27,11 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN, - _Backend.FLASHINFER, - _Backend.FLEX_ATTENTION, - _Backend.TRITON_ATTN, - _Backend.TREE_ATTN, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.FLEX_ATTENTION, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.TREE_ATTN, "FLEX_ATTENTION_SLOW", ] @@ -39,7 +39,7 @@ try: import flashinfer # noqa: F401 except ImportError: - BACKENDS_TO_TEST.remove(_Backend.FLASHINFER) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER) def _convert_dtype_to_torch(dtype): @@ -192,7 +192,7 @@ def __init__(self, device: torch.device): def run_attention_backend( - backend: _Backend, + backend: AttentionBackendEnum, kv_cache_spec: FullAttentionSpec, layer_names: list[str], vllm_config, @@ -211,13 +211,13 @@ def run_attention_backend( use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") if backend == "FLEX_ATTENTION_SLOW": - actual_backend = _Backend.FLEX_ATTENTION + actual_backend = AttentionBackendEnum.FLEX_ATTENTION use_direct_block_mask = False builder_cls, impl_cls = try_get_attention_backend(actual_backend) # Mock flashinfer's get_per_layer_parameters if needed - if actual_backend == _Backend.FLASHINFER: + if actual_backend == AttentionBackendEnum.FLASHINFER: import unittest.mock from vllm.v1.attention.backends.utils import PerLayerParameters @@ -246,7 +246,7 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): else: # Build metadata builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) - if actual_backend == _Backend.FLEX_ATTENTION: + if actual_backend == AttentionBackendEnum.FLEX_ATTENTION: builder.direct_build = use_direct_block_mask attn_metadata = builder.build( common_prefix_len=0, @@ -289,7 +289,7 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): def _test_backend_correctness( batch_spec: BatchSpec, model: str, - backend_to_test: list[_Backend | str], + backend_to_test: list[AttentionBackendEnum | str], mask_mod, *, block_size: int = 16, @@ -429,17 +429,20 @@ def _test_backend_correctness( # Select the appropriate KV cache format for each backend kv_cache_for_backend = kv_cache reset_kv_cache_layout = False - if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN): + if backend_name in ( + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.TRITON_ATTN, + ): kv_cache_for_backend = kv_cache.transpose(0, 1) - if backend_name == _Backend.FLASHINFER: + if backend_name == AttentionBackendEnum.FLASHINFER: # For FlashInfer default to HND layout and kv_cache_for_backend = ( kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) ) set_kv_cache_layout("HND") reset_kv_cache_layout = True - elif backend_name == _Backend.TRITON_ATTN: + elif backend_name == AttentionBackendEnum.TRITON_ATTN: kv_cache_for_backend = kv_cache_for_backend.contiguous() try: @@ -518,7 +521,9 @@ def causal_mask_mod( batch_spec = BATCH_SPECS[batch_spec_name] LARGE_BLOCK_BACKENDS = ( - [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + [AttentionBackendEnum.FLEX_ATTENTION] + if is_torch_equal_or_newer("2.9.0.dev0") + else [] ) SMALL_BLOCK_BACKENDS = [ x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS @@ -533,9 +538,9 @@ def causal_mask_mod( SLIDING_WINDOW_BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN, - _Backend.FLEX_ATTENTION, - _Backend.TRITON_ATTN, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLEX_ATTENTION, + AttentionBackendEnum.TRITON_ATTN, "FLEX_ATTENTION_SLOW", ] @@ -569,7 +574,9 @@ def sliding_window_mask_mod( ) LARGE_BLOCK_BACKENDS = ( - [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + [AttentionBackendEnum.FLEX_ATTENTION] + if is_torch_equal_or_newer("2.9.0.dev0") + else [] ) SMALL_BLOCK_BACKENDS = [ x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index cda4fb11c096..f02798fb78e7 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -18,12 +18,11 @@ try_get_attention_backend, ) from vllm import _custom_ops as ops -from vllm.attention.backends.registry import _Backend, backend_to_class_str +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.utils.fa_utils import flash_attn_supports_mla from vllm.config.vllm import set_current_vllm_config from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backends.mla.common import QueryLenSupport @@ -31,25 +30,25 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.CUTLASS_MLA, - _Backend.FLASHMLA, - _Backend.FLASH_ATTN_MLA, - _Backend.FLASHINFER_MLA, - _Backend.TRITON_MLA, + AttentionBackendEnum.CUTLASS_MLA, + AttentionBackendEnum.FLASHMLA, + AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.FLASHINFER_MLA, + AttentionBackendEnum.TRITON_MLA, ] # Remove sm100 backends from the list if not using sm100 if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: - BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) - BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_MLA) # Remove FLASH_ATTN_MLA from the list if not supported if not flash_attn_supports_mla(): - BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASH_ATTN_MLA) # Remove FLASHMLA from the list if not supported if not is_flashmla_dense_supported()[0]: - BACKENDS_TO_TEST.remove(_Backend.FLASHMLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA) SPEC_DECODE_BACKENDS = [] for backend in BACKENDS_TO_TEST: @@ -62,9 +61,7 @@ BACKEND_BLOCK_SIZES = {} for backend in BACKENDS_TO_TEST: - backend_class_str = backend_to_class_str(backend) - backend_class = resolve_obj_by_qualname(backend_class_str) - supported_sizes = backend_class.get_supported_kernel_block_size() + supported_sizes = backend.get_class().supported_kernel_block_sizes if supported_sizes: default_size = supported_sizes[0] block_size = ( @@ -291,7 +288,7 @@ def get_kv_cache_spec(self, vllm_config): def run_attention_backend( - backend: _Backend, + backend: AttentionBackendEnum, kv_cache_spec: FullAttentionSpec, layer_names: list[str], vllm_config, diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index b166d9d4ff68..dea89babd4b4 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -8,7 +8,7 @@ import torch from vllm.attention.backends.abstract import AttentionImpl -from vllm.attention.backends.registry import _Backend, backend_to_class_str +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, CompilationConfig, @@ -20,7 +20,6 @@ VllmConfig, ) from vllm.config.model import ModelDType -from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, @@ -120,15 +119,14 @@ def create_common_attn_metadata( def try_get_attention_backend( - backend: _Backend, + backend: AttentionBackendEnum, ) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]: """Try to get the attention backend class, skipping test if not found.""" - backend_class_str = backend_to_class_str(backend) try: - backend_class = resolve_obj_by_qualname(backend_class_str) + backend_class = backend.get_class() return backend_class.get_builder_cls(), backend_class.get_impl_cls() except ImportError as e: - pytest.skip(f"{backend_class_str} not available: {e}") + pytest.skip(f"{backend.name} not available: {e}") raise AssertionError("unreachable") from None diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 47d05a20a65d..89d0ec769ac0 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -13,7 +13,7 @@ create_standard_kv_cache_spec, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, DeviceConfig, @@ -534,11 +534,17 @@ def create_deterministic_logits(token_ids): sampling_metadata = mock.MagicMock() if attn_backend == "FLASH_ATTN": - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) elif attn_backend == "TRITON_ATTN": - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.TRITON_ATTN + ) elif attn_backend == "TREE_ATTN": - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.TREE_ATTN + ) else: raise ValueError(f"Unsupported attention backend: {attn_backend}") @@ -673,7 +679,9 @@ def create_deterministic_logits(token_ids, k: int): proposer.attn_layer_names = ["layer.0"] # Get the tree attention metadata builder. - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.TREE_ATTN + ) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), layer_names=proposer.attn_layer_names, diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index 9ca7cf9e3e0e..6d59b58e739e 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -12,7 +12,7 @@ create_standard_kv_cache_spec, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, DeviceConfig, @@ -177,7 +177,9 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): sampling_metadata = mock.MagicMock() # Setup attention metadata - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index b365e75d5514..6958d62dc7e9 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -10,7 +10,7 @@ create_vllm_config, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ParallelConfig, SpeculativeConfig from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -35,7 +35,7 @@ def forward_attention( block_table: torch.Tensor, slot_mapping: torch.Tensor, seqlen_k: int, - backend: _Backend, + backend: AttentionBackendEnum, spec_token_tree: str | None = None, num_spec_tokens: int = 0, ) -> torch.Tensor: @@ -241,7 +241,7 @@ def test_tree_attn_correctness() -> None: block_table=block_table, slot_mapping=tree_slot_mapping, seqlen_k=seqlen_k, - backend=_Backend.TREE_ATTN, + backend=AttentionBackendEnum.TREE_ATTN, spec_token_tree=spec_token_tree, num_spec_tokens=tree_size_q - 1, ).view(batch_size, -1, num_heads, dim_per_head) @@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None: block_table=block_table, slot_mapping=branch_slot_mapping, seqlen_k=sequence_position + q_len, - backend=_Backend.FLASH_ATTN, + backend=AttentionBackendEnum.FLASH_ATTN, ).view(batch_size, -1, num_heads, dim_per_head) # Compare the outputs. diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index db0215511d32..b6bf9e6add81 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -185,9 +185,7 @@ def _make_mock_backend_for_kernel_block_size( supported_sizes: list[int | MultipleOf], ): class _MockBackend: - @staticmethod - def get_supported_kernel_block_size(): - return supported_sizes + supported_kernel_block_sizes = supported_sizes return _MockBackend() @@ -466,13 +464,20 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): # This test checks if GPUModelRunner initializes correctly when an attention # backend enforces a non-default KV cache stride order. n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config) - expected_kv_cache_shape = [ - 2, - NUM_BLOCKS, - BLOCK_SIZE, - n_heads, - model_runner.model_config.get_head_size(), - ] + head_size = model_runner.model_config.get_head_size() + + # Get the expected shape from the backend's get_kv_cache_shape method + # to ensure compatibility with different backends (triton vs flexattention) + attn_backend = None + for attn_group in model_runner._attn_group_iterator(): + attn_backend = attn_group.backend + break + + assert attn_backend is not None, "No attention backend found" + expected_kv_cache_shape = list( + attn_backend.get_kv_cache_shape(NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size) + ) + # TODO mla test default_stride = tuple(range(5)) # Permutation that gets you back to expected kv shape diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index e9c6a278a941..ebf12c302b4c 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,13 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Generic, Protocol, TypeVar +from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, cast, get_args import torch from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey +if TYPE_CHECKING: + from vllm.config.cache import BlockSize, CacheDType + from vllm.platforms.interface import DeviceCapability + from vllm.v1.attention.backends.utils import KVCacheLayoutType + class AttentionType: """ @@ -40,6 +45,9 @@ class AttentionBackend(ABC): # calling the custom op. When piecewise cudagraph is enabled, this # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)] + supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"] @staticmethod @abstractmethod @@ -56,10 +64,6 @@ def get_impl_cls() -> type["AttentionImpl"]: def get_metadata_cls() -> type["AttentionMetadata"]: raise NotImplementedError - @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: - return cls.get_impl_cls().get_supported_kernel_block_size() - @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) @@ -88,6 +92,159 @@ def get_kv_cache_stride_order() -> tuple[int, ...]: def full_cls_name(cls) -> tuple[str, str]: return (cls.__module__, cls.__qualname__) + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + supported_head_sizes = cls.get_supported_head_sizes() + return (not supported_head_sizes) or head_size in supported_head_sizes + + @classmethod + def supports_dtype(cls, dtype: torch.dtype) -> bool: + return dtype in cls.supported_dtypes + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool: + if kv_cache_dtype is None: + return True + return (not cls.supported_kv_cache_dtypes) or ( + kv_cache_dtype in cls.supported_kv_cache_dtypes + ) + + @classmethod + def supports_block_size(cls, block_size: int | None) -> bool: + from vllm.config.cache import BlockSize + + if block_size is None: + return True + + valid_sizes = get_args(BlockSize) + if block_size not in valid_sizes: + return False + + if not cls.supported_kernel_block_sizes: + return True + + for supported_size in cls.supported_kernel_block_sizes: + is_multiple_of = ( + isinstance(supported_size, MultipleOf) + and block_size % supported_size.base == 0 + ) + is_int_equal = ( + isinstance(supported_size, int) and block_size == supported_size + ) + if is_multiple_of or is_int_equal: + return True + return False + + @classmethod + def get_default_block_size(cls) -> "BlockSize": + from vllm.config.cache import BlockSize + + if not cls.supported_kernel_block_sizes: + raise ValueError( + f"Fallback failed, no explicitly supported block sizes for " + f"backend {cls.get_name()}" + ) + + block_size = cls.supported_kernel_block_sizes[0] + if isinstance(block_size, MultipleOf): + block_size = block_size.base + + valid_sizes = get_args(BlockSize) + if block_size not in valid_sizes: + raise ValueError( + f"Default block size {block_size} for backend {cls.get_name()} is not " + f"a valid BlockSize." + ) + + return cast(BlockSize, block_size) + + @classmethod + def is_mla(cls) -> bool: + return False + + @classmethod + def supports_sink(cls) -> bool: + return False + + @classmethod + def is_sparse(cls) -> bool: + return False + + @classmethod + def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: + return True + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: "CacheDType | None", + block_size: int | None, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: "DeviceCapability", + ) -> str | None: + return None + + @classmethod + def validate_configuration( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: "CacheDType | None", + block_size: int | None, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: "DeviceCapability", + ) -> list[str]: + invalid_reasons = [] + if not cls.supports_head_size(head_size): + invalid_reasons.append("head_size not supported") + if not cls.supports_dtype(dtype): + invalid_reasons.append("dtype not supported") + if not cls.supports_kv_cache_dtype(kv_cache_dtype): + invalid_reasons.append("kv_cache_dtype not supported") + if not cls.supports_block_size(block_size): + invalid_reasons.append("block_size not supported") + if use_mla != cls.is_mla(): + if use_mla: + invalid_reasons.append("MLA not supported") + else: + invalid_reasons.append("non-MLA not supported") + if has_sink and not cls.supports_sink(): + invalid_reasons.append("sink setting not supported") + if use_sparse != cls.is_sparse(): + if use_sparse: + invalid_reasons.append("sparse not supported") + else: + invalid_reasons.append("non-sparse not supported") + if not cls.supports_compute_capability(device_capability): + invalid_reasons.append("compute capability not supported") + combination_reason = cls.supports_combination( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla, + has_sink, + use_sparse, + device_capability, + ) + if combination_reason is not None: + invalid_reasons.append(combination_reason) + return invalid_reasons + + @classmethod + def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": + return None + class AttentionMetadata: pass @@ -160,11 +317,6 @@ def __init__( ) -> None: raise NotImplementedError - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - # TODO: implement this function for all backends. - return [MultipleOf(1)] - @abstractmethod def forward( self, diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 05d0159d0861..4fb8c581890f 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -3,108 +3,155 @@ """Attention backend registry""" import enum +from collections.abc import Callable +from typing import TYPE_CHECKING, cast from vllm.utils.import_utils import resolve_obj_by_qualname +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend -class _Backend(enum.Enum): - FLASH_ATTN = enum.auto() - TRITON_ATTN = enum.auto() - XFORMERS = enum.auto() - ROCM_ATTN = enum.auto() - ROCM_AITER_MLA = enum.auto() - ROCM_AITER_FA = enum.auto() # used for ViT attn backend - TORCH_SDPA = enum.auto() - FLASHINFER = enum.auto() - FLASHINFER_MLA = enum.auto() - TRITON_MLA = enum.auto() - CUTLASS_MLA = enum.auto() - FLASHMLA = enum.auto() - FLASHMLA_SPARSE = enum.auto() - FLASH_ATTN_MLA = enum.auto() - PALLAS = enum.auto() - IPEX = enum.auto() - NO_ATTENTION = enum.auto() - FLEX_ATTENTION = enum.auto() - TREE_ATTN = enum.auto() - ROCM_AITER_UNIFIED_ATTN = enum.auto() - - -BACKEND_MAP = { - _Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501 - _Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501 - _Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501 - _Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501 - _Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501 - _Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501 - _Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501 - _Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501 - _Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501 - _Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501 - _Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501 - _Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501 - _Backend.FLASHMLA_SPARSE: "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend", # noqa: E501 - _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501 - _Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501 - _Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501 - _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501 - _Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501 -} - - -def register_attn_backend(backend: _Backend, class_path: str | None = None): - """ - Decorator: register a custom attention backend into BACKEND_MAPPING. - - If class_path is provided, use it. - - Otherwise, auto-generate from the class object. - Validation: only checks if 'backend' is a valid _Backend enum member. - Overwriting existing mappings is allowed. This enables other hardware - platforms to plug in custom out-of-tree backends. - """ - if not isinstance(backend, _Backend): - raise ValueError(f"{backend} is not a valid _Backend enum value.") - def decorator(cls): - path = class_path or f"{cls.__module__}.{cls.__qualname__}" - BACKEND_MAP[backend] = path - return cls +class _AttentionBackendEnumMeta(enum.EnumMeta): + """Metaclass for AttentionBackendEnum to provide better error messages.""" - return decorator + def __getitem__(cls, name: str): + """Get backend by name with helpful error messages.""" + try: + return super().__getitem__(name) + except KeyError: + members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values() + valid_backends = ", ".join(m.name for m in members) + raise ValueError( + f"Unknown attention backend: '{name}'. " + f"Valid options are: {valid_backends}" + ) from None -def backend_to_class_str(backend: _Backend) -> str: - """Get the backend class string +class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta): + """Enumeration of all supported attention backends. - Args: - backend: The backend enum value + The enum value is the default class path, but this can be overridden + at runtime using register_backend(). - Returns: - The backend class string + To get the actual backend class (respecting overrides), use: + backend.get_class() """ - return BACKEND_MAP[backend] - -def backend_to_class(backend: _Backend) -> type: - """Get the backend class. + FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" + XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" + ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" + ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" + ROCM_AITER_FA = ( + "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" + ) + TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" + FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" + FLASHINFER_MLA = ( + "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" + ) + TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" + CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" + FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" + FLASHMLA_SPARSE = ( + "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend" + ) + FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" + PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend" + NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend" + FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" + TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" + ROCM_AITER_UNIFIED_ATTN = ( + "vllm.v1.attention.backends.rocm_aiter_unified_attn." + "RocmAiterUnifiedAttentionBackend" + ) + # Placeholder for third-party/custom backends - must be registered before use + CUSTOM = "" + + def get_path(self) -> str: + """Get the class path for this backend (respects overrides). + + Returns: + The fully qualified class path string + + Raises: + ValueError: If Backend.CUSTOM is used without being registered + """ + path = _OVERRIDES.get(self, self.value) + if not path: + raise ValueError( + f"Backend {self.name} must be registered before use. " + f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')" + ) + return path + + def get_class(self) -> "type[AttentionBackend]": + """Get the backend class (respects overrides). + + Returns: + The backend class + + Raises: + ImportError: If the backend class cannot be imported + ValueError: If Backend.CUSTOM is used without being registered + """ + return resolve_obj_by_qualname(self.get_path()) + + def is_overridden(self) -> bool: + """Check if this backend has been overridden. + + Returns: + True if the backend has a registered override + """ + return self in _OVERRIDES + + def clear_override(self) -> None: + """Clear any override for this backend, reverting to the default.""" + _OVERRIDES.pop(self, None) + + +_OVERRIDES: dict[AttentionBackendEnum, str] = {} + + +def register_backend( + backend: AttentionBackendEnum, class_path: str | None = None +) -> Callable[[type], type]: + """Register or override a backend implementation. Args: - backend: The backend enum value + backend: The AttentionBackendEnum member to register + class_path: Optional class path. If not provided and used as + decorator, will be auto-generated from the class. Returns: - The backend class + Decorator function if class_path is None, otherwise a no-op + + Examples: + # Override an existing backend + @register_backend(AttentionBackendEnum.FLASH_ATTN) + class MyCustomFlashAttn: + ... + + # Register a custom third-party backend + @register_backend(AttentionBackendEnum.CUSTOM) + class MyCustomBackend: + ... + + # Direct registration + register_backend( + AttentionBackendEnum.CUSTOM, + "my.module.MyCustomBackend" + ) """ - backend_class_name = backend_to_class_str(backend) - return resolve_obj_by_qualname(backend_class_name) + def decorator(cls: type) -> type: + _OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" + return cls -def backend_name_to_enum(backend_name: str) -> _Backend | None: - """ - Convert a string backend name to a _Backend enum value. + if class_path is not None: + _OVERRIDES[backend] = class_path + return lambda x: x - Returns: - _Backend: enum value if backend_name is a valid in-tree type - None: otherwise it's an invalid in-tree type or an out-of-tree platform - is loaded. - """ - assert backend_name is not None - return _Backend[backend_name] if backend_name in _Backend.__members__ else None + return decorator diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 17e025155a43..777207950aed 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -12,7 +12,7 @@ import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config @@ -99,29 +99,30 @@ def check_upstream_fa_availability(dtype: torch.dtype): def maybe_get_vit_flash_attn_backend( - attn_backend: _Backend, + attn_backend: AttentionBackendEnum, use_upstream_fa: bool, - attn_backend_override: _Backend | None = None, -) -> tuple[_Backend, Callable | None]: + attn_backend_override: AttentionBackendEnum | None = None, +) -> tuple[AttentionBackendEnum, Callable | None]: if current_platform.is_rocm(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - attn_backend = _Backend.ROCM_AITER_FA + attn_backend = AttentionBackendEnum.ROCM_AITER_FA elif ( check_upstream_fa_availability(torch.get_default_dtype()) and on_gfx9() and attn_backend_override is None ): - attn_backend = _Backend.FLASH_ATTN + attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True else: - return _Backend.TORCH_SDPA, None + return AttentionBackendEnum.TORCH_SDPA, None elif current_platform.is_cuda(): - if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - attn_backend = _Backend.FLASH_ATTN + attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True elif current_platform.is_xpu(): assert attn_backend == _Backend.FLASH_ATTN, ( @@ -129,10 +130,13 @@ def maybe_get_vit_flash_attn_backend( ) use_upstream_fa = False else: - return _Backend.TORCH_SDPA, None + return AttentionBackendEnum.TORCH_SDPA, None - if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: - if attn_backend == _Backend.ROCM_AITER_FA: + if attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: + if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: if use_upstream_fa: @@ -309,7 +313,7 @@ def __init__( kv_sharing_target_layer_name, **extra_impl_args, ) - self.backend = backend_name_to_enum(self.attn_backend.get_name()) + self.backend = AttentionBackendEnum[self.attn_backend.get_name()] self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how @@ -530,13 +534,13 @@ def __init__( backend if backend in { - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.PALLAS, - _Backend.ROCM_AITER_FA, - _Backend.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.PALLAS, + AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, } - else _Backend.TORCH_SDPA + else AttentionBackendEnum.TORCH_SDPA ) self.attn_backend, self._flash_attn_varlen_func = ( @@ -547,17 +551,23 @@ def __init__( ) ) - if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability(): - self.attn_backend = _Backend.TORCH_SDPA + if ( + self.attn_backend == AttentionBackendEnum.XFORMERS + and not check_xformers_availability() + ): + self.attn_backend = AttentionBackendEnum.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } # this condition is just to make sure that the # use_upstream_fa in the log is correct - if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: + if ( + current_platform.is_rocm() + and self.attn_backend == AttentionBackendEnum.FLASH_ATTN + ): use_upstream_fa = True logger.info_once( @@ -606,17 +616,17 @@ def forward( max_seqlen_k=kv_len, softmax_scale=self.scale, ) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward( query, key, value, scale=self.scale ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) out = out.transpose(1, 2) - elif self.attn_backend == _Backend.PALLAS: + elif self.attn_backend == AttentionBackendEnum.PALLAS: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) from torch_xla.experimental.custom_kernel import flash_attention diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 9890d8d80cba..2cf0111348e4 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -4,14 +4,15 @@ import os from collections.abc import Generator from contextlib import contextmanager -from dataclasses import dataclass from functools import cache +from typing import cast, get_args import torch import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.utils import STR_BACKEND_ENV_VAR from vllm.utils.import_utils import resolve_obj_by_qualname @@ -19,18 +20,18 @@ logger = init_logger(__name__) -def get_env_variable_attn_backend() -> _Backend | None: +def get_env_variable_attn_backend() -> AttentionBackendEnum | None: """ Get the backend override specified by the vLLM attention backend environment variable, if one is specified. Returns: - * _Backend enum value if an override is specified + * AttentionBackendEnum value if an override is specified * None otherwise """ backend_name = os.environ.get(STR_BACKEND_ENV_VAR) - return None if backend_name is None else backend_name_to_enum(backend_name) + return None if backend_name is None else AttentionBackendEnum[backend_name] # Global state allows a particular choice of backend @@ -40,10 +41,10 @@ def get_env_variable_attn_backend() -> _Backend | None: # # THIS SELECTION TAKES PRECEDENCE OVER THE # VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE -forced_attn_backend: _Backend | None = None +forced_attn_backend: AttentionBackendEnum | None = None -def global_force_attn_backend(attn_backend: _Backend | None) -> None: +def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None: """ Force all attention operations to use a specified backend. @@ -58,7 +59,7 @@ def global_force_attn_backend(attn_backend: _Backend | None) -> None: forced_attn_backend = attn_backend -def get_global_forced_attn_backend() -> _Backend | None: +def get_global_forced_attn_backend() -> AttentionBackendEnum | None: """ Get the currently-forced choice of attention backend, or None if auto-selection is currently enabled. @@ -66,69 +67,11 @@ def get_global_forced_attn_backend() -> _Backend | None: return forced_attn_backend -@dataclass(frozen=True) -class _IsSupported: - can_import: bool - head_size: bool - dtype: bool - - def __bool__(self) -> bool: - return self.can_import and self.head_size and self.dtype - - -def is_attn_backend_supported( - attn_backend: str | type[AttentionBackend], - head_size: int, - dtype: torch.dtype, - *, - allow_import_error: bool = True, -) -> _IsSupported: - if isinstance(attn_backend, str): - try: - attn_backend = resolve_obj_by_qualname(attn_backend) - except ImportError: - if not allow_import_error: - raise - - return _IsSupported(can_import=False, head_size=False, dtype=False) - - assert isinstance(attn_backend, type) - - # TODO: Update the interface once V0 is removed - if get_supported_head_sizes := getattr( - attn_backend, "get_supported_head_sizes", None - ): - is_head_size_supported = head_size in get_supported_head_sizes() - elif validate_head_size := getattr(attn_backend, "validate_head_size", None): - try: - validate_head_size(head_size) - is_head_size_supported = True - except Exception: - is_head_size_supported = False - else: - raise NotImplementedError( - f"{attn_backend.__name__} does not support head size validation" - ) - - if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None): - is_dtype_supported = dtype in get_supported_dtypes() - else: - raise NotImplementedError( - f"{attn_backend.__name__} does not support dtype validation" - ) - - return _IsSupported( - can_import=True, - head_size=is_head_size_supported, - dtype=is_dtype_supported, - ) - - def get_attn_backend( head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, - block_size: int, + block_size: int | None, use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, @@ -138,10 +81,19 @@ def get_attn_backend( # value to be returned from the cache if the value changes between calls. # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the # private function. + + # Validate kv_cache_dtype is a valid CacheDType value + if kv_cache_dtype is not None: + valid_cache_dtypes = get_args(CacheDType) + assert kv_cache_dtype in valid_cache_dtypes, ( + f"Invalid kv_cache_dtype: {kv_cache_dtype}. " + f"Valid values are: {valid_cache_dtypes}" + ) + return _cached_get_attn_backend( head_size=head_size, dtype=dtype, - kv_cache_dtype=kv_cache_dtype, + kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype), block_size=block_size, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, @@ -154,8 +106,8 @@ def get_attn_backend( def _cached_get_attn_backend( head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, - block_size: int, + kv_cache_dtype: CacheDType | None, + block_size: int | None, use_v1: bool = False, use_mla: bool = False, has_sink: bool = False, @@ -167,7 +119,9 @@ def _cached_get_attn_backend( # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # ENVIRONMENT VARIABLE. selected_backend = None - backend_by_global_setting: _Backend | None = get_global_forced_attn_backend() + backend_by_global_setting: AttentionBackendEnum | None = ( + get_global_forced_attn_backend() + ) if backend_by_global_setting is not None: selected_backend = backend_by_global_setting else: @@ -183,12 +137,13 @@ def _cached_get_attn_backend( STR_BACKEND_ENV_VAR, ) backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1") - selected_backend = backend_name_to_enum(backend_by_env_var) - if selected_backend is None: + try: + selected_backend = AttentionBackendEnum[backend_by_env_var] + except KeyError as e: raise ValueError( - f"Invalid attention backend: '{backend_by_env_var}'. " - f"Valid backends are: {list(_Backend.__members__.keys())}" - ) + f"Invalid attention backend: '{backend_by_env_var}'. Valid " + f"backends are: {list(AttentionBackendEnum.__members__.keys())}" + ) from e # get device-specific attn_backend from vllm.platforms import current_platform @@ -208,12 +163,26 @@ def _cached_get_attn_backend( raise ValueError( f"Invalid attention backend for {current_platform.device_name}" ) - return resolve_obj_by_qualname(attention_cls) + backend = resolve_obj_by_qualname(attention_cls) + + # Adjust kv cache layout if the selected backend requires a specific one + required_layout = backend.get_required_kv_cache_layout() + if required_layout is not None: + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + set_kv_cache_layout(required_layout) + logger.info( + "Using %s KV cache layout for %s backend.", + required_layout, + backend.get_name(), + ) + + return backend @contextmanager def global_force_attn_backend_context_manager( - attn_backend: _Backend, + attn_backend: AttentionBackendEnum, ) -> Generator[None, None, None]: """ Globally force a vLLM attention backend override within a diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 031df3091f1c..864cf1be81b2 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -21,7 +21,15 @@ logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128, 256] -CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] +CacheDType = Literal[ + "auto", + "bfloat16", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + "fp8_inc", + "fp8_ds_mla", +] MambaDType = Literal["auto", "float32"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] KVOffloadingBackend = Literal["native", "lmcache"] diff --git a/vllm/config/model.py b/vllm/config/model.py index 082f90653f5a..8bd54bb560ff 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -45,7 +45,7 @@ import vllm.model_executor.layers.quantization as me_quant import vllm.model_executor.models as me_models - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config.load import LoadConfig from vllm.config.parallel import ParallelConfig from vllm.model_executor.layers.quantization import QuantizationMethods @@ -53,7 +53,7 @@ else: PretrainedConfig = Any - _Backend = Any + AttentionBackendEnum = Any me_quant = LazyLoader( "model_executor", globals(), "vllm.model_executor.layers.quantization" ) @@ -308,7 +308,7 @@ class ModelConfig: mm_processor_cache_type: InitVar[MMCacheType | None] = None mm_shm_cache_max_object_size_mb: InitVar[int | None] = None mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None - mm_encoder_attn_backend: InitVar[_Backend | str | None] = None + mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None interleave_mm_strings: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None video_pruning_rate: InitVar[float | None] = None @@ -428,7 +428,7 @@ def __post_init__( mm_processor_cache_type: MMCacheType | None, mm_shm_cache_max_object_size_mb: int | None, mm_encoder_tp_mode: MMEncoderTPMode | None, - mm_encoder_attn_backend: _Backend | str | None, + mm_encoder_attn_backend: AttentionBackendEnum | str | None, interleave_mm_strings: bool | None, skip_mm_profiling: bool | None, video_pruning_rate: float | None, diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index ef73720efe09..9348c1b2af8c 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -11,9 +11,9 @@ from vllm.config.utils import config if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum else: - _Backend = Any + AttentionBackendEnum = Any @dataclass @@ -125,10 +125,10 @@ class MultiModalConfig: DP (which is controlled by `--data-parallel-size`). This is only supported on a per-model basis and falls back to `"weights"` if the encoder does not support DP.""" - mm_encoder_attn_backend: _Backend | None = None + mm_encoder_attn_backend: AttentionBackendEnum | None = None """Optional override for the multi-modal encoder attention backend when using vision transformers. Accepts any value from - `vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`).""" + `vllm.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`).""" interleave_mm_strings: bool = False """Enable fully interleaved support for multimodal prompts, while using --chat-template-content-format=string.""" @@ -167,26 +167,16 @@ def _validate_limit_per_prompt( @field_validator("mm_encoder_attn_backend", mode="before") @classmethod - def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None: - from vllm.attention.backends.registry import ( - _Backend as BackendEnum, - ) - from vllm.attention.backends.registry import ( - backend_name_to_enum, - ) - - if value is None or isinstance(value, BackendEnum): + def _validate_mm_encoder_attn_backend( + cls, value: str | AttentionBackendEnum | None + ) -> AttentionBackendEnum | None: + if value is None or isinstance(value, AttentionBackendEnum): return value - if isinstance(value, str): - candidate = backend_name_to_enum(value.upper()) - if candidate is not None: - return candidate - - valid_backends = ", ".join(sorted(BackendEnum.__members__.keys())) - raise ValueError( - f"Invalid mm encoder attention backend. Expected one of: {valid_backends}." + assert isinstance(value, str), ( + "mm_encoder_attn_backend must be a string or an AttentionBackendEnum." ) + return AttentionBackendEnum[value.upper()] @model_validator(mode="after") def _validate_multimodal_config(self): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 4651cedbc7df..4ce62917e32f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -21,7 +21,7 @@ import zmq from vllm import envs -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -868,9 +868,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): use_mla=self.use_mla, ) self.backend_name = backend.get_name() - attn_backend = backend_name_to_enum(self.backend_name) - self._use_flashinfer = attn_backend == _Backend.FLASHINFER - self._use_pallas = attn_backend == _Backend.PALLAS + attn_backend = AttentionBackendEnum[self.backend_name] + self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER + self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 66c75d944ec8..49e51538f324 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -32,7 +32,7 @@ from typing_extensions import TypeIs, deprecated import vllm.envs as envs -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, CompilationConfig, @@ -464,7 +464,7 @@ class EngineArgs: MultiModalConfig.mm_shm_cache_max_object_size_mb ) mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode - mm_encoder_attn_backend: _Backend | str | None = ( + mm_encoder_attn_backend: AttentionBackendEnum | str | None = ( MultiModalConfig.mm_encoder_attn_backend ) io_processor_plugin: str | None = None @@ -1745,32 +1745,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "such as ngram, medusa, eagle, or mtp." ) - V1_BACKENDS = [ - "FLASH_ATTN", - "PALLAS", - "TRITON_ATTN", - "TRITON_MLA", - "CUTLASS_MLA", - "FLASHMLA", - "FLASH_ATTN_MLA", - "FLASHINFER", - "FLASHINFER_MLA", - "ROCM_AITER_MLA", - "TORCH_SDPA", - "FLEX_ATTENTION", - "TREE_ATTN", - "XFORMERS", - "ROCM_ATTN", - "ROCM_AITER_UNIFIED_ATTN", - ] - if ( - envs.is_set("VLLM_ATTENTION_BACKEND") - and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS - ): - name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}" - _raise_or_fallback(feature_name=name, recommend_to_remove=True) - return False - ############################################################# # Experimental Features - allow users to opt in. diff --git a/vllm/envs.py b/vllm/envs.py index 21237c70a45e..809ab46b2bf9 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -620,14 +620,14 @@ def get_vllm_port() -> int | None: # - "FLASH_ATTN_MLA": use FlashAttention for MLA # - "FLASHINFER_MLA": use FlashInfer for MLA # - "CUTLASS_MLA": use CUTLASS for MLA - # All possible options loaded dynamically from _Backend enum + # All possible options loaded dynamically from AttentionBackendEnum "VLLM_ATTENTION_BACKEND": env_with_choices( "VLLM_ATTENTION_BACKEND", None, lambda: list( __import__( - "vllm.attention.backends.registry", fromlist=["_Backend"] - )._Backend.__members__.keys() + "vllm.attention.backends.registry", fromlist=["AttentionBackendEnum"] + ).AttentionBackendEnum.__members__.keys() ), ), # If set, vllm will use flashinfer sampler diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 6d462ad8ae62..1b2bb60a17c1 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -9,7 +9,7 @@ from torch.nn import LayerNorm from transformers.models.qwen2_vl import Qwen2VLProcessor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -256,7 +256,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -303,17 +303,17 @@ def __init__( ) ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Unsupported vision attention backend: {self.attn_backend}" ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def forward( @@ -361,7 +361,7 @@ def forward( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: outputs = [] for i in range(1, len(cu_seqlens)): s = int(cu_seqlens[i - 1]) @@ -373,7 +373,7 @@ def forward( out_i = out_i.permute(0, 2, 1, 3) outputs.append(out_i) context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -514,7 +514,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() @@ -567,7 +567,7 @@ def __init__( require_post_norm: bool | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.config = config @@ -582,10 +582,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN self.out_hidden_size = config.hidden_size # Keep blocks for compatibility with other vision towers num_layers = ( @@ -666,11 +667,11 @@ def compute_attn_mask_seqlen( ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 86536b21c33f..d1644827ed64 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -36,7 +36,7 @@ from einops import rearrange, repeat from transformers import BatchFeature, PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -164,7 +164,7 @@ def __init__( projection_size: int, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -211,17 +211,17 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Ernie45-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -291,7 +291,7 @@ def forward( context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] for i in range(1, len(cu_seqlens)): @@ -310,7 +310,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -370,7 +370,7 @@ def __init__( norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -463,7 +463,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() patch_size = vision_config.patch_size @@ -515,10 +515,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -565,11 +566,11 @@ def compute_attn_mask_seqlen( ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 3e243385fd04..de8000951fee 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -46,7 +46,7 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor from transformers.video_utils import VideoMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -252,7 +252,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -306,18 +306,18 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"GLM-4V does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -377,7 +377,7 @@ def forward( context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] for i in range(1, len(cu_seqlens)): @@ -396,7 +396,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -425,7 +425,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -703,7 +703,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -772,10 +772,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -824,8 +825,8 @@ def compute_attn_mask_seqlen( max_seqlen, seqlens = None, None seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 5f8659a3064e..9a1bfc2d33b0 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -16,7 +16,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.utils import torch_int -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( maybe_get_vit_flash_attn_backend, ) @@ -360,7 +360,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -414,17 +414,17 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def forward( @@ -489,7 +489,7 @@ def forward( softmax_scale=self.scale, ) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -536,7 +536,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -590,7 +590,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -685,7 +685,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -768,7 +768,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index f6461ae9a412..9a4d69dea096 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -10,7 +10,7 @@ import torch.nn as nn from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear @@ -106,7 +106,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -135,7 +135,7 @@ def _init_backbone( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): model_type = config.model_type if model_type == "siglip2_navit": diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3585783e4ccc..49bae869b834 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -42,7 +42,7 @@ Qwen2_5_VLVisionConfig, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, @@ -313,9 +313,9 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, use_upstream_fa: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -362,13 +362,16 @@ def __init__( # On ROCm with FLASH_ATTN backend, upstream flash_attn is used from vllm.platforms import current_platform - if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: + if ( + current_platform.is_rocm() + and self.attn_backend == AttentionBackendEnum.FLASH_ATTN + ): self.use_upstream_fa = True if current_platform.is_xpu(): self.use_upstream_fa = False self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -429,10 +432,10 @@ def forward( cu_seqlens, max_seqlen, batch_size, - self.attn_backend == _Backend.ROCM_AITER_FA, + self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, self.use_upstream_fa, ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. from vllm.platforms import current_platform @@ -459,7 +462,7 @@ def forward( context_layer = einops.rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) output, _ = self.proj(context_layer) @@ -488,9 +491,9 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, use_upstream_fa: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -664,7 +667,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -716,10 +719,10 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen2.5-VL does not support {self.attn_backend} backend now." @@ -858,9 +861,12 @@ def compute_attn_mask_seqlen( ) -> tuple[torch.Tensor, torch.Tensor]: max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) - if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1ec12bdb55df..c198b8a1cf68 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -329,7 +329,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -378,18 +378,18 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -460,7 +460,7 @@ def forward( context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. from vllm.platforms import current_platform @@ -485,7 +485,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -515,7 +515,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -679,7 +679,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -739,10 +739,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -789,9 +790,12 @@ def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None - if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index efcd003fbbda..5729f2485023 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -47,7 +47,7 @@ ) from transformers.models.whisper import WhisperFeatureExtractor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -302,7 +302,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -378,10 +378,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -491,9 +492,9 @@ def compute_attn_mask_seqlen( ) -> tuple[torch.Tensor, torch.Tensor]: max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) - if self.attn_backend == _Backend.FLASH_ATTN: + if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index d611580c7182..da6348256e65 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -49,7 +49,7 @@ ) from transformers.video_utils import VideoMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -198,7 +198,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, use_upstream_fa: bool = False, ) -> None: super().__init__() @@ -306,7 +306,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -372,18 +372,18 @@ def __init__( ) use_upstream_fa = False if ( - self.attn_backend != _Backend.FLASH_ATTN - and self.attn_backend != _Backend.ROCM_AITER_FA + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen3-VL does not support {self.attn_backend} backend now." @@ -516,11 +516,11 @@ def compute_attn_mask_seqlen( max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index bab5c1d82ded..c20bcd975ca3 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -12,7 +12,7 @@ from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -208,7 +208,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -264,14 +264,14 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.ROCM_AITER_FA, }: - self.attn_backend = _Backend.TORCH_SDPA + self.attn_backend = AttentionBackendEnum.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def forward( @@ -308,7 +308,7 @@ def forward( attn_output = self.flash_attn_varlen_func( queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen ).reshape(seq_length, -1) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. batch_size = cu_seqlens.shape[0] - 1 outputs = [] @@ -376,7 +376,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -440,7 +440,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -626,7 +626,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -667,7 +667,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index b5f6c60514c0..0b8843dbbaee 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -10,7 +10,7 @@ import torch from transformers import PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -82,8 +82,8 @@ def get_vit_attn_backend( head_size: int, dtype: torch.dtype, *, - attn_backend_override: _Backend | None = None, -) -> _Backend: + attn_backend_override: AttentionBackendEnum | None = None, +) -> AttentionBackendEnum: """ Get the available attention backend for Vision Transformer. """ @@ -93,7 +93,7 @@ def get_vit_attn_backend( # Lazy import to avoid circular dependency from vllm.attention.selector import get_env_variable_attn_backend - selected_backend: _Backend | None = get_env_variable_attn_backend() + selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend() if selected_backend is not None: return selected_backend diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 4b9f4aef022d..41b205c83478 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -22,10 +22,10 @@ logger = init_logger(__name__) if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig else: - _Backend = None + AttentionBackendEnum = None VllmConfig = None @@ -126,7 +126,7 @@ def get_device_name(cls, device_id: int = 0) -> str: @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -136,9 +136,9 @@ def get_attn_backend_cls( has_sink: bool, use_sparse: bool, ) -> str: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum - if selected_backend and selected_backend != _Backend.TORCH_SDPA: + if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: raise NotImplementedError("MLA is not supported on CPU.") @@ -147,7 +147,7 @@ def get_attn_backend_cls( logger.info("Using Torch SDPA backend.") if not use_v1: raise ValueError("CPU backend only supports V1.") - return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" + return AttentionBackendEnum.TORCH_SDPA.get_path() @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 32734c3aba5e..2837d5e765f0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -22,10 +22,13 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig + from vllm.config.cache import CacheDType else: - _Backend = None + AttentionBackendEnum = None + VllmConfig = None + CacheDType = None logger = init_logger(__name__) @@ -39,6 +42,49 @@ torch.backends.cuda.enable_cudnn_sdp(False) +@cache +def _get_backend_priorities( + use_mla: bool, + device_capability: DeviceCapability, +) -> dict[AttentionBackendEnum, int]: + """Get backend priorities with lazy import to avoid circular dependency.""" + from vllm.attention.backends.registry import AttentionBackendEnum + + if use_mla: + if device_capability.major == 10: + return { + AttentionBackendEnum.CUTLASS_MLA: 0, + AttentionBackendEnum.FLASHINFER_MLA: 1, + AttentionBackendEnum.FLASHMLA: 2, + AttentionBackendEnum.FLASH_ATTN_MLA: 3, + AttentionBackendEnum.TRITON_MLA: 4, + AttentionBackendEnum.FLASHMLA_SPARSE: 5, + } + else: + return { + AttentionBackendEnum.FLASHMLA: 0, + AttentionBackendEnum.FLASH_ATTN_MLA: 1, + AttentionBackendEnum.FLASHINFER_MLA: 2, + AttentionBackendEnum.TRITON_MLA: 3, + AttentionBackendEnum.FLASHMLA_SPARSE: 4, + } + else: + if device_capability.major == 10: + return { + AttentionBackendEnum.FLASHINFER: 0, + AttentionBackendEnum.FLASH_ATTN: 1, + AttentionBackendEnum.TRITON_ATTN: 2, + AttentionBackendEnum.FLEX_ATTENTION: 3, + } + else: + return { + AttentionBackendEnum.FLASH_ATTN: 0, + AttentionBackendEnum.FLASHINFER: 1, + AttentionBackendEnum.TRITON_ATTN: 2, + AttentionBackendEnum.FLEX_ATTENTION: 3, + } + + def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: @@ -216,217 +262,171 @@ def get_current_memory_usage( return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": - from vllm.attention.backends.registry import _Backend + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> "AttentionBackendEnum": + from vllm.attention.backends.registry import AttentionBackendEnum # For Blackwell GPUs, force TORCH_SDPA for now. # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 if cls.has_device_capability(100): - return _Backend.TORCH_SDPA + return AttentionBackendEnum.TORCH_SDPA if dtype not in (torch.float16, torch.bfloat16): - return _Backend.XFORMERS + return AttentionBackendEnum.XFORMERS if cls.has_device_capability(80): - FLASH_ATTN_V1 = ( - "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 - ) - from vllm.attention.selector import is_attn_backend_supported - - is_default_fa_supported = is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False - ) - if is_default_fa_supported: - return _Backend.FLASH_ATTN + backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() + if backend_class.supports_head_size( + head_size + ) and backend_class.supports_dtype(dtype): + return AttentionBackendEnum.FLASH_ATTN else: - # Fallback to XFORMERS - return _Backend.XFORMERS + return AttentionBackendEnum.XFORMERS else: # Fallback for Volta/Turing GPUs or FA not supported - return _Backend.XFORMERS + return AttentionBackendEnum.XFORMERS @classmethod - def get_attn_backend_cls( + def get_valid_backends( cls, - selected_backend, head_size, dtype, kv_cache_dtype, block_size, - use_v1, use_mla, has_sink, use_sparse, - ) -> str: - from vllm.attention.backends.registry import _Backend - - if use_mla: - # explicitly reject non-MLA backends when MLA is enabled to avoid - # silently selecting an incompatible backend (e.g., FLASHINFER). - if selected_backend in { - _Backend.FLASHINFER, - _Backend.FLASH_ATTN, - _Backend.TRITON_ATTN, - _Backend.TREE_ATTN, - _Backend.XFORMERS, - }: - raise ValueError( - f"Attention backend {selected_backend} incompatible with MLA. " - "Please use one of the MLA backends: FLASHINFER_MLA, CUTLASS_MLA, " - "FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set " - "VLLM_MLA_DISABLE=1 to disable MLA for this model." + device_capability, + ) -> tuple[ + list[tuple["AttentionBackendEnum", int]], + dict["AttentionBackendEnum", list[str]], + ]: + valid_backends_priorities = [] + invalid_reasons = {} + + backend_priorities = _get_backend_priorities(use_mla, device_capability) + for backend, priority in backend_priorities.items(): + try: + backend_class = backend.get_class() + invalid_reasons_i = backend_class.validate_configuration( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla, + has_sink, + use_sparse, + device_capability, ) + except ImportError: + invalid_reasons_i = ["ImportError"] + if invalid_reasons_i: + invalid_reasons[backend] = invalid_reasons_i + else: + valid_backends_priorities.append((backend, priority)) - from vllm.attention.ops.flashmla import is_flashmla_dense_supported - from vllm.attention.utils.fa_utils import flash_attn_supports_mla + return valid_backends_priorities, invalid_reasons - if use_sparse: - logger.info_once("Using Sparse MLA backend.") - return ( - "vllm.v1.attention.backends.mla.flashmla_sparse." - "FlashMLASparseBackend" - ) - - use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( - selected_backend is None - and cls.is_device_capability(100) - and block_size % 128 == 0 - ) - use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( - selected_backend is None - and cls.is_device_capability(100) - and (block_size == 32 or block_size % 64 == 0) - ) - use_flashmla = selected_backend == _Backend.FLASHMLA or ( - selected_backend is None and is_flashmla_dense_supported()[0] - ) - use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( - selected_backend is None and flash_attn_supports_mla() - ) - use_triton = selected_backend == _Backend.TRITON_MLA or ( - selected_backend is None + @classmethod + def get_attn_backend_cls( + cls, + selected_backend: "AttentionBackendEnum", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: "CacheDType | None", + block_size: int | None, + use_v1: bool, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + ) -> str: + if not use_v1: + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend." ) - if use_cutlassmla: - logger.info_once("Using Cutlass MLA backend.", scope="local") - return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" - if use_flashinfermla: - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - set_kv_cache_layout("HND") - logger.info_once("Using FlashInfer MLA backend.") - return ( - "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" + device_capability = cls.get_device_capability() + assert device_capability is not None + + # First try checking just the selected backend, if there is one. + if selected_backend is not None: + try: + backend_class = selected_backend.get_class() + invalid_reasons = backend_class.validate_configuration( + head_size, + dtype, + kv_cache_dtype, + None, + use_mla, + has_sink, + use_sparse, + device_capability, ) - if use_flashmla: - if block_size % 64 != 0: - logger.warning( - "FlashMLA backend is not supported for block size %d" - " (currently only supports block size 64).", - block_size, - ) - else: - logger.info_once("Using FlashMLA backend.") - return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" - if use_flashattn: - logger.info_once("Using FlashAttention MLA backend.") - return ( - "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" + except ImportError: + invalid_reasons = ["ImportError"] + if invalid_reasons: + raise ValueError( + f"Selected backend {selected_backend} is not valid for " + f"this configuration. Reason: {invalid_reasons}" ) - if use_triton: - logger.info_once("Using Triton MLA backend.") - return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" - - FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 - FLEX_ATTENTION_V1 = ( - "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 + else: + logger.info("Using %s backend.", selected_backend) + return selected_backend.get_path() + + # No selected backend or the selected backend is invalid, + # so we try finding a valid backend. + valid_backends_priorities, invalid_reasons = cls.get_valid_backends( + head_size, + dtype, + kv_cache_dtype, + None, + use_mla, + has_sink, + use_sparse, + device_capability, ) - TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 - FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 - TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 - XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 - - use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( - "fp8" + reasons_str = ( + "{" + + ", ".join( + f"{backend.name}: [{', '.join(reasons)}]" + for backend, reasons in invalid_reasons.items() + ) + + "}" ) + config_str = ( + f"head_size: {head_size}, dtype: {dtype}, " + f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, " + f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}" + ) + logger.debug_once( + f"Some attention backends are not valid for {cls.device_name} with " + f"{config_str}. Reasons: {reasons_str}." + ) + if len(valid_backends_priorities) == 0: + raise ValueError( + f"No valid attention backend found for {cls.device_name} " + f"with {config_str}. Reasons: {reasons_str}." + ) - if selected_backend == _Backend.FLASHINFER: - logger.info_once("Using FlashInfer backend.") - if cls.has_device_capability(100): - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - set_kv_cache_layout("HND") - return FLASHINFER_V1 - elif selected_backend == _Backend.FLEX_ATTENTION: - logger.info_once("Using FlexAttention backend.") - return FLEX_ATTENTION_V1 - elif selected_backend == _Backend.TRITON_ATTN: - logger.info_once("Using Triton backend.") - return TRITON_ATTN - elif selected_backend == _Backend.FLASH_ATTN: - logger.info_once("Using Flash Attention backend.") - return FLASH_ATTN_V1 - elif selected_backend == _Backend.TREE_ATTN: - logger.info_once("Using Tree Attention backend.") - return TREE_ATTN_V1 - elif selected_backend == _Backend.XFORMERS: - logger.info_once("Using XFormers backend.") - return XFORMERS_V1 - - from vllm.attention.selector import is_attn_backend_supported - - # Default backends for V1 engine - # Prefer FlashInfer for Blackwell GPUs if installed - if cls.is_device_capability(100): - if is_default_backend_supported := is_attn_backend_supported( - FLASHINFER_V1, head_size, dtype - ): - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - logger.info_once( - "Using FlashInfer backend with HND KV cache layout on " - "V1 engine by default for Blackwell (SM 10.0) GPUs." - ) - set_kv_cache_layout("HND") - - return FLASHINFER_V1 - - if not is_default_backend_supported.can_import: - logger.warning_once( - "FlashInfer failed to import on Blackwell (SM 10.0) GPUs; " - "it is recommended to install FlashInfer for better " - "performance." - ) - - # FlashAttention is the default for SM 8.0+ GPUs - if cls.has_device_capability(80): - if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90): - logger.info_once("Using Triton backend.") - return TRITON_ATTN - elif is_default_backend_supported := is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False - ): - logger.info_once("Using Flash Attention backend.") - return FLASH_ATTN_V1 - - # FlexAttention is the default for older GPUs - else: - logger.info_once("Using FlexAttention backend.") - return FLEX_ATTENTION_V1 - - assert not is_default_backend_supported - - use_flex_attention_reason = {} - if not is_default_backend_supported.head_size: - use_flex_attention_reason["head_size"] = head_size - if not is_default_backend_supported.dtype: - use_flex_attention_reason["dtype"] = dtype - - logger.info_once( - "Using FlexAttention backend for %s.", - ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()), + # We have found some valid backends. Select the one with the + # highest priority. + logger.info( + "Valid backends: %s", [b[0].name for b in valid_backends_priorities] ) - return FLEX_ATTENTION_V1 + sorted_indices = sorted( + range(len(valid_backends_priorities)), + key=lambda i: valid_backends_priorities[i][1], + ) + selected_index = sorted_indices[0] + selected_backend = valid_backends_priorities[selected_index][0] + logger.info( + "Using %s backend.", + selected_backend.name, + ) + + return selected_backend.get_path() @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 15e3b3a22bde..4969bcf116a4 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -17,8 +17,9 @@ if TYPE_CHECKING: from torch.distributed import PrefixStore, ProcessGroup - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig + from vllm.config.cache import CacheDType from vllm.inputs import ProcessorInputs, PromptType from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -58,6 +59,31 @@ class DeviceCapability(NamedTuple): major: int minor: int + def __lt__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) < (other.major, other.minor) + + def __le__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) <= (other.major, other.minor) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) == (other.major, other.minor) + + def __ge__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) >= (other.major, other.minor) + + def __gt__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) > (other.major, other.minor) + def as_version_str(self) -> str: return f"{self.major}.{self.minor}" @@ -173,19 +199,21 @@ def import_kernels(cls) -> None: import vllm._moe_C # noqa: F401 @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": - # Import _Backend here to avoid circular import. - from vllm.attention.backends.registry import _Backend + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> "AttentionBackendEnum": + # Import AttentionBackendEnum here to avoid circular import. + from vllm.attention.backends.registry import AttentionBackendEnum - return _Backend.TORCH_SDPA + return AttentionBackendEnum.TORCH_SDPA @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, + kv_cache_dtype: "CacheDType | None", block_size: int, use_v1: bool, use_mla: bool, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 0c03a5564db8..626e4bd912ba 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -14,10 +14,10 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig else: - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -201,18 +201,20 @@ class RocmPlatform(Platform): ] @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> "AttentionBackendEnum": from importlib.util import find_spec - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - return _Backend.ROCM_AITER_FA + return AttentionBackendEnum.ROCM_AITER_FA if on_gfx9() and find_spec("flash_attn") is not None: - return _Backend.FLASH_ATTN + return AttentionBackendEnum.FLASH_ATTN - return _Backend.TORCH_SDPA + return AttentionBackendEnum.TORCH_SDPA @classmethod def get_attn_backend_cls( @@ -227,7 +229,7 @@ def get_attn_backend_cls( has_sink, use_sparse, ) -> str: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: raise NotImplementedError("Sparse Attention is not supported on ROCm.") @@ -238,25 +240,23 @@ def get_attn_backend_cls( if selected_backend is None: selected_backend = ( - _Backend.ROCM_AITER_MLA + AttentionBackendEnum.ROCM_AITER_MLA if is_aiter_mla_enabled() or block_size == 1 - else _Backend.TRITON_MLA + else AttentionBackendEnum.TRITON_MLA ) - if selected_backend == _Backend.TRITON_MLA: + if selected_backend == AttentionBackendEnum.TRITON_MLA: if block_size != 1: logger.info_once("Using Triton MLA backend.") - return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" + return AttentionBackendEnum.TRITON_MLA.get_path() raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}." ) - if selected_backend == _Backend.ROCM_AITER_MLA: + if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA: if block_size == 1: logger.info("Using AITER MLA backend.") - return ( - "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 - ) + return AttentionBackendEnum.ROCM_AITER_MLA.get_path() raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}." @@ -267,33 +267,35 @@ def get_attn_backend_cls( f"is not MLA type while requested for MLA backend." ) - if selected_backend == _Backend.FLEX_ATTENTION: - logger.info("Using FlexAttention backend.") - return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" - if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() - ) or selected_backend == _Backend.ROCM_AITER_FA: - logger.info("Using Aiter Flash Attention backend.") - return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" - if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION - ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: - logger.info("Using Aiter Unified Attention backend.") - return ( - "vllm.v1.attention.backends." - "rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend" - ) - if ( - envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - or selected_backend == _Backend.ROCM_ATTN - ): - # rocm specific backend, with aiter and/or - # triton prefix-prefill - logger.info("Using Rocm Attention backend.") - return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" - # default case, using triton unified attention - logger.info("Using Triton Attention backend.") - return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" + if envs.VLLM_USE_V1: + if selected_backend == AttentionBackendEnum.FLEX_ATTENTION: + logger.info("Using FlexAttention backend.") + return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" + if ( + envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() + ) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA: + logger.info("Using Aiter Flash Attention backend.") + return AttentionBackendEnum.ROCM_AITER_FA.get_path() + if ( + envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + ) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: + logger.info("Using Aiter Unified Attention backend.") + return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() + if ( + envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + or selected_backend == AttentionBackendEnum.ROCM_ATTN + ): + # rocm specific backend, with aiter and/or + # triton prefix-prefill + logger.info("Using Rocm Attention backend.") + return AttentionBackendEnum.ROCM_ATTN.get_path() + # default case, using triton unified attention + logger.info("Using Triton Attention backend.") + return AttentionBackendEnum.TRITON_ATTN.get_path() + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend." + ) @classmethod def set_device(cls, device: torch.device) -> None: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 1a4b67a1762f..575a9892c211 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -15,16 +15,15 @@ from .interface import Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend - from vllm.config import ModelConfig, VllmConfig + from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.config import VllmConfig from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams else: BlockSize = None - ModelConfig = None VllmConfig = None PoolingParams = None - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -54,7 +53,7 @@ def import_kernels(cls) -> None: @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -64,17 +63,17 @@ def get_attn_backend_cls( has_sink, use_sparse, ) -> str: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: raise NotImplementedError("Sparse Attention is not supported on TPU.") - if selected_backend != _Backend.PALLAS: + if selected_backend != AttentionBackendEnum.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) if not use_v1: raise ValueError("TPU backend only supports V1.") logger.info("Using Pallas V1 backend.") - return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + return AttentionBackendEnum.PALLAS.get_path() @classmethod def set_device(cls, device: torch.device) -> None: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index e4ecd0c807da..c0ed7bc254b1 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -14,12 +14,11 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend - from vllm.config import ModelConfig, VllmConfig + from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.config import VllmConfig else: - ModelConfig = None VllmConfig = None - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -44,7 +43,7 @@ def import_kernels(cls) -> None: @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -62,18 +61,19 @@ def get_attn_backend_cls( "only NHD layout is supported by XPU attention kernels." ) - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: raise NotImplementedError("Sparse Attention is not supported on XPU.") - TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 - FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 - if selected_backend == _Backend.TRITON_ATTN: + use_v1 = envs.VLLM_USE_V1 + if not use_v1: + raise ValueError("XPU backend only supports V1.") + if selected_backend == AttentionBackendEnum.TRITON_ATTN: logger.info_once("Using Triton backend.") - return TRITON_ATTN - elif selected_backend == _Backend.FLASH_ATTN: + return AttentionBackendEnum.TRITON_ATTN.get_path() + elif selected_backend == AttentionBackendEnum.FLASH_ATTN: logger.info_once("Using Flash Attention backend.") - return FLASH_ATTN + return AttentionBackendEnum.FLASH_ATTN.get_path() elif selected_backend: raise ValueError( f"Invalid attention backend for {cls.device_name}, " @@ -81,7 +81,7 @@ def get_attn_backend_cls( ) logger.info("Using Flash Attention backend.") - return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + return AttentionBackendEnum.FLASH_ATTN.get_path() @classmethod def set_device(cls, device: torch.device) -> None: diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 0d3e1729ff20..a3960e686a44 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import numpy as np import torch @@ -40,23 +40,16 @@ class TorchSDPABackend(AttentionBackend): accept_output_buffer: bool = False + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: + def get_supported_head_sizes(cls) -> list[int]: attn_impl = _get_paged_attn_impl() - is_valid, supported_head_sizes = attn_impl.validate_head_size(head_size) - if not is_valid: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + return attn_impl.get_supported_head_sizes() @staticmethod def get_name() -> str: @@ -763,9 +756,8 @@ def _make_sliding_window_bias( class _PagedAttention: @staticmethod - def validate_head_size(head_size: int) -> tuple[bool, list[int]]: - SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256] - return head_size in SUPPORT_HS, SUPPORT_HS + def get_supported_head_sizes() -> list[int]: + return [32, 64, 80, 96, 112, 128, 192, 256] @staticmethod def get_kv_cache_shape( @@ -865,8 +857,8 @@ def forward_decode( class _IPEXPagedAttention(_PagedAttention): @staticmethod - def validate_head_size(head_size: int) -> tuple[bool, list[int]]: - return True, [] + def get_supported_head_sizes() -> list[int]: + return [] @staticmethod def split_kv_cache( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1eac94940e78..657a538ed587 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -3,6 +3,7 @@ """Attention layer with FlashAttention.""" from dataclasses import dataclass +from typing import ClassVar import numpy as np import torch @@ -32,11 +33,13 @@ reshape_and_cache_flash, ) from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionCGSupport, @@ -51,30 +54,8 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @staticmethod def get_name() -> str: @@ -124,6 +105,38 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: else: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: + if kv_cache_dtype is None: + return True + if kv_cache_dtype.startswith("fp8"): + return flash_attn_supports_fp8() + return kv_cache_dtype in ["auto"] + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability >= DeviceCapability(8, 0) + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if has_sink and device_capability < DeviceCapability(9, 0): + return "sink not supported on compute capability < 9.0" + return None + @dataclass class FlashAttentionMetadata: @@ -479,8 +492,6 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads - FlashAttentionBackend.validate_head_size(head_size) - self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() # Cache the batch invariant result for use in forward passes diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index e71d4ca4629d..28cefcad7d87 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -23,6 +23,7 @@ MultipleOf, ) from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -33,6 +34,7 @@ kNvfp4Quant, ) from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils.flashinfer import ( can_use_trtllm_attention, @@ -45,6 +47,7 @@ AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + KVCacheLayoutType, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, @@ -158,34 +161,17 @@ def trtllm_prefill_attn_kvfp8_dequant( class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 - return [64, 128, 256] - - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - # Note: Not sure for all platforms, - # but on Blackwell, only support a page size of - # 16, 32, 64 - return [16, 32, 64] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + # Note: Not sure for all platforms, + # but on Blackwell, only support a page size of + # 16, 32, 64 + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + ] @staticmethod def get_name() -> str: @@ -235,6 +221,26 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: else: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + return [64, 128, 256] + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability( + 12, 1 + ) + + @classmethod + def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: + from vllm.platforms import current_platform + + capability = current_platform.get_device_capability() + if capability is not None and capability.major == 10: + return "HND" + return None + @dataclass class FlashInferMetadata: @@ -332,7 +338,6 @@ def __init__( ) self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size - FlashInferBackend.validate_head_size(self.head_dim) self.page_size = self.kv_cache_spec.block_size self.cache_dtype = self.cache_config.cache_dtype diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 928252636d58..78b3050679aa 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -4,6 +4,7 @@ import math from dataclasses import dataclass +from typing import ClassVar import torch import torch._dynamo.decorators @@ -25,6 +26,7 @@ is_quantized_kv_cache, ) from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -72,14 +74,12 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - return # FlexAttention supports any head size + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] @staticmethod def get_name() -> str: @@ -111,6 +111,10 @@ def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]: def use_cascade_attention(*args, **kwargs) -> bool: return False + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + # @torch.compile(fullgraph=True, mode="reduce-overhead") def physical_to_logical_mapping( @@ -725,7 +729,6 @@ def __init__( if kv_sharing_target_layer_name is not None: raise NotImplementedError("FlexAttention does not support kv sharing yet.") - FlexAttentionBackend.validate_head_size(head_size) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "FlexAttention does not support quantized kv-cache. Yet" diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0ec157300419..4b589ca9cf7b 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -324,25 +324,13 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + def is_mla(cls) -> bool: + return True @dataclass @@ -442,8 +430,10 @@ class MLACommonMetadata(Generic[D]): ) = None def __post_init__(self): - if self.head_dim is not None: - MLACommonBackend.validate_head_size(self.head_dim) + if self.head_dim is not None and not MLACommonBackend.supports_head_size( + self.head_dim + ): + raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.") M = TypeVar("M", bound=MLACommonMetadata) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index c35e238eac4c..0a10ce74cd1d 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -13,7 +13,9 @@ MultipleOf, is_quantized_kv_cache, ) +from vllm.config.cache import CacheDType from vllm.logger import init_logger +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonImpl, @@ -33,6 +35,14 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class CutlassMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + ] + @staticmethod def get_name() -> str: return "CUTLASS_MLA" @@ -45,9 +55,9 @@ def get_impl_cls() -> type["CutlassMLAImpl"]: def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: return CutlassMLAMetadataBuilder - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [128] + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major == 10 class SM100Workspace: diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index a6aac701b784..f85569d5cc2b 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import ( AttentionLayer, AttentionType, + MultipleOf, is_quantized_kv_cache, ) from vllm.attention.utils.fa_utils import ( @@ -17,10 +18,12 @@ get_flash_attn_version, ) from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -37,6 +40,10 @@ class FlashAttnMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + @staticmethod def get_name() -> str: return "FLASH_ATTN_MLA" @@ -53,6 +60,26 @@ def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashAttnMLAImpl"]: return FlashAttnMLAImpl + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major == 9 + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if not flash_attn_supports_mla(): + return "FlashAttention MLA not supported on this device" + return None + @dataclass class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index ebbcfd0eaa2f..9dcf39672acf 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -6,8 +6,14 @@ import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla -from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + MultipleOf, +) +from vllm.config.cache import BlockSize, CacheDType from vllm.logger import init_logger +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonImpl, @@ -15,7 +21,7 @@ MLACommonMetadataBuilder, QueryLenSupport, ) -from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType logger = init_logger(__name__) @@ -28,6 +34,14 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class FlashInferMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + ] + @staticmethod def get_name() -> str: return "FLASHINFER_MLA" @@ -41,8 +55,16 @@ def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: return FlashInferMLAMetadataBuilder @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: - return [32, 64] + def get_default_block_size(cls) -> BlockSize: + return 64 + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major == 10 + + @classmethod + def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": + return "HND" g_fi_workspace = torch.zeros( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 1f98204031ed..dfa726a4e517 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -13,10 +13,12 @@ is_flashmla_dense_supported, ) from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -36,6 +38,14 @@ class FlashMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + ] + @staticmethod def get_name() -> str: return "FLASHMLA" @@ -52,9 +62,30 @@ def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [64] + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major in [9, 10] + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if use_sparse: + from vllm.attention.ops.flashmla import is_flashmla_sparse_supported + + return is_flashmla_sparse_supported()[1] + else: + from vllm.attention.ops.flashmla import is_flashmla_dense_supported + + return is_flashmla_dense_supported()[1] @dataclass diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index bf8e4d5a6289..c9fce9c252c7 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -11,6 +11,7 @@ AttentionBackend, AttentionLayer, AttentionMetadata, + MultipleOf, ) from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.flashmla import ( @@ -19,8 +20,10 @@ get_mla_metadata, ) from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl @@ -52,6 +55,9 @@ class FlashMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"] @staticmethod def get_name() -> str: @@ -69,6 +75,22 @@ def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: def get_impl_cls() -> type["FlashMLASparseImpl"]: return FlashMLASparseImpl + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + @classmethod + def is_mla(cls) -> bool: + return True + + @classmethod + def is_sparse(cls) -> bool: + return True + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major in [9, 10] + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -84,14 +106,6 @@ def get_kv_cache_shape( else: return (num_blocks, block_size, head_size) - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [576] - @dataclass class FlashMLASparseMetadata: diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 49009a939d0b..c431edae4176 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -24,6 +24,8 @@ class DeepseekV32IndexerBackend(AttentionBackend): + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: return DeepseekV32IndexerMetadata @@ -51,10 +53,6 @@ def get_kv_cache_shape( def get_kv_cache_stride_order() -> tuple[int, ...]: return (0, 1, 2) - @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: - return [64] - @dataclass class DeepseekV32IndexerPrefillChunkMetadata: diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 781f77e96319..0149639e8c0b 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import ClassVar import torch @@ -12,11 +13,13 @@ ) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -28,6 +31,9 @@ class TritonMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + @staticmethod def get_name() -> str: return "TRITON_MLA" @@ -36,6 +42,10 @@ def get_name() -> str: def get_impl_cls() -> type["TritonMLAImpl"]: return TritonMLAImpl + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return True + class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index f7a4114a0a70..90328461c748 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -3,6 +3,7 @@ """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass +from typing import ClassVar import torch @@ -350,31 +351,13 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class AiterFlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [64, 128, 256] - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "FLASH_ATTN" @@ -439,8 +422,6 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - AiterFlashAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 8b7ce90a3cca..be925a232372 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -153,10 +153,7 @@ def build( class RocmAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -164,12 +161,11 @@ def get_supported_head_sizes(cls) -> list[int]: @classmethod def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: + if not cls.supports_head_size(head_size): attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " + f"Supported head sizes are: {cls.get_supported_head_sizes()}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes." ) diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index ee6ead9ad9b3..73872dd6cb07 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -4,7 +4,7 @@ import ast from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import torch @@ -31,31 +31,13 @@ class TreeAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "TREE_ATTN" @@ -336,8 +318,6 @@ def __init__( else: self.sliding_window = (sliding_window - 1, 0) - TreeAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index b1d34dbfd172..42721bcec094 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -19,12 +19,14 @@ ) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, ) from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -148,25 +150,18 @@ def build( class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - # Triton Attention supports any head size above 32 - if head_size < 32: - raise ValueError( - f"Head size {head_size} is not supported by TritonAttention." - f"Head sizes need to be larger or equal 32 for this backend. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + ] @staticmethod def get_name() -> str: @@ -200,6 +195,18 @@ def use_cascade_attention(*args, **kwargs) -> bool: def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: return TritonAttentionMetadataBuilder + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + return head_size >= 32 + + @classmethod + def supports_sink(cls) -> bool: + return True + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return True + class TritonAttentionImpl(AttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): @@ -242,8 +249,6 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads - TritonAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 457b15ebdd82..112e1b30de70 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import torch @@ -42,10 +42,8 @@ class XFormersAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -81,22 +79,6 @@ def get_supported_head_sizes(cls) -> list[int]: 256, ] - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "XFORMERS" @@ -310,8 +292,6 @@ def __init__( logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap - XFormersAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1e18eea2330a..c93deeaa2915 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -150,11 +150,13 @@ def __init__( ) # Determine allowed attention backends once during initialization. + from vllm.attention.backends.registry import AttentionBackendEnum + self.allowed_attn_types: tuple | None = None if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] - # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend - if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): + # ROCM_AITER_FA is an optional backend + if find_spec(AttentionBackendEnum.ROCM_AITER_FA.get_path()): from vllm.v1.attention.backends.rocm_aiter_fa import ( AiterFlashAttentionMetadata, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 66a9d7291261..39cdad0fc9b7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4245,7 +4245,7 @@ def block_size_is_supported( """ for backend in backends: is_supported = False - for supported_size in backend.get_supported_kernel_block_size(): + for supported_size in backend.supported_kernel_block_sizes: if isinstance(supported_size, int): if block_size == supported_size: is_supported = True @@ -4276,7 +4276,7 @@ def block_size_is_supported( all_int_supported_sizes = set( supported_size for backend in backends - for supported_size in backend.get_supported_kernel_block_size() + for supported_size in backend.supported_kernel_block_sizes if isinstance(supported_size, int) )