Skip to content

Commit 2e94b9c

Browse files
[Attention] Flash MLA for V1 (#13867)
Signed-off-by: Yang Chen <yangche@fb.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Yang Chen <yangche@fb.com>
1 parent 8294773 commit 2e94b9c

File tree

5 files changed

+170
-20
lines changed

5 files changed

+170
-20
lines changed

vllm/platforms/cuda.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,9 @@ def get_current_memory_usage(cls,
161161
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
162162
kv_cache_dtype, block_size, use_v1,
163163
use_mla) -> str:
164-
if use_v1:
165-
if use_mla:
166-
logger.info("Using Triton MLA backend on V1 engine.")
167-
return "vllm.v1.attention.backends.triton_mla.TritonMLABackend"
168-
else:
169-
logger.info("Using Flash Attention backend on V1 engine.")
170-
return ("vllm.v1.attention.backends.flash_attn."
171-
"FlashAttentionBackend")
172164
if use_mla:
165+
# TODO(lucas): refactor to be more concise
166+
# we should probably consider factoring out V1 here
173167
if selected_backend == _Backend.FLASHMLA:
174168
from vllm.attention.backends.flashmla import (
175169
is_flashmla_supported)
@@ -183,11 +177,26 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
183177
" (currently only supports block size 64).",
184178
block_size)
185179
else:
186-
logger.info("Using FlashMLA backend.")
187-
return "vllm.attention.backends.flashmla.FlashMLABackend"
188-
189-
logger.info("Using Triton MLA backend.")
190-
return "vllm.attention.backends.triton_mla.TritonMLABackend"
180+
if use_v1:
181+
logger.info("Using FlashMLA backend on V1 engine.")
182+
return ("vllm.v1.attention.backends.mla."
183+
"flashmla.FlashMLABackend")
184+
else:
185+
logger.info("Using FlashMLA backend.")
186+
return ("vllm.attention.backends."
187+
"flashmla.FlashMLABackend")
188+
189+
if use_v1:
190+
logger.info("Using Triton MLA backend on V1 engine.")
191+
return ("vllm.v1.attention.backends.mla."
192+
"triton_mla.TritonMLABackend")
193+
else:
194+
logger.info("Using Triton MLA backend.")
195+
return "vllm.attention.backends.triton_mla.TritonMLABackend"
196+
if use_v1:
197+
logger.info("Using Flash Attention backend on V1 engine.")
198+
return ("vllm.v1.attention.backends.flash_attn."
199+
"FlashAttentionBackend")
191200
if selected_backend == _Backend.FLASHINFER:
192201
logger.info("Using FlashInfer backend.")
193202
return "vllm.attention.backends.flashinfer.FlashInferBackend"

vllm/platforms/interface.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ class _Backend(enum.Enum):
3434
TORCH_SDPA = enum.auto()
3535
OPENVINO = enum.auto()
3636
FLASHINFER = enum.auto()
37-
TRITON_MLA = enum.auto()
38-
TRITON_MLA_VLLM_V1 = enum.auto()
39-
FLASHMLA = enum.auto()
37+
TRITON_MLA = enum.auto() # Supported by V1
38+
FLASHMLA = enum.auto() # Supported by V1
4039
HPU_ATTN = enum.auto()
4140
PALLAS = enum.auto()
4241
PALLAS_VLLM_V1 = enum.auto()

vllm/v1/attention/backends/mla/common.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,13 +333,16 @@ def __post_init__(self):
333333
T = TypeVar("T", bound=MLACommonMetadata)
334334

335335

336-
class MLACommonMetadataBuilder:
336+
class MLACommonMetadataBuilder(Generic[T]):
337337
"""
338338
NOTE: Please read the comment at the top of the file before trying to
339339
understand this class
340340
"""
341341

342-
def __init__(self, runner: "GPUModelRunner"):
342+
def __init__(self,
343+
runner: "GPUModelRunner",
344+
cls: Optional[type[T]] = None):
345+
self.cls = cls if cls is not None else MLACommonMetadata
343346
self.runner = runner
344347
scheduler_config = runner.scheduler_config
345348
model_config = runner.model_config
@@ -431,7 +434,7 @@ def reorder_batch(self, input_batch: "InputBatch",
431434
self._num_prefill_tokens = num_prefill_tokens
432435

433436
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
434-
common_prefix_len: int):
437+
common_prefix_len: int) -> T:
435438
device = self.runner.device
436439
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
437440
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
@@ -502,7 +505,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
502505
assert max(context_chunk_seq_tot) <= \
503506
self.chunked_prefill_workspace_size
504507

505-
return MLACommonMetadata(
508+
return self.cls(
506509
input_positions=input_positions,
507510
num_actual_tokens=num_actual_tokens,
508511
max_query_len=max_query_len,
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Dict, List, Optional, Tuple, Type
5+
6+
import torch
7+
8+
from vllm.attention.backends.abstract import AttentionType
9+
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
10+
get_mla_metadata,
11+
is_flashmla_supported)
12+
from vllm.logger import init_logger
13+
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
14+
MLACommonImpl,
15+
MLACommonMetadata,
16+
MLACommonMetadataBuilder)
17+
18+
logger = init_logger(__name__)
19+
20+
21+
class FlashMLABackend(MLACommonBackend):
22+
23+
@staticmethod
24+
def get_name() -> str:
25+
return "FLASHMLA_VLLM_V1"
26+
27+
@staticmethod
28+
def get_metadata_cls() -> Type["FlashMLAMetadata"]:
29+
return FlashMLAMetadata
30+
31+
@staticmethod
32+
def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]:
33+
return FlashMLAMetadataBuilder
34+
35+
@staticmethod
36+
def get_impl_cls() -> Type["FlashMLAImpl"]:
37+
return FlashMLAImpl
38+
39+
40+
@dataclass
41+
class FlashMLAMetadata(MLACommonMetadata):
42+
decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor,
43+
torch.Tensor]] = None
44+
decode_num_splits: Optional[torch.Tensor] = None
45+
46+
47+
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
48+
49+
def __init__(self, runner):
50+
super().__init__(runner, cls=FlashMLAMetadata)
51+
52+
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
53+
self.runner.parallel_config)
54+
55+
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
56+
common_prefix_len: int):
57+
m = super().build(num_reqs, num_actual_tokens, max_query_len,
58+
common_prefix_len)
59+
60+
if m.num_decode_tokens is not None and m.num_decode_tokens > 0:
61+
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
62+
get_mla_metadata(
63+
m.seq_lens[:m.num_decode_tokens],
64+
self.num_q_heads,
65+
1, # MQA for the decode path
66+
)
67+
68+
return m
69+
70+
71+
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
72+
73+
def __init__(
74+
self,
75+
num_heads: int,
76+
head_size: int,
77+
scale: float,
78+
num_kv_heads: int,
79+
alibi_slopes: Optional[List[float]],
80+
sliding_window: Optional[int],
81+
kv_cache_dtype: str,
82+
blocksparse_params: Optional[Dict[str, Any]],
83+
logits_soft_cap: Optional[float],
84+
attn_type: str,
85+
# MLA Specific Arguments
86+
**mla_args) -> None:
87+
super().__init__(num_heads, head_size, scale, num_kv_heads,
88+
alibi_slopes, sliding_window, kv_cache_dtype,
89+
blocksparse_params, logits_soft_cap, attn_type,
90+
**mla_args)
91+
92+
assert is_flashmla_supported(), \
93+
"FlashMLA is not supported on this device"
94+
95+
unsupported_features = [
96+
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
97+
]
98+
if any(unsupported_features):
99+
raise NotImplementedError(
100+
"FlashMLAImpl does not support one of the following: "
101+
"alibi_slopes, sliding_window, blocksparse_params, "
102+
"logits_soft_cap")
103+
104+
if attn_type != AttentionType.DECODER:
105+
raise NotImplementedError("Encoder self-attention and "
106+
"encoder/decoder cross-attention "
107+
"are not implemented for "
108+
"FlashMLAImpl")
109+
110+
def _forward_decode(
111+
self,
112+
q_nope: torch.Tensor,
113+
q_pe: torch.Tensor,
114+
kv_c_and_k_pe_cache: torch.Tensor,
115+
attn_metadata: FlashMLAMetadata,
116+
) -> torch.Tensor:
117+
assert kv_c_and_k_pe_cache.numel() > 0
118+
if self.kv_cache_dtype.startswith("fp8"):
119+
raise NotImplementedError("FP8 FlashMLA not yet supported")
120+
121+
q = torch.cat([q_nope, q_pe], dim=-1)\
122+
.unsqueeze(1) # Add seqlen dim of 1 (decode)
123+
124+
o, _ = flash_mla_with_kvcache(
125+
q=q,
126+
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
127+
block_table=attn_metadata.block_table[:attn_metadata.num_decodes,
128+
...],
129+
cache_seqlens=attn_metadata.seq_lens[:attn_metadata.
130+
num_decode_tokens],
131+
head_dim_v=self.kv_lora_rank,
132+
tile_scheduler_metadata=attn_metadata.
133+
decode_tile_scheduler_metadata,
134+
num_splits=attn_metadata.decode_num_splits,
135+
softmax_scale=self.scale,
136+
causal=True,
137+
)
138+
139+
return self._v_up_proj_and_o_proj(o)

0 commit comments

Comments
 (0)