diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 0209c7236278..2a4cac46c066 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -161,15 +161,9 @@ def get_current_memory_usage(cls, def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) -> str: - if use_v1: - if use_mla: - logger.info("Using Triton MLA backend on V1 engine.") - return "vllm.v1.attention.backends.triton_mla.TritonMLABackend" - else: - logger.info("Using Flash Attention backend on V1 engine.") - return ("vllm.v1.attention.backends.flash_attn." - "FlashAttentionBackend") if use_mla: + # TODO(lucas): refactor to be more concise + # we should probably consider factoring out V1 here if selected_backend == _Backend.FLASHMLA: from vllm.attention.backends.flashmla import ( is_flashmla_supported) @@ -183,11 +177,26 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, " (currently only supports block size 64).", block_size) else: - logger.info("Using FlashMLA backend.") - return "vllm.attention.backends.flashmla.FlashMLABackend" - - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" + if use_v1: + logger.info("Using FlashMLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashmla.FlashMLABackend") + else: + logger.info("Using FlashMLA backend.") + return ("vllm.attention.backends." + "flashmla.FlashMLABackend") + + if use_v1: + logger.info("Using Triton MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "triton_mla.TritonMLABackend") + else: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" + if use_v1: + logger.info("Using Flash Attention backend on V1 engine.") + return ("vllm.v1.attention.backends.flash_attn." + "FlashAttentionBackend") if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") return "vllm.attention.backends.flashinfer.FlashInferBackend" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 4af413dff0fa..d81a66e4bcb1 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -34,9 +34,8 @@ class _Backend(enum.Enum): TORCH_SDPA = enum.auto() OPENVINO = enum.auto() FLASHINFER = enum.auto() - TRITON_MLA = enum.auto() - TRITON_MLA_VLLM_V1 = enum.auto() - FLASHMLA = enum.auto() + TRITON_MLA = enum.auto() # Supported by V1 + FLASHMLA = enum.auto() # Supported by V1 HPU_ATTN = enum.auto() PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 2a742f5ce524..30bce5cc8b68 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -333,13 +333,16 @@ def __post_init__(self): T = TypeVar("T", bound=MLACommonMetadata) -class MLACommonMetadataBuilder: +class MLACommonMetadataBuilder(Generic[T]): """ NOTE: Please read the comment at the top of the file before trying to understand this class """ - def __init__(self, runner: "GPUModelRunner"): + def __init__(self, + runner: "GPUModelRunner", + cls: Optional[type[T]] = None): + self.cls = cls if cls is not None else MLACommonMetadata self.runner = runner scheduler_config = runner.scheduler_config model_config = runner.model_config @@ -431,7 +434,7 @@ def reorder_batch(self, input_batch: "InputBatch", self._num_prefill_tokens = num_prefill_tokens def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int) -> T: device = self.runner.device max_seq_len = self.runner.seq_lens_np[:num_reqs].max() query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( @@ -502,7 +505,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, assert max(context_chunk_seq_tot) <= \ self.chunked_prefill_workspace_size - return MLACommonMetadata( + return self.cls( input_positions=input_positions, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py new file mode 100644 index 000000000000..8a7b7b974e36 --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) + +logger = init_logger(__name__) + + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA_VLLM_V1" + + @staticmethod + def get_metadata_cls() -> Type["FlashMLAMetadata"]: + return FlashMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> Type["FlashMLAImpl"]: + return FlashMLAImpl + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata): + decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None + decode_num_splits: Optional[torch.Tensor] = None + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + + def __init__(self, runner): + super().__init__(runner, cls=FlashMLAMetadata) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int): + m = super().build(num_reqs, num_actual_tokens, max_query_len, + common_prefix_len) + + if m.num_decode_tokens is not None and m.num_decode_tokens > 0: + m.decode_tile_scheduler_metadata, m.decode_num_splits = \ + get_mla_metadata( + m.seq_lens[:m.num_decode_tokens], + self.num_q_heads, + 1, # MQA for the decode path + ) + + return m + + +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + assert is_flashmla_supported(), \ + "FlashMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 FlashMLA not yet supported") + + q = torch.cat([q_nope, q_pe], dim=-1)\ + .unsqueeze(1) # Add seqlen dim of 1 (decode) + + o, _ = flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + block_table=attn_metadata.block_table[:attn_metadata.num_decodes, + ...], + cache_seqlens=attn_metadata.seq_lens[:attn_metadata. + num_decode_tokens], + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=attn_metadata. + decode_tile_scheduler_metadata, + num_splits=attn_metadata.decode_num_splits, + softmax_scale=self.scale, + causal=True, + ) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/v1/attention/backends/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py similarity index 100% rename from vllm/v1/attention/backends/triton_mla.py rename to vllm/v1/attention/backends/mla/triton_mla.py