|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import Any, Optional |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +import vllm.envs as envs |
| 9 | +from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd |
| 10 | +# yapf conflicts with isort for this docstring |
| 11 | +# yapf: disable |
| 12 | +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, |
| 13 | + MLACommonDecodeMetadata, |
| 14 | + MLACommonImpl, |
| 15 | + MLACommonMetadata, |
| 16 | + MLACommonMetadataBuilder) |
| 17 | + |
| 18 | +# yapf: enable |
| 19 | + |
| 20 | + |
| 21 | +def is_aiter_mla_enabled() -> bool: |
| 22 | + return envs.VLLM_ROCM_USE_AITER \ |
| 23 | + and envs.VLLM_ROCM_USE_AITER_MLA |
| 24 | + |
| 25 | + |
| 26 | +class AiterMLABackend(MLACommonBackend): |
| 27 | + |
| 28 | + @staticmethod |
| 29 | + def get_name() -> str: |
| 30 | + return "ROCM_AITER_MLA_VLLM_V1" |
| 31 | + |
| 32 | + @staticmethod |
| 33 | + def get_impl_cls() -> type["AiterMLAImpl"]: |
| 34 | + return AiterMLAImpl |
| 35 | + |
| 36 | + @staticmethod |
| 37 | + def get_metadata_cls() -> type["AiterMLAMetadata"]: |
| 38 | + return AiterMLAMetadata |
| 39 | + |
| 40 | + @staticmethod |
| 41 | + def get_builder_cls() -> type["AiterMLAMetadataBuilder"]: |
| 42 | + return AiterMLAMetadataBuilder |
| 43 | + |
| 44 | + |
| 45 | +@dataclass |
| 46 | +class AiterMLADecodeMetadata(MLACommonDecodeMetadata): |
| 47 | + # The indptr of the paged kv cache, shape: [batch_size + 1] |
| 48 | + paged_kv_indptr: Optional[torch.Tensor] = None |
| 49 | + # The page indices of the paged kv cache |
| 50 | + paged_kv_indices: Optional[torch.Tensor] = None |
| 51 | + # The number of entries in the last page of each request in |
| 52 | + # the paged kv cache, shape: [batch_size] |
| 53 | + paged_kv_last_page_len: Optional[torch.Tensor] = None |
| 54 | + |
| 55 | + |
| 56 | +class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): |
| 57 | + pass |
| 58 | + |
| 59 | + |
| 60 | +class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): |
| 61 | + |
| 62 | + def __init__(self, runner): |
| 63 | + super().__init__(runner) |
| 64 | + max_model_len = self.runner.model_config.max_model_len |
| 65 | + assert max_model_len == 32768,\ |
| 66 | + "AITER MLA requires max_model_len=32768" |
| 67 | + assert self.runner.block_size == 1, "AITER MLA" \ |
| 68 | + "only supports block size 1." |
| 69 | + |
| 70 | + def _get_paged_kv_tensors( |
| 71 | + self, block_table: torch.Tensor, |
| 72 | + seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: |
| 73 | + page_size = self.runner.block_size |
| 74 | + block_table_bounds = (seq_lens + page_size - 1) // page_size |
| 75 | + |
| 76 | + mask = (torch.arange(block_table.size(1), |
| 77 | + dtype=block_table.dtype, |
| 78 | + device=block_table.device).unsqueeze(0) |
| 79 | + < block_table_bounds.unsqueeze(1)) |
| 80 | + paged_kv_indices = block_table[mask] |
| 81 | + |
| 82 | + paged_kv_indptr = torch.cat([ |
| 83 | + torch.zeros(1, |
| 84 | + dtype=block_table_bounds.dtype, |
| 85 | + device=block_table_bounds.device), |
| 86 | + block_table_bounds.cumsum(dim=0, dtype=torch.int32) |
| 87 | + ]) |
| 88 | + |
| 89 | + paged_kv_last_page_len = seq_lens % page_size |
| 90 | + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, |
| 91 | + page_size, paged_kv_last_page_len) |
| 92 | + return ( |
| 93 | + paged_kv_indices, |
| 94 | + paged_kv_indptr, |
| 95 | + paged_kv_last_page_len, |
| 96 | + ) |
| 97 | + |
| 98 | + def _build_decode(self, input_positions: torch.Tensor, |
| 99 | + block_table: torch.Tensor, |
| 100 | + seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: |
| 101 | + |
| 102 | + ( |
| 103 | + paged_kv_indices, |
| 104 | + paged_kv_indptr, |
| 105 | + paged_last_page_len, |
| 106 | + ) = self._get_paged_kv_tensors(block_table, seq_lens) |
| 107 | + |
| 108 | + attn_metadata = AiterMLADecodeMetadata( |
| 109 | + input_positions=input_positions, |
| 110 | + block_table=block_table, |
| 111 | + seq_lens=seq_lens, |
| 112 | + paged_kv_indptr=paged_kv_indptr, |
| 113 | + paged_kv_indices=paged_kv_indices, |
| 114 | + paged_kv_last_page_len=paged_last_page_len) |
| 115 | + |
| 116 | + return attn_metadata |
| 117 | + |
| 118 | + |
| 119 | +class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): |
| 120 | + |
| 121 | + def __init__( |
| 122 | + self, |
| 123 | + num_heads: int, |
| 124 | + head_size: int, |
| 125 | + scale: float, |
| 126 | + num_kv_heads: int, |
| 127 | + alibi_slopes: Optional[list[float]], |
| 128 | + sliding_window: Optional[int], |
| 129 | + kv_cache_dtype: str, |
| 130 | + blocksparse_params: Optional[dict[str, Any]], |
| 131 | + logits_soft_cap: Optional[float], |
| 132 | + attn_type: str, |
| 133 | + # MLA Specific Arguments |
| 134 | + **mla_args) -> None: |
| 135 | + super().__init__(num_heads, head_size, scale, num_kv_heads, |
| 136 | + alibi_slopes, sliding_window, kv_cache_dtype, |
| 137 | + blocksparse_params, logits_soft_cap, attn_type, |
| 138 | + **mla_args) |
| 139 | + |
| 140 | + unsupported_features = [ |
| 141 | + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap |
| 142 | + ] |
| 143 | + if any(unsupported_features): |
| 144 | + raise NotImplementedError( |
| 145 | + "Aiter MLA does not support one of the following: " |
| 146 | + "alibi_slopes, sliding_window, blocksparse_params, " |
| 147 | + "logits_soft_cap") |
| 148 | + |
| 149 | + from aiter import flash_attn_varlen_func |
| 150 | + self.flash_attn_varlen_func = flash_attn_varlen_func |
| 151 | + |
| 152 | + def _flash_attn_varlen_diff_headdims(self, |
| 153 | + q, |
| 154 | + k, |
| 155 | + v, |
| 156 | + return_softmax_lse=False, |
| 157 | + softmax_scale=None, |
| 158 | + **kwargs): |
| 159 | + output = self.flash_attn_varlen_func( |
| 160 | + q=q, |
| 161 | + k=k, |
| 162 | + v=v, |
| 163 | + softmax_scale=softmax_scale, |
| 164 | + return_lse=return_softmax_lse, |
| 165 | + **kwargs, |
| 166 | + ) |
| 167 | + |
| 168 | + return output |
| 169 | + |
| 170 | + def _forward_decode( |
| 171 | + self, |
| 172 | + q_nope: torch.Tensor, |
| 173 | + q_pe: torch.Tensor, |
| 174 | + kv_c_and_k_pe_cache: torch.Tensor, |
| 175 | + attn_metadata: AiterMLAMetadata, |
| 176 | + ) -> torch.Tensor: |
| 177 | + assert kv_c_and_k_pe_cache.numel() > 0 |
| 178 | + assert attn_metadata.decode is not None |
| 179 | + |
| 180 | + B = q_nope.shape[0] |
| 181 | + |
| 182 | + q = torch.cat([q_nope, q_pe], dim=-1) |
| 183 | + o = torch.zeros(B, |
| 184 | + self.num_heads, |
| 185 | + self.kv_lora_rank, |
| 186 | + dtype=q.dtype, |
| 187 | + device=q.device) |
| 188 | + |
| 189 | + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) |
| 190 | + |
| 191 | + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, |
| 192 | + attn_metadata.decode.paged_kv_indptr, |
| 193 | + attn_metadata.decode.paged_kv_indices, |
| 194 | + attn_metadata.decode.paged_kv_last_page_len) |
| 195 | + |
| 196 | + return self._v_up_proj(o) |
0 commit comments