From eb34c8dfc0f566df2c3e980043e6c6c2691d734a Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 8 Oct 2025 15:24:12 -0400 Subject: [PATCH 01/84] add support methods to abstract Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 129 ++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index bb2f36271103..dce0360f2b76 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -84,6 +84,135 @@ def get_kv_cache_stride_order() -> tuple[int, ...]: def full_cls_name(cls) -> tuple[str, str]: return (cls.__module__, cls.__qualname__) + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + supported_head_sizes = cls.get_supported_head_sizes() + return (not supported_head_sizes) or head_size in supported_head_sizes + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def supports_dtype(cls, dtype: torch.dtype) -> bool: + supported_dtypes = cls.get_supported_dtypes() + return (not supported_dtypes) or dtype in supported_dtypes + + @classmethod + def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: + return ["auto"] + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: Optional[str]) -> bool: + supported_kv_cache_dtypes = cls.get_supported_kv_cache_dtypes() + return (not supported_kv_cache_dtypes) or ( + kv_cache_dtype is not None and kv_cache_dtype in supported_kv_cache_dtypes + ) + + @classmethod + def get_supported_block_sizes(cls) -> list[int]: + return [] + + @classmethod + def supports_block_size(cls, block_size: int) -> bool: + supported_block_sizes = cls.get_supported_block_sizes() + return (not supported_block_sizes) or block_size in supported_block_sizes + + @classmethod + def is_mla(cls) -> bool: + raise NotImplementedError + + @classmethod + def supports_sink(cls) -> bool: + return True + + @classmethod + def is_sparse(cls) -> bool: + return False + + @classmethod + def get_min_compute_capability(cls) -> Optional[int]: + return None + + @classmethod + def get_max_compute_capability(cls) -> Optional[int]: + return None + + @classmethod + def supports_compute_capability(cls, capability: int) -> bool: + min_capability = cls.get_min_compute_capability() + max_capability = cls.get_max_compute_capability() + return ((min_capability is None) or (capability >= min_capability)) and ( + (max_capability is None) or (capability <= max_capability) + ) + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool, + use_mla: bool, + has_sink: bool, + device_capability: int, + ) -> Optional[str]: + return None + + @classmethod + def validate_configuration( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: int, + ) -> list[str]: + invalid_reasons = [] + if not cls.supports_head_size(head_size): + invalid_reasons.append("head_size not supported") + if not cls.supports_dtype(dtype): + invalid_reasons.append("dtype not supported") + if not cls.supports_kv_cache_dtype(kv_cache_dtype): + invalid_reasons.append("kv_cache_dtype not supported") + if not cls.supports_block_size(block_size): + invalid_reasons.append("block_size not supported") + if use_mla != cls.is_mla(): + if use_mla: + invalid_reasons.append("MLA not supported") + else: + invalid_reasons.append("non-MLA not supported") + if has_sink and not cls.supports_sink(): + invalid_reasons.append("sink setting not supported") + if use_sparse != cls.is_sparse(): + if use_sparse: + invalid_reasons.append("sparse not supported") + else: + invalid_reasons.append("non-sparse not supported") + if not cls.supports_compute_capability(device_capability): + invalid_reasons.append("compute capability not supported") + combination_reason = cls.supports_combination( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla, + has_sink, + use_sparse, + device_capability, + ) + if combination_reason is not None: + invalid_reasons.append(combination_reason) + return invalid_reasons + class AttentionMetadata: pass From 87edf38538948ff4c89d40ff852f42f3b3e4e408 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 8 Oct 2025 15:26:31 -0400 Subject: [PATCH 02/84] remove is_attn_backend_supported Signed-off-by: Matthew Bonanni --- vllm/attention/selector.py | 61 +------------------------------------- 1 file changed, 1 insertion(+), 60 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 53677372e055..ebe75fb976d7 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -4,9 +4,8 @@ import os from collections.abc import Generator from contextlib import contextmanager -from dataclasses import dataclass from functools import cache -from typing import Optional, Union +from typing import Optional import torch @@ -67,64 +66,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: return forced_attn_backend -@dataclass(frozen=True) -class _IsSupported: - can_import: bool - head_size: bool - dtype: bool - - def __bool__(self) -> bool: - return self.can_import and self.head_size and self.dtype - - -def is_attn_backend_supported( - attn_backend: Union[str, type[AttentionBackend]], - head_size: int, - dtype: torch.dtype, - *, - allow_import_error: bool = True, -) -> _IsSupported: - if isinstance(attn_backend, str): - try: - attn_backend = resolve_obj_by_qualname(attn_backend) - except ImportError: - if not allow_import_error: - raise - - return _IsSupported(can_import=False, head_size=False, dtype=False) - - assert isinstance(attn_backend, type) - - # TODO: Update the interface once V0 is removed - if get_supported_head_sizes := getattr( - attn_backend, "get_supported_head_sizes", None - ): - is_head_size_supported = head_size in get_supported_head_sizes() - elif validate_head_size := getattr(attn_backend, "validate_head_size", None): - try: - validate_head_size(head_size) - is_head_size_supported = True - except Exception: - is_head_size_supported = False - else: - raise NotImplementedError( - f"{attn_backend.__name__} does not support head size validation" - ) - - if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None): - is_dtype_supported = dtype in get_supported_dtypes() - else: - raise NotImplementedError( - f"{attn_backend.__name__} does not support dtype validation" - ) - - return _IsSupported( - can_import=True, - head_size=is_head_size_supported, - dtype=is_dtype_supported, - ) - - def get_attn_backend( head_size: int, dtype: torch.dtype, From fc493ae45a6272cf747fa1a84e0235812f97e9c2 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 8 Oct 2025 15:27:06 -0400 Subject: [PATCH 03/84] all backends are V1 now Signed-off-by: Matthew Bonanni --- vllm/engine/arg_utils.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index df1abe3c8459..abd4592b0740 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1621,32 +1621,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "such as ngram, medusa, eagle, or mtp." ) - V1_BACKENDS = [ - "FLASH_ATTN", - "PALLAS", - "TRITON_ATTN", - "TRITON_MLA", - "CUTLASS_MLA", - "FLASHMLA", - "FLASH_ATTN_MLA", - "FLASHINFER", - "FLASHINFER_MLA", - "ROCM_AITER_MLA", - "TORCH_SDPA", - "FLEX_ATTENTION", - "TREE_ATTN", - "XFORMERS", - "ROCM_ATTN", - "ROCM_AITER_UNIFIED_ATTN", - ] - if ( - envs.is_set("VLLM_ATTENTION_BACKEND") - and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS - ): - name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}" - _raise_or_fallback(feature_name=name, recommend_to_remove=True) - return False - ############################################################# # Experimental Features - allow users to opt in. From 961897909493331961ad0d6b7ef0fb235a696d9a Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 8 Oct 2025 15:31:43 -0400 Subject: [PATCH 04/84] use backend_to_class_str Signed-off-by: Matthew Bonanni --- vllm/platforms/cpu.py | 4 ++-- vllm/platforms/rocm.py | 22 +++++++--------------- vllm/platforms/tpu.py | 4 ++-- vllm/platforms/xpu.py | 10 ++++------ 4 files changed, 15 insertions(+), 25 deletions(-) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 24e08a8ecbd7..d5ecce2642cb 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -134,7 +134,7 @@ def get_attn_backend_cls( has_sink: bool, use_sparse: bool, ) -> str: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import _Backend, backend_to_class_str if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) @@ -145,7 +145,7 @@ def get_attn_backend_cls( logger.info("Using Torch SDPA backend.") if not use_v1: raise ValueError("CPU backend only supports V1.") - return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" + return backend_to_class_str(_Backend.TORCH_SDPA) @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 25601011491f..4b3ed6801c84 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -229,7 +229,7 @@ def get_attn_backend_cls( has_sink, use_sparse, ) -> str: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import _Backend, backend_to_class_str if use_sparse: raise NotImplementedError("Sparse Attention is not supported on ROCm.") @@ -254,7 +254,7 @@ def get_attn_backend_cls( if selected_backend == _Backend.TRITON_MLA: if block_size != 1: logger.info_once("Using Triton MLA backend on V1 engine.") - return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" + return backend_to_class_str(_Backend.TRITON_MLA) raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}." @@ -262,9 +262,7 @@ def get_attn_backend_cls( if selected_backend == _Backend.ROCM_AITER_MLA: if block_size == 1: logger.info("Using AITER MLA backend on V1 engine.") - return ( - "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 - ) + return backend_to_class_str(_Backend.ROCM_AITER_MLA) raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}." @@ -280,18 +278,12 @@ def get_attn_backend_cls( envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() ) or selected_backend == _Backend.ROCM_AITER_FA: logger.info("Using Aiter Flash Attention backend on V1 engine.") - return ( - "vllm.v1.attention.backends." - "rocm_aiter_fa.AiterFlashAttentionBackend" - ) + return backend_to_class_str(_Backend.ROCM_AITER_FA) if ( envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: logger.info("Using Aiter Unified Attention backend on V1 engine.") - return ( - "vllm.v1.attention.backends." - "rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend" - ) + return backend_to_class_str(_Backend.ROCM_AITER_UNIFIED_ATTN) if ( envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or selected_backend == _Backend.ROCM_ATTN @@ -299,10 +291,10 @@ def get_attn_backend_cls( # rocm specific backend, with aiter and/or # triton prefix-prefill logger.info("Using Rocm Attention backend on V1 engine.") - return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" + return backend_to_class_str(_Backend.ROCM_ATTN) # default case, using triton unified attention logger.info("Using Triton Attention backend on V1 engine.") - return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" + return backend_to_class_str(_Backend.TRITON_ATTN) raise RuntimeError( "V0 attention backends have been removed. Set VLLM_USE_V1=1 " "to select a supported backend." diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 8c23b1de44e4..773e7d826826 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -61,7 +61,7 @@ def get_attn_backend_cls( has_sink, use_sparse, ) -> str: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import _Backend, backend_to_class_str if use_sparse: raise NotImplementedError("Sparse Attention is not supported on TPU.") @@ -71,7 +71,7 @@ def get_attn_backend_cls( if not use_v1: raise ValueError("TPU backend only supports V1.") logger.info("Using Pallas V1 backend.") - return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + return backend_to_class_str(_Backend.PALLAS) @classmethod def set_device(cls, device: torch.device) -> None: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 2f2f3ab8b9d9..ddf02158d960 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -51,21 +51,19 @@ def get_attn_backend_cls( has_sink: bool, use_sparse, ) -> str: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import _Backend, backend_to_class_str if use_sparse: raise NotImplementedError("Sparse Attention is not supported on XPU.") use_v1 = envs.VLLM_USE_V1 if not use_v1: raise ValueError("XPU backend only supports V1.") - TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 - FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 if selected_backend == _Backend.TRITON_ATTN: logger.info_once("Using Triton backend on V1 engine.") - return TRITON_ATTN + return backend_to_class_str(_Backend.TRITON_ATTN) elif selected_backend == _Backend.FLASH_ATTN: logger.info_once("Using Flash Attention backend on V1 engine.") - return FLASH_ATTN + return backend_to_class_str(_Backend.FLASH_ATTN) elif selected_backend: raise ValueError( f"Invalid attention backend for {cls.device_name}, " @@ -73,7 +71,7 @@ def get_attn_backend_cls( ) logger.info("Using Flash Attention backend on V1 engine.") - return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + return backend_to_class_str(_Backend.FLASH_ATTN) @classmethod def is_kv_cache_dtype_supported( From 8aeb461c94558556d973cdaef323ac05cc9ec0ff Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 8 Oct 2025 15:36:54 -0400 Subject: [PATCH 05/84] add MLA backend support details Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 16 ++------- vllm/v1/attention/backends/mla/cutlass_mla.py | 20 +++++++++++ .../attention/backends/mla/flashattn_mla.py | 36 +++++++++++++++++++ .../attention/backends/mla/flashinfer_mla.py | 20 +++++++++++ vllm/v1/attention/backends/mla/flashmla.py | 20 +++++++++++ .../attention/backends/mla/flashmla_sparse.py | 32 ++++++++++++----- vllm/v1/attention/backends/mla/triton_mla.py | 20 +++++++++++ 7 files changed, 142 insertions(+), 22 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 3fb00f5917ea..ed3c579569c2 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -300,25 +300,13 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + def is_mla(cls) -> bool: + return True @dataclass diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index a3c677ca2108..e587020438f0 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -44,6 +44,26 @@ def get_impl_cls() -> type["CutlassMLAImpl"]: def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: return CutlassMLAMetadataBuilder + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: + return ["auto", "fp16", "bf16", "e4m3fn"] + + @classmethod + def get_supported_block_sizes(cls) -> list[int]: + return [128] + + @classmethod + def get_min_compute_capability(cls) -> Optional[int]: + return 100 + + @classmethod + def get_max_compute_capability(cls) -> Optional[int]: + return 109 + class SM100Workspace: def __init__(self, initial_workspace_size): diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index c0c2dbe1f961..9086f0243df2 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -50,6 +50,42 @@ def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashAttnMLAImpl"]: return FlashAttnMLAImpl + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: + return ["auto", "fp16", "bf16"] + + @classmethod + def get_supported_block_sizes(cls) -> list[int]: + return [] + + @classmethod + def get_min_compute_capability(cls) -> Optional[int]: + return 90 + + @classmethod + def get_max_compute_capability(cls) -> Optional[int]: + return 90 + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool, + use_mla: bool, + has_sink: bool, + device_capability: int, + ) -> Optional[str]: + if not flash_attn_supports_mla(): + return "FlashAttention MLA not supported on this device" + return None + @dataclass class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 206f96ea366a..e848b430797a 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -42,6 +42,26 @@ def get_impl_cls() -> type["FlashInferMLAImpl"]: def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: return FlashInferMLAMetadataBuilder + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: + return ["auto", "fp16", "bf16", "fp8", "fp8_e4m3"] + + @classmethod + def get_supported_block_sizes(cls) -> list[int]: + return [32, 64] + + @classmethod + def get_min_compute_capability(cls) -> Optional[int]: + return 100 + + @classmethod + def get_max_compute_capability(cls) -> Optional[int]: + return 109 + g_fi_workspace = torch.zeros( FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 6ba2c682760c..15514106110f 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -44,6 +44,26 @@ def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: + return ["auto", "fp16", "bf16", "fp8", "fp8_e4m3"] + + @classmethod + def get_supported_block_sizes(cls) -> list[int]: + return [64] + + @classmethod + def get_min_compute_capability(cls) -> Optional[int]: + return 90 + + @classmethod + def get_max_compute_capability(cls) -> Optional[int]: + return 109 + @dataclass class FlashMLADecodeMetadata(MLACommonDecodeMetadata): diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 49c29de35da1..103054efaf05 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -69,6 +69,30 @@ def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: def get_impl_cls() -> type["FlashMLASparseImpl"]: return FlashMLASparseImpl + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: + return ["auto", "bf16", "fp8_ds_mla"] + + @classmethod + def get_supported_block_sizes(cls) -> list[int]: + return [32, 64] + + @classmethod + def is_sparse(cls) -> bool: + return True + + @classmethod + def get_min_compute_capability(cls) -> Optional[int]: + return 100 + + @classmethod + def get_max_compute_capability(cls) -> Optional[int]: + return 109 + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -84,14 +108,6 @@ def get_kv_cache_shape( else: return (num_blocks, block_size, head_size) - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [576] - @dataclass class FlashMLASparseMetadata: diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 3b6718c48d09..d3f9b2bd4f16 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -34,6 +34,26 @@ def get_name() -> str: def get_impl_cls() -> type["TritonMLAImpl"]: return TritonMLAImpl + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: + return ["auto", "fp16", "bf16"] + + @classmethod + def get_supported_block_sizes(cls) -> list[int]: + return [] + + @classmethod + def get_min_compute_capability(cls) -> Optional[int]: + return None + + @classmethod + def get_max_compute_capability(cls) -> Optional[int]: + return None + class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True From eb8426f13c9ca8ec236d84411f55b7b991d05e43 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 8 Oct 2025 15:38:53 -0400 Subject: [PATCH 06/84] use backend_to_class_str Signed-off-by: Matthew Bonanni --- vllm/v1/spec_decode/eagle.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1e1161727be1..c902320c4441 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -132,11 +132,13 @@ def __init__( ) # Determine allowed attention backends once during initialization. + from vllm.attention.backends.registry import _Backend, backend_to_class_str + self.allowed_attn_types: Optional[tuple] = None if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] - # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend - if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): + # ROCM_AITER_FA is an optional backend + if find_spec(backend_to_class_str(_Backend.ROCM_AITER_FA)): from vllm.v1.attention.backends.rocm_aiter_fa import ( AiterFlashAttentionMetadata, ) From aba576c346e2601d697883cdcce4d05871034907 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 8 Oct 2025 15:51:16 -0400 Subject: [PATCH 07/84] add support details for standard attention backends Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/cpu_attn.py | 27 +++----- vllm/v1/attention/backends/flash_attn.py | 68 +++++++++++++------- vllm/v1/attention/backends/flashinfer.py | 51 ++++++++------- vllm/v1/attention/backends/flex_attention.py | 37 ++++++++--- vllm/v1/attention/backends/rocm_aiter_fa.py | 14 ---- vllm/v1/attention/backends/tree_attn.py | 14 ---- vllm/v1/attention/backends/triton_attn.py | 52 ++++++++++----- vllm/v1/attention/backends/xformers.py | 14 ---- 8 files changed, 147 insertions(+), 130 deletions(-) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 6e27e93c9115..420e4fd115f1 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -42,21 +42,13 @@ class TorchSDPABackend(AttentionBackend): accept_output_buffer: bool = False @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] + def get_supported_head_sizes(cls) -> list[int]: + attn_impl = _get_paged_attn_impl() + return attn_impl.get_supported_head_sizes() @classmethod - def validate_head_size(cls, head_size: int) -> None: - attn_impl = _get_paged_attn_impl() - is_valid, supported_head_sizes = attn_impl.validate_head_size(head_size) - if not is_valid: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16, torch.float32] @staticmethod def get_name() -> str: @@ -759,9 +751,8 @@ def _make_sliding_window_bias( class _PagedAttention: @staticmethod - def validate_head_size(head_size: int) -> tuple[bool, list[int]]: - SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256] - return head_size in SUPPORT_HS, SUPPORT_HS + def get_supported_head_sizes() -> list[int]: + return [32, 64, 80, 96, 112, 128, 192, 256] @staticmethod def get_kv_cache_shape( @@ -861,8 +852,8 @@ def forward_decode( class _IPEXPagedAttention(_PagedAttention): @staticmethod - def validate_head_size(head_size: int) -> tuple[bool, list[int]]: - return True, [] + def get_supported_head_sizes() -> list[int]: + return [] @staticmethod def split_kv_cache( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1f6b7e41b37e..d4bca0b07d07 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -49,26 +49,6 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True supports_quant_query_input: bool = True - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "FLASH_ATTN" @@ -117,6 +97,52 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: else: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: Optional[str]) -> bool: + if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + return flash_attn_supports_fp8() + return kv_cache_dtype in [None, "auto", "fp16", "bf16"] + + @classmethod + def supports_block_size(cls, block_size: int) -> bool: + return block_size % 16 == 0 + + @classmethod + def is_mla(cls) -> bool: + return False + + @classmethod + def get_min_compute_capability(cls) -> Optional[int]: + return 80 + + @classmethod + def get_max_compute_capability(cls) -> Optional[int]: + return None + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool, + use_mla: bool, + has_sink: bool, + device_capability: int, + ) -> Optional[str]: + if has_sink and device_capability < 90: + return "sink not supported on compute capability < 9.0" + return None + @dataclass class FlashAttentionMetadata: @@ -428,8 +454,6 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads - FlashAttentionBackend.validate_head_size(head_size) - self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 38cf0ca56733..2f8790f1159b 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -157,27 +157,6 @@ def trtllm_prefill_attn_kvfp8_dequant( class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 - return [64, 128, 256] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "FLASHINFER" @@ -226,6 +205,35 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: else: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + return [64, 128, 256] + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_kv_cache_dtypes(cls) -> list[str | None]: + return ["auto", "fp16", "bf16", "fp8", "fp8_e4m3", "fp8_e5m2"] + + @classmethod + def get_supported_block_sizes(cls) -> list[int]: + return [] + + @classmethod + def is_mla(cls) -> bool: + return False + + @classmethod + def get_min_compute_capability(cls) -> int | None: + return 100 + + @classmethod + def get_max_compute_capability(cls) -> int | None: + return 109 + @dataclass class FlashInferMetadata: @@ -307,7 +315,6 @@ def __init__( ) self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size - FlashInferBackend.validate_head_size(self.head_dim) self.page_size = self.kv_cache_spec.block_size self.cache_dtype = self.cache_config.cache_dtype diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 7775445ae773..3b2207dced18 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -72,14 +72,6 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - return # FlexAttention supports any head size - @staticmethod def get_name() -> str: return "FLEX_ATTENTION" @@ -110,6 +102,34 @@ def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]: def use_cascade_attention(*args, **kwargs) -> bool: return False + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16, torch.float32] + + @classmethod + def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: + return ["auto", "fp16", "bf16"] + + @classmethod + def get_supported_block_sizes(cls) -> list[int]: + return [] + + @classmethod + def is_mla(cls) -> bool: + return False + + @classmethod + def get_min_compute_capability(cls) -> Optional[int]: + return None + + @classmethod + def get_max_compute_capability(cls) -> Optional[int]: + return None + # @torch.compile(fullgraph=True, mode="reduce-overhead") def physical_to_logical_mapping( @@ -720,7 +740,6 @@ def __init__( if kv_sharing_target_layer_name is not None: raise NotImplementedError("FlexAttention does not support kv sharing yet.") - FlexAttentionBackend.validate_head_size(head_size) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "FlexAttention does not support quantized kv-cache. Yet" diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 348eca55eefb..821483d19283 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -359,18 +359,6 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [64, 128, 256] - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "FLASH_ATTN" @@ -435,8 +423,6 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - AiterFlashAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index a209bb79580c..8a7387c7ff98 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -39,18 +39,6 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "TREE_ATTN" @@ -331,8 +319,6 @@ def __init__( else: self.sliding_window = (sliding_window - 1, 0) - TreeAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 9997ed16bed1..b69b142da53f 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -153,21 +153,6 @@ def build( class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - # Triton Attention supports any head size above 32 - if head_size < 32: - raise ValueError( - f"Head size {head_size} is not supported by TritonAttention." - f"Head sizes need to be larger or equal 32 for this backend. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "TRITON_ATTN" @@ -200,6 +185,41 @@ def use_cascade_attention(*args, **kwargs) -> bool: def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: return TritonAttentionMetadataBuilder + @classmethod + def validate_head_size(cls, head_size: int) -> None: + # Triton Attention supports any head size above 32 + if head_size < 32: + raise ValueError( + f"Head size {head_size} is not supported by TritonAttention." + f"Head sizes need to be larger or equal 32 for this backend. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes." + ) + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16, torch.float32] + + @classmethod + def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: + return ["auto", "fp16", "bf16"] + + @classmethod + def get_supported_block_sizes(cls) -> list[int]: + return [] + + @classmethod + def is_mla(cls) -> bool: + return False + + @classmethod + def get_min_compute_capability(cls) -> Optional[int]: + return None + + @classmethod + def get_max_compute_capability(cls) -> Optional[int]: + return None + class TritonAttentionImpl(AttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): @@ -239,8 +259,6 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads - TritonAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index b21562fac741..eb7c98310c71 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -80,18 +80,6 @@ def get_supported_head_sizes(cls) -> list[int]: 256, ] - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "XFORMERS" @@ -305,8 +293,6 @@ def __init__( logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap - XFormersAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " From ff18a9a7c286f8dd85940aae7fd087a628f8398f Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 8 Oct 2025 16:02:09 -0400 Subject: [PATCH 08/84] update cuda logic Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 424 ++++++++---------- .../attention/backends/mla/flashmla_sparse.py | 2 +- 2 files changed, 183 insertions(+), 243 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8a4565b4d1a0..7d5446d81138 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -18,7 +18,11 @@ import vllm._C # noqa import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless, import_pynvml +from vllm.utils import ( + cuda_device_count_stateless, + import_pynvml, + resolve_obj_by_qualname, +) from .interface import DeviceCapability, Platform, PlatformEnum @@ -40,6 +44,26 @@ torch.backends.cuda.enable_cudnn_sdp(False) +@cache +def _get_backend_priorities(): + """Get backend priorities with lazy import to avoid circular dependency.""" + from vllm.attention.backends.registry import _Backend + + return { + # non-MLA backends + _Backend.FLASHINFER: 0, + _Backend.FLASH_ATTN: 1, + _Backend.TRITON_ATTN: 2, + _Backend.FLEX_ATTENTION: 3, + # MLA backends + _Backend.CUTLASS_MLA: 0, + _Backend.FLASHINFER_MLA: 1, + _Backend.FLASHMLA: 2, + _Backend.FLASH_ATTN_MLA: 3, + _Backend.TRITON_MLA: 4, + } + + def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: @@ -116,64 +140,38 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: if cache_config and cache_config.block_size is None: cache_config.block_size = 16 - # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing if model_config is not None and model_config.use_mla: - use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, # else we default to CutlassMLA. For each case, we force the # required block_size. - use_flashmla = False - use_cutlass_mla = False - use_flashinfer_mla = False if envs.VLLM_ATTENTION_BACKEND is None: # Default case if cls.is_device_capability(100): # Blackwell => Force CutlassMLA. - use_cutlass_mla = True # TODO: This does not work, because the # global_force_attn_backend_context_manager is not set. # See vllm/attention/selector.py:_cached_get_attn_backend envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA" else: - # Not Blackwell - use_flashmla = True - else: - # Forced case - use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" - use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" - use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" - - from vllm.attention.ops.flashmla import is_flashmla_dense_supported - - if ( - use_flashmla - and is_flashmla_dense_supported()[0] - and cache_config.block_size != 64 - ): - cache_config.block_size = 64 - logger.info("Forcing kv cache block size to 64 for FlashMLA backend.") + # Not Blackwell => Force FlashMLA. + envs.VLLM_ATTENTION_BACKEND = "FLASHMLA" - if use_cutlass_mla and cache_config.block_size != 128: - cache_config.block_size = 128 - logger.info( - "Forcing kv cache block size to 128 for CUTLASS_MLA backend." - ) + # Adjust block sizes for MLA backends based on their requirements + from vllm.attention.backends.registry import _Backend, backend_to_class - if use_flashinfer_mla and cache_config.block_size not in [32, 64]: - cache_config.block_size = 64 + backend_enum = _Backend[envs.VLLM_ATTENTION_BACKEND] + backend_class = backend_to_class(backend_enum) + if not backend_class.supports_block_size(cache_config.block_size): + cache_config.block_size = backend_class.get_supported_block_sizes()[0] logger.info( - "Forcing kv cache block size to 64 for FlashInferMLA backend." + "Forcing kv cache block size to %s for %s backend.", + cache_config.block_size, + envs.VLLM_ATTENTION_BACKEND, ) - # TODO(Chen): remove this hacky code - if use_sparse and cache_config.block_size != 64: - cache_config.block_size = 64 - logger.info( - "Forcing kv cache block size to 64 for FlashMLASparse backend." - ) # lazy import to avoid circular import from vllm.config import CUDAGraphMode @@ -206,7 +204,7 @@ def get_current_memory_usage( @classmethod def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import _Backend, backend_to_class # For Blackwell GPUs, force TORCH_SDPA for now. # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 @@ -217,204 +215,177 @@ def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": return _Backend.XFORMERS if cls.has_device_capability(80): - FLASH_ATTN_V1 = ( - "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 - ) - from vllm.attention.selector import is_attn_backend_supported - - is_default_fa_supported = is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False - ) - if is_default_fa_supported: + backend_class = backend_to_class(_Backend.FLASH_ATTN) + if backend_class.supports_head_size( + head_size + ) and backend_class.supports_dtype(dtype): return _Backend.FLASH_ATTN else: - # Fallback to XFORMERS return _Backend.XFORMERS else: # Fallback for Volta/Turing GPUs or FA not supported return _Backend.XFORMERS @classmethod - def get_attn_backend_cls( + def get_valid_backends( cls, - selected_backend, head_size, dtype, kv_cache_dtype, block_size, - use_v1, use_mla, has_sink, use_sparse, - ) -> str: - from vllm.attention.backends.registry import _Backend - - if use_mla: - if not use_v1: - raise RuntimeError( - "MLA attention backends require the V1 engine. " - "Set VLLM_USE_V1=1 to enable them." + device_capability_int, + ) -> tuple[list[tuple["_Backend", int]], dict["_Backend", list[str]]]: + valid_backends_priorities = [] + invalid_reasons = {} + from vllm.attention.backends.registry import _Backend, backend_to_class + + backend_priorities = _get_backend_priorities() + for backend in _Backend: + if backend not in backend_priorities: + continue + priority = backend_priorities[backend] + try: + backend_class = backend_to_class(backend) + invalid_reasons_i = backend_class.validate_configuration( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla, + has_sink, + use_sparse, + device_capability_int, ) + except ImportError: + invalid_reasons_i = ["ImportError"] + if invalid_reasons_i: + invalid_reasons[backend] = invalid_reasons_i + else: + valid_backends_priorities.append((backend, priority)) - from vllm.attention.ops.flashmla import is_flashmla_dense_supported - from vllm.attention.utils.fa_utils import flash_attn_supports_mla - - if use_sparse: - logger.info_once("Using Sparse MLA backend on V1 engine.") - return ( - "vllm.v1.attention.backends.mla.flashmla_sparse." - "FlashMLASparseBackend" - ) + return valid_backends_priorities, invalid_reasons - use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( - selected_backend is None - and cls.is_device_capability(100) - and block_size == 128 - ) - use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( - selected_backend is None - and cls.is_device_capability(100) - and block_size in [32, 64] - ) - use_flashmla = selected_backend == _Backend.FLASHMLA or ( - selected_backend is None and is_flashmla_dense_supported()[0] - ) - use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( - selected_backend is None and flash_attn_supports_mla() - ) - use_triton = selected_backend == _Backend.TRITON_MLA or ( - selected_backend is None + @classmethod + def get_attn_backend_cls( + cls, + selected_backend: "_Backend", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + ) -> str: + if not use_v1: + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend." ) - if use_cutlassmla: - logger.info_once("Using Cutlass MLA backend on V1 engine.") - return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" - if use_flashinfermla: - from vllm.v1.attention.backends.utils import set_kv_cache_layout + from vllm.attention.backends.registry import _Backend, backend_to_class_str + + device_capability = cls.get_device_capability() + device_capability_int = ( + device_capability.to_int() if device_capability is not None else None + ) - set_kv_cache_layout("HND") - logger.info_once("Using FlashInfer MLA backend on V1 engine.") - return ( - "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" + # First try checking just the selected backend, if there is one. + if selected_backend is not None: + backend_class_str = backend_to_class_str(selected_backend) + try: + backend_class = resolve_obj_by_qualname(backend_class_str) + invalid_reasons = backend_class.validate_configuration( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink, + use_sparse, + device_capability_int, ) - if use_flashmla: - if block_size != 64: - logger.warning( - "FlashMLA backend is not supported for block size %d" - " (currently only supports block size 64).", - block_size, - ) - else: - logger.info_once("Using FlashMLA backend on V1 engine.") - return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" - if use_flashattn: - logger.info_once("Using FlashAttention MLA backend on V1 engine.") - return ( - "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" + except ImportError: + invalid_reasons = ["ImportError"] + if invalid_reasons: + logger.warning( + "Selected backend %s is not valid for this configuration. " + "Reason: %s", + selected_backend, + invalid_reasons, ) - if use_triton: - logger.info_once("Using Triton MLA backend on V1 engine.") - return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" - if use_v1: - FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 - FLEX_ATTENTION_V1 = ( - "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 - ) - TRITON_ATTN = ( - "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 + else: + engine_version = "V1" if use_v1 else "V0" + logger.info( + "Using %s backend on %s engine.", selected_backend, engine_version + ) + return backend_class_str + + # No selected backend or the selected backend is invalid, + # so we try finding a valid backend. + valid_backends_priorities, invalid_reasons = cls.get_valid_backends( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla, + has_sink, + use_sparse, + device_capability_int, + ) + + if len(valid_backends_priorities) == 0: + reasons_str = ( + "{" + + ", ".join( + f"{backend.name}: [{', '.join(reasons)}]" + for backend, reasons in invalid_reasons.items() + ) + + "}" ) - FLASH_ATTN_V1 = ( - "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + raise ValueError( + f"No valid attention backend from priority list for " + f"{cls.device_name} with head_size: {head_size}, " + f"dtype: {dtype}, kv_cache_dtype: {kv_cache_dtype}, " + f"use_mla: {use_mla}, has_sink: {has_sink}, " + f"use_sparse: {use_sparse}. Reasons: {reasons_str}" ) - TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 - XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 - use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( - "fp8" - ) + # We have found some valid backends. Select the one with the + # highest priority. + logger.info( + "Valid backends: %s", [b[0].name for b in valid_backends_priorities] + ) - if selected_backend == _Backend.FLASHINFER: - logger.info_once("Using FlashInfer backend on V1 engine.") - if cls.has_device_capability(100): - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - set_kv_cache_layout("HND") - return FLASHINFER_V1 - elif selected_backend == _Backend.FLEX_ATTENTION: - logger.info_once("Using FlexAttention backend on V1 engine.") - return FLEX_ATTENTION_V1 - elif selected_backend == _Backend.TRITON_ATTN: - logger.info_once("Using Triton backend on V1 engine.") - return TRITON_ATTN - elif selected_backend == _Backend.FLASH_ATTN: - logger.info_once("Using Flash Attention backend on V1 engine.") - return FLASH_ATTN_V1 - elif selected_backend == _Backend.TREE_ATTN: - logger.info_once("Using Tree Attention backend on V1 engine.") - return TREE_ATTN_V1 - elif selected_backend == _Backend.XFORMERS: - logger.info_once("Using XFormers backend on V1 engine.") - return XFORMERS_V1 - - from vllm.attention.selector import is_attn_backend_supported - - # Default backends for V1 engine - # Prefer FlashInfer for Blackwell GPUs if installed - if cls.is_device_capability(100): - if is_default_backend_supported := is_attn_backend_supported( - FLASHINFER_V1, head_size, dtype - ): - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - logger.info_once( - "Using FlashInfer backend with HND KV cache layout on " - "V1 engine by default for Blackwell (SM 10.0) GPUs." - ) - set_kv_cache_layout("HND") - - return FLASHINFER_V1 - - if not is_default_backend_supported.can_import: - logger.warning_once( - "FlashInfer failed to import for V1 engine on " - "Blackwell (SM 10.0) GPUs; it is recommended to " - "install FlashInfer for better performance." - ) - - # FlashAttention is the default for SM 8.0+ GPUs - if cls.has_device_capability(80): - if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90): - logger.info_once("Using Triton backend on V1 engine.") - return TRITON_ATTN - elif is_default_backend_supported := is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False - ): - logger.info_once("Using Flash Attention backend on V1 engine.") - return FLASH_ATTN_V1 - - # FlexAttention is the default for older GPUs - else: - logger.info_once("Using FlexAttention backend on V1 engine.") - return FLEX_ATTENTION_V1 + valid_backends_classes_str = [ + backend_to_class_str(b[0]) for b in valid_backends_priorities + ] + sorted_indices = sorted( + range(len(valid_backends_priorities)), + key=lambda i: valid_backends_priorities[i][1], + ) + selected_index = sorted_indices[0] - assert not is_default_backend_supported + engine_version = "V1" if use_v1 else "V0" + logger.info( + "Using %s backend on %s engine.", + valid_backends_priorities[selected_index][0].name, + engine_version, + ) - use_flex_attention_reason = {} - if not is_default_backend_supported.head_size: - use_flex_attention_reason["head_size"] = head_size - if not is_default_backend_supported.dtype: - use_flex_attention_reason["dtype"] = dtype + # Post-selection modifications + if valid_backends_priorities[selected_index][0] == _Backend.FLASHINFER_MLA: + from vllm.v1.attention.backends.utils import set_kv_cache_layout - logger.info_once( - "Using FlexAttention backend for %s on V1 engine.", - ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()), - ) - return FLEX_ATTENTION_V1 + set_kv_cache_layout("HND") + logger.info("Using HND KV cache layout for FlashInferMLA.") - raise RuntimeError( - "V0 attention backends have been removed. Set VLLM_USE_V1=1 " - "to select a supported backend." - ) + return valid_backends_classes_str[selected_index] @classmethod def get_punica_wrapper(cls) -> str: @@ -481,44 +452,13 @@ def device_count(cls) -> int: def is_kv_cache_dtype_supported( cls, kv_cache_dtype: str, model_config: "ModelConfig" ) -> bool: - fp8_attention = kv_cache_dtype.startswith("fp8") - attention_backend = envs.VLLM_ATTENTION_BACKEND + if not envs.VLLM_ATTENTION_BACKEND: + return True + from vllm.attention.backends.registry import _Backend, backend_to_class - supported = False - if model_config is not None and model_config.use_mla: - # Default to CutlassMLA for blackwell, - # FlashMLA otherwise - if attention_backend is None: - if cls.is_device_capability(100): - attention_backend = "CUTLASS_MLA" - else: - attention_backend = "FLASHMLA" - - # Only FlashMLA and CUTLASS_MLA support fp8 - if attention_backend in ["FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"]: - supported = True - else: - supported = not fp8_attention - else: - # Default to FlashAttention - if attention_backend is None: - attention_backend = "FLASH_ATTN" - - # All Blackwell backends support fp8 - if cls.is_device_capability(100): - supported = True - elif attention_backend == "FLASH_ATTN": - if fp8_attention: - from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 - - supported = flash_attn_supports_fp8() - else: - supported = True - elif attention_backend == "FLASHINFER": - supported = True - elif attention_backend == "TRITON_ATTN": - supported = cls.supports_fp8() - return supported + attention_backend = _Backend[envs.VLLM_ATTENTION_BACKEND] + backend_class = backend_to_class(attention_backend) + return backend_class.supports_kv_cache_dtype(kv_cache_dtype) @classmethod def check_if_supports_dtype(cls, torch_dtype: torch.dtype): diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 103054efaf05..b3785f950e50 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -79,7 +79,7 @@ def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: @classmethod def get_supported_block_sizes(cls) -> list[int]: - return [32, 64] + return [64] @classmethod def is_sparse(cls) -> bool: From df494842d4d74c0db195a853e29ec27af231eb75 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 8 Oct 2025 16:15:20 -0400 Subject: [PATCH 09/84] fix pre-commit Signed-off-by: Matthew Bonanni --- vllm/attention/backends/registry.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index b74ae09e6112..97f107357ae3 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -3,10 +3,13 @@ """Attention backend registry""" import enum -from typing import Optional +from typing import TYPE_CHECKING, Optional from vllm.utils import resolve_obj_by_qualname +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -83,7 +86,7 @@ def backend_to_class_str(backend: _Backend) -> str: return BACKEND_MAP[backend] -def backend_to_class(backend: _Backend) -> type: +def backend_to_class(backend: _Backend) -> "type[AttentionBackend]": """Get the backend class. Args: @@ -93,7 +96,7 @@ def backend_to_class(backend: _Backend) -> type: The backend class """ backend_class_name = backend_to_class_str(backend) - return resolve_obj_by_qualname(backend_class_name) + return resolve_obj_by_qualname(backend_class_name) # type: ignore[return-value] def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: From ff5ad7c7e7511650e7070fc604d0918d62fa4fc0 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 8 Oct 2025 16:17:38 -0400 Subject: [PATCH 10/84] fix argument mismatch Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 2 +- vllm/platforms/cuda.py | 1 - vllm/v1/attention/backends/flash_attn.py | 2 +- vllm/v1/attention/backends/mla/flashattn_mla.py | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index dce0360f2b76..958aa992a4f6 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -157,9 +157,9 @@ def supports_combination( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - use_v1: bool, use_mla: bool, has_sink: bool, + use_sparse: bool, device_capability: int, ) -> Optional[str]: return None diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 7d5446d81138..76f1799436eb 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -304,7 +304,6 @@ def get_attn_backend_cls( dtype, kv_cache_dtype, block_size, - use_v1, use_mla, has_sink, use_sparse, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d4bca0b07d07..85f535c21c06 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -134,9 +134,9 @@ def supports_combination( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - use_v1: bool, use_mla: bool, has_sink: bool, + use_sparse: bool, device_capability: int, ) -> Optional[str]: if has_sink and device_capability < 90: diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 9086f0243df2..0835b74375e4 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -77,9 +77,9 @@ def supports_combination( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - use_v1: bool, use_mla: bool, has_sink: bool, + use_sparse: bool, device_capability: int, ) -> Optional[str]: if not flash_attn_supports_mla(): From 712ae590ad4402cdf98401f5dd8d3c88c92f6a83 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 8 Oct 2025 16:25:17 -0400 Subject: [PATCH 11/84] fix pre-commit Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 76f1799436eb..9fc5c5bc6fe2 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -165,7 +165,13 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: backend_enum = _Backend[envs.VLLM_ATTENTION_BACKEND] backend_class = backend_to_class(backend_enum) if not backend_class.supports_block_size(cache_config.block_size): - cache_config.block_size = backend_class.get_supported_block_sizes()[0] + from typing import cast + + from vllm.config.cache import BlockSize + + cache_config.block_size = cast( + BlockSize, backend_class.get_supported_block_sizes()[0] + ) logger.info( "Forcing kv cache block size to %s for %s backend.", cache_config.block_size, From 97e1a2c422120da6deb3d413d9452e5d6f36ec48 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 8 Oct 2025 16:35:21 -0400 Subject: [PATCH 12/84] use block size literals Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 13 ++++++++++--- vllm/platforms/cuda.py | 8 +------- vllm/v1/attention/backends/flashinfer.py | 3 ++- vllm/v1/attention/backends/flex_attention.py | 3 ++- vllm/v1/attention/backends/mla/cutlass_mla.py | 3 ++- vllm/v1/attention/backends/mla/flashattn_mla.py | 3 ++- vllm/v1/attention/backends/mla/flashinfer_mla.py | 3 ++- vllm/v1/attention/backends/mla/flashmla.py | 3 ++- vllm/v1/attention/backends/mla/flashmla_sparse.py | 3 ++- vllm/v1/attention/backends/mla/triton_mla.py | 3 ++- vllm/v1/attention/backends/triton_attn.py | 3 ++- 11 files changed, 29 insertions(+), 19 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 958aa992a4f6..a86b0138ba49 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Generic, Optional, Protocol, TypeVar +from typing import Generic, Optional, Protocol, TypeVar, cast import torch +from vllm.config.cache import BlockSize from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey @@ -114,13 +115,19 @@ def supports_kv_cache_dtype(cls, kv_cache_dtype: Optional[str]) -> bool: ) @classmethod - def get_supported_block_sizes(cls) -> list[int]: + def get_supported_block_sizes(cls) -> list[BlockSize]: return [] @classmethod def supports_block_size(cls, block_size: int) -> bool: + try: + block_size_literal = cast(BlockSize, block_size) + except ValueError: + return False supported_block_sizes = cls.get_supported_block_sizes() - return (not supported_block_sizes) or block_size in supported_block_sizes + return ( + not supported_block_sizes + ) or block_size_literal in supported_block_sizes @classmethod def is_mla(cls) -> bool: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 9fc5c5bc6fe2..76f1799436eb 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -165,13 +165,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: backend_enum = _Backend[envs.VLLM_ATTENTION_BACKEND] backend_class = backend_to_class(backend_enum) if not backend_class.supports_block_size(cache_config.block_size): - from typing import cast - - from vllm.config.cache import BlockSize - - cache_config.block_size = cast( - BlockSize, backend_class.get_supported_block_sizes()[0] - ) + cache_config.block_size = backend_class.get_supported_block_sizes()[0] logger.info( "Forcing kv cache block size to %s for %s backend.", cache_config.block_size, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 3b875ae31e20..34a4bda693cb 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -25,6 +25,7 @@ AttentionType, ) from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config.cache import BlockSize from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -218,7 +219,7 @@ def get_supported_kv_cache_dtypes(cls) -> list[str | None]: return ["auto", "fp16", "bf16", "fp8", "fp8_e4m3", "fp8_e5m2"] @classmethod - def get_supported_block_sizes(cls) -> list[int]: + def get_supported_block_sizes(cls) -> list[BlockSize]: return [] @classmethod diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 3b2207dced18..2471ff175e17 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -25,6 +25,7 @@ is_quantized_kv_cache, ) from vllm.config import VllmConfig +from vllm.config.cache import BlockSize from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_kernel_override_batch_invariant, @@ -115,7 +116,7 @@ def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: return ["auto", "fp16", "bf16"] @classmethod - def get_supported_block_sizes(cls) -> list[int]: + def get_supported_block_sizes(cls) -> list[BlockSize]: return [] @classmethod diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index e587020438f0..6e0795ea8a47 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -12,6 +12,7 @@ AttentionType, is_quantized_kv_cache, ) +from vllm.config.cache import BlockSize from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -53,7 +54,7 @@ def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: return ["auto", "fp16", "bf16", "e4m3fn"] @classmethod - def get_supported_block_sizes(cls) -> list[int]: + def get_supported_block_sizes(cls) -> list[BlockSize]: return [128] @classmethod diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 0835b74375e4..9929cb7166cd 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -17,6 +17,7 @@ get_flash_attn_version, ) from vllm.config import VllmConfig +from vllm.config.cache import BlockSize from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import ( @@ -59,7 +60,7 @@ def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: return ["auto", "fp16", "bf16"] @classmethod - def get_supported_block_sizes(cls) -> list[int]: + def get_supported_block_sizes(cls) -> list[BlockSize]: return [] @classmethod diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index e848b430797a..80ddb0366226 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -7,6 +7,7 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.config.cache import BlockSize from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -51,7 +52,7 @@ def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: return ["auto", "fp16", "bf16", "fp8", "fp8_e4m3"] @classmethod - def get_supported_block_sizes(cls) -> list[int]: + def get_supported_block_sizes(cls) -> list[BlockSize]: return [32, 64] @classmethod diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 15514106110f..911047a46e2a 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -13,6 +13,7 @@ is_flashmla_dense_supported, ) from vllm.config import VllmConfig +from vllm.config.cache import BlockSize from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -53,7 +54,7 @@ def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: return ["auto", "fp16", "bf16", "fp8", "fp8_e4m3"] @classmethod - def get_supported_block_sizes(cls) -> list[int]: + def get_supported_block_sizes(cls) -> list[BlockSize]: return [64] @classmethod diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index b3785f950e50..6083b2dbb7bf 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -19,6 +19,7 @@ get_mla_metadata, ) from vllm.config import VllmConfig +from vllm.config.cache import BlockSize from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -78,7 +79,7 @@ def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: return ["auto", "bf16", "fp8_ds_mla"] @classmethod - def get_supported_block_sizes(cls) -> list[int]: + def get_supported_block_sizes(cls) -> list[BlockSize]: return [64] @classmethod diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index d3f9b2bd4f16..2fe2de9788c0 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -13,6 +13,7 @@ ) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention +from vllm.config.cache import BlockSize from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON @@ -43,7 +44,7 @@ def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: return ["auto", "fp16", "bf16"] @classmethod - def get_supported_block_sizes(cls) -> list[int]: + def get_supported_block_sizes(cls) -> list[BlockSize]: return [] @classmethod diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index b69b142da53f..4bd40da229cc 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,6 +18,7 @@ ) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig +from vllm.config.cache import BlockSize from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -205,7 +206,7 @@ def get_supported_kv_cache_dtypes(cls) -> list[Optional[str]]: return ["auto", "fp16", "bf16"] @classmethod - def get_supported_block_sizes(cls) -> list[int]: + def get_supported_block_sizes(cls) -> list[BlockSize]: return [] @classmethod From 8f867147148082cada75b81d13007d4ab286da9d Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 14:56:11 -0400 Subject: [PATCH 13/84] replace backend_name_to_enum with direct calls Signed-off-by: Matthew Bonanni --- vllm/attention/backends/registry.py | 15 +-------------- vllm/attention/layer.py | 4 ++-- vllm/attention/selector.py | 11 ++++++----- .../kv_transfer/kv_connector/v1/nixl_connector.py | 4 ++-- 4 files changed, 11 insertions(+), 23 deletions(-) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 97f107357ae3..c88514c79d4c 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -96,17 +96,4 @@ def backend_to_class(backend: _Backend) -> "type[AttentionBackend]": The backend class """ backend_class_name = backend_to_class_str(backend) - return resolve_obj_by_qualname(backend_class_name) # type: ignore[return-value] - - -def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: - """ - Convert a string backend name to a _Backend enum value. - - Returns: - _Backend: enum value if backend_name is a valid in-tree type - None: otherwise it's an invalid in-tree type or an out-of-tree platform - is loaded. - """ - assert backend_name is not None - return _Backend[backend_name] if backend_name in _Backend.__members__ else None + return resolve_obj_by_qualname(backend_class_name) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b429c74aa559..2950384bfaaa 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -11,7 +11,7 @@ import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import _Backend from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config @@ -250,7 +250,7 @@ def __init__( kv_sharing_target_layer_name, **extra_impl_args, ) - self.backend = backend_name_to_enum(self.attn_backend.get_name()) + self.backend = _Backend[self.attn_backend.get_name()] self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index ebe75fb976d7..30da874bbc5f 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -11,7 +11,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import _Backend from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname @@ -30,7 +30,7 @@ def get_env_variable_attn_backend() -> Optional[_Backend]: * None otherwise """ backend_name = os.environ.get(STR_BACKEND_ENV_VAR) - return None if backend_name is None else backend_name_to_enum(backend_name) + return None if backend_name is None else _Backend[backend_name] # Global state allows a particular choice of backend @@ -125,12 +125,13 @@ def _cached_get_attn_backend( STR_BACKEND_ENV_VAR, ) backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1") - selected_backend = backend_name_to_enum(backend_by_env_var) - if selected_backend is None: + try: + selected_backend = _Backend[backend_by_env_var] + except KeyError as e: raise ValueError( f"Invalid attention backend: '{backend_by_env_var}'. " f"Valid backends are: {list(_Backend.__members__.keys())}" - ) + ) from e # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0d4744b9f4ab..66fc8ea80ce5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -21,7 +21,7 @@ import zmq from vllm import envs -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import _Backend from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -646,7 +646,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): use_mla=self.use_mla, ) self.backend_name = backend.get_name() - attn_backend = backend_name_to_enum(self.backend_name) + attn_backend = _Backend[self.backend_name] self._use_flashinfer = attn_backend == _Backend.FLASHINFER self._use_pallas = attn_backend == _Backend.PALLAS self.kv_cache_layout = get_kv_cache_layout() From 50596d82a88d8a2e7b2355733e4740d444630172 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 15:10:40 -0400 Subject: [PATCH 14/84] use DeviceCapability objects Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 11 ++++---- vllm/platforms/cuda.py | 11 +++----- vllm/platforms/interface.py | 25 +++++++++++++++++++ vllm/v1/attention/backends/flash_attn.py | 7 +++--- vllm/v1/attention/backends/flashinfer.py | 9 ++++--- vllm/v1/attention/backends/flex_attention.py | 8 ------ vllm/v1/attention/backends/mla/cutlass_mla.py | 9 ++++--- .../attention/backends/mla/flashattn_mla.py | 9 ++++--- .../attention/backends/mla/flashinfer_mla.py | 9 ++++--- vllm/v1/attention/backends/mla/flashmla.py | 9 ++++--- .../attention/backends/mla/flashmla_sparse.py | 9 ++++--- vllm/v1/attention/backends/mla/triton_mla.py | 5 ++-- vllm/v1/attention/backends/triton_attn.py | 5 ++-- 13 files changed, 75 insertions(+), 51 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a86b0138ba49..04d4bcd8aebe 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -8,6 +8,7 @@ from vllm.config.cache import BlockSize from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey +from vllm.platforms.interface import DeviceCapability class AttentionType: @@ -142,15 +143,15 @@ def is_sparse(cls) -> bool: return False @classmethod - def get_min_compute_capability(cls) -> Optional[int]: + def get_min_compute_capability(cls) -> Optional[DeviceCapability]: return None @classmethod - def get_max_compute_capability(cls) -> Optional[int]: + def get_max_compute_capability(cls) -> Optional[DeviceCapability]: return None @classmethod - def supports_compute_capability(cls, capability: int) -> bool: + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: min_capability = cls.get_min_compute_capability() max_capability = cls.get_max_compute_capability() return ((min_capability is None) or (capability >= min_capability)) and ( @@ -167,7 +168,7 @@ def supports_combination( use_mla: bool, has_sink: bool, use_sparse: bool, - device_capability: int, + device_capability: DeviceCapability, ) -> Optional[str]: return None @@ -181,7 +182,7 @@ def validate_configuration( use_mla: bool, has_sink: bool, use_sparse: bool, - device_capability: int, + device_capability: DeviceCapability, ) -> list[str]: invalid_reasons = [] if not cls.supports_head_size(head_size): diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 76f1799436eb..cb0ded968897 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -236,7 +236,7 @@ def get_valid_backends( use_mla, has_sink, use_sparse, - device_capability_int, + device_capability, ) -> tuple[list[tuple["_Backend", int]], dict["_Backend", list[str]]]: valid_backends_priorities = [] invalid_reasons = {} @@ -257,7 +257,7 @@ def get_valid_backends( use_mla, has_sink, use_sparse, - device_capability_int, + device_capability, ) except ImportError: invalid_reasons_i = ["ImportError"] @@ -290,9 +290,6 @@ def get_attn_backend_cls( from vllm.attention.backends.registry import _Backend, backend_to_class_str device_capability = cls.get_device_capability() - device_capability_int = ( - device_capability.to_int() if device_capability is not None else None - ) # First try checking just the selected backend, if there is one. if selected_backend is not None: @@ -307,7 +304,7 @@ def get_attn_backend_cls( use_mla, has_sink, use_sparse, - device_capability_int, + device_capability, ) except ImportError: invalid_reasons = ["ImportError"] @@ -335,7 +332,7 @@ def get_attn_backend_cls( use_mla, has_sink, use_sparse, - device_capability_int, + device_capability, ) if len(valid_backends_priorities) == 0: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 6dc49f99ac2a..a9e82df40270 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -65,6 +65,31 @@ class DeviceCapability(NamedTuple): major: int minor: int + def __lt__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) < (other.major, other.minor) + + def __le__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) <= (other.major, other.minor) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) == (other.major, other.minor) + + def __ge__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) >= (other.major, other.minor) + + def __gt__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) > (other.major, other.minor) + def as_version_str(self) -> str: return f"{self.major}.{self.minor}" diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 85f535c21c06..e5de1e4d81b1 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -33,6 +33,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger +from vllm.platforms.interface import DeviceCapability from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionCGSupport, @@ -120,11 +121,11 @@ def is_mla(cls) -> bool: return False @classmethod - def get_min_compute_capability(cls) -> Optional[int]: - return 80 + def get_min_compute_capability(cls) -> Optional[DeviceCapability]: + return DeviceCapability(8, 0) @classmethod - def get_max_compute_capability(cls) -> Optional[int]: + def get_max_compute_capability(cls) -> Optional[DeviceCapability]: return None @classmethod diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 34a4bda693cb..55d25908514f 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -33,6 +33,7 @@ kNvfp4Quant, ) from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available from vllm.utils.flashinfer import ( @@ -227,12 +228,12 @@ def is_mla(cls) -> bool: return False @classmethod - def get_min_compute_capability(cls) -> int | None: - return 100 + def get_min_compute_capability(cls) -> DeviceCapability | None: + return DeviceCapability(10, 0) @classmethod - def get_max_compute_capability(cls) -> int | None: - return 109 + def get_max_compute_capability(cls) -> DeviceCapability | None: + return DeviceCapability(10, 9) @dataclass diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 2471ff175e17..7622b35b26a8 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -123,14 +123,6 @@ def get_supported_block_sizes(cls) -> list[BlockSize]: def is_mla(cls) -> bool: return False - @classmethod - def get_min_compute_capability(cls) -> Optional[int]: - return None - - @classmethod - def get_max_compute_capability(cls) -> Optional[int]: - return None - # @torch.compile(fullgraph=True, mode="reduce-overhead") def physical_to_logical_mapping( diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 6e0795ea8a47..a7595b524f41 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -14,6 +14,7 @@ ) from vllm.config.cache import BlockSize from vllm.logger import init_logger +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonImpl, @@ -58,12 +59,12 @@ def get_supported_block_sizes(cls) -> list[BlockSize]: return [128] @classmethod - def get_min_compute_capability(cls) -> Optional[int]: - return 100 + def get_min_compute_capability(cls) -> Optional[DeviceCapability]: + return DeviceCapability(10, 0) @classmethod - def get_max_compute_capability(cls) -> Optional[int]: - return 109 + def get_max_compute_capability(cls) -> Optional[DeviceCapability]: + return DeviceCapability(10, 3) class SM100Workspace: diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 9929cb7166cd..1a0fe7829047 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -20,6 +20,7 @@ from vllm.config.cache import BlockSize from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -64,12 +65,12 @@ def get_supported_block_sizes(cls) -> list[BlockSize]: return [] @classmethod - def get_min_compute_capability(cls) -> Optional[int]: - return 90 + def get_min_compute_capability(cls) -> Optional[DeviceCapability]: + return DeviceCapability(9, 0) @classmethod - def get_max_compute_capability(cls) -> Optional[int]: - return 90 + def get_max_compute_capability(cls) -> Optional[DeviceCapability]: + return DeviceCapability(9, 0) @classmethod def supports_combination( diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 80ddb0366226..eed4f73849da 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.config.cache import BlockSize from vllm.logger import init_logger +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonImpl, @@ -56,12 +57,12 @@ def get_supported_block_sizes(cls) -> list[BlockSize]: return [32, 64] @classmethod - def get_min_compute_capability(cls) -> Optional[int]: - return 100 + def get_min_compute_capability(cls) -> Optional[DeviceCapability]: + return DeviceCapability(10, 0) @classmethod - def get_max_compute_capability(cls) -> Optional[int]: - return 109 + def get_max_compute_capability(cls) -> Optional[DeviceCapability]: + return DeviceCapability(10, 9) g_fi_workspace = torch.zeros( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 911047a46e2a..c81e54549415 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -15,6 +15,7 @@ from vllm.config import VllmConfig from vllm.config.cache import BlockSize from vllm.logger import init_logger +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -58,12 +59,12 @@ def get_supported_block_sizes(cls) -> list[BlockSize]: return [64] @classmethod - def get_min_compute_capability(cls) -> Optional[int]: - return 90 + def get_min_compute_capability(cls) -> Optional[DeviceCapability]: + return DeviceCapability(9, 0) @classmethod - def get_max_compute_capability(cls) -> Optional[int]: - return 109 + def get_max_compute_capability(cls) -> Optional[DeviceCapability]: + return DeviceCapability(10, 9) @dataclass diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 6083b2dbb7bf..f07779df231b 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -22,6 +22,7 @@ from vllm.config.cache import BlockSize from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils import cdiv from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl @@ -87,12 +88,12 @@ def is_sparse(cls) -> bool: return True @classmethod - def get_min_compute_capability(cls) -> Optional[int]: - return 100 + def get_min_compute_capability(cls) -> Optional[DeviceCapability]: + return DeviceCapability(10, 0) @classmethod - def get_max_compute_capability(cls) -> Optional[int]: - return 109 + def get_max_compute_capability(cls) -> Optional[DeviceCapability]: + return DeviceCapability(10, 9) @staticmethod def get_kv_cache_shape( diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 2fe2de9788c0..3c1bf23baf7a 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -16,6 +16,7 @@ from vllm.config.cache import BlockSize from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -48,11 +49,11 @@ def get_supported_block_sizes(cls) -> list[BlockSize]: return [] @classmethod - def get_min_compute_capability(cls) -> Optional[int]: + def get_min_compute_capability(cls) -> Optional[DeviceCapability]: return None @classmethod - def get_max_compute_capability(cls) -> Optional[int]: + def get_max_compute_capability(cls) -> Optional[DeviceCapability]: return None diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 4bd40da229cc..882e0429a702 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -25,6 +25,7 @@ kFp8StaticTensorSym, ) from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -214,11 +215,11 @@ def is_mla(cls) -> bool: return False @classmethod - def get_min_compute_capability(cls) -> Optional[int]: + def get_min_compute_capability(cls) -> Optional[DeviceCapability]: return None @classmethod - def get_max_compute_capability(cls) -> Optional[int]: + def get_max_compute_capability(cls) -> Optional[DeviceCapability]: return None From 03f69632a8e05618a4a83c441cc56dd6da23317b Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 15:11:32 -0400 Subject: [PATCH 15/84] update max Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/flashinfer.py | 2 +- vllm/v1/attention/backends/mla/flashinfer_mla.py | 2 +- vllm/v1/attention/backends/mla/flashmla.py | 2 +- vllm/v1/attention/backends/mla/flashmla_sparse.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 55d25908514f..3e3682b7a20b 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -233,7 +233,7 @@ def get_min_compute_capability(cls) -> DeviceCapability | None: @classmethod def get_max_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(10, 9) + return DeviceCapability(10, 3) @dataclass diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index eed4f73849da..bcbdfb9f7cc0 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -62,7 +62,7 @@ def get_min_compute_capability(cls) -> Optional[DeviceCapability]: @classmethod def get_max_compute_capability(cls) -> Optional[DeviceCapability]: - return DeviceCapability(10, 9) + return DeviceCapability(10, 3) g_fi_workspace = torch.zeros( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index c81e54549415..544caee0c7d4 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -64,7 +64,7 @@ def get_min_compute_capability(cls) -> Optional[DeviceCapability]: @classmethod def get_max_compute_capability(cls) -> Optional[DeviceCapability]: - return DeviceCapability(10, 9) + return DeviceCapability(10, 3) @dataclass diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index f07779df231b..c585badb6d37 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -93,7 +93,7 @@ def get_min_compute_capability(cls) -> Optional[DeviceCapability]: @classmethod def get_max_compute_capability(cls) -> Optional[DeviceCapability]: - return DeviceCapability(10, 9) + return DeviceCapability(10, 3) @staticmethod def get_kv_cache_shape( From 3bee84e162c938c43ec8484a137d7d9587f0b4dc Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 9 Oct 2025 16:03:37 -0400 Subject: [PATCH 16/84] Fix block size adjustment Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 8 +++-- vllm/attention/selector.py | 19 +++++++++++- vllm/platforms/cuda.py | 37 ++---------------------- vllm/v1/attention/backends/flash_attn.py | 4 ++- 4 files changed, 28 insertions(+), 40 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 04d4bcd8aebe..f5b9d19e0999 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -120,7 +120,9 @@ def get_supported_block_sizes(cls) -> list[BlockSize]: return [] @classmethod - def supports_block_size(cls, block_size: int) -> bool: + def supports_block_size(cls, block_size: Optional[int]) -> bool: + if block_size is None: + return True try: block_size_literal = cast(BlockSize, block_size) except ValueError: @@ -164,7 +166,7 @@ def supports_combination( head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, + block_size: Optional[int], use_mla: bool, has_sink: bool, use_sparse: bool, @@ -178,7 +180,7 @@ def validate_configuration( head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, + block_size: Optional[int], use_mla: bool, has_sink: bool, use_sparse: bool, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 30da874bbc5f..2733bd6c9df3 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -149,7 +149,24 @@ def _cached_get_attn_backend( raise ValueError( f"Invalid attention backend for {current_platform.device_name}" ) - return resolve_obj_by_qualname(attention_cls) + backend = resolve_obj_by_qualname(attention_cls) + + # Adjust block size if the selected backend doesn't support it + if not backend.supports_block_size(block_size): + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + if vllm_config and vllm_config.cache_config: + new_block_size = backend.get_supported_block_sizes()[0] + logger.info( + "Adjusting kv cache block size from %d to %d for %s backend.", + block_size, + new_block_size, + backend.get_name(), + ) + vllm_config.cache_config.block_size = new_block_size + + return backend @contextmanager diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index cb0ded968897..7f52e753a1f0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -131,7 +131,6 @@ def log_warnings(cls): @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: parallel_config = vllm_config.parallel_config - model_config = vllm_config.model_config if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" @@ -140,38 +139,6 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: if cache_config and cache_config.block_size is None: cache_config.block_size = 16 - # Note: model_config may be None during testing - if model_config is not None and model_config.use_mla: - # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, - # then we default to FlashMLA backend for non-blackwell GPUs, - # else we default to CutlassMLA. For each case, we force the - # required block_size. - - if envs.VLLM_ATTENTION_BACKEND is None: - # Default case - if cls.is_device_capability(100): - # Blackwell => Force CutlassMLA. - # TODO: This does not work, because the - # global_force_attn_backend_context_manager is not set. - # See vllm/attention/selector.py:_cached_get_attn_backend - envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA" - else: - # Not Blackwell => Force FlashMLA. - envs.VLLM_ATTENTION_BACKEND = "FLASHMLA" - - # Adjust block sizes for MLA backends based on their requirements - from vllm.attention.backends.registry import _Backend, backend_to_class - - backend_enum = _Backend[envs.VLLM_ATTENTION_BACKEND] - backend_class = backend_to_class(backend_enum) - if not backend_class.supports_block_size(cache_config.block_size): - cache_config.block_size = backend_class.get_supported_block_sizes()[0] - logger.info( - "Forcing kv cache block size to %s for %s backend.", - cache_config.block_size, - envs.VLLM_ATTENTION_BACKEND, - ) - # lazy import to avoid circular import from vllm.config import CUDAGraphMode @@ -253,7 +220,7 @@ def get_valid_backends( head_size, dtype, kv_cache_dtype, - block_size, + None, # ignore block_size here, it will be adjusted if needed use_mla, has_sink, use_sparse, @@ -300,7 +267,7 @@ def get_attn_backend_cls( head_size, dtype, kv_cache_dtype, - block_size, + None, # ignore block_size here, it will be adjusted if needed use_mla, has_sink, use_sparse, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e5de1e4d81b1..0ff62817739a 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -113,7 +113,9 @@ def supports_kv_cache_dtype(cls, kv_cache_dtype: Optional[str]) -> bool: return kv_cache_dtype in [None, "auto", "fp16", "bf16"] @classmethod - def supports_block_size(cls, block_size: int) -> bool: + def supports_block_size(cls, block_size: Optional[int]) -> bool: + if block_size is None: + return True return block_size % 16 == 0 @classmethod From 24336693f116136629b123e11bb7ff7823193042 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 14 Oct 2025 11:34:30 -0400 Subject: [PATCH 17/84] split priorities by capability, update flashinfer min capability Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 44 +++++++++++++------ vllm/v1/attention/backends/flashinfer.py | 2 +- .../attention/backends/mla/flashinfer_mla.py | 2 +- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ecef7effb177..82655eab79a5 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -46,23 +46,39 @@ @cache -def _get_backend_priorities(): +def _get_backend_priorities( + device_capability: DeviceCapability | None = None, +) -> dict[_Backend, int]: """Get backend priorities with lazy import to avoid circular dependency.""" from vllm.attention.backends.registry import _Backend - return { - # non-MLA backends - _Backend.FLASHINFER: 0, - _Backend.FLASH_ATTN: 1, - _Backend.TRITON_ATTN: 2, - _Backend.FLEX_ATTENTION: 3, - # MLA backends - _Backend.CUTLASS_MLA: 0, - _Backend.FLASHINFER_MLA: 1, - _Backend.FLASHMLA: 2, - _Backend.FLASH_ATTN_MLA: 3, - _Backend.TRITON_MLA: 4, - } + if device_capability >= DeviceCapability(10, 0): + return { + # non-MLA backends + _Backend.FLASHINFER: 0, + _Backend.FLASH_ATTN: 1, + _Backend.TRITON_ATTN: 2, + _Backend.FLEX_ATTENTION: 3, + # MLA backends + _Backend.CUTLASS_MLA: 0, + _Backend.FLASHINFER_MLA: 1, + _Backend.FLASHMLA: 2, + _Backend.FLASH_ATTN_MLA: 3, + _Backend.TRITON_MLA: 4, + } + else: + return { + # non-MLA backends + _Backend.FLASH_ATTN: 0, + _Backend.FLASHINFER: 1, + _Backend.TRITON_ATTN: 2, + _Backend.FLEX_ATTENTION: 3, + # MLA backends + _Backend.FLASHMLA: 0, + _Backend.FLASH_ATTN_MLA: 1, + _Backend.FLASHINFER_MLA: 2, + _Backend.TRITON_MLA: 3, + } def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4293f65239ed..7626a4fbc889 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -235,7 +235,7 @@ def is_mla(cls) -> bool: @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(10, 0) + return DeviceCapability(7, 5) @classmethod def get_max_compute_capability(cls) -> DeviceCapability | None: diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index f05ef41c7878..a56fb08620d2 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -58,7 +58,7 @@ def get_supported_block_sizes(cls) -> list[BlockSize]: @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(10, 0) + return DeviceCapability(7, 5) @classmethod def get_max_compute_capability(cls) -> DeviceCapability | None: From a3617d79f81f9acf4efc807de910ff926469bcfd Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 15 Oct 2025 09:28:26 -0400 Subject: [PATCH 18/84] change to typing imports Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index aea32c69cc18..9708c4860d78 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,14 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Generic, Protocol, TypeVar, cast +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast import torch -from vllm.config.cache import BlockSize from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey -from vllm.platforms.interface import DeviceCapability + +if TYPE_CHECKING: + from vllm.config.cache import BlockSize + from vllm.platforms.interface import DeviceCapability class AttentionType: From 81d1b7b28add310a2ac7efda0218d2270c2c6ede Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 15 Oct 2025 09:40:19 -0400 Subject: [PATCH 19/84] backends specify their required kv cache layout Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 8 +++++ vllm/platforms/cuda.py | 29 +++++++++++-------- vllm/v1/attention/backends/flashinfer.py | 8 +++++ .../attention/backends/mla/flashinfer_mla.py | 4 +++ 4 files changed, 37 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 9708c4860d78..6adeb171d764 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -237,6 +237,14 @@ def validate_configuration( invalid_reasons.append(combination_reason) return invalid_reasons + @classmethod + def get_required_kv_cache_layout(cls, capability: DeviceCapability) -> str | None: + """ + Some backends require a specific kv cache layout. + This function returns the required layout if any. + """ + return None + class AttentionMetadata: pass diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 82655eab79a5..b0f02b342c40 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -271,7 +271,7 @@ def get_attn_backend_cls( "to select a supported backend." ) - from vllm.attention.backends.registry import _Backend, backend_to_class_str + from vllm.attention.backends.registry import backend_to_class_str device_capability = cls.get_device_capability() @@ -341,31 +341,36 @@ def get_attn_backend_cls( logger.info( "Valid backends: %s", [b[0].name for b in valid_backends_priorities] ) - - valid_backends_classes_str = [ - backend_to_class_str(b[0]) for b in valid_backends_priorities - ] sorted_indices = sorted( range(len(valid_backends_priorities)), key=lambda i: valid_backends_priorities[i][1], ) selected_index = sorted_indices[0] - + selected_backend = valid_backends_priorities[selected_index][0] + selected_backend_class_str = backend_to_class_str(selected_backend) + selected_backend_class = resolve_obj_by_qualname(selected_backend_class_str) engine_version = "V1" if use_v1 else "V0" logger.info( "Using %s backend on %s engine.", - valid_backends_priorities[selected_index][0].name, + selected_backend.name, engine_version, ) - # Post-selection modifications - if valid_backends_priorities[selected_index][0] == _Backend.FLASHINFER_MLA: + # Set required kv cache layout if any + required_layout = selected_backend_class.get_required_kv_cache_layout( + device_capability + ) + if required_layout is not None: from vllm.v1.attention.backends.utils import set_kv_cache_layout - set_kv_cache_layout("HND") - logger.info("Using HND KV cache layout for FlashInferMLA.") + set_kv_cache_layout(required_layout) + logger.info( + "Using %s KV cache layout for %s backend.", + required_layout, + selected_backend.name, + ) - return valid_backends_classes_str[selected_index] + return selected_backend_class_str @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 7626a4fbc889..03e5e42b1453 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -241,6 +241,14 @@ def get_min_compute_capability(cls) -> DeviceCapability | None: def get_max_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(10, 3) + @classmethod + def get_required_kv_cache_layout(cls, capability: DeviceCapability) -> str | None: + if capability >= DeviceCapability(10, 0) and capability <= DeviceCapability( + 10, 3 + ): + return "HND" + return None + @dataclass class FlashInferMetadata: diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index a56fb08620d2..2a58aeef38b3 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -64,6 +64,10 @@ def get_min_compute_capability(cls) -> DeviceCapability | None: def get_max_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(10, 3) + @classmethod + def get_required_kv_cache_layout(cls, capability: DeviceCapability) -> str | None: + return "HND" + g_fi_workspace = torch.zeros( FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, From adaf53b2fcc536c5aa1abd38822a35304b5ea999 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 15 Oct 2025 09:42:55 -0400 Subject: [PATCH 20/84] flashinfer supports up to 12.1 Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/flashinfer.py | 2 +- vllm/v1/attention/backends/mla/flashinfer_mla.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 03e5e42b1453..4a4604ea96fc 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -239,7 +239,7 @@ def get_min_compute_capability(cls) -> DeviceCapability | None: @classmethod def get_max_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(10, 3) + return DeviceCapability(12, 1) @classmethod def get_required_kv_cache_layout(cls, capability: DeviceCapability) -> str | None: diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 2a58aeef38b3..66df9af325f4 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -62,7 +62,7 @@ def get_min_compute_capability(cls) -> DeviceCapability | None: @classmethod def get_max_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(10, 3) + return DeviceCapability(12, 1) @classmethod def get_required_kv_cache_layout(cls, capability: DeviceCapability) -> str | None: From d1f1362625fc2789c6c8af6b577a9f1f935a775c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 15 Oct 2025 09:45:02 -0400 Subject: [PATCH 21/84] is_mla is false in base class Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 4 ---- vllm/v1/attention/backends/flashinfer.py | 4 ---- vllm/v1/attention/backends/flex_attention.py | 4 ---- vllm/v1/attention/backends/triton_attn.py | 4 ---- 5 files changed, 1 insertion(+), 17 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 6adeb171d764..cbc620a04dbf 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -148,7 +148,7 @@ def supports_block_size(cls, block_size: int | None) -> bool: @classmethod def is_mla(cls) -> bool: - raise NotImplementedError + return False @classmethod def supports_sink(cls) -> bool: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index bc14e9b6e2c7..30ae64436c60 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -122,10 +122,6 @@ def supports_block_size(cls, block_size: int | None) -> bool: return True return block_size % 16 == 0 - @classmethod - def is_mla(cls) -> bool: - return False - @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(8, 0) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4a4604ea96fc..cb36c6949995 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -229,10 +229,6 @@ def get_supported_kv_cache_dtypes(cls) -> list[str | None]: def get_supported_block_sizes(cls) -> list[BlockSize]: return [] - @classmethod - def is_mla(cls) -> bool: - return False - @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(7, 5) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index cf4329b5a4ea..3ba965ce8a77 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -118,10 +118,6 @@ def get_supported_kv_cache_dtypes(cls) -> list[str | None]: def get_supported_block_sizes(cls) -> list[BlockSize]: return [] - @classmethod - def is_mla(cls) -> bool: - return False - # @torch.compile(fullgraph=True, mode="reduce-overhead") def physical_to_logical_mapping( diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index ab037857acb8..b3b12ea8357f 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -215,10 +215,6 @@ def get_supported_kv_cache_dtypes(cls) -> list[str | None]: def get_supported_block_sizes(cls) -> list[BlockSize]: return [] - @classmethod - def is_mla(cls) -> bool: - return False - @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: return None From abb83750f7a36ca7252f3af90b1571211d046729 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 15 Oct 2025 09:51:16 -0400 Subject: [PATCH 22/84] triton supports fp8 Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/triton_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index b3b12ea8357f..522b6f755a21 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -209,7 +209,7 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: @classmethod def get_supported_kv_cache_dtypes(cls) -> list[str | None]: - return ["auto", "fp16", "bf16"] + return ["auto", "fp16", "bf16", "fp8", "fp8_e4m3", "fp8_e5m2"] @classmethod def get_supported_block_sizes(cls) -> list[BlockSize]: From 85d8719fdcb3e8ed9220478065cae10dc3b5e9bd Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 15 Oct 2025 10:19:50 -0400 Subject: [PATCH 23/84] use CacheDType Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 10 +++++----- vllm/config/cache.py | 10 +++++++++- vllm/v1/attention/backends/flash_attn.py | 11 +++++++---- vllm/v1/attention/backends/flashinfer.py | 6 +++--- vllm/v1/attention/backends/flex_attention.py | 6 +++--- vllm/v1/attention/backends/mla/cutlass_mla.py | 6 +++--- vllm/v1/attention/backends/mla/flashattn_mla.py | 8 ++++---- vllm/v1/attention/backends/mla/flashinfer_mla.py | 6 +++--- vllm/v1/attention/backends/mla/flashmla.py | 6 +++--- vllm/v1/attention/backends/mla/flashmla_sparse.py | 6 +++--- vllm/v1/attention/backends/mla/triton_mla.py | 6 +++--- vllm/v1/attention/backends/triton_attn.py | 6 +++--- 12 files changed, 49 insertions(+), 38 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index cbc620a04dbf..734e0abeacb4 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey if TYPE_CHECKING: - from vllm.config.cache import BlockSize + from vllm.config.cache import BlockSize, CacheDType from vllm.platforms.interface import DeviceCapability @@ -119,11 +119,11 @@ def supports_dtype(cls, dtype: torch.dtype) -> bool: return (not supported_dtypes) or dtype in supported_dtypes @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[str | None]: + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto"] @classmethod - def supports_kv_cache_dtype(cls, kv_cache_dtype: str | None) -> bool: + def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool: supported_kv_cache_dtypes = cls.get_supported_kv_cache_dtypes() return (not supported_kv_cache_dtypes) or ( kv_cache_dtype is not None and kv_cache_dtype in supported_kv_cache_dtypes @@ -179,7 +179,7 @@ def supports_combination( cls, head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, + kv_cache_dtype: "CacheDType | None", block_size: int | None, use_mla: bool, has_sink: bool, @@ -193,7 +193,7 @@ def validate_configuration( cls, head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, + kv_cache_dtype: "CacheDType | None", block_size: int | None, use_mla: bool, has_sink: bool, diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 04b1e7bf2ac1..0e674854e5b7 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -20,7 +20,15 @@ logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128] -CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] +CacheDType = Literal[ + "auto", + "bfloat16", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + "fp8_inc", + "fp8_ds_mla", +] MambaDType = Literal["auto", "float32"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 30ae64436c60..49646ddc72b5 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -32,6 +32,7 @@ ) from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms.interface import DeviceCapability from vllm.utils import cdiv @@ -111,10 +112,12 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod - def supports_kv_cache_dtype(cls, kv_cache_dtype: str | None) -> bool: - if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: + if kv_cache_dtype is None: + return True + if kv_cache_dtype.startswith("fp8"): return flash_attn_supports_fp8() - return kv_cache_dtype in [None, "auto", "fp16", "bf16"] + return kv_cache_dtype in ["auto"] @classmethod def supports_block_size(cls, block_size: int | None) -> bool: @@ -135,7 +138,7 @@ def supports_combination( cls, head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, + kv_cache_dtype: CacheDType | None, block_size: int, use_mla: bool, has_sink: bool, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cb36c6949995..ca030ea96a8c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -24,7 +24,7 @@ MultipleOf, ) from vllm.config import CUDAGraphMode, VllmConfig -from vllm.config.cache import BlockSize +from vllm.config.cache import BlockSize, CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -222,8 +222,8 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[str | None]: - return ["auto", "fp16", "bf16", "fp8", "fp8_e4m3", "fp8_e5m2"] + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + return ["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] @classmethod def get_supported_block_sizes(cls) -> list[BlockSize]: diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 3ba965ce8a77..22ac9198f547 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -24,7 +24,7 @@ is_quantized_kv_cache, ) from vllm.config import VllmConfig -from vllm.config.cache import BlockSize +from vllm.config.cache import BlockSize, CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_kernel_override_batch_invariant, @@ -111,8 +111,8 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16, torch.float32] @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[str | None]: - return ["auto", "fp16", "bf16"] + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + return ["auto"] @classmethod def get_supported_block_sizes(cls) -> list[BlockSize]: diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index c6a92773fe59..ad101db39b63 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -13,7 +13,7 @@ MultipleOf, is_quantized_kv_cache, ) -from vllm.config.cache import BlockSize +from vllm.config.cache import BlockSize, CacheDType from vllm.logger import init_logger from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( @@ -56,8 +56,8 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[str | None]: - return ["auto", "fp16", "bf16", "e4m3fn"] + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + return ["auto", "fp8", "fp8_e4m3"] @classmethod def get_supported_block_sizes(cls) -> list[BlockSize]: diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 9317d22f977c..48578329ad9e 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -17,7 +17,7 @@ get_flash_attn_version, ) from vllm.config import VllmConfig -from vllm.config.cache import BlockSize +from vllm.config.cache import BlockSize, CacheDType from vllm.logger import init_logger from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( @@ -56,8 +56,8 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[str | None]: - return ["auto", "fp16", "bf16"] + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + return ["auto"] @classmethod def get_supported_block_sizes(cls) -> list[BlockSize]: @@ -76,7 +76,7 @@ def supports_combination( cls, head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, + kv_cache_dtype: CacheDType | None, block_size: int, use_mla: bool, has_sink: bool, diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 66df9af325f4..44010716e05e 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -7,7 +7,7 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla from vllm.attention.backends.abstract import AttentionLayer, AttentionType -from vllm.config.cache import BlockSize +from vllm.config.cache import BlockSize, CacheDType from vllm.logger import init_logger from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( @@ -49,8 +49,8 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[str | None]: - return ["auto", "fp16", "bf16", "fp8", "fp8_e4m3"] + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + return ["auto", "fp8", "fp8_e4m3"] @classmethod def get_supported_block_sizes(cls) -> list[BlockSize]: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index ba41cb1c9117..adece45f76dd 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -13,7 +13,7 @@ is_flashmla_dense_supported, ) from vllm.config import VllmConfig -from vllm.config.cache import BlockSize +from vllm.config.cache import BlockSize, CacheDType from vllm.logger import init_logger from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( @@ -55,8 +55,8 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[str | None]: - return ["auto", "fp16", "bf16", "fp8", "fp8_e4m3"] + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + return ["auto", "fp8", "fp8_e4m3"] @classmethod def get_supported_block_sizes(cls) -> list[BlockSize]: diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index dbafd1292ef5..de03a2d70504 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -19,7 +19,7 @@ get_mla_metadata, ) from vllm.config import VllmConfig -from vllm.config.cache import BlockSize +from vllm.config.cache import BlockSize, CacheDType from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability @@ -76,8 +76,8 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16] @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[str | None]: - return ["auto", "bf16", "fp8_ds_mla"] + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + return ["auto", "fp8_ds_mla"] @classmethod def get_supported_block_sizes(cls) -> list[BlockSize]: diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 2fc3f976f0e0..8d1cef8ab42e 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -12,7 +12,7 @@ ) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention -from vllm.config.cache import BlockSize +from vllm.config.cache import BlockSize, CacheDType from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability @@ -40,8 +40,8 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[str | None]: - return ["auto", "fp16", "bf16"] + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + return ["auto"] @classmethod def get_supported_block_sizes(cls) -> list[BlockSize]: diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 522b6f755a21..19fd6a115e69 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -19,7 +19,7 @@ ) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig -from vllm.config.cache import BlockSize +from vllm.config.cache import BlockSize, CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -208,8 +208,8 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16, torch.float32] @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[str | None]: - return ["auto", "fp16", "bf16", "fp8", "fp8_e4m3", "fp8_e5m2"] + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + return ["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] @classmethod def get_supported_block_sizes(cls) -> list[BlockSize]: From 1ef0417d6110db5d9cb18ac9f589d0a8de7afe2a Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 15 Oct 2025 10:28:49 -0400 Subject: [PATCH 24/84] add todo Signed-off-by: Matthew Bonanni --- vllm/attention/selector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index a11dc59ede05..4f3be1c17ac4 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -152,6 +152,7 @@ def _cached_get_attn_backend( backend = resolve_obj_by_qualname(attention_cls) # Adjust block size if the selected backend doesn't support it + # TODO: per-layer block size configuration if not backend.supports_block_size(block_size): from vllm.config import get_current_vllm_config From 16f937348de26986552a2b97b37fe2dee9499953 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 15 Oct 2025 11:23:08 -0400 Subject: [PATCH 25/84] is_quantized_kv_cache use CacheDType Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 3 ++- vllm/platforms/interface.py | 3 ++- vllm/platforms/rocm.py | 3 ++- vllm/platforms/tpu.py | 4 ++-- vllm/platforms/xpu.py | 3 ++- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 6554862b0eb5..966c8994c747 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig + from vllm.config.cache import CacheDType else: _Backend = None @@ -435,7 +436,7 @@ def device_count(cls) -> int: @classmethod def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" + cls, kv_cache_dtype: "CacheDType", model_config: "ModelConfig" ) -> bool: if not envs.VLLM_ATTENTION_BACKEND: return True diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 89231f9fe621..857874d7ad49 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig + from vllm.config.cache import CacheDType from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import FlexibleArgumentParser @@ -580,7 +581,7 @@ def stateless_init_device_torch_dist_pg( @classmethod def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: ModelConfig + cls, kv_cache_dtype: "CacheDType", model_config: ModelConfig ) -> bool: """ Returns if the kv_cache_dtype is supported by the current platform. diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 3b33978054b7..23afd99eae02 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig + from vllm.config.cache import CacheDType else: _Backend = None @@ -505,7 +506,7 @@ def device_count(cls) -> int: @classmethod def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" + cls, kv_cache_dtype: "CacheDType", model_config: "ModelConfig" ) -> bool: return True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 232fe7844477..d215a7cb1867 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig - from vllm.config.cache import BlockSize + from vllm.config.cache import BlockSize, CacheDType from vllm.pooling_params import PoolingParams else: BlockSize = None @@ -223,7 +223,7 @@ def validate_request( @classmethod def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" + cls, kv_cache_dtype: "CacheDType", model_config: "ModelConfig" ) -> bool: return True diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 7cc5758573f6..96f3ee1e54d4 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig + from vllm.config.cache import CacheDType else: ModelConfig = None VllmConfig = None @@ -86,7 +87,7 @@ def get_attn_backend_cls( @classmethod def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" + cls, kv_cache_dtype: "CacheDType", model_config: "ModelConfig" ) -> bool: """ Check if the kv_cache_dtype is supported. From 8474a1445cb58bbbde468d8f81966aaa17f6e679 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 15 Oct 2025 17:48:03 -0400 Subject: [PATCH 26/84] fix supports_sink Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 2 +- vllm/v1/attention/backends/triton_attn.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 734e0abeacb4..1f3881badc86 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -152,7 +152,7 @@ def is_mla(cls) -> bool: @classmethod def supports_sink(cls) -> bool: - return True + return False @classmethod def is_sparse(cls) -> bool: diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 19fd6a115e69..1b0b3e453605 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -215,6 +215,10 @@ def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: def get_supported_block_sizes(cls) -> list[BlockSize]: return [] + @classmethod + def supports_sink(cls) -> bool: + return True + @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: return None From 62e629052c1a9ffc4e1d5b58e939429bb5a335af Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 15 Oct 2025 17:48:47 -0400 Subject: [PATCH 27/84] fix priority list Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 966c8994c747..859839416152 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -53,7 +53,7 @@ def _get_backend_priorities( """Get backend priorities with lazy import to avoid circular dependency.""" from vllm.attention.backends.registry import _Backend - if device_capability >= DeviceCapability(10, 0): + if device_capability == DeviceCapability(10, 0): return { # non-MLA backends _Backend.FLASHINFER: 0, From 22dd1b8ff69ceeaeffb9da6ad090a47f76a93385 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 16 Oct 2025 10:07:41 -0400 Subject: [PATCH 28/84] fix FA block sizes Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 4 ++-- vllm/v1/attention/backends/mla/flashattn_mla.py | 6 ++++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 1f3881badc86..60a76234aae9 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -134,7 +134,7 @@ def get_supported_block_sizes(cls) -> list[BlockSize]: return [] @classmethod - def supports_block_size(cls, block_size: int | None) -> bool: + def supports_block_size(cls, block_size: BlockSize | None) -> bool: if block_size is None: return True try: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ebe722026cdf..55b218218304 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -33,7 +33,7 @@ ) from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.config.cache import CacheDType +from vllm.config.cache import BlockSize, CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.platforms.interface import DeviceCapability @@ -122,7 +122,7 @@ def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: return kv_cache_dtype in ["auto"] @classmethod - def supports_block_size(cls, block_size: int | None) -> bool: + def supports_block_size(cls, block_size: BlockSize | None) -> bool: if block_size is None: return True return block_size % 16 == 0 diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index b2d641fbc351..0ec740a6433c 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -61,8 +61,10 @@ def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto"] @classmethod - def get_supported_block_sizes(cls) -> list[BlockSize]: - return [] + def supports_block_size(cls, block_size: BlockSize | None) -> bool: + if block_size is None: + return True + return block_size % 16 == 0 @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: From 121d442335301f706ed2d0010dca5b23762f0c3a Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 16 Oct 2025 10:17:17 -0400 Subject: [PATCH 29/84] fix import failure Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 7a3bddc5b307..ca84d4c56649 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -111,7 +111,7 @@ def supports_dtype(cls, dtype: torch.dtype) -> bool: return (not supported_dtypes) or dtype in supported_dtypes @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + def get_supported_kv_cache_dtypes(cls) -> list["CacheDType"]: return ["auto"] @classmethod From 963cc9f1090fd0f387d1beacb6d53e3ade93c135 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 16 Oct 2025 10:49:46 -0400 Subject: [PATCH 30/84] fix import error Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index ca84d4c56649..74807e3cd33b 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -122,11 +122,13 @@ def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool: ) @classmethod - def get_supported_block_sizes(cls) -> list[BlockSize]: + def get_supported_block_sizes(cls) -> list["BlockSize"]: return [] @classmethod - def supports_block_size(cls, block_size: BlockSize | None) -> bool: + def supports_block_size(cls, block_size: "BlockSize | None") -> bool: + from vllm.config.cache import BlockSize + if block_size is None: return True try: From de3f302dd98a3a8ddeae9f3a6783f01a7e13be77 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 20 Oct 2025 09:53:08 -0400 Subject: [PATCH 31/84] fix import error Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 859eff85d6cd..20518ddacc9f 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -153,15 +153,15 @@ def is_sparse(cls) -> bool: return False @classmethod - def get_min_compute_capability(cls) -> DeviceCapability | None: + def get_min_compute_capability(cls) -> "DeviceCapability | None": return None @classmethod - def get_max_compute_capability(cls) -> DeviceCapability | None: + def get_max_compute_capability(cls) -> "DeviceCapability | None": return None @classmethod - def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: min_capability = cls.get_min_compute_capability() max_capability = cls.get_max_compute_capability() return ((min_capability is None) or (capability >= min_capability)) and ( @@ -178,7 +178,7 @@ def supports_combination( use_mla: bool, has_sink: bool, use_sparse: bool, - device_capability: DeviceCapability, + device_capability: "DeviceCapability", ) -> str | None: return None @@ -192,7 +192,7 @@ def validate_configuration( use_mla: bool, has_sink: bool, use_sparse: bool, - device_capability: DeviceCapability, + device_capability: "DeviceCapability", ) -> list[str]: invalid_reasons = [] if not cls.supports_head_size(head_size): @@ -232,7 +232,7 @@ def validate_configuration( return invalid_reasons @classmethod - def get_required_kv_cache_layout(cls, capability: DeviceCapability) -> str | None: + def get_required_kv_cache_layout(cls, capability: "DeviceCapability") -> str | None: """ Some backends require a specific kv cache layout. This function returns the required layout if any. From bc10beefa89b64f2f5641cae08df69f159fb3af5 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 20 Oct 2025 12:10:16 -0400 Subject: [PATCH 32/84] fix import Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index aa2ca59269ff..c043bf25d4b4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -16,7 +16,8 @@ import vllm._C # noqa import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import import_pynvml, resolve_obj_by_qualname +from vllm.utils import import_pynvml +from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum From 05aab3e239937e36c87d56ae3e81783ef2b75fbf Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 21 Oct 2025 09:53:32 -0400 Subject: [PATCH 33/84] fix type error Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/flash_attn.py | 4 ++-- vllm/v1/attention/backends/mla/flashattn_mla.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4d6a70b2c7bb..75c3680d7c34 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -146,9 +146,9 @@ def supports_combination( use_mla: bool, has_sink: bool, use_sparse: bool, - device_capability: int, + device_capability: DeviceCapability, ) -> str | None: - if has_sink and device_capability < 90: + if has_sink and device_capability < DeviceCapability(9, 0): return "sink not supported on compute capability < 9.0" return None diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 0dfd9521b511..aec5bd9b8e00 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -87,7 +87,7 @@ def supports_combination( use_mla: bool, has_sink: bool, use_sparse: bool, - device_capability: int, + device_capability: DeviceCapability, ) -> str | None: if not flash_attn_supports_mla(): return "FlashAttention MLA not supported on this device" From 7936c4748ad7e002c6ed25e954bba8caa699a409 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 21 Oct 2025 10:04:33 -0400 Subject: [PATCH 34/84] add flashmla support test Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashmla.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index afb48cb9f93c..0ae88bf29d1d 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -78,6 +78,27 @@ def get_min_compute_capability(cls) -> DeviceCapability | None: def get_max_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(10, 3) + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if use_sparse: + from vllm.attention.ops.flashmla import is_flashmla_sparse_supported + + return is_flashmla_sparse_supported()[1] + else: + from vllm.attention.ops.flashmla import is_flashmla_dense_supported + + return is_flashmla_dense_supported()[1] + @dataclass class FlashMLADecodeMetadata(MLACommonDecodeMetadata): From 4f0f955741f45f069d1124dc534cb7ea1937a596 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 21 Oct 2025 10:15:00 -0400 Subject: [PATCH 35/84] clean up head size validation Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 6 ++++-- vllm/v1/attention/backends/rocm_attn.py | 5 ++--- vllm/v1/attention/backends/triton_attn.py | 11 ++--------- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0852812c81b8..90fd97d94392 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -430,8 +430,10 @@ class MLACommonMetadata(Generic[D]): ) = None def __post_init__(self): - if self.head_dim is not None: - MLACommonBackend.validate_head_size(self.head_dim) + if self.head_dim is not None and not MLACommonBackend.supports_head_size( + self.head_dim + ): + raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.") M = TypeVar("M", bound=MLACommonMetadata) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 8b7ce90a3cca..2279496636ba 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -164,12 +164,11 @@ def get_supported_head_sizes(cls) -> list[int]: @classmethod def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: + if not cls.supports_head_size(head_size): attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " + f"Supported head sizes are: {cls.get_supported_head_sizes()}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes." ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index efe6e390b6e1..f7236c02f717 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -184,15 +184,8 @@ def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: return TritonAttentionMetadataBuilder @classmethod - def validate_head_size(cls, head_size: int) -> None: - # Triton Attention supports any head size above 32 - if head_size < 32: - raise ValueError( - f"Head size {head_size} is not supported by TritonAttention." - f"Head sizes need to be larger or equal 32 for this backend. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + def supports_head_size(cls, head_size: int) -> bool: + return head_size >= 32 @classmethod def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: From d8b8043cc664adef60d7bee3f14cfa6f00f60947 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 21 Oct 2025 12:20:19 -0400 Subject: [PATCH 36/84] use KVCacheLayoutType Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 5 ++++- vllm/v1/attention/backends/flashinfer.py | 5 ++++- vllm/v1/attention/backends/mla/flashinfer_mla.py | 6 ++++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 20518ddacc9f..aeb0f69a9c99 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from vllm.config.cache import BlockSize, CacheDType from vllm.platforms.interface import DeviceCapability + from vllm.v1.attention.backends.utils import KVCacheLayoutType class AttentionType: @@ -232,7 +233,9 @@ def validate_configuration( return invalid_reasons @classmethod - def get_required_kv_cache_layout(cls, capability: "DeviceCapability") -> str | None: + def get_required_kv_cache_layout( + cls, capability: "DeviceCapability" + ) -> "KVCacheLayoutType | None": """ Some backends require a specific kv cache layout. This function returns the required layout if any. diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 25f44cac601e..ac93c81e6432 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -46,6 +46,7 @@ AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + KVCacheLayoutType, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, @@ -241,7 +242,9 @@ def get_max_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(12, 1) @classmethod - def get_required_kv_cache_layout(cls, capability: DeviceCapability) -> str | None: + def get_required_kv_cache_layout( + cls, capability: DeviceCapability + ) -> KVCacheLayoutType | None: if capability >= DeviceCapability(10, 0) and capability <= DeviceCapability( 10, 3 ): diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index ad45e91e6642..8d2d47a30aa6 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -17,7 +17,7 @@ MLACommonMetadataBuilder, QueryLenSupport, ) -from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType logger = init_logger(__name__) @@ -63,7 +63,9 @@ def get_max_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(12, 1) @classmethod - def get_required_kv_cache_layout(cls, capability: DeviceCapability) -> str | None: + def get_required_kv_cache_layout( + cls, capability: DeviceCapability + ) -> "KVCacheLayoutType | None": return "HND" From a3ccbba5936a7c70d61ee66d9b1c89d8f5c0c901 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 21 Oct 2025 12:27:22 -0400 Subject: [PATCH 37/84] move selector layout change to same place as block size change Signed-off-by: Matthew Bonanni --- vllm/attention/selector.py | 13 +++++++++++++ vllm/platforms/cuda.py | 15 --------------- vllm/v1/attention/backends/flashinfer.py | 6 ++++-- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index cd6908b4b5d2..4262f1681585 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -168,6 +168,19 @@ def _cached_get_attn_backend( ) vllm_config.cache_config.block_size = new_block_size + # Adjust kv cache layout if the selected backend requires a specific one + device_capability = current_platform.get_device_capability() + required_layout = backend.get_required_kv_cache_layout(device_capability) + if required_layout is not None: + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + set_kv_cache_layout(required_layout) + logger.info( + "Using %s KV cache layout for %s backend.", + required_layout, + backend.get_name(), + ) + return backend diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c043bf25d4b4..844a8fa285cc 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -344,7 +344,6 @@ def get_attn_backend_cls( selected_index = sorted_indices[0] selected_backend = valid_backends_priorities[selected_index][0] selected_backend_class_str = backend_to_class_str(selected_backend) - selected_backend_class = resolve_obj_by_qualname(selected_backend_class_str) engine_version = "V1" if use_v1 else "V0" logger.info( "Using %s backend on %s engine.", @@ -352,20 +351,6 @@ def get_attn_backend_cls( engine_version, ) - # Set required kv cache layout if any - required_layout = selected_backend_class.get_required_kv_cache_layout( - device_capability - ) - if required_layout is not None: - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - set_kv_cache_layout(required_layout) - logger.info( - "Using %s KV cache layout for %s backend.", - required_layout, - selected_backend.name, - ) - return selected_backend_class_str @classmethod diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index ac93c81e6432..3766edc0b0ee 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -245,8 +245,10 @@ def get_max_compute_capability(cls) -> DeviceCapability | None: def get_required_kv_cache_layout( cls, capability: DeviceCapability ) -> KVCacheLayoutType | None: - if capability >= DeviceCapability(10, 0) and capability <= DeviceCapability( - 10, 3 + if ( + capability is not None + and capability >= DeviceCapability(10, 0) + and capability <= DeviceCapability(10, 3) ): return "HND" return None From 3285c2c7b79c09ab99193fe4e47c77a34d09240f Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 21 Oct 2025 13:06:18 -0400 Subject: [PATCH 38/84] MLA only supports head size 576 Signed-off-by: Matthew Bonanni --- tests/kernels/attention/test_attention_selector.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 48a42ce6ffab..cf75b13ee5bb 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -143,7 +143,7 @@ def test_env( pytest.skip("CUTLASS_MLA only supports block_size 128") else: backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, block_size, use_mla=use_mla ) expected = "CUTLASS_MLA" assert backend.get_name() == expected @@ -155,7 +155,7 @@ def test_env( ) else: backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, block_size, use_mla=use_mla ) expected = "FLASHINFER_MLA" assert backend.get_name() == expected @@ -173,20 +173,24 @@ def test_env( pytest.skip("FlashMLA not supported on this platform") else: backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, + torch.float16, + None, + block_size, + use_mla=use_mla, ) expected = name assert backend.get_name() == expected elif name == "FLASH_ATTN_MLA": backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, block_size, use_mla=use_mla ) expected = "FLASH_ATTN_MLA" assert backend.get_name() == expected else: # TRITON_MLA or other fallback backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, block_size, use_mla=use_mla ) expected = "TRITON_MLA" assert backend.get_name() == expected From 6eab50470ad914d4781001f972c2e0085c53af11 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 21 Oct 2025 13:06:40 -0400 Subject: [PATCH 39/84] fix kv_cache_dtype support logic Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index aeb0f69a9c99..0d9501eb79b4 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -117,9 +117,11 @@ def get_supported_kv_cache_dtypes(cls) -> list["CacheDType"]: @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool: + if kv_cache_dtype is None: + return True supported_kv_cache_dtypes = cls.get_supported_kv_cache_dtypes() return (not supported_kv_cache_dtypes) or ( - kv_cache_dtype is not None and kv_cache_dtype in supported_kv_cache_dtypes + kv_cache_dtype in supported_kv_cache_dtypes ) @classmethod From 5523dacbc6a00a3f9ed1ee33ec5a342b40769fe5 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 21 Oct 2025 20:13:28 +0000 Subject: [PATCH 40/84] fix test Signed-off-by: Matthew Bonanni --- tests/v1/worker/test_gpu_model_runner.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index e985578f05ec..bfc9b0f2acd1 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -425,13 +425,19 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): # This test checks if GPUModelRunner initializes correctly when an attention # backend enforces a non-default KV cache stride order. n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config) - expected_kv_cache_shape = [ - 2, - NUM_BLOCKS, - BLOCK_SIZE, - n_heads, - model_runner.model_config.get_head_size(), - ] + head_size = model_runner.model_config.get_head_size() + + # Get the expected shape from the backend's get_kv_cache_shape method + # to ensure compatibility with different backends (triton vs flexattention) + attn_backend = None + for attn_group in model_runner._attn_group_iterator(): + attn_backend = attn_group.backend + break + + expected_kv_cache_shape = list(attn_backend.get_kv_cache_shape( + NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size + )) + # TODO mla test default_stride = tuple(range(5)) # Permutation that gets you back to expected kv shape From 58fc8889dbe9174222c0d179c68867d916a6ce9c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 21 Oct 2025 20:43:05 +0000 Subject: [PATCH 41/84] skip FA MLA if test is run on hardware where it's not supported Signed-off-by: Matthew Bonanni --- .../attention/test_attention_selector.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index cf75b13ee5bb..08d780141529 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -182,11 +182,20 @@ def test_env( expected = name assert backend.get_name() == expected elif name == "FLASH_ATTN_MLA": - backend = get_attn_backend( - 576, torch.float16, None, block_size, use_mla=use_mla + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_mla, ) - expected = "FLASH_ATTN_MLA" - assert backend.get_name() == expected + + if not flash_attn_supports_mla(): + pytest.skip( + "FlashAttention MLA not supported on this platform" + ) + else: + backend = get_attn_backend( + 576, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "FLASH_ATTN_MLA" + assert backend.get_name() == expected else: # TRITON_MLA or other fallback backend = get_attn_backend( From 17fd9546d925db36989115b2274f412267c948ee Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 21 Oct 2025 21:03:15 +0000 Subject: [PATCH 42/84] fix test Signed-off-by: Matthew Bonanni --- tests/compile/test_fusion_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index fecb1e2e918f..7f510c221e93 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -314,6 +314,7 @@ def test_attention_quant_pattern( custom_ops_list = custom_ops.split(",") if custom_ops else [] device = torch.device("cuda:0") + torch.set_default_dtype(dtype) torch.manual_seed(42) vllm_config = VllmConfig( From 2b237127bbfee99a774fbc178d329ced1f3d6e42 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 21 Oct 2025 21:42:53 +0000 Subject: [PATCH 43/84] fix pre-commit Signed-off-by: Matthew Bonanni --- tests/v1/worker/test_gpu_model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index bfc9b0f2acd1..a45767de3eb8 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -434,9 +434,9 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): attn_backend = attn_group.backend break - expected_kv_cache_shape = list(attn_backend.get_kv_cache_shape( - NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size - )) + expected_kv_cache_shape = list( + attn_backend.get_kv_cache_shape(NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size) + ) # TODO mla test default_stride = tuple(range(5)) From fc1d3f3bfe95c6f291c6da8eaca67a1a6b6efb03 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 22 Oct 2025 10:52:34 -0400 Subject: [PATCH 44/84] fix head size Signed-off-by: Matthew Bonanni --- tests/kernels/attention/test_attention_selector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 08d780141529..4f5156f1d829 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -205,7 +205,7 @@ def test_env( assert backend.get_name() == expected elif name == "FLASHINFER": backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 64, torch.float16, None, block_size, use_mla=use_mla ) expected = "FLASHINFER" assert backend.get_name() == expected From ecdef49dd75c728c6883be741eb70a5d6d94296a Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 22 Oct 2025 10:57:10 -0400 Subject: [PATCH 45/84] fix pre-commit Signed-off-by: Matthew Bonanni --- tests/v1/worker/test_gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index a45767de3eb8..c253560a72b9 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -434,6 +434,7 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): attn_backend = attn_group.backend break + assert attn_backend is not None, "No attention backend found" expected_kv_cache_shape = list( attn_backend.get_kv_cache_shape(NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size) ) From 9008e56e74641cb9ee9025752c50c04e0a7e5432 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 22 Oct 2025 11:13:01 -0400 Subject: [PATCH 46/84] flashinfer_mla only support blackwell (only uses TRTLLM kernels) Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashinfer_mla.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 8d2d47a30aa6..bd5b222a23b8 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -56,11 +56,11 @@ def get_supported_block_sizes(cls) -> list[BlockSize]: @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(7, 5) + return DeviceCapability(10, 0) @classmethod def get_max_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(12, 1) + return DeviceCapability(10, 3) @classmethod def get_required_kv_cache_layout( From b756ceb3fbaf95404ed952905168e3033599b39c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 22 Oct 2025 13:19:06 -0400 Subject: [PATCH 47/84] compute capability checks Signed-off-by: Matthew Bonanni --- .../attention/test_attention_selector.py | 76 ++++++++++--------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 4f5156f1d829..1e7873b7f145 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -127,12 +127,13 @@ def test_env( elif device == "cuda": with patch("vllm.platforms.current_platform", CudaPlatform()): + capability = torch.cuda.get_device_capability() if use_mla: # CUDA MLA backend logic: # - CUTLASS_MLA: only supported with block_size == 128 - # and Blackwell GPUs (SM 10.0), V1 only + # and Blackwell GPUs (SM 10.x), V1 only # - FLASHINFER_MLA: only supported on Blackwell GPUs - # (SM 10.0+), V1 only + # (SM 10.x), V1 only # - FLASHMLA: only supported with block_size == 64 # - FLASH_ATTN_MLA: V1 only # - TRITON_MLA: fallback for other cases @@ -141,46 +142,48 @@ def test_env( if block_size != 128: # CUTLASS_MLA only supports block_size == 128 pytest.skip("CUTLASS_MLA only supports block_size 128") - else: - backend = get_attn_backend( - 576, torch.float16, None, block_size, use_mla=use_mla - ) - expected = "CUTLASS_MLA" - assert backend.get_name() == expected + if capability[0] != 10: + pytest.skip("CUTLASS MLA is not supported on this platform") + backend = get_attn_backend( + 576, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "CUTLASS_MLA" + assert backend.get_name() == expected elif name == "FLASHINFER_MLA": + if capability[0] != 10: + pytest.skip( + "FlashInfer MLA is not supported on this platform" + ) if block_size not in [32, 64]: # FlashInfer MLA only supports block_size 32 or 64 pytest.skip( "FlashInfer MLA only supports block_size 32 or 64" ) - else: - backend = get_attn_backend( - 576, torch.float16, None, block_size, use_mla=use_mla - ) - expected = "FLASHINFER_MLA" - assert backend.get_name() == expected + backend = get_attn_backend( + 576, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "FLASHINFER_MLA" + assert backend.get_name() == expected elif name == "FLASHMLA": if block_size != 64: # FlashMLA only supports block_size == 64 pytest.skip("FlashMLA only supports block_size 64") - else: - from vllm.v1.attention.backends.mla.flashmla import ( - is_flashmla_dense_supported, - ) + from vllm.v1.attention.backends.mla.flashmla import ( + is_flashmla_dense_supported, + ) - is_supported, _ = is_flashmla_dense_supported() - if not is_supported: - pytest.skip("FlashMLA not supported on this platform") - else: - backend = get_attn_backend( - 576, - torch.float16, - None, - block_size, - use_mla=use_mla, - ) - expected = name - assert backend.get_name() == expected + is_supported, _ = is_flashmla_dense_supported() + if not is_supported: + pytest.skip("FlashMLA not supported on this platform") + backend = get_attn_backend( + 576, + torch.float16, + None, + block_size, + use_mla=use_mla, + ) + expected = name + assert backend.get_name() == expected elif name == "FLASH_ATTN_MLA": from vllm.attention.utils.fa_utils import ( flash_attn_supports_mla, @@ -190,12 +193,11 @@ def test_env( pytest.skip( "FlashAttention MLA not supported on this platform" ) - else: - backend = get_attn_backend( - 576, torch.float16, None, block_size, use_mla=use_mla - ) - expected = "FLASH_ATTN_MLA" - assert backend.get_name() == expected + backend = get_attn_backend( + 576, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "FLASH_ATTN_MLA" + assert backend.get_name() == expected else: # TRITON_MLA or other fallback backend = get_attn_backend( From afccece026cc104de9a2916df1e8c82684f3bfd5 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 22 Oct 2025 15:11:26 -0400 Subject: [PATCH 48/84] remove reference to backend_name_to_enum Signed-off-by: Matthew Bonanni --- vllm/config/multimodal.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index e80d072dab45..f50c6a20e67f 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -163,22 +163,21 @@ def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None: from vllm.attention.backends.registry import ( _Backend as BackendEnum, ) - from vllm.attention.backends.registry import ( - backend_name_to_enum, - ) if value is None or isinstance(value, BackendEnum): return value - if isinstance(value, str): - candidate = backend_name_to_enum(value.upper()) - if candidate is not None: - return candidate - - valid_backends = ", ".join(sorted(BackendEnum.__members__.keys())) - raise ValueError( - f"Invalid mm encoder attention backend. Expected one of: {valid_backends}." + assert isinstance(value, str), ( + "mm_encoder_attn_backend must be a string or a BackendEnum." ) + try: + return BackendEnum[value.upper()] + except KeyError as exc: + valid_backends = ", ".join(sorted(BackendEnum.__members__.keys())) + raise ValueError( + f"Invalid mm encoder attention backend. " + f"Expected one of: {valid_backends}." + ) from exc @model_validator(mode="after") def _validate_multimodal_config(self): From 33cb1ef787796be7cb8f591455f18ec06fce8cfd Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 22 Oct 2025 15:18:28 -0400 Subject: [PATCH 49/84] fix default block size Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 10 ++++++++++ vllm/attention/selector.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 4 ++++ vllm/v1/attention/backends/mla/flashattn_mla.py | 4 ++++ 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0d9501eb79b4..5aefe76b3583 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -143,6 +143,16 @@ def supports_block_size(cls, block_size: "BlockSize | None") -> bool: not supported_block_sizes ) or block_size_literal in supported_block_sizes + @classmethod + def get_default_block_size(cls) -> "BlockSize": + supported_block_sizes = cls.get_supported_block_sizes() + if not supported_block_sizes: + raise ValueError( + f"Fallback failed, no explicitly supported block sizes for " + f"backend {cls.get_name()}" + ) + return supported_block_sizes[0] + @classmethod def is_mla(cls) -> bool: return False diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4262f1681585..0d6ed072c7ca 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -159,7 +159,7 @@ def _cached_get_attn_backend( vllm_config = get_current_vllm_config() if vllm_config and vllm_config.cache_config: - new_block_size = backend.get_supported_block_sizes()[0] + new_block_size = backend.get_default_block_size() logger.info( "Adjusting kv cache block size from %d to %d for %s backend.", block_size, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 75c3680d7c34..f2ac56a88a90 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -128,6 +128,10 @@ def supports_block_size(cls, block_size: BlockSize | None) -> bool: return True return block_size % 16 == 0 + @classmethod + def get_default_block_size(cls) -> BlockSize: + return 16 + @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(8, 0) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index aec5bd9b8e00..c8c07e911aa4 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -69,6 +69,10 @@ def supports_block_size(cls, block_size: BlockSize | None) -> bool: return True return block_size % 16 == 0 + @classmethod + def get_default_block_size(cls) -> BlockSize: + return 16 + @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(9, 0) From 3f5439e9f1ee4f04e3d3ca2a4cba79d51cd5e28a Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 23 Oct 2025 10:31:50 -0400 Subject: [PATCH 50/84] improve logs Signed-off-by: Matthew Bonanni --- vllm/attention/selector.py | 2 +- vllm/platforms/cuda.py | 37 ++++++++++++++++++++----------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 0d6ed072c7ca..11266c141dbd 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -161,7 +161,7 @@ def _cached_get_attn_backend( if vllm_config and vllm_config.cache_config: new_block_size = backend.get_default_block_size() logger.info( - "Adjusting kv cache block size from %d to %d for %s backend.", + "Adjusting KV cache block size from %d to %d for %s backend.", block_size, new_block_size, backend.get_name(), diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 6e18bfc6fd30..948ebae393d2 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -312,22 +312,27 @@ def get_attn_backend_cls( use_sparse, device_capability, ) - - if len(valid_backends_priorities) == 0: - reasons_str = ( - "{" - + ", ".join( - f"{backend.name}: [{', '.join(reasons)}]" - for backend, reasons in invalid_reasons.items() - ) - + "}" + reasons_str = ( + "{" + + ", ".join( + f"{backend.name}: [{', '.join(reasons)}]" + for backend, reasons in invalid_reasons.items() ) + + "}" + ) + config_str = ( + f"head_size: {head_size}, dtype: {dtype}, " + f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, " + f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}" + ) + logger.debug_once( + f"Some attention backends are not valid for {cls.device_name} with " + f"{config_str}. Reasons: {reasons_str}." + ) + if len(valid_backends_priorities) == 0: raise ValueError( - f"No valid attention backend from priority list for " - f"{cls.device_name} with head_size: {head_size}, " - f"dtype: {dtype}, kv_cache_dtype: {kv_cache_dtype}, " - f"use_mla: {use_mla}, has_sink: {has_sink}, " - f"use_sparse: {use_sparse}. Reasons: {reasons_str}" + f"No valid attention backend found for {cls.device_name} " + f"with {config_str}. Reasons: {reasons_str}." ) # We have found some valid backends. Select the one with the @@ -342,11 +347,9 @@ def get_attn_backend_cls( selected_index = sorted_indices[0] selected_backend = valid_backends_priorities[selected_index][0] selected_backend_class_str = backend_to_class_str(selected_backend) - engine_version = "V1" if use_v1 else "V0" logger.info( - "Using %s backend on %s engine.", + "Using %s backend.", selected_backend.name, - engine_version, ) return selected_backend_class_str From 75fce85cd3786886c7bd96229551aa31f4f915df Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 23 Oct 2025 10:40:06 -0400 Subject: [PATCH 51/84] fix block size support Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashinfer_mla.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index bd5b222a23b8..84e2f062a693 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -51,8 +51,14 @@ def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto", "fp8", "fp8_e4m3"] @classmethod - def get_supported_block_sizes(cls) -> list[BlockSize]: - return [32, 64] + def supports_block_size(cls, block_size: BlockSize | None) -> bool: + if block_size is None: + return True + return (block_size == 32) or (block_size % 64 == 0) + + @classmethod + def get_default_block_size(cls) -> BlockSize: + return 64 @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: From ba51339c9f9fbe3931e9695de3b8a6572c155859 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 23 Oct 2025 11:32:01 -0400 Subject: [PATCH 52/84] fix getting priority list Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 948ebae393d2..2b6431e04459 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -46,7 +46,9 @@ def _get_backend_priorities( """Get backend priorities with lazy import to avoid circular dependency.""" from vllm.attention.backends.registry import _Backend - if device_capability == DeviceCapability(10, 0): + if device_capability >= DeviceCapability(10, 0) and ( + device_capability < DeviceCapability(11, 0) + ): return { # non-MLA backends _Backend.FLASHINFER: 0, @@ -220,7 +222,7 @@ def get_valid_backends( invalid_reasons = {} from vllm.attention.backends.registry import _Backend, backend_to_class - backend_priorities = _get_backend_priorities() + backend_priorities = _get_backend_priorities(device_capability) for backend in _Backend: if backend not in backend_priorities: continue From d49fbf9b8314124d78bbc18ec71e5892eb0bc96f Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 24 Oct 2025 11:48:23 -0400 Subject: [PATCH 53/84] remove redundant block size methods Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 64 ++++++++++++------- vllm/attention/selector.py | 1 + vllm/v1/attention/backends/flash_attn.py | 14 +--- vllm/v1/attention/backends/flashinfer.py | 8 +-- vllm/v1/attention/backends/flex_attention.py | 6 +- vllm/v1/attention/backends/mla/cutlass_mla.py | 8 +-- .../attention/backends/mla/flashattn_mla.py | 17 ++--- .../attention/backends/mla/flashinfer_mla.py | 20 +++--- vllm/v1/attention/backends/mla/flashmla.py | 8 +-- .../attention/backends/mla/flashmla_sparse.py | 11 ++-- vllm/v1/attention/backends/mla/indexer.py | 2 +- vllm/v1/attention/backends/mla/triton_mla.py | 6 +- vllm/v1/attention/backends/rocm_aiter_fa.py | 2 +- vllm/v1/attention/backends/tree_attn.py | 2 +- vllm/v1/attention/backends/triton_attn.py | 8 +-- vllm/v1/attention/backends/xformers.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- 17 files changed, 81 insertions(+), 100 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 5aefe76b3583..aa486148091b 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast, get_args import torch @@ -62,8 +62,8 @@ def get_metadata_cls() -> type["AttentionMetadata"]: raise NotImplementedError @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: - return cls.get_impl_cls().get_supported_kernel_block_size() + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: + return [MultipleOf(1)] @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": @@ -125,33 +125,54 @@ def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool: ) @classmethod - def get_supported_block_sizes(cls) -> list["BlockSize"]: - return [] - - @classmethod - def supports_block_size(cls, block_size: "BlockSize | None") -> bool: - from vllm.config.cache import BlockSize - + def supports_block_size(cls, block_size: int | None) -> bool: if block_size is None: return True - try: - block_size_literal = cast(BlockSize, block_size) - except ValueError: + + valid_sizes = get_args(BlockSize) + if block_size not in valid_sizes: return False - supported_block_sizes = cls.get_supported_block_sizes() - return ( - not supported_block_sizes - ) or block_size_literal in supported_block_sizes + + supported_block_sizes = cls.get_supported_kernel_block_sizes() + if not supported_block_sizes: + return True + + for supported_size in supported_block_sizes: + is_multiple_of = ( + isinstance(supported_size, MultipleOf) + and block_size % supported_size.base == 0 + ) + is_int_divisor = ( + isinstance(supported_size, int) and block_size % supported_size == 0 + ) + if is_multiple_of or is_int_divisor: + return True + return False @classmethod def get_default_block_size(cls) -> "BlockSize": - supported_block_sizes = cls.get_supported_block_sizes() + from vllm.config.cache import BlockSize + + valid_sizes = get_args(BlockSize) + + supported_block_sizes = cls.get_supported_kernel_block_sizes() if not supported_block_sizes: raise ValueError( f"Fallback failed, no explicitly supported block sizes for " f"backend {cls.get_name()}" ) - return supported_block_sizes[0] + + block_size = supported_block_sizes[0] + if isinstance(block_size, MultipleOf): + block_size = block_size.base + + if block_size not in valid_sizes: + raise ValueError( + f"Default block size {block_size} for backend {cls.get_name()} is not " + f"a valid BlockSize." + ) + + return cast(BlockSize, block_size) @classmethod def is_mla(cls) -> bool: @@ -326,11 +347,6 @@ def __init__( ) -> None: raise NotImplementedError - @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: - # TODO: implement this function for all backends. - return [MultipleOf(1)] - @abstractmethod def forward( self, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 11266c141dbd..57aa528cef1d 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -153,6 +153,7 @@ def _cached_get_attn_backend( backend = resolve_obj_by_qualname(attention_cls) # Adjust block size if the selected backend doesn't support it + # The hybrid block table will handle mapping between allocation and kernel sizes # TODO: per-layer block size configuration if not backend.supports_block_size(block_size): from vllm.config import get_current_vllm_config diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f2ac56a88a90..8c369c569473 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -32,7 +32,7 @@ reshape_and_cache_flash, ) from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.config.cache import BlockSize, CacheDType +from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( @@ -107,7 +107,7 @@ def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: return [MultipleOf(16)] @classmethod @@ -122,16 +122,6 @@ def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: return flash_attn_supports_fp8() return kv_cache_dtype in ["auto"] - @classmethod - def supports_block_size(cls, block_size: BlockSize | None) -> bool: - if block_size is None: - return True - return block_size % 16 == 0 - - @classmethod - def get_default_block_size(cls) -> BlockSize: - return 16 - @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(8, 0) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 3766edc0b0ee..88eca4bb9266 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -23,7 +23,7 @@ MultipleOf, ) from vllm.config import CUDAGraphMode, VllmConfig -from vllm.config.cache import BlockSize, CacheDType +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -215,7 +215,7 @@ def get_supported_head_sizes(cls) -> list[int]: return [64, 128, 256] @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: # Note: Not sure for all platforms, # but on Blackwell, only support a page size of # 16, 32, 64 @@ -229,10 +229,6 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] - @classmethod - def get_supported_block_sizes(cls) -> list[BlockSize]: - return [] - @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(7, 5) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index d06f316a501e..f8eae5aa5ab2 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -24,7 +24,7 @@ is_quantized_kv_cache, ) from vllm.config import VllmConfig -from vllm.config.cache import BlockSize, CacheDType +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -115,10 +115,6 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto"] - @classmethod - def get_supported_block_sizes(cls) -> list[BlockSize]: - return [] - # @torch.compile(fullgraph=True, mode="reduce-overhead") def physical_to_logical_mapping( diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index bcb17fe670e2..1500a55e2d74 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -13,7 +13,7 @@ MultipleOf, is_quantized_kv_cache, ) -from vllm.config.cache import BlockSize, CacheDType +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( @@ -48,7 +48,7 @@ def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: return CutlassMLAMetadataBuilder @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: return [128] @classmethod @@ -59,10 +59,6 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto", "fp8", "fp8_e4m3"] - @classmethod - def get_supported_block_sizes(cls) -> list[BlockSize]: - return [128] - @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(10, 0) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index c8c07e911aa4..e739ec4311a1 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import ( AttentionLayer, AttentionType, + MultipleOf, is_quantized_kv_cache, ) from vllm.attention.utils.fa_utils import ( @@ -17,7 +18,7 @@ get_flash_attn_version, ) from vllm.config import VllmConfig -from vllm.config.cache import BlockSize, CacheDType +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -60,18 +61,12 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: - return ["auto"] - - @classmethod - def supports_block_size(cls, block_size: BlockSize | None) -> bool: - if block_size is None: - return True - return block_size % 16 == 0 + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: + return [MultipleOf(16)] @classmethod - def get_default_block_size(cls) -> BlockSize: - return 16 + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + return ["auto"] @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 84e2f062a693..fe0e88bd8266 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -6,7 +6,11 @@ import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla -from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + MultipleOf, +) from vllm.config.cache import BlockSize, CacheDType from vllm.logger import init_logger from vllm.platforms.interface import DeviceCapability @@ -47,19 +51,17 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: - return ["auto", "fp8", "fp8_e4m3"] - - @classmethod - def supports_block_size(cls, block_size: BlockSize | None) -> bool: - if block_size is None: - return True - return (block_size == 32) or (block_size % 64 == 0) + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: + return [32, MultipleOf(64)] @classmethod def get_default_block_size(cls) -> BlockSize: return 64 + @classmethod + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + return ["auto", "fp8", "fp8_e4m3"] + @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(10, 0) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index da1d69423bd6..2faf74dc2752 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -13,7 +13,7 @@ is_flashmla_dense_supported, ) from vllm.config import VllmConfig -from vllm.config.cache import BlockSize, CacheDType +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -55,7 +55,7 @@ def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: return [64] @classmethod @@ -66,10 +66,6 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto", "fp8", "fp8_e4m3"] - @classmethod - def get_supported_block_sizes(cls) -> list[BlockSize]: - return [64] - @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: return DeviceCapability(9, 0) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index de03a2d70504..11612235ad13 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -11,6 +11,7 @@ AttentionBackend, AttentionLayer, AttentionMetadata, + MultipleOf, ) from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.flashmla import ( @@ -19,7 +20,7 @@ get_mla_metadata, ) from vllm.config import VllmConfig -from vllm.config.cache import BlockSize, CacheDType +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability @@ -76,12 +77,12 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16] @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: - return ["auto", "fp8_ds_mla"] + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: + return [64] @classmethod - def get_supported_block_sizes(cls) -> list[BlockSize]: - return [64] + def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: + return ["auto", "fp8_ds_mla"] @classmethod def is_sparse(cls) -> bool: diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 49009a939d0b..661787c3aa58 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -52,7 +52,7 @@ def get_kv_cache_stride_order() -> tuple[int, ...]: return (0, 1, 2) @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: return [64] diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 7d1fd99b32a5..5cc78eef2881 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -12,7 +12,7 @@ ) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention -from vllm.config.cache import BlockSize, CacheDType +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -46,10 +46,6 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto"] - @classmethod - def get_supported_block_sizes(cls) -> list[BlockSize]: - return [] - @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: return None diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 3963a75872d1..3e7a9b7f6e36 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -388,7 +388,7 @@ def get_kv_cache_shape( return (2, num_blocks, block_size, num_kv_heads, head_size) @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: return [MultipleOf(16)] diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index e0529b3642b8..8df5f5b04617 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -41,7 +41,7 @@ def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: return [MultipleOf(16)] @staticmethod diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index f7236c02f717..31410b47509f 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -19,7 +19,7 @@ ) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig -from vllm.config.cache import BlockSize, CacheDType +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -188,7 +188,7 @@ def supports_head_size(cls, head_size: int) -> bool: return head_size >= 32 @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: return [MultipleOf(16)] @classmethod @@ -199,10 +199,6 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] - @classmethod - def get_supported_block_sizes(cls) -> list[BlockSize]: - return [] - @classmethod def supports_sink(cls) -> bool: return True diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index f1f2a5c3f8a4..47847d6fa6f7 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -82,7 +82,7 @@ def get_supported_head_sizes(cls) -> list[int]: ] @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: return [MultipleOf(16)] @staticmethod diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b2d99a0ec69b..092a41827b72 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4157,7 +4157,7 @@ def _find_compatible_block_sizes( Raises: ValueError: If no compatible block size found """ - supported_block_size = backend_cls.get_supported_kernel_block_size() + supported_block_size = backend_cls.get_supported_kernel_block_sizes() compatible_sizes = [] for block_size in supported_block_size: From b18a19361e881038387aef4d433044492ec32cf3 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 24 Oct 2025 11:56:38 -0400 Subject: [PATCH 54/84] fix import Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index aa486148091b..dc3b61e8552e 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -126,6 +126,8 @@ def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool: @classmethod def supports_block_size(cls, block_size: int | None) -> bool: + from vllm.config.cache import BlockSize + if block_size is None: return True @@ -153,8 +155,6 @@ def supports_block_size(cls, block_size: int | None) -> bool: def get_default_block_size(cls) -> "BlockSize": from vllm.config.cache import BlockSize - valid_sizes = get_args(BlockSize) - supported_block_sizes = cls.get_supported_kernel_block_sizes() if not supported_block_sizes: raise ValueError( @@ -166,6 +166,7 @@ def get_default_block_size(cls) -> "BlockSize": if isinstance(block_size, MultipleOf): block_size = block_size.base + valid_sizes = get_args(BlockSize) if block_size not in valid_sizes: raise ValueError( f"Default block size {block_size} for backend {cls.get_name()} is not " From 0e0cb6d81e12adb15f24dddffff5a3abef787ee7 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 24 Oct 2025 17:34:44 -0400 Subject: [PATCH 55/84] raise error instead of implicitly changing backend Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2b6431e04459..c703a4ccc180 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -289,17 +289,12 @@ def get_attn_backend_cls( except ImportError: invalid_reasons = ["ImportError"] if invalid_reasons: - logger.warning( - "Selected backend %s is not valid for this configuration. " - "Reason: %s", - selected_backend, - invalid_reasons, + raise ValueError( + f"Selected backend {selected_backend} is not valid for " + f"this configuration. Reason: {invalid_reasons}" ) else: - engine_version = "V1" if use_v1 else "V0" - logger.info( - "Using %s backend on %s engine.", selected_backend, engine_version - ) + logger.info("Using %s backend.", selected_backend) return backend_class_str # No selected backend or the selected backend is invalid, From 1eefe900fa35f68977fca8fb5e1c701386ad1508 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 27 Oct 2025 12:56:18 -0400 Subject: [PATCH 56/84] don't ignore block size Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index d72699305f93..e34f8c926a5d 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -232,7 +232,7 @@ def get_valid_backends( head_size, dtype, kv_cache_dtype, - None, # ignore block_size here, it will be adjusted if needed + block_size, use_mla, has_sink, use_sparse, @@ -279,7 +279,7 @@ def get_attn_backend_cls( head_size, dtype, kv_cache_dtype, - None, # ignore block_size here, it will be adjusted if needed + block_size, use_mla, has_sink, use_sparse, From 97bee0437127fc6d6fb4b36aaa239ff0992a3bf2 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 27 Oct 2025 13:19:46 -0400 Subject: [PATCH 57/84] move block_size update back to check_and_update_config Signed-off-by: Matthew Bonanni --- vllm/attention/selector.py | 21 ++------------------- vllm/platforms/cuda.py | 33 ++++++++++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 57aa528cef1d..52af311df3ac 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -69,7 +69,7 @@ def get_attn_backend( head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, - block_size: int, + block_size: int | None, use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, @@ -96,7 +96,7 @@ def _cached_get_attn_backend( head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, - block_size: int, + block_size: int | None, use_v1: bool = False, use_mla: bool = False, has_sink: bool = False, @@ -152,23 +152,6 @@ def _cached_get_attn_backend( ) backend = resolve_obj_by_qualname(attention_cls) - # Adjust block size if the selected backend doesn't support it - # The hybrid block table will handle mapping between allocation and kernel sizes - # TODO: per-layer block size configuration - if not backend.supports_block_size(block_size): - from vllm.config import get_current_vllm_config - - vllm_config = get_current_vllm_config() - if vllm_config and vllm_config.cache_config: - new_block_size = backend.get_default_block_size() - logger.info( - "Adjusting KV cache block size from %d to %d for %s backend.", - block_size, - new_block_size, - backend.get_name(), - ) - vllm_config.cache_config.block_size = new_block_size - # Adjust kv cache layout if the selected backend requires a specific one device_capability = current_platform.get_device_capability() required_layout = backend.get_required_kv_cache_layout(device_capability) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e34f8c926a5d..313d5b131921 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -14,6 +14,7 @@ # import custom ops, trigger op registration import vllm._C # noqa +from vllm.attention.selector import get_attn_backend from vllm.logger import init_logger from vllm.utils.import_utils import import_pynvml, resolve_obj_by_qualname from vllm.utils.torch_utils import cuda_device_count_stateless @@ -148,8 +149,34 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" cache_config = vllm_config.cache_config - if cache_config and cache_config.block_size is None: - cache_config.block_size = 16 + model_config = vllm_config.model_config + + # Attempt to set an appropriate block size based on what backend will be used. + # TODO: per-layer block size configuration + if cache_config and model_config: + backend = get_attn_backend( + head_size=model_config.get_head_size(), + dtype=model_config.dtype, + kv_cache_dtype=cache_config.cache_dtype, + block_size=None, + use_mla=model_config.use_mla, + has_sink=False, # Model isn't loaded yet, can't determine this + use_sparse=hasattr(model_config.hf_config, "index_topk"), + ) + if cache_config.block_size and not backend.supports_block_size( + cache_config.block_size + ): + new_block_size = backend.get_default_block_size() + logger.info( + "Adjusting KV cache block size from %d to %d for %s backend.", + cache_config.block_size, + new_block_size, + backend.get_name(), + ) + cache_config.block_size = new_block_size + else: + if cache_config.block_size is None: + cache_config.block_size = 16 # default block size # lazy import to avoid circular import from vllm.config import CUDAGraphMode @@ -254,7 +281,7 @@ def get_attn_backend_cls( head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, - block_size: int, + block_size: int | None, use_v1: bool, use_mla: bool, has_sink: bool, From 0812fac45a4048967a5e597d7be4eb7764e11892 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 27 Oct 2025 16:18:29 -0400 Subject: [PATCH 58/84] fix import Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 313d5b131921..ef2ce846064b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -14,7 +14,6 @@ # import custom ops, trigger op registration import vllm._C # noqa -from vllm.attention.selector import get_attn_backend from vllm.logger import init_logger from vllm.utils.import_utils import import_pynvml, resolve_obj_by_qualname from vllm.utils.torch_utils import cuda_device_count_stateless @@ -154,6 +153,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # Attempt to set an appropriate block size based on what backend will be used. # TODO: per-layer block size configuration if cache_config and model_config: + from vllm.attention.selector import get_attn_backend + backend = get_attn_backend( head_size=model_config.get_head_size(), dtype=model_config.dtype, From ec392470129f8fcc4d5da4d55901e51c699d038b Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 27 Oct 2025 17:47:14 -0400 Subject: [PATCH 59/84] address missing case Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ef2ce846064b..194e258149b0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -164,17 +164,26 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: has_sink=False, # Model isn't loaded yet, can't determine this use_sparse=hasattr(model_config.hf_config, "index_topk"), ) - if cache_config.block_size and not backend.supports_block_size( - cache_config.block_size - ): - new_block_size = backend.get_default_block_size() + backend_default = backend.get_default_block_size() + + if cache_config.block_size is None: + cache_config.block_size = backend_default + logger.info( + "Setting KV cache block size to %d for %s backend.", + backend_default, + backend.get_name(), + ) + elif not backend.supports_block_size(cache_config.block_size): logger.info( "Adjusting KV cache block size from %d to %d for %s backend.", cache_config.block_size, - new_block_size, + backend_default, backend.get_name(), ) - cache_config.block_size = new_block_size + cache_config.block_size = backend_default + else: + # user-specified block size is supported + pass else: if cache_config.block_size is None: cache_config.block_size = 16 # default block size From 860bfdbe6afe970e8b1e42fea6dbbccbc7c12111 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 28 Oct 2025 16:12:46 -0400 Subject: [PATCH 60/84] fix flashmla_sparse support Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 1 + vllm/v1/attention/backends/mla/flashmla_sparse.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 194e258149b0..f656e2eac482 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -60,6 +60,7 @@ def _get_backend_priorities( _Backend.FLASHMLA: 2, _Backend.FLASH_ATTN_MLA: 3, _Backend.TRITON_MLA: 4, + _Backend.FLASHMLA_SPARSE: 5, } else: return { diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 5fe15c87cd39..aa0d6972635d 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -90,7 +90,7 @@ def is_sparse(cls) -> bool: @classmethod def get_min_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(10, 0) + return DeviceCapability(9, 0) @classmethod def get_max_compute_capability(cls) -> DeviceCapability | None: From df1cd64db9d56795dc2b8c5c3982d395b03a17a0 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 28 Oct 2025 17:27:17 -0400 Subject: [PATCH 61/84] fix hybrid models Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index f656e2eac482..cf9b463bed82 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -151,9 +151,14 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config model_config = vllm_config.model_config - # Attempt to set an appropriate block size based on what backend will be used. + # Attempt to set an appropriate block size based on what backend will + # be used. # TODO: per-layer block size configuration - if cache_config and model_config: + # Note: Hybrid models (models with both attention and mamba layers) + # have their block_size initialized in + # HybridAttentionMambaModelConfig.verify_and_update_config, + # which is called before this method. We should not override it here. + if cache_config and model_config and not model_config.is_hybrid: from vllm.attention.selector import get_attn_backend backend = get_attn_backend( @@ -186,7 +191,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # user-specified block size is supported pass else: - if cache_config.block_size is None: + if cache_config and cache_config.block_size is None: cache_config.block_size = 16 # default block size # lazy import to avoid circular import @@ -317,7 +322,7 @@ def get_attn_backend_cls( head_size, dtype, kv_cache_dtype, - block_size, + None, use_mla, has_sink, use_sparse, @@ -340,7 +345,7 @@ def get_attn_backend_cls( head_size, dtype, kv_cache_dtype, - block_size, + None, use_mla, has_sink, use_sparse, From 01b43ff324dba2880a7d9e12be330eab7823055e Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 29 Oct 2025 11:29:07 -0400 Subject: [PATCH 62/84] return only mla or non-mla priorities Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 71 ++++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index cf9b463bed82..e755b046fe9e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -40,41 +40,52 @@ @cache def _get_backend_priorities( + use_mla: bool, device_capability: DeviceCapability | None = None, ) -> dict[_Backend, int]: """Get backend priorities with lazy import to avoid circular dependency.""" from vllm.attention.backends.registry import _Backend - if device_capability >= DeviceCapability(10, 0) and ( - device_capability < DeviceCapability(11, 0) - ): - return { - # non-MLA backends - _Backend.FLASHINFER: 0, - _Backend.FLASH_ATTN: 1, - _Backend.TRITON_ATTN: 2, - _Backend.FLEX_ATTENTION: 3, - # MLA backends - _Backend.CUTLASS_MLA: 0, - _Backend.FLASHINFER_MLA: 1, - _Backend.FLASHMLA: 2, - _Backend.FLASH_ATTN_MLA: 3, - _Backend.TRITON_MLA: 4, - _Backend.FLASHMLA_SPARSE: 5, - } + if use_mla: + if ( + device_capability + and device_capability >= DeviceCapability(10, 0) + and device_capability < DeviceCapability(11, 0) + ): + return { + _Backend.CUTLASS_MLA: 0, + _Backend.FLASHINFER_MLA: 1, + _Backend.FLASHMLA: 2, + _Backend.FLASH_ATTN_MLA: 3, + _Backend.TRITON_MLA: 4, + _Backend.FLASHMLA_SPARSE: 5, + } + else: + return { + _Backend.FLASHMLA: 0, + _Backend.FLASH_ATTN_MLA: 1, + _Backend.FLASHINFER_MLA: 2, + _Backend.TRITON_MLA: 3, + } else: - return { - # non-MLA backends - _Backend.FLASH_ATTN: 0, - _Backend.FLASHINFER: 1, - _Backend.TRITON_ATTN: 2, - _Backend.FLEX_ATTENTION: 3, - # MLA backends - _Backend.FLASHMLA: 0, - _Backend.FLASH_ATTN_MLA: 1, - _Backend.FLASHINFER_MLA: 2, - _Backend.TRITON_MLA: 3, - } + if ( + device_capability + and device_capability >= DeviceCapability(10, 0) + and device_capability < DeviceCapability(11, 0) + ): + return { + _Backend.FLASHINFER: 0, + _Backend.FLASH_ATTN: 1, + _Backend.TRITON_ATTN: 2, + _Backend.FLEX_ATTENTION: 3, + } + else: + return { + _Backend.FLASH_ATTN: 0, + _Backend.FLASHINFER: 1, + _Backend.TRITON_ATTN: 2, + _Backend.FLEX_ATTENTION: 3, + } def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @@ -264,7 +275,7 @@ def get_valid_backends( invalid_reasons = {} from vllm.attention.backends.registry import _Backend, backend_to_class - backend_priorities = _get_backend_priorities(device_capability) + backend_priorities = _get_backend_priorities(use_mla, device_capability) for backend in _Backend: if backend not in backend_priorities: continue From ee894eaa4915c2a4508906e332b40c0c33c11c2e Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 29 Oct 2025 11:31:02 -0400 Subject: [PATCH 63/84] cleanup Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e755b046fe9e..272d98be4da3 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -273,13 +273,10 @@ def get_valid_backends( ) -> tuple[list[tuple["_Backend", int]], dict["_Backend", list[str]]]: valid_backends_priorities = [] invalid_reasons = {} - from vllm.attention.backends.registry import _Backend, backend_to_class + from vllm.attention.backends.registry import backend_to_class backend_priorities = _get_backend_priorities(use_mla, device_capability) - for backend in _Backend: - if backend not in backend_priorities: - continue - priority = backend_priorities[backend] + for backend, priority in backend_priorities.items(): try: backend_class = backend_to_class(backend) invalid_reasons_i = backend_class.validate_configuration( From 842e89b098f0d9daeaf90c3e814359f74e017af1 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 29 Oct 2025 12:33:45 -0400 Subject: [PATCH 64/84] skip test on hopper Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_sparse_mla_backends.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 02324d2aca6e..f289f3c876ec 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -116,6 +116,8 @@ def _quantize_dequantize_fp8_ds_mla( def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype): if not torch.cuda.is_available(): pytest.skip("CUDA is required for sparse MLA decode test") + if torch.cuda.get_device_capability()[0] != 10: + pytest.skip("Sparse MLA only supported on SM 10 devices") device = torch.device("cuda") dtype = torch.bfloat16 From bd190e761e89c6d60d02ae4a12732f0d91a56df2 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 29 Oct 2025 18:05:26 +0000 Subject: [PATCH 65/84] temp: apply fixes for test Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 4 +++- vllm/attention/selector.py | 17 +---------------- vllm/platforms/cuda.py | 5 ++++- vllm/v1/attention/backends/utils.py | 3 ++- 4 files changed, 10 insertions(+), 19 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index dc3b61e8552e..4cf9d2cee372 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -196,7 +196,9 @@ def get_max_compute_capability(cls) -> "DeviceCapability | None": return None @classmethod - def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: + def supports_compute_capability(cls, capability: "DeviceCapability | None") -> bool: + if capability is None: + return True min_capability = cls.get_min_compute_capability() max_capability = cls.get_max_compute_capability() return ((min_capability is None) or (capability >= min_capability)) and ( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 52af311df3ac..5a35b67d8441 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -150,22 +150,7 @@ def _cached_get_attn_backend( raise ValueError( f"Invalid attention backend for {current_platform.device_name}" ) - backend = resolve_obj_by_qualname(attention_cls) - - # Adjust kv cache layout if the selected backend requires a specific one - device_capability = current_platform.get_device_capability() - required_layout = backend.get_required_kv_cache_layout(device_capability) - if required_layout is not None: - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - set_kv_cache_layout(required_layout) - logger.info( - "Using %s KV cache layout for %s backend.", - required_layout, - backend.get_name(), - ) - - return backend + return resolve_obj_by_qualname(attention_cls) @contextmanager diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 272d98be4da3..90f13874d07e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -310,6 +310,7 @@ def get_attn_backend_cls( use_mla: bool, has_sink: bool, use_sparse: bool, + device_capability: "DeviceCapability | None" = None, ) -> str: if not use_v1: raise RuntimeError( @@ -319,7 +320,9 @@ def get_attn_backend_cls( from vllm.attention.backends.registry import backend_to_class_str - device_capability = cls.get_device_capability() + # Don't get device capability here to avoid early CUDA init. + # The validation functions can handle None for device_capability, + # and it will be retrieved later when actually needed. # First try checking just the selected backend, if there is one. if selected_backend is not None: diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 389baf1488be..fb4a4bb4f7be 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -379,9 +379,10 @@ def get_kv_cache_layout(): return cache_layout -def set_kv_cache_layout(cache_layout: KVCacheLayoutType): +def set_kv_cache_layout(cache_layout: KVCacheLayoutType | None): global _KV_CACHE_LAYOUT_OVERRIDE _KV_CACHE_LAYOUT_OVERRIDE = cache_layout + get_kv_cache_layout.cache_clear() @dataclass From 5bf94f69436f6c96eda4560f59163b6e335da7de Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 29 Oct 2025 18:23:27 +0000 Subject: [PATCH 66/84] Revert "skip test on hopper" This reverts commit 842e89b098f0d9daeaf90c3e814359f74e017af1. Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_sparse_mla_backends.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index f289f3c876ec..02324d2aca6e 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -116,8 +116,6 @@ def _quantize_dequantize_fp8_ds_mla( def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype): if not torch.cuda.is_available(): pytest.skip("CUDA is required for sparse MLA decode test") - if torch.cuda.get_device_capability()[0] != 10: - pytest.skip("Sparse MLA only supported on SM 10 devices") device = torch.device("cuda") dtype = torch.bfloat16 From 7e34939be2ffea90fb7349be74a6886d554e8e8c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 29 Oct 2025 18:59:45 +0000 Subject: [PATCH 67/84] revert to old check_and_update_config block_size logic Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 106 ++++++++++++++++++++++++++--------------- 1 file changed, 68 insertions(+), 38 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 90f13874d07e..89e87e29a6a0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -14,6 +14,7 @@ # import custom ops, trigger op registration import vllm._C # noqa +import vllm.envs as envs from vllm.logger import init_logger from vllm.utils.import_utils import import_pynvml, resolve_obj_by_qualname from vllm.utils.torch_utils import cuda_device_count_stateless @@ -155,56 +156,85 @@ def log_warnings(cls): @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: parallel_config = vllm_config.parallel_config + model_config = vllm_config.model_config if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" cache_config = vllm_config.cache_config - model_config = vllm_config.model_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 16 + + # TODO(lucas): handle this more gracefully + # Note: model_config may be None during testing + # Note: block_size is initialized in + # HybridAttentionMambaModelConfig.verify_and_update_config + # for models with both attention and mamba, + # and doesn't need to be reinitialized here + if ( + model_config is not None + and model_config.use_mla + and cache_config.block_size is not None + ): + use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") + # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, + # then we default to FlashMLA backend for non-blackwell GPUs, + # else we default to CutlassMLA. For each case, we force the + # required block_size. + use_flashmla = False + use_cutlass_mla = False + use_flashinfer_mla = False + + if envs.VLLM_ATTENTION_BACKEND is None: + # Default case + if cls.is_device_capability(100): + # Blackwell => Force CutlassMLA. + use_cutlass_mla = True + # TODO: This does not work, because the + # global_force_attn_backend_context_manager is not set. + # See vllm/attention/selector.py:_cached_get_attn_backend + envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA" + else: + # Not Blackwell + use_flashmla = True + else: + # Forced case + use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" + use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" + use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" - # Attempt to set an appropriate block size based on what backend will - # be used. - # TODO: per-layer block size configuration - # Note: Hybrid models (models with both attention and mamba layers) - # have their block_size initialized in - # HybridAttentionMambaModelConfig.verify_and_update_config, - # which is called before this method. We should not override it here. - if cache_config and model_config and not model_config.is_hybrid: - from vllm.attention.selector import get_attn_backend - - backend = get_attn_backend( - head_size=model_config.get_head_size(), - dtype=model_config.dtype, - kv_cache_dtype=cache_config.cache_dtype, - block_size=None, - use_mla=model_config.use_mla, - has_sink=False, # Model isn't loaded yet, can't determine this - use_sparse=hasattr(model_config.hf_config, "index_topk"), - ) - backend_default = backend.get_default_block_size() + from vllm.attention.ops.flashmla import is_flashmla_dense_supported - if cache_config.block_size is None: - cache_config.block_size = backend_default + if ( + use_flashmla + and is_flashmla_dense_supported()[0] + and cache_config.block_size % 64 != 0 + ): + cache_config.block_size = 64 + logger.info("Forcing kv cache block size to 64 for FlashMLA backend.") + + if use_cutlass_mla and cache_config.block_size % 128 != 0: + cache_config.block_size = 128 logger.info( - "Setting KV cache block size to %d for %s backend.", - backend_default, - backend.get_name(), + "Forcing kv cache block size to 128 for CUTLASS_MLA backend." ) - elif not backend.supports_block_size(cache_config.block_size): + + if ( + use_flashinfer_mla + and cache_config.block_size != 32 + and cache_config.block_size % 64 != 0 + ): + cache_config.block_size = 64 logger.info( - "Adjusting KV cache block size from %d to %d for %s backend.", - cache_config.block_size, - backend_default, - backend.get_name(), + "Forcing kv cache block size to 64 for FlashInferMLA backend." ) - cache_config.block_size = backend_default - else: - # user-specified block size is supported - pass - else: - if cache_config and cache_config.block_size is None: - cache_config.block_size = 16 # default block size + # TODO(Chen): remove this hacky code + if use_sparse and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashMLASparse backend." + ) # lazy import to avoid circular import from vllm.config import CUDAGraphMode From 3b1e92f71afc4d2cef880f839f8e5756b522edbd Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 29 Oct 2025 16:22:41 -0400 Subject: [PATCH 68/84] Revert "temp: apply fixes for test" This reverts commit cdf907d096b34944b002e1e58f22028c71c1d963. Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 4 +--- vllm/attention/selector.py | 17 ++++++++++++++++- vllm/platforms/cuda.py | 5 +---- vllm/v1/attention/backends/utils.py | 3 +-- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 4cf9d2cee372..dc3b61e8552e 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -196,9 +196,7 @@ def get_max_compute_capability(cls) -> "DeviceCapability | None": return None @classmethod - def supports_compute_capability(cls, capability: "DeviceCapability | None") -> bool: - if capability is None: - return True + def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: min_capability = cls.get_min_compute_capability() max_capability = cls.get_max_compute_capability() return ((min_capability is None) or (capability >= min_capability)) and ( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 5a35b67d8441..52af311df3ac 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -150,7 +150,22 @@ def _cached_get_attn_backend( raise ValueError( f"Invalid attention backend for {current_platform.device_name}" ) - return resolve_obj_by_qualname(attention_cls) + backend = resolve_obj_by_qualname(attention_cls) + + # Adjust kv cache layout if the selected backend requires a specific one + device_capability = current_platform.get_device_capability() + required_layout = backend.get_required_kv_cache_layout(device_capability) + if required_layout is not None: + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + set_kv_cache_layout(required_layout) + logger.info( + "Using %s KV cache layout for %s backend.", + required_layout, + backend.get_name(), + ) + + return backend @contextmanager diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 89e87e29a6a0..e818b7008b66 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -340,7 +340,6 @@ def get_attn_backend_cls( use_mla: bool, has_sink: bool, use_sparse: bool, - device_capability: "DeviceCapability | None" = None, ) -> str: if not use_v1: raise RuntimeError( @@ -350,9 +349,7 @@ def get_attn_backend_cls( from vllm.attention.backends.registry import backend_to_class_str - # Don't get device capability here to avoid early CUDA init. - # The validation functions can handle None for device_capability, - # and it will be retrieved later when actually needed. + device_capability = cls.get_device_capability() # First try checking just the selected backend, if there is one. if selected_backend is not None: diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index fb4a4bb4f7be..389baf1488be 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -379,10 +379,9 @@ def get_kv_cache_layout(): return cache_layout -def set_kv_cache_layout(cache_layout: KVCacheLayoutType | None): +def set_kv_cache_layout(cache_layout: KVCacheLayoutType): global _KV_CACHE_LAYOUT_OVERRIDE _KV_CACHE_LAYOUT_OVERRIDE = cache_layout - get_kv_cache_layout.cache_clear() @dataclass From d34eb77d299bfd1566eba8f3c3455701748b0cab Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 08:52:09 -0400 Subject: [PATCH 69/84] add test_attention_selector to Blackwell Tests Signed-off-by: Matthew Bonanni --- .buildkite/test-pipeline.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 339e3aab6c03..a94d125a6676 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -873,11 +873,16 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/attention/backends/mla/cutlass_mla.py + - vllm/v1/attention/backends/mla/flashinfer_mla.py + - vllm/platforms/cuda.py + - vllm/attention/selector.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py # Attention # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 + - pytest -v -s tests/kernels/attention/test_attention_selector.py - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py From 48290ee273128af2cfffecf1f614e863575695e3 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 10:56:37 -0400 Subject: [PATCH 70/84] rename _Backend to AttentionBackendEnum, add class methods Signed-off-by: Matthew Bonanni --- tests/compile/test_fusion_attn.py | 30 +-- tests/compile/test_fusions_e2e.py | 24 +- tests/config/test_multimodal_config.py | 6 +- tests/kernels/attention/test_mha_attn.py | 12 +- tests/kernels/utils.py | 4 +- tests/v1/attention/test_attention_backends.py | 47 ++-- tests/v1/attention/test_mla_backends.py | 21 +- tests/v1/attention/utils.py | 10 +- tests/v1/spec_decode/test_eagle.py | 18 +- tests/v1/spec_decode/test_mtp.py | 6 +- tests/v1/spec_decode/test_tree_attention.py | 8 +- vllm/attention/backends/registry.py | 208 +++++++++++------- vllm/attention/layer.py | 68 +++--- vllm/attention/selector.py | 26 ++- vllm/config/model.py | 8 +- vllm/config/multimodal.py | 29 +-- .../kv_connector/v1/nixl_connector.py | 8 +- vllm/engine/arg_utils.py | 4 +- vllm/envs.py | 6 +- vllm/model_executor/models/dots_ocr.py | 37 ++-- vllm/model_executor/models/ernie45_vl.py | 37 ++-- vllm/model_executor/models/glm4_1v.py | 35 +-- vllm/model_executor/models/keye.py | 28 ++- vllm/model_executor/models/ovis2_5.py | 6 +- vllm/model_executor/models/qwen2_5_vl.py | 41 ++-- vllm/model_executor/models/qwen2_vl.py | 37 ++-- .../models/qwen3_omni_moe_thinker.py | 15 +- vllm/model_executor/models/qwen3_vl.py | 26 +-- vllm/model_executor/models/siglip2navit.py | 26 +-- vllm/model_executor/models/vision.py | 8 +- vllm/platforms/cpu.py | 12 +- vllm/platforms/cuda.py | 84 +++---- vllm/platforms/interface.py | 14 +- vllm/platforms/rocm.py | 46 ++-- vllm/platforms/tpu.py | 12 +- vllm/platforms/xpu.py | 18 +- vllm/v1/spec_decode/eagle.py | 4 +- 37 files changed, 560 insertions(+), 469 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 7f510c221e93..ea61c94953a7 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -10,7 +10,7 @@ from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes @@ -104,7 +104,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: # TODO(luka) use get_kv_cache_stride_order # Create dummy KV cache for the selected backend - if backend == _Backend.ROCM_ATTN: + if backend == AttentionBackendEnum.ROCM_ATTN: # k/v as 1st dimention # HND: [num_blocks, num_kv_heads, block_size, head_size] kv_cache = torch.zeros( @@ -116,7 +116,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: dtype=self.kv_cache_dtype, device=self.device, ) - elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: # k/v as 1st dimention # NHD: [num_blocks, block_size, num_kv_heads, head_size] kv_cache = torch.zeros( @@ -128,7 +128,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: dtype=self.kv_cache_dtype, device=self.device, ) - elif backend == _Backend.TRITON_ATTN: + elif backend == AttentionBackendEnum.TRITON_ATTN: # k/v as 2nd dimention # NHD: [num_blocks, block_size, num_kv_heads, head_size] kv_cache = torch.zeros( @@ -140,7 +140,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: dtype=self.kv_cache_dtype, device=self.device, ) - elif backend == _Backend.FLASHINFER: + elif backend == AttentionBackendEnum.FLASHINFER: kv_cache = torch.zeros( num_blocks, 2, @@ -244,8 +244,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): MODELS_FP4: list[tuple[str, type]] = [] HEADS: list[tuple[int, int]] = [] SPLIT_ATTENTION: list[bool] = [] -BACKENDS_FP8: list[_Backend] = [] -BACKENDS_FP4: list[_Backend] = [] +BACKENDS_FP8: list[AttentionBackendEnum] = [] +BACKENDS_FP4: list[AttentionBackendEnum] = [] if current_platform.is_cuda(): HEADS = [(64, 8), (40, 8)] @@ -261,8 +261,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): TestAttentionNvfp4QuantPatternModel, ) ] - BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER] - BACKENDS_FP4 = [_Backend.FLASHINFER] + BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER] + BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER] elif current_platform.is_rocm(): HEADS = [(32, 8), (40, 8)] @@ -270,9 +270,9 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ] BACKENDS = [ - _Backend.ROCM_AITER_UNIFIED_ATTN, - _Backend.ROCM_ATTN, - _Backend.TRITON_ATTN, + AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, + AttentionBackendEnum.ROCM_ATTN, + AttentionBackendEnum.TRITON_ATTN, ] @@ -302,11 +302,11 @@ def test_attention_quant_pattern( custom_ops: str, model_name: str, model_class: type[AttentionQuantPatternModel], - backend: _Backend, + backend: AttentionBackendEnum, dist_init, ): """Test AttentionStaticQuantPattern fusion pass""" - if backend == _Backend.FLASHINFER and ( + if backend == AttentionBackendEnum.FLASHINFER and ( not current_platform.is_device_capability((10, 0)) or not has_flashinfer() ): pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") @@ -403,7 +403,7 @@ def test_attention_quant_pattern( result_fused_1 = model_compiled(q, k, v) - if backend == _Backend.FLASHINFER: + if backend == AttentionBackendEnum.FLASHINFER: # With the Flashinfer backend after the 1st round of the forward # pass, output quant scale should be loaded into the attn layer's # _o_scale_float, the 2nd round should reuse the loaded diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index d66c60ccb5b2..bbc545348c67 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -11,7 +11,7 @@ import pytest import regex as re -from tests.v1.attention.utils import _Backend +from tests.v1.attention.utils import AttentionBackendEnum from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform @@ -24,7 +24,7 @@ class ModelBackendTestCase(NamedTuple): model_name: str model_kwargs: dict[str, Any] - backend: _Backend + backend: AttentionBackendEnum attention_fusions: int allreduce_fusions: int | None = None @@ -39,14 +39,14 @@ class ModelBackendTestCase(NamedTuple): # Use smaller model for L40s in CI model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=32, allreduce_fusions=65, ), ModelBackendTestCase( model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), - backend=_Backend.FLASHINFER, + backend=AttentionBackendEnum.FLASHINFER, attention_fusions=48, allreduce_fusions=96, ), @@ -56,7 +56,7 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), - backend=_Backend.FLASHINFER, + backend=AttentionBackendEnum.FLASHINFER, attention_fusions=48, allreduce_fusions=96, ), @@ -67,7 +67,7 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="meta-llama/Llama-3.1-8B-Instruct", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=0, allreduce_fusions=65, ), @@ -78,19 +78,19 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=32, ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.ROCM_ATTN, + backend=AttentionBackendEnum.ROCM_ATTN, attention_fusions=32, ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.ROCM_AITER_UNIFIED_ATTN, + backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, attention_fusions=32, ), ] @@ -111,7 +111,7 @@ class ModelBackendTestCase(NamedTuple): def test_attn_quant( model_name: str, model_kwargs: dict[str, Any], - backend: _Backend, + backend: AttentionBackendEnum, attention_fusions: int, allreduce_fusions: int, custom_ops: str, @@ -119,7 +119,7 @@ def test_attn_quant( caplog_mp_spawn, monkeypatch, ): - if backend == _Backend.FLASHINFER and ( + if backend == AttentionBackendEnum.FLASHINFER and ( not current_platform.is_device_capability((10, 0)) or not has_flashinfer() ): pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") @@ -203,7 +203,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: def test_tp2_attn_quant_allreduce_rmsnorm( model_name: str, model_kwargs: dict, - backend: _Backend, + backend: AttentionBackendEnum, attention_fusions: int, allreduce_fusions: int, custom_ops: str, diff --git a/tests/config/test_multimodal_config.py b/tests/config/test_multimodal_config.py index b1a09d88ed9d..3d02893e52f1 100644 --- a/tests/config/test_multimodal_config.py +++ b/tests/config/test_multimodal_config.py @@ -3,13 +3,13 @@ import pytest -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config.multimodal import MultiModalConfig def test_mm_encoder_attn_backend_str_conversion(): config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN") - assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN + assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN def test_mm_encoder_attn_backend_invalid(): @@ -20,6 +20,6 @@ def test_mm_encoder_attn_backend_invalid(): def test_mm_encoder_attn_backend_hash_updates(): base_hash = MultiModalConfig().compute_hash() overridden_hash = MultiModalConfig( - mm_encoder_attn_backend=_Backend.FLASH_ATTN + mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN ).compute_hash() assert base_hash != overridden_hash diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 14d1618bca3c..183bbf3bf4e0 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -11,7 +11,7 @@ import pytest import torch -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import MultiHeadAttention from vllm.attention.selector import _cached_get_attn_backend from vllm.platforms import current_platform @@ -43,14 +43,14 @@ def test_mha_attn_platform(device: str): patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.TORCH_SDPA + assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA elif device == "hip": with ( patch("vllm.attention.layer.current_platform", RocmPlatform()), patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.TORCH_SDPA + assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA else: # Test CUDA with head_size=64 (divisible by 32) # - should use vLLM's FlashAttention @@ -59,7 +59,7 @@ def test_mha_attn_platform(device: str): patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.FLASH_ATTN + assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN # Test CUDA with head_size=72 (not divisible by 32) # - with upstream FA not available @@ -73,7 +73,7 @@ def test_mha_attn_platform(device: str): ), ): attn = MultiHeadAttention(16, 72, scale=1) - assert attn.attn_backend == _Backend.XFORMERS + assert attn.attn_backend == AttentionBackendEnum.XFORMERS # Test CUDA with head_size=72 (not divisible by 32) # - with upstream FA available @@ -96,7 +96,7 @@ def test_mha_attn_platform(device: str): ), ): attn = MultiHeadAttention(16, 72, scale=1) - assert attn.attn_backend == _Backend.FLASH_ATTN + assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN def ref_attention( diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index eb00bc72b4b0..1befd644202d 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -15,7 +15,7 @@ from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils import ( @@ -878,7 +878,7 @@ def make_block_tables_slot_mapping( def make_test_metadata( - attn_backend: _Backend, + attn_backend: AttentionBackendEnum, is_prompt: bool, seq_lens: list[int] | None, decoder_test_params: PhaseTestParameters | None, diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 6659b3eb1e98..b83105428289 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -15,7 +15,7 @@ create_vllm_config, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ModelConfig from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv @@ -27,11 +27,11 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN, - _Backend.FLASHINFER, - _Backend.FLEX_ATTENTION, - _Backend.TRITON_ATTN, - _Backend.TREE_ATTN, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.FLEX_ATTENTION, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.TREE_ATTN, "FLEX_ATTENTION_SLOW", ] @@ -39,7 +39,7 @@ try: import flashinfer # noqa: F401 except ImportError: - BACKENDS_TO_TEST.remove(_Backend.FLASHINFER) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER) def _convert_dtype_to_torch(dtype): @@ -192,7 +192,7 @@ def __init__(self, device: torch.device): def run_attention_backend( - backend: _Backend, + backend: AttentionBackendEnum, kv_cache_spec: FullAttentionSpec, layer_names: list[str], vllm_config, @@ -211,13 +211,13 @@ def run_attention_backend( use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") if backend == "FLEX_ATTENTION_SLOW": - actual_backend = _Backend.FLEX_ATTENTION + actual_backend = AttentionBackendEnum.FLEX_ATTENTION use_direct_block_mask = False builder_cls, impl_cls = try_get_attention_backend(actual_backend) # Mock flashinfer's get_per_layer_parameters if needed - if actual_backend == _Backend.FLASHINFER: + if actual_backend == AttentionBackendEnum.FLASHINFER: import unittest.mock from vllm.v1.attention.backends.utils import PerLayerParameters @@ -246,7 +246,7 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): else: # Build metadata builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) - if actual_backend == _Backend.FLEX_ATTENTION: + if actual_backend == AttentionBackendEnum.FLEX_ATTENTION: builder.direct_build = use_direct_block_mask attn_metadata = builder.build( common_prefix_len=0, @@ -289,7 +289,7 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): def _test_backend_correctness( batch_spec: BatchSpec, model: str, - backend_to_test: list[_Backend | str], + backend_to_test: list[AttentionBackendEnum | str], mask_mod, *, block_size: int = 16, @@ -429,17 +429,20 @@ def _test_backend_correctness( # Select the appropriate KV cache format for each backend kv_cache_for_backend = kv_cache reset_kv_cache_layout = False - if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN): + if backend_name in ( + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.TRITON_ATTN, + ): kv_cache_for_backend = kv_cache.transpose(0, 1) - if backend_name == _Backend.FLASHINFER: + if backend_name == AttentionBackendEnum.FLASHINFER: # For FlashInfer default to HND layout and kv_cache_for_backend = ( kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) ) set_kv_cache_layout("HND") reset_kv_cache_layout = True - elif backend_name == _Backend.TRITON_ATTN: + elif backend_name == AttentionBackendEnum.TRITON_ATTN: kv_cache_for_backend = kv_cache_for_backend.contiguous() try: @@ -518,7 +521,9 @@ def causal_mask_mod( batch_spec = BATCH_SPECS[batch_spec_name] LARGE_BLOCK_BACKENDS = ( - [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + [AttentionBackendEnum.FLEX_ATTENTION] + if is_torch_equal_or_newer("2.9.0.dev0") + else [] ) SMALL_BLOCK_BACKENDS = [ x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS @@ -533,9 +538,9 @@ def causal_mask_mod( SLIDING_WINDOW_BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN, - _Backend.FLEX_ATTENTION, - _Backend.TRITON_ATTN, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLEX_ATTENTION, + AttentionBackendEnum.TRITON_ATTN, "FLEX_ATTENTION_SLOW", ] @@ -569,7 +574,9 @@ def sliding_window_mask_mod( ) LARGE_BLOCK_BACKENDS = ( - [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + [AttentionBackendEnum.FLEX_ATTENTION] + if is_torch_equal_or_newer("2.9.0.dev0") + else [] ) SMALL_BLOCK_BACKENDS = [ x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 1b1753288484..d409413cde5c 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -19,7 +19,7 @@ try_get_attention_backend, ) from vllm import _custom_ops as ops -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.config.vllm import set_current_vllm_config from vllm.utils.math_utils import cdiv @@ -28,19 +28,19 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.CUTLASS_MLA, - _Backend.FLASHMLA, - _Backend.FLASH_ATTN_MLA, - _Backend.TRITON_MLA, + AttentionBackendEnum.CUTLASS_MLA, + AttentionBackendEnum.FLASHMLA, + AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.TRITON_MLA, ] # Remove CUTLASS_MLA from the list if not using sm100 if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: - BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA) # Remove FLASHMLA from the list if not supported if not is_flashmla_dense_supported()[0]: - BACKENDS_TO_TEST.remove(_Backend.FLASHMLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA) torch.manual_seed(42) @@ -239,7 +239,7 @@ def __init__(self, device: torch.device): def run_attention_backend( - backend: _Backend, + backend: AttentionBackendEnum, kv_cache_spec: FullAttentionSpec, layer_names: list[str], vllm_config, @@ -357,7 +357,10 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): batch_spec = BATCH_SPECS[batch_spec_name] is_spec_decode_test = batch_spec_name.startswith("spec_decode") - spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA} + spec_decode_backends = { + AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.FLASHMLA, + } block_size = 16 required_blocks = sum( diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 15ed7bdc835b..192a0edf6e04 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -8,7 +8,7 @@ import torch from vllm.attention.backends.abstract import AttentionImpl -from vllm.attention.backends.registry import _Backend, backend_to_class_str +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, CompilationConfig, @@ -20,7 +20,6 @@ VllmConfig, ) from vllm.config.model import ModelDType -from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, @@ -120,15 +119,14 @@ def create_common_attn_metadata( def try_get_attention_backend( - backend: _Backend, + backend: AttentionBackendEnum, ) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]: """Try to get the attention backend class, skipping test if not found.""" - backend_class_str = backend_to_class_str(backend) try: - backend_class = resolve_obj_by_qualname(backend_class_str) + backend_class = backend.get_class() return backend_class.get_builder_cls(), backend_class.get_impl_cls() except ImportError as e: - pytest.skip(f"{backend_class_str} not available: {e}") + pytest.skip(f"{backend.name} not available: {e}") raise AssertionError("unreachable") from None diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 47d05a20a65d..89d0ec769ac0 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -13,7 +13,7 @@ create_standard_kv_cache_spec, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, DeviceConfig, @@ -534,11 +534,17 @@ def create_deterministic_logits(token_ids): sampling_metadata = mock.MagicMock() if attn_backend == "FLASH_ATTN": - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) elif attn_backend == "TRITON_ATTN": - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.TRITON_ATTN + ) elif attn_backend == "TREE_ATTN": - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.TREE_ATTN + ) else: raise ValueError(f"Unsupported attention backend: {attn_backend}") @@ -673,7 +679,9 @@ def create_deterministic_logits(token_ids, k: int): proposer.attn_layer_names = ["layer.0"] # Get the tree attention metadata builder. - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.TREE_ATTN + ) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), layer_names=proposer.attn_layer_names, diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index 9ca7cf9e3e0e..6d59b58e739e 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -12,7 +12,7 @@ create_standard_kv_cache_spec, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, DeviceConfig, @@ -177,7 +177,9 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): sampling_metadata = mock.MagicMock() # Setup attention metadata - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index b365e75d5514..6958d62dc7e9 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -10,7 +10,7 @@ create_vllm_config, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ParallelConfig, SpeculativeConfig from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -35,7 +35,7 @@ def forward_attention( block_table: torch.Tensor, slot_mapping: torch.Tensor, seqlen_k: int, - backend: _Backend, + backend: AttentionBackendEnum, spec_token_tree: str | None = None, num_spec_tokens: int = 0, ) -> torch.Tensor: @@ -241,7 +241,7 @@ def test_tree_attn_correctness() -> None: block_table=block_table, slot_mapping=tree_slot_mapping, seqlen_k=seqlen_k, - backend=_Backend.TREE_ATTN, + backend=AttentionBackendEnum.TREE_ATTN, spec_token_tree=spec_token_tree, num_spec_tokens=tree_size_q - 1, ).view(batch_size, -1, num_heads, dim_per_head) @@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None: block_table=block_table, slot_mapping=branch_slot_mapping, seqlen_k=sequence_position + q_len, - backend=_Backend.FLASH_ATTN, + backend=AttentionBackendEnum.FLASH_ATTN, ).view(batch_size, -1, num_heads, dim_per_head) # Compare the outputs. diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 0a3e70ac6c89..19d51f57a168 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -3,7 +3,7 @@ """Attention backend registry""" import enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from vllm.utils.import_utils import resolve_obj_by_qualname @@ -11,91 +11,145 @@ from vllm.attention.backends.abstract import AttentionBackend -class _Backend(enum.Enum): - FLASH_ATTN = enum.auto() - TRITON_ATTN = enum.auto() - XFORMERS = enum.auto() - ROCM_ATTN = enum.auto() - ROCM_AITER_MLA = enum.auto() - ROCM_AITER_FA = enum.auto() # used for ViT attn backend - TORCH_SDPA = enum.auto() - FLASHINFER = enum.auto() - FLASHINFER_MLA = enum.auto() - TRITON_MLA = enum.auto() - CUTLASS_MLA = enum.auto() - FLASHMLA = enum.auto() - FLASHMLA_SPARSE = enum.auto() - FLASH_ATTN_MLA = enum.auto() - PALLAS = enum.auto() - IPEX = enum.auto() - NO_ATTENTION = enum.auto() - FLEX_ATTENTION = enum.auto() - TREE_ATTN = enum.auto() - ROCM_AITER_UNIFIED_ATTN = enum.auto() - - -BACKEND_MAP = { - _Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501 - _Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501 - _Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501 - _Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501 - _Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501 - _Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501 - _Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501 - _Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501 - _Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501 - _Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501 - _Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501 - _Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501 - _Backend.FLASHMLA_SPARSE: "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend", # noqa: E501 - _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501 - _Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501 - _Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501 - _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501 - _Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501 -} - - -def register_attn_backend(backend: _Backend, class_path: str | None = None): - """ - Decorator: register a custom attention backend into BACKEND_MAPPING. - - If class_path is provided, use it. - - Otherwise, auto-generate from the class object. - Validation: only checks if 'backend' is a valid _Backend enum member. - Overwriting existing mappings is allowed. This enables other hardware - platforms to plug in custom out-of-tree backends. - """ - if not isinstance(backend, _Backend): - raise ValueError(f"{backend} is not a valid _Backend enum value.") +class _AttentionBackendEnumMeta(enum.EnumMeta): + """Metaclass for AttentionBackendEnum to provide better error messages.""" - def decorator(cls): - path = class_path or f"{cls.__module__}.{cls.__qualname__}" - BACKEND_MAP[backend] = path - return cls - - return decorator + def __getitem__(cls, name: str): + """Get backend by name with helpful error messages.""" + try: + return super().__getitem__(name) + except KeyError: + members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values() + valid_backends = ", ".join(m.name for m in members) + raise ValueError( + f"Unknown attention backend: '{name}'. " + f"Valid options are: {valid_backends}" + ) from None -def backend_to_class_str(backend: _Backend) -> str: - """Get the backend class string +class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta): + """Enumeration of all supported attention backends. - Args: - backend: The backend enum value + The enum value is the default class path, but this can be overridden + at runtime using register_backend(). - Returns: - The backend class string + To get the actual backend class (respecting overrides), use: + backend.get_class() """ - return BACKEND_MAP[backend] - -def backend_to_class(backend: _Backend) -> "type[AttentionBackend]": - """Get the backend class. + FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" + XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" + ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" + ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" + ROCM_AITER_FA = ( + "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" + ) + TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" + FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" + FLASHINFER_MLA = ( + "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" + ) + TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" + CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" + FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" + FLASHMLA_SPARSE = ( + "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend" + ) + FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" + PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend" + NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend" + FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" + TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" + ROCM_AITER_UNIFIED_ATTN = ( + "vllm.v1.attention.backends.rocm_aiter_unified_attn." + "RocmAiterUnifiedAttentionBackend" + ) + # Placeholder for third-party/custom backends - must be registered before use + CUSTOM = "" + + def get_path(self) -> str: + """Get the class path for this backend (respects overrides). + + Returns: + The fully qualified class path string + + Raises: + ValueError: If Backend.CUSTOM is used without being registered + """ + path = _OVERRIDES.get(self, self.value) + if not path: + raise ValueError( + f"Backend {self.name} must be registered before use. " + f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')" + ) + return path + + def get_class(self) -> "type[AttentionBackend]": + """Get the backend class (respects overrides). + + Returns: + The backend class + + Raises: + ImportError: If the backend class cannot be imported + ValueError: If Backend.CUSTOM is used without being registered + """ + return resolve_obj_by_qualname(self.get_path()) + + def is_overridden(self) -> bool: + """Check if this backend has been overridden. + + Returns: + True if the backend has a registered override + """ + return self in _OVERRIDES + + def clear_override(self) -> None: + """Clear any override for this backend, reverting to the default.""" + _OVERRIDES.pop(self, None) + + +_OVERRIDES: dict[AttentionBackendEnum, str] = {} + + +def register_backend(backend: AttentionBackendEnum, class_path: str | None = None): + """Register or override a backend implementation. Args: - backend: The backend enum value + backend: The AttentionBackendEnum member to register + class_path: Optional class path. If not provided and used as + decorator, will be auto-generated from the class. Returns: - The backend class + Decorator function if class_path is None, otherwise a no-op + + Examples: + # Override an existing backend + @register_backend(AttentionBackendEnum.FLASH_ATTN) + class MyCustomFlashAttn: + ... + + # Register a custom third-party backend + @register_backend(AttentionBackendEnum.CUSTOM) + class MyCustomBackend: + ... + + # Direct registration + register_backend( + AttentionBackendEnum.CUSTOM, + "my.module.MyCustomBackend" + ) """ - backend_class_name = backend_to_class_str(backend) - return resolve_obj_by_qualname(backend_class_name) + + def decorator(cls): + path = class_path or f"{cls.__module__}.{cls.__qualname__}" + _OVERRIDES[backend] = path + return cls + + if class_path is not None: + _OVERRIDES[backend] = class_path + return lambda x: x + + return decorator diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 221e0ca4057f..3310fdb0072e 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -12,7 +12,7 @@ import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config @@ -99,35 +99,39 @@ def check_upstream_fa_availability(dtype: torch.dtype): def maybe_get_vit_flash_attn_backend( - attn_backend: _Backend, + attn_backend: AttentionBackendEnum, use_upstream_fa: bool, - attn_backend_override: _Backend | None = None, -) -> tuple[_Backend, Callable | None]: + attn_backend_override: AttentionBackendEnum | None = None, +) -> tuple[AttentionBackendEnum, Callable | None]: if current_platform.is_rocm(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - attn_backend = _Backend.ROCM_AITER_FA + attn_backend = AttentionBackendEnum.ROCM_AITER_FA elif ( check_upstream_fa_availability(torch.get_default_dtype()) and on_gfx9() and attn_backend_override is None ): - attn_backend = _Backend.FLASH_ATTN + attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True else: - return _Backend.TORCH_SDPA, None + return AttentionBackendEnum.TORCH_SDPA, None elif current_platform.is_cuda(): - if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - attn_backend = _Backend.FLASH_ATTN + attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True else: - return _Backend.TORCH_SDPA, None + return AttentionBackendEnum.TORCH_SDPA, None - if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: - if attn_backend == _Backend.ROCM_AITER_FA: + if attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: + if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: if use_upstream_fa: @@ -304,7 +308,7 @@ def __init__( kv_sharing_target_layer_name, **extra_impl_args, ) - self.backend = _Backend[self.attn_backend.get_name()] + self.backend = AttentionBackendEnum[self.attn_backend.get_name()] self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how @@ -523,19 +527,19 @@ def __init__( if current_platform.is_xpu(): # currently, only torch_sdpa is supported on xpu - self.attn_backend = _Backend.TORCH_SDPA + self.attn_backend = AttentionBackendEnum.TORCH_SDPA else: self.attn_backend = ( backend if backend in { - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.PALLAS, - _Backend.ROCM_AITER_FA, - _Backend.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.PALLAS, + AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, } - else _Backend.TORCH_SDPA + else AttentionBackendEnum.TORCH_SDPA ) self.attn_backend, self._flash_attn_varlen_func = ( @@ -546,17 +550,23 @@ def __init__( ) ) - if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability(): - self.attn_backend = _Backend.TORCH_SDPA + if ( + self.attn_backend == AttentionBackendEnum.XFORMERS + and not check_xformers_availability() + ): + self.attn_backend = AttentionBackendEnum.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } # this condition is just to make sure that the # use_upstream_fa in the log is correct - if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: + if ( + current_platform.is_rocm() + and self.attn_backend == AttentionBackendEnum.FLASH_ATTN + ): use_upstream_fa = True logger.info_once( @@ -605,17 +615,17 @@ def forward( max_seqlen_k=kv_len, softmax_scale=self.scale, ) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward( query, key, value, scale=self.scale ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) out = out.transpose(1, 2) - elif self.attn_backend == _Backend.PALLAS: + elif self.attn_backend == AttentionBackendEnum.PALLAS: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) from torch_xla.experimental.custom_kernel import flash_attention diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 52af311df3ac..0fe9e844c2e3 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -10,7 +10,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.utils import STR_BACKEND_ENV_VAR from vllm.utils.import_utils import resolve_obj_by_qualname @@ -18,18 +18,18 @@ logger = init_logger(__name__) -def get_env_variable_attn_backend() -> _Backend | None: +def get_env_variable_attn_backend() -> AttentionBackendEnum | None: """ Get the backend override specified by the vLLM attention backend environment variable, if one is specified. Returns: - * _Backend enum value if an override is specified + * AttentionBackendEnum value if an override is specified * None otherwise """ backend_name = os.environ.get(STR_BACKEND_ENV_VAR) - return None if backend_name is None else _Backend[backend_name] + return None if backend_name is None else AttentionBackendEnum[backend_name] # Global state allows a particular choice of backend @@ -39,10 +39,10 @@ def get_env_variable_attn_backend() -> _Backend | None: # # THIS SELECTION TAKES PRECEDENCE OVER THE # VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE -forced_attn_backend: _Backend | None = None +forced_attn_backend: AttentionBackendEnum | None = None -def global_force_attn_backend(attn_backend: _Backend | None) -> None: +def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None: """ Force all attention operations to use a specified backend. @@ -57,7 +57,7 @@ def global_force_attn_backend(attn_backend: _Backend | None) -> None: forced_attn_backend = attn_backend -def get_global_forced_attn_backend() -> _Backend | None: +def get_global_forced_attn_backend() -> AttentionBackendEnum | None: """ Get the currently-forced choice of attention backend, or None if auto-selection is currently enabled. @@ -108,7 +108,9 @@ def _cached_get_attn_backend( # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # ENVIRONMENT VARIABLE. selected_backend = None - backend_by_global_setting: _Backend | None = get_global_forced_attn_backend() + backend_by_global_setting: AttentionBackendEnum | None = ( + get_global_forced_attn_backend() + ) if backend_by_global_setting is not None: selected_backend = backend_by_global_setting else: @@ -125,11 +127,11 @@ def _cached_get_attn_backend( ) backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1") try: - selected_backend = _Backend[backend_by_env_var] + selected_backend = AttentionBackendEnum[backend_by_env_var] except KeyError as e: raise ValueError( - f"Invalid attention backend: '{backend_by_env_var}'. " - f"Valid backends are: {list(_Backend.__members__.keys())}" + f"Invalid attention backend: '{backend_by_env_var}'. Valid " + f"backends are: {list(AttentionBackendEnum.__members__.keys())}" ) from e # get device-specific attn_backend @@ -170,7 +172,7 @@ def _cached_get_attn_backend( @contextmanager def global_force_attn_backend_context_manager( - attn_backend: _Backend, + attn_backend: AttentionBackendEnum, ) -> Generator[None, None, None]: """ Globally force a vLLM attention backend override within a diff --git a/vllm/config/model.py b/vllm/config/model.py index 092c67e7bed8..3c25aafc1ce0 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -48,7 +48,7 @@ import vllm.model_executor.layers.quantization as me_quant import vllm.model_executor.models as me_models - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config.load import LoadConfig from vllm.config.parallel import ParallelConfig from vllm.model_executor.layers.quantization import QuantizationMethods @@ -56,7 +56,7 @@ else: PretrainedConfig = Any - _Backend = Any + AttentionBackendEnum = Any me_quant = LazyLoader( "model_executor", globals(), "vllm.model_executor.layers.quantization" ) @@ -311,7 +311,7 @@ class ModelConfig: mm_processor_cache_type: InitVar[MMCacheType | None] = None mm_shm_cache_max_object_size_mb: InitVar[int | None] = None mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None - mm_encoder_attn_backend: InitVar[_Backend | str | None] = None + mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None interleave_mm_strings: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None video_pruning_rate: InitVar[float | None] = None @@ -431,7 +431,7 @@ def __post_init__( mm_processor_cache_type: MMCacheType | None, mm_shm_cache_max_object_size_mb: int | None, mm_encoder_tp_mode: MMEncoderTPMode | None, - mm_encoder_attn_backend: _Backend | str | None, + mm_encoder_attn_backend: AttentionBackendEnum | str | None, interleave_mm_strings: bool | None, skip_mm_profiling: bool | None, video_pruning_rate: float | None, diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 8e001441f3cf..3bfc20100123 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -11,9 +11,9 @@ from vllm.config.utils import config if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum else: - _Backend = Any + AttentionBackendEnum = Any @dataclass @@ -125,10 +125,10 @@ class MultiModalConfig: DP (which is controlled by `--data-parallel-size`). This is only supported on a per-model basis and falls back to `"weights"` if the encoder does not support DP.""" - mm_encoder_attn_backend: _Backend | None = None + mm_encoder_attn_backend: AttentionBackendEnum | None = None """Optional override for the multi-modal encoder attention backend when using vision transformers. Accepts any value from - `vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`).""" + `vllm.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`).""" interleave_mm_strings: bool = False """Enable fully interleaved support for multimodal prompts, while using --chat-template-content-format=string.""" @@ -167,25 +167,16 @@ def _validate_limit_per_prompt( @field_validator("mm_encoder_attn_backend", mode="before") @classmethod - def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None: - from vllm.attention.backends.registry import ( - _Backend as BackendEnum, - ) - - if value is None or isinstance(value, BackendEnum): + def _validate_mm_encoder_attn_backend( + cls, value: object + ) -> AttentionBackendEnum | None: + if value is None or isinstance(value, AttentionBackendEnum): return value assert isinstance(value, str), ( - "mm_encoder_attn_backend must be a string or a BackendEnum." + "mm_encoder_attn_backend must be a string or an AttentionBackendEnum." ) - try: - return BackendEnum[value.upper()] - except KeyError as exc: - valid_backends = ", ".join(sorted(BackendEnum.__members__.keys())) - raise ValueError( - f"Invalid mm encoder attention backend. " - f"Expected one of: {valid_backends}." - ) from exc + return AttentionBackendEnum[value.upper()] @model_validator(mode="after") def _validate_multimodal_config(self): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 42f3f6e93a0a..d985c105d1f0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -21,7 +21,7 @@ import zmq from vllm import envs -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -768,9 +768,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): use_mla=self.use_mla, ) self.backend_name = backend.get_name() - attn_backend = _Backend[self.backend_name] - self._use_flashinfer = attn_backend == _Backend.FLASHINFER - self._use_pallas = attn_backend == _Backend.PALLAS + attn_backend = AttentionBackendEnum[self.backend_name] + self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER + self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0f82d73664f9..17b929202088 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -32,7 +32,7 @@ from typing_extensions import TypeIs, deprecated import vllm.envs as envs -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, CompilationConfig, @@ -458,7 +458,7 @@ class EngineArgs: MultiModalConfig.mm_shm_cache_max_object_size_mb ) mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode - mm_encoder_attn_backend: _Backend | str | None = ( + mm_encoder_attn_backend: AttentionBackendEnum | str | None = ( MultiModalConfig.mm_encoder_attn_backend ) io_processor_plugin: str | None = None diff --git a/vllm/envs.py b/vllm/envs.py index 0548f01fc8cd..b87ad0906550 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -615,14 +615,14 @@ def get_vllm_port() -> int | None: # - "FLASH_ATTN_MLA": use FlashAttention for MLA # - "FLASHINFER_MLA": use FlashInfer for MLA # - "CUTLASS_MLA": use CUTLASS for MLA - # All possible options loaded dynamically from _Backend enum + # All possible options loaded dynamically from AttentionBackendEnum "VLLM_ATTENTION_BACKEND": env_with_choices( "VLLM_ATTENTION_BACKEND", None, lambda: list( __import__( - "vllm.attention.backends.registry", fromlist=["_Backend"] - )._Backend.__members__.keys() + "vllm.attention.backends.registry", fromlist=["AttentionBackendEnum"] + ).AttentionBackendEnum.__members__.keys() ), ), # If set, vllm will use flashinfer sampler diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 6d462ad8ae62..1b2bb60a17c1 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -9,7 +9,7 @@ from torch.nn import LayerNorm from transformers.models.qwen2_vl import Qwen2VLProcessor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -256,7 +256,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -303,17 +303,17 @@ def __init__( ) ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Unsupported vision attention backend: {self.attn_backend}" ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def forward( @@ -361,7 +361,7 @@ def forward( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: outputs = [] for i in range(1, len(cu_seqlens)): s = int(cu_seqlens[i - 1]) @@ -373,7 +373,7 @@ def forward( out_i = out_i.permute(0, 2, 1, 3) outputs.append(out_i) context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -514,7 +514,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() @@ -567,7 +567,7 @@ def __init__( require_post_norm: bool | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.config = config @@ -582,10 +582,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN self.out_hidden_size = config.hidden_size # Keep blocks for compatibility with other vision towers num_layers = ( @@ -666,11 +667,11 @@ def compute_attn_mask_seqlen( ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 86536b21c33f..d1644827ed64 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -36,7 +36,7 @@ from einops import rearrange, repeat from transformers import BatchFeature, PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -164,7 +164,7 @@ def __init__( projection_size: int, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -211,17 +211,17 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Ernie45-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -291,7 +291,7 @@ def forward( context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] for i in range(1, len(cu_seqlens)): @@ -310,7 +310,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -370,7 +370,7 @@ def __init__( norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -463,7 +463,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() patch_size = vision_config.patch_size @@ -515,10 +515,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -565,11 +566,11 @@ def compute_attn_mask_seqlen( ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 9f1439e21ef7..a199ad5c246e 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -45,7 +45,7 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor from transformers.video_utils import VideoMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -250,7 +250,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -304,18 +304,18 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"GLM-4V does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -375,7 +375,7 @@ def forward( context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] for i in range(1, len(cu_seqlens)): @@ -394,7 +394,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -423,7 +423,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -701,7 +701,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -770,10 +770,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -822,8 +823,8 @@ def compute_attn_mask_seqlen( max_seqlen, seqlens = None, None seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index acfd51a6d0cc..ef734898a2a8 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -16,7 +16,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.utils import torch_int -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -353,7 +353,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -399,13 +399,17 @@ def __init__( ) self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN self.use_upstream_fa = True - if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: + if self.attn_backend not in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.XFORMERS, + }: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now." ) @@ -457,7 +461,7 @@ def forward( self.head_dim, ) - if self.attn_backend == _Backend.FLASH_ATTN: + if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: if self.use_upstream_fa: from flash_attn import flash_attn_varlen_func else: @@ -477,7 +481,7 @@ def forward( softmax_scale=self.scale, ) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -524,7 +528,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -578,7 +582,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -673,7 +677,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -756,7 +760,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index f6461ae9a412..9a4d69dea096 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -10,7 +10,7 @@ import torch.nn as nn from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear @@ -106,7 +106,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -135,7 +135,7 @@ def _init_backbone( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): model_type = config.model_type if model_type == "siglip2_navit": diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3d67653726bd..2ee016f72e64 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -42,7 +42,7 @@ Qwen2_5_VLVisionConfig, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, @@ -313,9 +313,9 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, use_upstream_fa: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -362,11 +362,14 @@ def __init__( # On ROCm with FLASH_ATTN backend, upstream flash_attn is used from vllm.platforms import current_platform - if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: + if ( + current_platform.is_rocm() + and self.attn_backend == AttentionBackendEnum.FLASH_ATTN + ): self.use_upstream_fa = True self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -427,10 +430,10 @@ def forward( cu_seqlens, max_seqlen, batch_size, - self.attn_backend == _Backend.ROCM_AITER_FA, + self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, self.use_upstream_fa, ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. from vllm.platforms import current_platform @@ -457,7 +460,7 @@ def forward( context_layer = einops.rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) output, _ = self.proj(context_layer) @@ -486,9 +489,9 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, use_upstream_fa: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -662,7 +665,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -714,10 +717,10 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen2.5-VL does not support {self.attn_backend} backend now." @@ -857,11 +860,11 @@ def compute_attn_mask_seqlen( max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f0d7e2e7d7ec..20daf59c88e7 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -329,7 +329,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -378,18 +378,18 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -460,7 +460,7 @@ def forward( context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. from vllm.platforms import current_platform @@ -485,7 +485,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -515,7 +515,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -679,7 +679,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -739,10 +739,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -790,11 +791,11 @@ def compute_attn_mask_seqlen( ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index efcd003fbbda..5729f2485023 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -47,7 +47,7 @@ ) from transformers.models.whisper import WhisperFeatureExtractor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -302,7 +302,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -378,10 +378,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -491,9 +492,9 @@ def compute_attn_mask_seqlen( ) -> tuple[torch.Tensor, torch.Tensor]: max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) - if self.attn_backend == _Backend.FLASH_ATTN: + if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index d611580c7182..da6348256e65 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -49,7 +49,7 @@ ) from transformers.video_utils import VideoMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -198,7 +198,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, use_upstream_fa: bool = False, ) -> None: super().__init__() @@ -306,7 +306,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -372,18 +372,18 @@ def __init__( ) use_upstream_fa = False if ( - self.attn_backend != _Backend.FLASH_ATTN - and self.attn_backend != _Backend.ROCM_AITER_FA + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen3-VL does not support {self.attn_backend} backend now." @@ -516,11 +516,11 @@ def compute_attn_mask_seqlen( max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index bab5c1d82ded..c20bcd975ca3 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -12,7 +12,7 @@ from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -208,7 +208,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -264,14 +264,14 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.ROCM_AITER_FA, }: - self.attn_backend = _Backend.TORCH_SDPA + self.attn_backend = AttentionBackendEnum.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def forward( @@ -308,7 +308,7 @@ def forward( attn_output = self.flash_attn_varlen_func( queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen ).reshape(seq_length, -1) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. batch_size = cu_seqlens.shape[0] - 1 outputs = [] @@ -376,7 +376,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -440,7 +440,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -626,7 +626,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -667,7 +667,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index b5f6c60514c0..0b8843dbbaee 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -10,7 +10,7 @@ import torch from transformers import PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -82,8 +82,8 @@ def get_vit_attn_backend( head_size: int, dtype: torch.dtype, *, - attn_backend_override: _Backend | None = None, -) -> _Backend: + attn_backend_override: AttentionBackendEnum | None = None, +) -> AttentionBackendEnum: """ Get the available attention backend for Vision Transformer. """ @@ -93,7 +93,7 @@ def get_vit_attn_backend( # Lazy import to avoid circular dependency from vllm.attention.selector import get_env_variable_attn_backend - selected_backend: _Backend | None = get_env_variable_attn_backend() + selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend() if selected_backend is not None: return selected_backend diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index e73dbf62bd3b..15a1e73847f1 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -22,10 +22,10 @@ logger = init_logger(__name__) if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig else: - _Backend = None + AttentionBackendEnum = None VllmConfig = None @@ -126,7 +126,7 @@ def get_device_name(cls, device_id: int = 0) -> str: @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -136,9 +136,9 @@ def get_attn_backend_cls( has_sink: bool, use_sparse: bool, ) -> str: - from vllm.attention.backends.registry import _Backend, backend_to_class_str + from vllm.attention.backends.registry import AttentionBackendEnum - if selected_backend and selected_backend != _Backend.TORCH_SDPA: + if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: raise NotImplementedError("MLA is not supported on CPU.") @@ -147,7 +147,7 @@ def get_attn_backend_cls( logger.info("Using Torch SDPA backend.") if not use_v1: raise ValueError("CPU backend only supports V1.") - return backend_to_class_str(_Backend.TORCH_SDPA) + return AttentionBackendEnum.TORCH_SDPA.get_path() @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e818b7008b66..b685e7b3b028 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -16,16 +16,16 @@ import vllm._C # noqa import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils.import_utils import import_pynvml, resolve_obj_by_qualname +from vllm.utils.import_utils import import_pynvml from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig else: - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -43,9 +43,9 @@ def _get_backend_priorities( use_mla: bool, device_capability: DeviceCapability | None = None, -) -> dict[_Backend, int]: +) -> dict[AttentionBackendEnum, int]: """Get backend priorities with lazy import to avoid circular dependency.""" - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if use_mla: if ( @@ -54,19 +54,19 @@ def _get_backend_priorities( and device_capability < DeviceCapability(11, 0) ): return { - _Backend.CUTLASS_MLA: 0, - _Backend.FLASHINFER_MLA: 1, - _Backend.FLASHMLA: 2, - _Backend.FLASH_ATTN_MLA: 3, - _Backend.TRITON_MLA: 4, - _Backend.FLASHMLA_SPARSE: 5, + AttentionBackendEnum.CUTLASS_MLA: 0, + AttentionBackendEnum.FLASHINFER_MLA: 1, + AttentionBackendEnum.FLASHMLA: 2, + AttentionBackendEnum.FLASH_ATTN_MLA: 3, + AttentionBackendEnum.TRITON_MLA: 4, + AttentionBackendEnum.FLASHMLA_SPARSE: 5, } else: return { - _Backend.FLASHMLA: 0, - _Backend.FLASH_ATTN_MLA: 1, - _Backend.FLASHINFER_MLA: 2, - _Backend.TRITON_MLA: 3, + AttentionBackendEnum.FLASHMLA: 0, + AttentionBackendEnum.FLASH_ATTN_MLA: 1, + AttentionBackendEnum.FLASHINFER_MLA: 2, + AttentionBackendEnum.TRITON_MLA: 3, } else: if ( @@ -75,17 +75,17 @@ def _get_backend_priorities( and device_capability < DeviceCapability(11, 0) ): return { - _Backend.FLASHINFER: 0, - _Backend.FLASH_ATTN: 1, - _Backend.TRITON_ATTN: 2, - _Backend.FLEX_ATTENTION: 3, + AttentionBackendEnum.FLASHINFER: 0, + AttentionBackendEnum.FLASH_ATTN: 1, + AttentionBackendEnum.TRITON_ATTN: 2, + AttentionBackendEnum.FLEX_ATTENTION: 3, } else: return { - _Backend.FLASH_ATTN: 0, - _Backend.FLASHINFER: 1, - _Backend.TRITON_ATTN: 2, - _Backend.FLEX_ATTENTION: 3, + AttentionBackendEnum.FLASH_ATTN: 0, + AttentionBackendEnum.FLASHINFER: 1, + AttentionBackendEnum.TRITON_ATTN: 2, + AttentionBackendEnum.FLEX_ATTENTION: 3, } @@ -266,28 +266,30 @@ def get_current_memory_usage( return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": - from vllm.attention.backends.registry import _Backend, backend_to_class + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> "AttentionBackendEnum": + from vllm.attention.backends.registry import AttentionBackendEnum # For Blackwell GPUs, force TORCH_SDPA for now. # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 if cls.has_device_capability(100): - return _Backend.TORCH_SDPA + return AttentionBackendEnum.TORCH_SDPA if dtype not in (torch.float16, torch.bfloat16): - return _Backend.XFORMERS + return AttentionBackendEnum.XFORMERS if cls.has_device_capability(80): - backend_class = backend_to_class(_Backend.FLASH_ATTN) + backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() if backend_class.supports_head_size( head_size ) and backend_class.supports_dtype(dtype): - return _Backend.FLASH_ATTN + return AttentionBackendEnum.FLASH_ATTN else: - return _Backend.XFORMERS + return AttentionBackendEnum.XFORMERS else: # Fallback for Volta/Turing GPUs or FA not supported - return _Backend.XFORMERS + return AttentionBackendEnum.XFORMERS @classmethod def get_valid_backends( @@ -300,15 +302,17 @@ def get_valid_backends( has_sink, use_sparse, device_capability, - ) -> tuple[list[tuple["_Backend", int]], dict["_Backend", list[str]]]: + ) -> tuple[ + list[tuple["AttentionBackendEnum", int]], + dict["AttentionBackendEnum", list[str]], + ]: valid_backends_priorities = [] invalid_reasons = {} - from vllm.attention.backends.registry import backend_to_class backend_priorities = _get_backend_priorities(use_mla, device_capability) for backend, priority in backend_priorities.items(): try: - backend_class = backend_to_class(backend) + backend_class = backend.get_class() invalid_reasons_i = backend_class.validate_configuration( head_size, dtype, @@ -331,7 +335,7 @@ def get_valid_backends( @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -347,15 +351,12 @@ def get_attn_backend_cls( "to select a supported backend." ) - from vllm.attention.backends.registry import backend_to_class_str - device_capability = cls.get_device_capability() # First try checking just the selected backend, if there is one. if selected_backend is not None: - backend_class_str = backend_to_class_str(selected_backend) try: - backend_class = resolve_obj_by_qualname(backend_class_str) + backend_class = selected_backend.get_class() invalid_reasons = backend_class.validate_configuration( head_size, dtype, @@ -375,7 +376,7 @@ def get_attn_backend_cls( ) else: logger.info("Using %s backend.", selected_backend) - return backend_class_str + return selected_backend.get_path() # No selected backend or the selected backend is invalid, # so we try finding a valid backend. @@ -423,13 +424,12 @@ def get_attn_backend_cls( ) selected_index = sorted_indices[0] selected_backend = valid_backends_priorities[selected_index][0] - selected_backend_class_str = backend_to_class_str(selected_backend) logger.info( "Using %s backend.", selected_backend.name, ) - return selected_backend_class_str + return selected_backend.get_path() @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 2934f36f6cad..fdbe739421a4 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from torch.distributed import PrefixStore, ProcessGroup - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.inputs import ProcessorInputs, PromptType from vllm.pooling_params import PoolingParams @@ -198,16 +198,18 @@ def import_kernels(cls) -> None: import vllm._moe_C # noqa: F401 @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": - # Import _Backend here to avoid circular import. - from vllm.attention.backends.registry import _Backend + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> "AttentionBackendEnum": + # Import AttentionBackendEnum here to avoid circular import. + from vllm.attention.backends.registry import AttentionBackendEnum - return _Backend.TORCH_SDPA + return AttentionBackendEnum.TORCH_SDPA @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 51a0172188b6..71371f4af4a9 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -14,10 +14,10 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig else: - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -205,18 +205,20 @@ class RocmPlatform(Platform): ] @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> "AttentionBackendEnum": from importlib.util import find_spec - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - return _Backend.ROCM_AITER_FA + return AttentionBackendEnum.ROCM_AITER_FA if on_gfx9() and find_spec("flash_attn") is not None: - return _Backend.FLASH_ATTN + return AttentionBackendEnum.FLASH_ATTN - return _Backend.TORCH_SDPA + return AttentionBackendEnum.TORCH_SDPA @classmethod def get_attn_backend_cls( @@ -231,7 +233,7 @@ def get_attn_backend_cls( has_sink, use_sparse, ) -> str: - from vllm.attention.backends.registry import _Backend, backend_to_class_str + from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: raise NotImplementedError("Sparse Attention is not supported on ROCm.") @@ -248,23 +250,23 @@ def get_attn_backend_cls( if selected_backend is None: selected_backend = ( - _Backend.ROCM_AITER_MLA + AttentionBackendEnum.ROCM_AITER_MLA if is_aiter_mla_enabled() or block_size == 1 - else _Backend.TRITON_MLA + else AttentionBackendEnum.TRITON_MLA ) - if selected_backend == _Backend.TRITON_MLA: + if selected_backend == AttentionBackendEnum.TRITON_MLA: if block_size != 1: logger.info_once("Using Triton MLA backend on V1 engine.") - return backend_to_class_str(_Backend.TRITON_MLA) + return AttentionBackendEnum.TRITON_MLA.get_path() raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}." ) - if selected_backend == _Backend.ROCM_AITER_MLA: + if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA: if block_size == 1: logger.info("Using AITER MLA backend on V1 engine.") - return backend_to_class_str(_Backend.ROCM_AITER_MLA) + return AttentionBackendEnum.ROCM_AITER_MLA.get_path() raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}." @@ -276,30 +278,30 @@ def get_attn_backend_cls( ) if envs.VLLM_USE_V1: - if selected_backend == _Backend.FLEX_ATTENTION: + if selected_backend == AttentionBackendEnum.FLEX_ATTENTION: logger.info("Using FlexAttention backend on V1 engine.") return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" if ( envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() - ) or selected_backend == _Backend.ROCM_AITER_FA: + ) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA: logger.info("Using Aiter Flash Attention backend on V1 engine.") - return backend_to_class_str(_Backend.ROCM_AITER_FA) + return AttentionBackendEnum.ROCM_AITER_FA.get_path() if ( envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION - ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + ) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: logger.info("Using Aiter Unified Attention backend on V1 engine.") - return backend_to_class_str(_Backend.ROCM_AITER_UNIFIED_ATTN) + return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() if ( envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - or selected_backend == _Backend.ROCM_ATTN + or selected_backend == AttentionBackendEnum.ROCM_ATTN ): # rocm specific backend, with aiter and/or # triton prefix-prefill logger.info("Using Rocm Attention backend on V1 engine.") - return backend_to_class_str(_Backend.ROCM_ATTN) + return AttentionBackendEnum.ROCM_ATTN.get_path() # default case, using triton unified attention logger.info("Using Triton Attention backend on V1 engine.") - return backend_to_class_str(_Backend.TRITON_ATTN) + return AttentionBackendEnum.TRITON_ATTN.get_path() raise RuntimeError( "V0 attention backends have been removed. Set VLLM_USE_V1=1 " "to select a supported backend." diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 18735012804a..04e40ea753fa 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -15,7 +15,7 @@ from .interface import Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams @@ -23,7 +23,7 @@ BlockSize = None VllmConfig = None PoolingParams = None - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -53,7 +53,7 @@ def import_kernels(cls) -> None: @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -63,17 +63,17 @@ def get_attn_backend_cls( has_sink, use_sparse, ) -> str: - from vllm.attention.backends.registry import _Backend, backend_to_class_str + from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: raise NotImplementedError("Sparse Attention is not supported on TPU.") - if selected_backend != _Backend.PALLAS: + if selected_backend != AttentionBackendEnum.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) if not use_v1: raise ValueError("TPU backend only supports V1.") logger.info("Using Pallas V1 backend.") - return backend_to_class_str(_Backend.PALLAS) + return AttentionBackendEnum.PALLAS.get_path() @classmethod def set_device(cls, device: torch.device) -> None: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index eb44a91aa0ed..90475539d966 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -14,11 +14,11 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig else: VllmConfig = None - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -43,7 +43,7 @@ def import_kernels(cls) -> None: @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -61,19 +61,19 @@ def get_attn_backend_cls( "only NHD layout is supported by XPU attention kernels." ) - from vllm.attention.backends.registry import _Backend, backend_to_class_str + from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: raise NotImplementedError("Sparse Attention is not supported on XPU.") use_v1 = envs.VLLM_USE_V1 if not use_v1: raise ValueError("XPU backend only supports V1.") - if selected_backend == _Backend.TRITON_ATTN: + if selected_backend == AttentionBackendEnum.TRITON_ATTN: logger.info_once("Using Triton backend on V1 engine.") - return backend_to_class_str(_Backend.TRITON_ATTN) - elif selected_backend == _Backend.FLASH_ATTN: + return AttentionBackendEnum.TRITON_ATTN.get_path() + elif selected_backend == AttentionBackendEnum.FLASH_ATTN: logger.info_once("Using Flash Attention backend on V1 engine.") - return backend_to_class_str(_Backend.FLASH_ATTN) + return AttentionBackendEnum.FLASH_ATTN.get_path() elif selected_backend: raise ValueError( f"Invalid attention backend for {cls.device_name}, " @@ -81,7 +81,7 @@ def get_attn_backend_cls( ) logger.info("Using Flash Attention backend on V1 engine.") - return backend_to_class_str(_Backend.FLASH_ATTN) + return AttentionBackendEnum.FLASH_ATTN.get_path() @classmethod def set_device(cls, device: torch.device) -> None: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 8b2f095009a5..9aee504e935d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -149,13 +149,13 @@ def __init__( ) # Determine allowed attention backends once during initialization. - from vllm.attention.backends.registry import _Backend, backend_to_class_str + from vllm.attention.backends.registry import AttentionBackendEnum self.allowed_attn_types: tuple | None = None if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] # ROCM_AITER_FA is an optional backend - if find_spec(backend_to_class_str(_Backend.ROCM_AITER_FA)): + if find_spec(AttentionBackendEnum.ROCM_AITER_FA.get_path()): from vllm.v1.attention.backends.rocm_aiter_fa import ( AiterFlashAttentionMetadata, ) From 1c71eabc1c3477bddae07ef9ce53c43f51c04c9a Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 11:04:57 -0400 Subject: [PATCH 71/84] get rid of get_min_compute_capability and get_max_compute_capability Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 14 +------------- vllm/v1/attention/backends/flash_attn.py | 8 ++------ vllm/v1/attention/backends/flashinfer.py | 10 ++++------ vllm/v1/attention/backends/mla/cutlass_mla.py | 8 ++------ vllm/v1/attention/backends/mla/flashattn_mla.py | 8 ++------ vllm/v1/attention/backends/mla/flashinfer_mla.py | 8 ++------ vllm/v1/attention/backends/mla/flashmla.py | 8 ++------ vllm/v1/attention/backends/mla/flashmla_sparse.py | 8 ++------ vllm/v1/attention/backends/mla/triton_mla.py | 8 ++------ vllm/v1/attention/backends/triton_attn.py | 8 ++------ 10 files changed, 21 insertions(+), 67 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index dc3b61e8552e..f517296f62c0 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -187,21 +187,9 @@ def supports_sink(cls) -> bool: def is_sparse(cls) -> bool: return False - @classmethod - def get_min_compute_capability(cls) -> "DeviceCapability | None": - return None - - @classmethod - def get_max_compute_capability(cls) -> "DeviceCapability | None": - return None - @classmethod def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: - min_capability = cls.get_min_compute_capability() - max_capability = cls.get_max_compute_capability() - return ((min_capability is None) or (capability >= min_capability)) and ( - (max_capability is None) or (capability <= max_capability) - ) + return True @classmethod def supports_combination( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 8cd5ef7040e5..0c3f84f411b0 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -123,12 +123,8 @@ def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: return kv_cache_dtype in ["auto"] @classmethod - def get_min_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(8, 0) - - @classmethod - def get_max_compute_capability(cls) -> DeviceCapability | None: - return None + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability >= DeviceCapability(8, 0) @classmethod def supports_combination( diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 3f0f0c603db8..5b622147b9b2 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -231,12 +231,10 @@ def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] @classmethod - def get_min_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(7, 5) - - @classmethod - def get_max_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(12, 1) + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability( + 12, 1 + ) @classmethod def get_required_kv_cache_layout( diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 1500a55e2d74..51199ea2c68c 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -60,12 +60,8 @@ def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto", "fp8", "fp8_e4m3"] @classmethod - def get_min_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(10, 0) - - @classmethod - def get_max_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(10, 3) + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major == 10 class SM100Workspace: diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 5c200435b8f2..4ae3d7a50764 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -69,12 +69,8 @@ def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto"] @classmethod - def get_min_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(9, 0) - - @classmethod - def get_max_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(9, 0) + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major == 9 @classmethod def supports_combination( diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index fe0e88bd8266..4e626f064322 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -63,12 +63,8 @@ def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto", "fp8", "fp8_e4m3"] @classmethod - def get_min_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(10, 0) - - @classmethod - def get_max_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(10, 3) + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major == 10 @classmethod def get_required_kv_cache_layout( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 5476dfc99d30..25c02958e505 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -67,12 +67,8 @@ def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto", "fp8", "fp8_e4m3"] @classmethod - def get_min_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(9, 0) - - @classmethod - def get_max_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(10, 3) + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major in [9, 10] @classmethod def supports_combination( diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index aa0d6972635d..eb941c780550 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -89,12 +89,8 @@ def is_sparse(cls) -> bool: return True @classmethod - def get_min_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(9, 0) - - @classmethod - def get_max_compute_capability(cls) -> DeviceCapability | None: - return DeviceCapability(10, 3) + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major in [9, 10] @staticmethod def get_kv_cache_shape( diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 5cc78eef2881..f66a8e5f2723 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -47,12 +47,8 @@ def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: return ["auto"] @classmethod - def get_min_compute_capability(cls) -> DeviceCapability | None: - return None - - @classmethod - def get_max_compute_capability(cls) -> DeviceCapability | None: - return None + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return True class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 31410b47509f..1ca6c8a15159 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -204,12 +204,8 @@ def supports_sink(cls) -> bool: return True @classmethod - def get_min_compute_capability(cls) -> DeviceCapability | None: - return None - - @classmethod - def get_max_compute_capability(cls) -> DeviceCapability | None: - return None + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return True class TritonAttentionImpl(AttentionImpl): From 6e9d1f16a06d92024bb0a3b7e67ec96b296817cb Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 11:24:21 -0400 Subject: [PATCH 72/84] fix pre-commit Signed-off-by: Matthew Bonanni --- vllm/attention/selector.py | 15 +++++++++++++-- vllm/platforms/cuda.py | 6 +++++- vllm/platforms/interface.py | 3 ++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 0fe9e844c2e3..d41d9b5e309d 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -5,12 +5,14 @@ from collections.abc import Generator from contextlib import contextmanager from functools import cache +from typing import cast, get_args import torch import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.utils import STR_BACKEND_ENV_VAR from vllm.utils.import_utils import resolve_obj_by_qualname @@ -79,10 +81,19 @@ def get_attn_backend( # value to be returned from the cache if the value changes between calls. # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the # private function. + + # Validate kv_cache_dtype is a valid CacheDType value + if kv_cache_dtype is not None: + valid_cache_dtypes = get_args(CacheDType) + assert kv_cache_dtype in valid_cache_dtypes, ( + f"Invalid kv_cache_dtype: {kv_cache_dtype}. " + f"Valid values are: {valid_cache_dtypes}" + ) + return _cached_get_attn_backend( head_size=head_size, dtype=dtype, - kv_cache_dtype=kv_cache_dtype, + kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype), block_size=block_size, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, @@ -95,7 +106,7 @@ def get_attn_backend( def _cached_get_attn_backend( head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, + kv_cache_dtype: CacheDType | None, block_size: int | None, use_v1: bool = False, use_mla: bool = False, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b685e7b3b028..b825b67efa5f 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -24,8 +24,11 @@ if TYPE_CHECKING: from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig + from vllm.config.cache import CacheDType else: AttentionBackendEnum = None + VllmConfig = None + CacheDType = None logger = init_logger(__name__) @@ -338,7 +341,7 @@ def get_attn_backend_cls( selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, + kv_cache_dtype: "CacheDType | None", block_size: int | None, use_v1: bool, use_mla: bool, @@ -352,6 +355,7 @@ def get_attn_backend_cls( ) device_capability = cls.get_device_capability() + assert device_capability is not None # First try checking just the selected backend, if there is one. if selected_backend is not None: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index fdbe739421a4..d2a8f997ad12 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -19,6 +19,7 @@ from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig + from vllm.config.cache import CacheDType from vllm.inputs import ProcessorInputs, PromptType from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -212,7 +213,7 @@ def get_attn_backend_cls( selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, + kv_cache_dtype: "CacheDType | None", block_size: int, use_v1: bool, use_mla: bool, From d3cdda7f2dd53dcf084e7b8a97e3a6c6bf348181 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 11:42:06 -0400 Subject: [PATCH 73/84] change methods to properties Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 35 ++++++------------- vllm/v1/attention/backends/cpu_attn.py | 11 +++--- vllm/v1/attention/backends/flash_attn.py | 11 ++---- vllm/v1/attention/backends/flashinfer.py | 23 +++++------- vllm/v1/attention/backends/flex_attention.py | 15 ++++---- vllm/v1/attention/backends/mla/cutlass_mla.py | 20 +++++------ .../attention/backends/mla/flashattn_mla.py | 16 +++------ .../attention/backends/mla/flashinfer_mla.py | 23 ++++++------ vllm/v1/attention/backends/mla/flashmla.py | 20 +++++------ .../attention/backends/mla/flashmla_sparse.py | 15 ++------ vllm/v1/attention/backends/mla/indexer.py | 6 ++-- vllm/v1/attention/backends/mla/triton_mla.py | 12 +++---- vllm/v1/attention/backends/rocm_aiter_fa.py | 11 ++---- vllm/v1/attention/backends/rocm_attn.py | 5 +-- vllm/v1/attention/backends/tree_attn.py | 12 ++----- vllm/v1/attention/backends/triton_attn.py | 24 ++++++------- vllm/v1/attention/backends/xformers.py | 12 ++----- vllm/v1/worker/gpu_model_runner.py | 3 +- 18 files changed, 98 insertions(+), 176 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index f517296f62c0..d1bc3d9f4608 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast, get_args +from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, cast, get_args import torch @@ -45,6 +45,9 @@ class AttentionBackend(ABC): # calling the custom op. When piecewise cudagraph is enabled, this # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)] + supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"] @staticmethod @abstractmethod @@ -61,10 +64,6 @@ def get_impl_cls() -> type["AttentionImpl"]: def get_metadata_cls() -> type["AttentionMetadata"]: raise NotImplementedError - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - return [MultipleOf(1)] - @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) @@ -102,26 +101,16 @@ def supports_head_size(cls, head_size: int) -> bool: supported_head_sizes = cls.get_supported_head_sizes() return (not supported_head_sizes) or head_size in supported_head_sizes - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - @classmethod def supports_dtype(cls, dtype: torch.dtype) -> bool: - supported_dtypes = cls.get_supported_dtypes() - return (not supported_dtypes) or dtype in supported_dtypes - - @classmethod - def get_supported_kv_cache_dtypes(cls) -> list["CacheDType"]: - return ["auto"] + return dtype in cls.supported_dtypes @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool: if kv_cache_dtype is None: return True - supported_kv_cache_dtypes = cls.get_supported_kv_cache_dtypes() - return (not supported_kv_cache_dtypes) or ( - kv_cache_dtype in supported_kv_cache_dtypes + return (not cls.supported_kv_cache_dtypes) or ( + kv_cache_dtype in cls.supported_kv_cache_dtypes ) @classmethod @@ -135,11 +124,10 @@ def supports_block_size(cls, block_size: int | None) -> bool: if block_size not in valid_sizes: return False - supported_block_sizes = cls.get_supported_kernel_block_sizes() - if not supported_block_sizes: + if not cls.supported_kernel_block_sizes: return True - for supported_size in supported_block_sizes: + for supported_size in cls.supported_kernel_block_sizes: is_multiple_of = ( isinstance(supported_size, MultipleOf) and block_size % supported_size.base == 0 @@ -155,14 +143,13 @@ def supports_block_size(cls, block_size: int | None) -> bool: def get_default_block_size(cls) -> "BlockSize": from vllm.config.cache import BlockSize - supported_block_sizes = cls.get_supported_kernel_block_sizes() - if not supported_block_sizes: + if not cls.supported_kernel_block_sizes: raise ValueError( f"Fallback failed, no explicitly supported block sizes for " f"backend {cls.get_name()}" ) - block_size = supported_block_sizes[0] + block_size = cls.supported_kernel_block_sizes[0] if isinstance(block_size, MultipleOf): block_size = block_size.base diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 0d54417119e9..a3960e686a44 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import numpy as np import torch @@ -40,16 +40,17 @@ class TorchSDPABackend(AttentionBackend): accept_output_buffer: bool = False + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] @classmethod def get_supported_head_sizes(cls) -> list[int]: attn_impl = _get_paged_attn_impl() return attn_impl.get_supported_head_sizes() - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - @staticmethod def get_name() -> str: return "TORCH_SDPA" diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 0c3f84f411b0..657a538ed587 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -3,6 +3,7 @@ """Attention layer with FlashAttention.""" from dataclasses import dataclass +from typing import ClassVar import numpy as np import torch @@ -53,6 +54,8 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @staticmethod def get_name() -> str: @@ -106,14 +109,6 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: if kv_cache_dtype is None: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 5b622147b9b2..05d1aa18111e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -161,6 +161,14 @@ def trtllm_prefill_attn_kvfp8_dequant( class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + ] @staticmethod def get_name() -> str: @@ -215,21 +223,6 @@ def get_supported_head_sizes(cls) -> list[int]: # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 return [64, 128, 256] - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - # Note: Not sure for all platforms, - # but on Blackwell, only support a page size of - # 16, 32, 64 - return [16, 32, 64] - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: - return ["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] - @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability( diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 27ddcc0c1916..a28ac50b06e7 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -4,6 +4,7 @@ import math from dataclasses import dataclass +from typing import ClassVar import torch import torch._dynamo.decorators @@ -73,6 +74,12 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] @staticmethod def get_name() -> str: @@ -108,14 +115,6 @@ def use_cascade_attention(*args, **kwargs) -> bool: def get_supported_head_sizes(cls) -> list[int]: return [] - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: - return ["auto"] - # @torch.compile(fullgraph=True, mode="reduce-overhead") def physical_to_logical_mapping( diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 51199ea2c68c..0a10ce74cd1d 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -35,6 +35,14 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class CutlassMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + ] + @staticmethod def get_name() -> str: return "CUTLASS_MLA" @@ -47,18 +55,6 @@ def get_impl_cls() -> type["CutlassMLAImpl"]: def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: return CutlassMLAMetadataBuilder - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - return [128] - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: - return ["auto", "fp8", "fp8_e4m3"] - @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return capability.major == 10 diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 4ae3d7a50764..f85569d5cc2b 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -40,6 +40,10 @@ class FlashAttnMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + @staticmethod def get_name() -> str: return "FLASH_ATTN_MLA" @@ -56,18 +60,6 @@ def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashAttnMLAImpl"]: return FlashAttnMLAImpl - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: - return ["auto"] - @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return capability.major == 9 diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 4e626f064322..5744f59b3c60 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -34,6 +34,17 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class FlashInferMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [ + 32, + MultipleOf(64), + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + ] + @staticmethod def get_name() -> str: return "FLASHINFER_MLA" @@ -46,22 +57,10 @@ def get_impl_cls() -> type["FlashInferMLAImpl"]: def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: return FlashInferMLAMetadataBuilder - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - return [32, MultipleOf(64)] - @classmethod def get_default_block_size(cls) -> BlockSize: return 64 - @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: - return ["auto", "fp8", "fp8_e4m3"] - @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return capability.major == 10 diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 25c02958e505..dfa726a4e517 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -38,6 +38,14 @@ class FlashMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + ] + @staticmethod def get_name() -> str: return "FLASHMLA" @@ -54,18 +62,6 @@ def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - return [64] - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: - return ["auto", "fp8", "fp8_e4m3"] - @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return capability.major in [9, 10] diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index eb941c780550..d76b63159475 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -55,6 +55,9 @@ class FlashMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"] @staticmethod def get_name() -> str: @@ -72,18 +75,6 @@ def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: def get_impl_cls() -> type["FlashMLASparseImpl"]: return FlashMLASparseImpl - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.bfloat16] - - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - return [64] - - @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: - return ["auto", "fp8_ds_mla"] - @classmethod def is_sparse(cls) -> bool: return True diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 661787c3aa58..c431edae4176 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -24,6 +24,8 @@ class DeepseekV32IndexerBackend(AttentionBackend): + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: return DeepseekV32IndexerMetadata @@ -51,10 +53,6 @@ def get_kv_cache_shape( def get_kv_cache_stride_order() -> tuple[int, ...]: return (0, 1, 2) - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - return [64] - @dataclass class DeepseekV32IndexerPrefillChunkMetadata: diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index f66a8e5f2723..0149639e8c0b 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import ClassVar import torch @@ -30,6 +31,9 @@ class TritonMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + @staticmethod def get_name() -> str: return "TRITON_MLA" @@ -38,14 +42,6 @@ def get_name() -> str: def get_impl_cls() -> type["TritonMLAImpl"]: return TritonMLAImpl - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: - return ["auto"] - @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return True diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 3e7a9b7f6e36..90328461c748 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -3,6 +3,7 @@ """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass +from typing import ClassVar import torch @@ -350,10 +351,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class AiterFlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -387,10 +386,6 @@ def get_kv_cache_shape( raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - return [MultipleOf(16)] - class AiterFlashAttentionImpl(AttentionImpl): def __init__( diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 2279496636ba..be925a232372 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -153,10 +153,7 @@ def build( class RocmAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] @classmethod def get_supported_head_sizes(cls) -> list[int]: diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 8df5f5b04617..73872dd6cb07 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -4,7 +4,7 @@ import ast from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import torch @@ -31,19 +31,13 @@ class TreeAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - return [MultipleOf(16)] - @staticmethod def get_name() -> str: return "TREE_ATTN" diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 1ca6c8a15159..42721bcec094 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -150,6 +150,18 @@ def build( class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + ] @staticmethod def get_name() -> str: @@ -187,18 +199,6 @@ def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: def supports_head_size(cls, head_size: int) -> bool: return head_size >= 32 - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @classmethod - def get_supported_kv_cache_dtypes(cls) -> list[CacheDType]: - return ["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] - @classmethod def supports_sink(cls) -> bool: return True diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 47847d6fa6f7..112e1b30de70 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import torch @@ -42,10 +42,8 @@ class XFormersAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -81,10 +79,6 @@ def get_supported_head_sizes(cls) -> list[int]: 256, ] - @classmethod - def get_supported_kernel_block_sizes(cls) -> list[int | MultipleOf]: - return [MultipleOf(16)] - @staticmethod def get_name() -> str: return "XFORMERS" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 13b4526ca2d3..a95eed7bcf19 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4171,10 +4171,9 @@ def _find_compatible_block_sizes( Raises: ValueError: If no compatible block size found """ - supported_block_size = backend_cls.get_supported_kernel_block_sizes() compatible_sizes = [] - for block_size in supported_block_size: + for block_size in backend_cls.supported_kernel_block_sizes: if isinstance(block_size, int): if kv_manager_block_size % block_size == 0: compatible_sizes.append(block_size) From 925069ca027db3031d60d3614816437cfac9e86e Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 11:47:51 -0400 Subject: [PATCH 74/84] device_capability not None Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b825b67efa5f..994dd0d18b04 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -45,17 +45,13 @@ @cache def _get_backend_priorities( use_mla: bool, - device_capability: DeviceCapability | None = None, + device_capability: DeviceCapability, ) -> dict[AttentionBackendEnum, int]: """Get backend priorities with lazy import to avoid circular dependency.""" from vllm.attention.backends.registry import AttentionBackendEnum if use_mla: - if ( - device_capability - and device_capability >= DeviceCapability(10, 0) - and device_capability < DeviceCapability(11, 0) - ): + if device_capability.major == 10: return { AttentionBackendEnum.CUTLASS_MLA: 0, AttentionBackendEnum.FLASHINFER_MLA: 1, @@ -72,11 +68,7 @@ def _get_backend_priorities( AttentionBackendEnum.TRITON_MLA: 3, } else: - if ( - device_capability - and device_capability >= DeviceCapability(10, 0) - and device_capability < DeviceCapability(11, 0) - ): + if device_capability.major == 10: return { AttentionBackendEnum.FLASHINFER: 0, AttentionBackendEnum.FLASH_ATTN: 1, From a0b56c5af5013247757bc32999bdc904f1e56a9a Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 11:53:28 -0400 Subject: [PATCH 75/84] query device_capability inside get_required_kv_cache_layout Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 8 +------- vllm/attention/selector.py | 3 +-- vllm/v1/attention/backends/flashinfer.py | 13 +++++-------- vllm/v1/attention/backends/mla/flashinfer_mla.py | 4 +--- 4 files changed, 8 insertions(+), 20 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index d1bc3d9f4608..003c3cb9b91e 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -242,13 +242,7 @@ def validate_configuration( return invalid_reasons @classmethod - def get_required_kv_cache_layout( - cls, capability: "DeviceCapability" - ) -> "KVCacheLayoutType | None": - """ - Some backends require a specific kv cache layout. - This function returns the required layout if any. - """ + def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": return None diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index d41d9b5e309d..2cf0111348e4 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -166,8 +166,7 @@ def _cached_get_attn_backend( backend = resolve_obj_by_qualname(attention_cls) # Adjust kv cache layout if the selected backend requires a specific one - device_capability = current_platform.get_device_capability() - required_layout = backend.get_required_kv_cache_layout(device_capability) + required_layout = backend.get_required_kv_cache_layout() if required_layout is not None: from vllm.v1.attention.backends.utils import set_kv_cache_layout diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 05d1aa18111e..cfe00ee0c781 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -230,14 +230,11 @@ def supports_compute_capability(cls, capability: DeviceCapability) -> bool: ) @classmethod - def get_required_kv_cache_layout( - cls, capability: DeviceCapability - ) -> KVCacheLayoutType | None: - if ( - capability is not None - and capability >= DeviceCapability(10, 0) - and capability <= DeviceCapability(10, 3) - ): + def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: + from vllm.platforms import current_platform + + capability = current_platform.get_device_capability() + if capability is not None and capability.major == 10: return "HND" return None diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 5744f59b3c60..1abe40386ebb 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -66,9 +66,7 @@ def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return capability.major == 10 @classmethod - def get_required_kv_cache_layout( - cls, capability: DeviceCapability - ) -> "KVCacheLayoutType | None": + def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": return "HND" From fff453a41eda2e5e4d19c8b03515dc5d97163739 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 12:17:36 -0400 Subject: [PATCH 76/84] Update vllm/attention/backends/abstract.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Luka Govedič Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 003c3cb9b91e..ebf12c302b4c 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -132,10 +132,10 @@ def supports_block_size(cls, block_size: int | None) -> bool: isinstance(supported_size, MultipleOf) and block_size % supported_size.base == 0 ) - is_int_divisor = ( - isinstance(supported_size, int) and block_size % supported_size == 0 + is_int_equal = ( + isinstance(supported_size, int) and block_size == supported_size ) - if is_multiple_of or is_int_divisor: + if is_multiple_of or is_int_equal: return True return False From 530f3569faf6edcc429ad24cc9da4c70b0c64a0c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 13:33:48 -0400 Subject: [PATCH 77/84] class_path always None in decorator Signed-off-by: Matthew Bonanni --- vllm/attention/backends/registry.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 19d51f57a168..da3ee35ad3d6 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -144,8 +144,7 @@ class MyCustomBackend: """ def decorator(cls): - path = class_path or f"{cls.__module__}.{cls.__qualname__}" - _OVERRIDES[backend] = path + _OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" return cls if class_path is not None: From 933ee5f73eaf4d4fd24fc40e5bc7d83668c7180b Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 13:35:40 -0400 Subject: [PATCH 78/84] type hint for value Signed-off-by: Matthew Bonanni --- vllm/config/multimodal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 3bfc20100123..9348c1b2af8c 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -168,7 +168,7 @@ def _validate_limit_per_prompt( @field_validator("mm_encoder_attn_backend", mode="before") @classmethod def _validate_mm_encoder_attn_backend( - cls, value: object + cls, value: str | AttentionBackendEnum | None ) -> AttentionBackendEnum | None: if value is None or isinstance(value, AttentionBackendEnum): return value From 255edc9d72abff3add377215360834750abca555 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 13:37:17 -0400 Subject: [PATCH 79/84] restore comment Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/flashinfer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cfe00ee0c781..28cefcad7d87 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -162,6 +162,9 @@ def trtllm_prefill_attn_kvfp8_dequant( class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + # Note: Not sure for all platforms, + # but on Blackwell, only support a page size of + # 16, 32, 64 supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", From c9d62f8f138206ff6d67cdf2421de4bf58f74aae Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 14:09:27 -0400 Subject: [PATCH 80/84] fix docs Signed-off-by: Matthew Bonanni --- vllm/attention/backends/registry.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index da3ee35ad3d6..4fb8c581890f 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -3,6 +3,7 @@ """Attention backend registry""" import enum +from collections.abc import Callable from typing import TYPE_CHECKING, cast from vllm.utils.import_utils import resolve_obj_by_qualname @@ -114,7 +115,9 @@ def clear_override(self) -> None: _OVERRIDES: dict[AttentionBackendEnum, str] = {} -def register_backend(backend: AttentionBackendEnum, class_path: str | None = None): +def register_backend( + backend: AttentionBackendEnum, class_path: str | None = None +) -> Callable[[type], type]: """Register or override a backend implementation. Args: @@ -143,7 +146,7 @@ class MyCustomBackend: ) """ - def decorator(cls): + def decorator(cls: type) -> type: _OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" return cls From f6a5a3257fb43164a4e5018e9dc324e75162f427 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 31 Oct 2025 15:44:42 -0400 Subject: [PATCH 81/84] add FLASHMLA_SPARSE to priority list Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 994dd0d18b04..2837d5e765f0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -66,6 +66,7 @@ def _get_backend_priorities( AttentionBackendEnum.FLASH_ATTN_MLA: 1, AttentionBackendEnum.FLASHINFER_MLA: 2, AttentionBackendEnum.TRITON_MLA: 3, + AttentionBackendEnum.FLASHMLA_SPARSE: 4, } else: if device_capability.major == 10: From 0435eca4502956cc9dc1496fb0c1944175576e50 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Sat, 1 Nov 2025 21:47:27 -0400 Subject: [PATCH 82/84] fix test Signed-off-by: Matthew Bonanni --- tests/v1/worker/test_gpu_model_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 8ad0014ac587..23d1f1240be7 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -187,9 +187,7 @@ def _make_mock_backend_for_kernel_block_size( supported_sizes: list[int | MultipleOf], ): class _MockBackend: - @staticmethod - def get_supported_kernel_block_size(): - return supported_sizes + supported_kernel_block_sizes = supported_sizes return _MockBackend() From a098d823373ebb45af809599419780797d0f16c5 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Sat, 1 Nov 2025 21:50:23 -0400 Subject: [PATCH 83/84] fix flashmla_sparse Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index d76b63159475..c9fce9c252c7 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -75,6 +75,14 @@ def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: def get_impl_cls() -> type["FlashMLASparseImpl"]: return FlashMLASparseImpl + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + @classmethod + def is_mla(cls) -> bool: + return True + @classmethod def is_sparse(cls) -> bool: return True From 4452f5f41003972013ff29972275252bcd64a230 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Sat, 1 Nov 2025 23:36:21 -0400 Subject: [PATCH 84/84] fix pre-commit Signed-off-by: Matthew Bonanni --- vllm/model_executor/models/qwen2_5_vl.py | 2 +- vllm/model_executor/models/qwen2_vl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index a5b70b713c41..49bae869b834 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -863,7 +863,7 @@ def compute_attn_mask_seqlen( seqlens = torch.zeros(1, device=cu_seqlens.device) if self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA + AttentionBackendEnum.ROCM_AITER_FA, }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() elif self.attn_backend == AttentionBackendEnum.XFORMERS: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 319c4f11eba0..c198b8a1cf68 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -792,7 +792,7 @@ def compute_attn_mask_seqlen( max_seqlen, seqlens = None, None if self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA + AttentionBackendEnum.ROCM_AITER_FA, }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == AttentionBackendEnum.XFORMERS: