|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import Any, Dict, List, Optional, Tuple, Type |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +from vllm.attention.backends.abstract import AttentionType |
| 9 | +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, |
| 10 | + get_mla_metadata, |
| 11 | + is_flashmla_supported) |
| 12 | +from vllm.logger import init_logger |
| 13 | +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, |
| 14 | + MLACommonImpl, |
| 15 | + MLACommonMetadata, |
| 16 | + MLACommonMetadataBuilder) |
| 17 | + |
| 18 | +logger = init_logger(__name__) |
| 19 | + |
| 20 | + |
| 21 | +class FlashMLABackend(MLACommonBackend): |
| 22 | + |
| 23 | + @staticmethod |
| 24 | + def get_name() -> str: |
| 25 | + return "FLASHMLA_VLLM_V1" |
| 26 | + |
| 27 | + @staticmethod |
| 28 | + def get_metadata_cls() -> Type["FlashMLAMetadata"]: |
| 29 | + return FlashMLAMetadata |
| 30 | + |
| 31 | + @staticmethod |
| 32 | + def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: |
| 33 | + return FlashMLAMetadataBuilder |
| 34 | + |
| 35 | + @staticmethod |
| 36 | + def get_impl_cls() -> Type["FlashMLAImpl"]: |
| 37 | + return FlashMLAImpl |
| 38 | + |
| 39 | + |
| 40 | +@dataclass |
| 41 | +class FlashMLAMetadata(MLACommonMetadata): |
| 42 | + decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, |
| 43 | + torch.Tensor]] = None |
| 44 | + decode_num_splits: Optional[torch.Tensor] = None |
| 45 | + |
| 46 | + |
| 47 | +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): |
| 48 | + |
| 49 | + def __init__(self, runner): |
| 50 | + super().__init__(runner, cls=FlashMLAMetadata) |
| 51 | + |
| 52 | + self.num_q_heads = self.runner.model_config.get_num_attention_heads( |
| 53 | + self.runner.parallel_config) |
| 54 | + |
| 55 | + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, |
| 56 | + common_prefix_len: int): |
| 57 | + m = super().build(num_reqs, num_actual_tokens, max_query_len, |
| 58 | + common_prefix_len) |
| 59 | + |
| 60 | + if m.num_decode_tokens is not None and m.num_decode_tokens > 0: |
| 61 | + m.decode_tile_scheduler_metadata, m.decode_num_splits = \ |
| 62 | + get_mla_metadata( |
| 63 | + m.seq_lens[:m.num_decode_tokens], |
| 64 | + self.num_q_heads, |
| 65 | + 1, # MQA for the decode path |
| 66 | + ) |
| 67 | + |
| 68 | + return m |
| 69 | + |
| 70 | + |
| 71 | +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): |
| 72 | + |
| 73 | + def __init__( |
| 74 | + self, |
| 75 | + num_heads: int, |
| 76 | + head_size: int, |
| 77 | + scale: float, |
| 78 | + num_kv_heads: int, |
| 79 | + alibi_slopes: Optional[List[float]], |
| 80 | + sliding_window: Optional[int], |
| 81 | + kv_cache_dtype: str, |
| 82 | + blocksparse_params: Optional[Dict[str, Any]], |
| 83 | + logits_soft_cap: Optional[float], |
| 84 | + attn_type: str, |
| 85 | + # MLA Specific Arguments |
| 86 | + **mla_args) -> None: |
| 87 | + super().__init__(num_heads, head_size, scale, num_kv_heads, |
| 88 | + alibi_slopes, sliding_window, kv_cache_dtype, |
| 89 | + blocksparse_params, logits_soft_cap, attn_type, |
| 90 | + **mla_args) |
| 91 | + |
| 92 | + assert is_flashmla_supported(), \ |
| 93 | + "FlashMLA is not supported on this device" |
| 94 | + |
| 95 | + unsupported_features = [ |
| 96 | + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap |
| 97 | + ] |
| 98 | + if any(unsupported_features): |
| 99 | + raise NotImplementedError( |
| 100 | + "FlashMLAImpl does not support one of the following: " |
| 101 | + "alibi_slopes, sliding_window, blocksparse_params, " |
| 102 | + "logits_soft_cap") |
| 103 | + |
| 104 | + if attn_type != AttentionType.DECODER: |
| 105 | + raise NotImplementedError("Encoder self-attention and " |
| 106 | + "encoder/decoder cross-attention " |
| 107 | + "are not implemented for " |
| 108 | + "FlashMLAImpl") |
| 109 | + |
| 110 | + def _forward_decode( |
| 111 | + self, |
| 112 | + q_nope: torch.Tensor, |
| 113 | + q_pe: torch.Tensor, |
| 114 | + kv_c_and_k_pe_cache: torch.Tensor, |
| 115 | + attn_metadata: FlashMLAMetadata, |
| 116 | + ) -> torch.Tensor: |
| 117 | + assert kv_c_and_k_pe_cache.numel() > 0 |
| 118 | + if self.kv_cache_dtype.startswith("fp8"): |
| 119 | + raise NotImplementedError("FP8 FlashMLA not yet supported") |
| 120 | + |
| 121 | + q = torch.cat([q_nope, q_pe], dim=-1)\ |
| 122 | + .unsqueeze(1) # Add seqlen dim of 1 (decode) |
| 123 | + |
| 124 | + o, _ = flash_mla_with_kvcache( |
| 125 | + q=q, |
| 126 | + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 |
| 127 | + block_table=attn_metadata.block_table[:attn_metadata.num_decodes, |
| 128 | + ...], |
| 129 | + cache_seqlens=attn_metadata.seq_lens[:attn_metadata. |
| 130 | + num_decode_tokens], |
| 131 | + head_dim_v=self.kv_lora_rank, |
| 132 | + tile_scheduler_metadata=attn_metadata. |
| 133 | + decode_tile_scheduler_metadata, |
| 134 | + num_splits=attn_metadata.decode_num_splits, |
| 135 | + softmax_scale=self.scale, |
| 136 | + causal=True, |
| 137 | + ) |
| 138 | + |
| 139 | + return self._v_up_proj_and_o_proj(o) |
0 commit comments