diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 697beed91869..9275d70fd86a 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -142,6 +142,17 @@ def supports_sink(cls) -> bool: def is_sparse(cls) -> bool: return False + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """Check if backend supports a given attention type. + + By default, only supports decoder attention. + Backends should override this to support other attention types. + """ + from vllm.attention import AttentionType + + return attn_type == AttentionType.DECODER + @classmethod def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: return True @@ -171,6 +182,7 @@ def validate_configuration( has_sink: bool, use_sparse: bool, device_capability: "DeviceCapability", + attn_type: str, ) -> list[str]: invalid_reasons = [] if not cls.supports_head_size(head_size): @@ -195,6 +207,8 @@ def validate_configuration( invalid_reasons.append("non-sparse not supported") if not cls.supports_compute_capability(device_capability): invalid_reasons.append("compute capability not supported") + if not cls.supports_attn_type(attn_type): + invalid_reasons.append(f"attention type {attn_type} not supported") combination_reason = cls.supports_combination( head_size, dtype, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 487bba76babf..37f9a4b383ce 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -291,6 +291,7 @@ def __init__( block_size, use_mla=False, has_sink=self.has_sink, + attn_type=attn_type, ) else: self.attn_backend = attn_backend diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py index 4929bbf5efc7..5e99c9901003 100644 --- a/vllm/attention/layers/encoder_only_attention.py +++ b/vllm/attention/layers/encoder_only_attention.py @@ -74,7 +74,11 @@ def __init__( block_size = 16 underlying_attn_backend = get_attn_backend( - head_size, dtype, kv_cache_dtype, block_size + head_size, + dtype, + kv_cache_dtype, + block_size, + attn_type=AttentionType.ENCODER_ONLY, ) attn_backend = create_encoder_only_attention_backend(underlying_attn_backend) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 262cdf0e575b..1a092db9ce37 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -76,6 +76,7 @@ def get_attn_backend( use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, + attn_type: str | None = None, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" @@ -94,6 +95,7 @@ def get_attn_backend( use_mla=use_mla, has_sink=has_sink, use_sparse=use_sparse, + attn_type=attn_type, ) @@ -106,6 +108,7 @@ def _cached_get_attn_backend( use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, + attn_type: str | None = None, ) -> type[AttentionBackend]: # Check whether a particular choice of backend was # previously forced. @@ -159,6 +162,7 @@ def _cached_get_attn_backend( use_mla, has_sink, use_sparse, + attn_type, ) else: attention_cls = current_platform.get_attn_backend_cls( @@ -170,6 +174,7 @@ def _cached_get_attn_backend( use_mla, has_sink, use_sparse, + attn_type, ) if not attention_cls: raise ValueError( diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 8b3b8d4cb44f..cf954768689f 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -134,6 +134,7 @@ def get_attn_backend_cls( use_mla: bool, has_sink: bool, use_sparse: bool, + attn_type: str | None = None, ) -> str: from vllm.attention.backends.registry import AttentionBackendEnum diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ebcc290a64cd..2e4dd8bb808b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -298,6 +298,7 @@ def get_valid_backends( has_sink, use_sparse, device_capability, + attn_type, ) -> tuple[ list[tuple["AttentionBackendEnum", int]], dict["AttentionBackendEnum", list[str]], @@ -318,6 +319,7 @@ def get_valid_backends( has_sink, use_sparse, device_capability, + attn_type, ) except ImportError: invalid_reasons_i = ["ImportError"] @@ -339,7 +341,13 @@ def get_attn_backend_cls( use_mla: bool, has_sink: bool, use_sparse: bool, + attn_type: str | None = None, ) -> str: + from vllm.attention import AttentionType + + if attn_type is None: + attn_type = AttentionType.DECODER + device_capability = cls.get_device_capability() assert device_capability is not None @@ -356,6 +364,7 @@ def get_attn_backend_cls( has_sink, use_sparse, device_capability, + attn_type, ) except ImportError: invalid_reasons = ["ImportError"] @@ -379,6 +388,7 @@ def get_attn_backend_cls( has_sink, use_sparse, device_capability, + attn_type, ) reasons_str = ( "{" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 12c377384270..0471c20429b1 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -222,6 +222,7 @@ def get_attn_backend_cls( use_mla: bool, has_sink: bool, use_sparse: bool, + attn_type: str | None = None, ) -> str: """Get the attention backend class of a device.""" return "" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d20dc9e6b067..788f9d69c357 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -216,6 +216,7 @@ def get_attn_backend_cls( use_mla, has_sink, use_sparse, + attn_type: str | None = None, ) -> str: from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.registry import AttentionBackendEnum diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 4773fef6829d..b997bb9e6999 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -61,6 +61,7 @@ def get_attn_backend_cls( use_mla: bool, has_sink, use_sparse, + attn_type: str | None = None, ) -> str: from vllm.attention.backends.registry import AttentionBackendEnum diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index c629325f76a3..5552e4ca4b2f 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -51,6 +51,7 @@ def get_attn_backend_cls( use_mla: bool, has_sink: bool, use_sparse, + attn_type: str | None = None, ) -> str: from vllm.v1.attention.backends.utils import set_kv_cache_layout diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 674398e19c4c..f1254352c058 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -48,6 +48,17 @@ def get_supported_head_sizes(cls) -> list[int]: def get_name() -> str: return "CPU_ATTN" + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """CPU attention supports decoder and encoder-only attention.""" + from vllm.attention import AttentionType + + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ) + @staticmethod def get_impl_cls() -> type["CPUAttentionBackendImpl"]: return CPUAttentionBackendImpl diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d9bd52d8f980..bfb4a45c2b56 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend): def get_name() -> str: return "FLASH_ATTN" + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """FlashAttention supports all attention types.""" + from vllm.attention import AttentionType + + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + AttentionType.ENCODER_DECODER, + ) + @staticmethod def get_impl_cls() -> type["FlashAttentionImpl"]: return FlashAttentionImpl diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index e53cd0d8af4f..7768827d26dc 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend): def get_name() -> str: return "FLEX_ATTENTION" + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """FlexAttention supports both decoder and encoder-only attention.""" + from vllm.attention import AttentionType + + return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY) + @staticmethod def get_impl_cls() -> type["FlexAttentionImpl"]: return FlexAttentionImpl diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 5fe9c69d3500..bb8d914d1571 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -40,14 +40,14 @@ """ NOTE: FlashMLA Sparse uses an fp8 cache with the following format -In the "FP8 with scale" format, each token's KV cache is 656 Bytes, +In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as: -- **First 512 bytes:** The "quantized NoPE" part, containing 512 +- **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values. -- **Next 16 bytes:** Scale factors, containing 4 `float32` values. - The first `float32` is the scale for the first 128 `float8_e4m3` values, +- **Next 16 bytes:** Scale factors, containing 4 `float32` values. + The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on. -- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This +- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy. """