Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
998803e
[Attention] MLA support for V1
chenyang78 Feb 24, 2025
12a5221
torch library bindings, unit tests running
LucasWilkinson Feb 24, 2025
3807602
comments
LucasWilkinson Feb 24, 2025
955cead
working in eager mode
LucasWilkinson Feb 24, 2025
1d5c868
format
LucasWilkinson Feb 24, 2025
eae4787
cuda-graphs still broken but closer i think
LucasWilkinson Feb 24, 2025
c79927d
better comments
LucasWilkinson Feb 24, 2025
37c4f9e
remove extra files
LucasWilkinson Feb 24, 2025
084b031
add attribution
LucasWilkinson Feb 24, 2025
07a9bad
fix cuda graphs
LucasWilkinson Feb 24, 2025
4dc8c35
cleaner build fallbacks
LucasWilkinson Feb 24, 2025
a6c36cc
ok cuda-graphs actually fixed now I think
LucasWilkinson Feb 24, 2025
3ae4a6e
format
LucasWilkinson Feb 25, 2025
a6213a4
fix deepseek-v2
LucasWilkinson Feb 25, 2025
68895a2
Merge branch 'lwilkinson/fix-deepseek-v2' into lwilkinson/flashmla-in…
LucasWilkinson Feb 25, 2025
5e7cd97
clean up
LucasWilkinson Feb 25, 2025
8bb3bdc
Merge remote-tracking branch 'origin/main' into lwilkinson/flashmla-i…
LucasWilkinson Feb 25, 2025
d439969
review comment
LucasWilkinson Feb 25, 2025
aa42226
fix mypy
LucasWilkinson Feb 25, 2025
d18261c
review comments
LucasWilkinson Feb 25, 2025
4c08a0a
cleanup
LucasWilkinson Feb 25, 2025
07332bf
fix bad logic
LucasWilkinson Feb 25, 2025
c4434d9
review comments
LucasWilkinson Feb 25, 2025
f570fe0
update to latest flashMLA which supports fp16
LucasWilkinson Feb 25, 2025
0bbcf27
update to use fork
LucasWilkinson Feb 25, 2025
177ee29
remove unnessary include
LucasWilkinson Feb 25, 2025
642456f
add fp16 source
LucasWilkinson Feb 25, 2025
2fa62a9
missing symbol
LucasWilkinson Feb 25, 2025
4b7ef4d
[Attention] MLA support for V1
chenyang78 Feb 24, 2025
0ae026a
Merge remote-tracking branch 'yang/mla-v1' into lwilkinson/flash-mla-v1
LucasWilkinson Feb 25, 2025
23c780f
address review feedback
chenyang78 Feb 25, 2025
29c06c7
restore to use attn_module.head_size
chenyang78 Feb 26, 2025
5f8526b
wip v1 FlashMLA
LucasWilkinson Feb 26, 2025
f955164
Merge remote-tracking branch 'yang/mla-v1' into lwilkinson/flash-mla-v1
LucasWilkinson Feb 26, 2025
04c8db4
[Attention] MLA support for V1
chenyang78 Feb 24, 2025
a456e05
address review feedback
chenyang78 Feb 25, 2025
867d2ed
restore to use attn_module.head_size
chenyang78 Feb 26, 2025
8715cfb
included more fixes from Lucas
chenyang78 Feb 26, 2025
6bf7bfb
addressed feedback from Woosuk Kwon
chenyang78 Feb 27, 2025
dab8ad6
Merge remote-tracking branch 'yang/mla-v1' into lwilkinson/flash-mla-v1
LucasWilkinson Feb 27, 2025
67b2b62
Merge remote-tracking branch 'origin/main' into lwilkinson/flash-mla-v1
LucasWilkinson Feb 27, 2025
e6e5789
cleanup
LucasWilkinson Feb 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,9 @@ def get_current_memory_usage(cls,
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
if use_v1:
if use_mla:
logger.info("Using Triton MLA backend on V1 engine.")
return "vllm.v1.attention.backends.triton_mla.TritonMLABackend"
else:
logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends.flash_attn."
"FlashAttentionBackend")
if use_mla:
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
if selected_backend == _Backend.FLASHMLA:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
Expand All @@ -183,11 +177,26 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
" (currently only supports block size 64).",
block_size)
else:
logger.info("Using FlashMLA backend.")
return "vllm.attention.backends.flashmla.FlashMLABackend"

logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
if use_v1:
logger.info("Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else:
logger.info("Using FlashMLA backend.")
return ("vllm.attention.backends."
"flashmla.FlashMLABackend")

if use_v1:
logger.info("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
else:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
if use_v1:
logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends.flash_attn."
"FlashAttentionBackend")
if selected_backend == _Backend.FLASHINFER:
logger.info("Using FlashInfer backend.")
return "vllm.attention.backends.flashinfer.FlashInferBackend"
Expand Down
5 changes: 2 additions & 3 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ class _Backend(enum.Enum):
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
TRITON_MLA = enum.auto()
TRITON_MLA_VLLM_V1 = enum.auto()
FLASHMLA = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1
FLASHMLA = enum.auto() # Supported by V1
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
Expand Down
11 changes: 7 additions & 4 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,16 @@ def __post_init__(self):
T = TypeVar("T", bound=MLACommonMetadata)


class MLACommonMetadataBuilder:
class MLACommonMetadataBuilder(Generic[T]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""

def __init__(self, runner: "GPUModelRunner"):
def __init__(self,
runner: "GPUModelRunner",
cls: Optional[type[T]] = None):
self.cls = cls if cls is not None else MLACommonMetadata
self.runner = runner
scheduler_config = runner.scheduler_config
model_config = runner.model_config
Expand Down Expand Up @@ -431,7 +434,7 @@ def reorder_batch(self, input_batch: "InputBatch",
self._num_prefill_tokens = num_prefill_tokens

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
common_prefix_len: int) -> T:
device = self.runner.device
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
Expand Down Expand Up @@ -502,7 +505,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
assert max(context_chunk_seq_tot) <= \
self.chunked_prefill_workspace_size

return MLACommonMetadata(
return self.cls(
input_positions=input_positions,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
Expand Down
139 changes: 139 additions & 0 deletions vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch

from vllm.attention.backends.abstract import AttentionType
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)

logger = init_logger(__name__)


class FlashMLABackend(MLACommonBackend):

@staticmethod
def get_name() -> str:
return "FLASHMLA_VLLM_V1"

@staticmethod
def get_metadata_cls() -> Type["FlashMLAMetadata"]:
return FlashMLAMetadata

@staticmethod
def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder

@staticmethod
def get_impl_cls() -> Type["FlashMLAImpl"]:
return FlashMLAImpl


@dataclass
class FlashMLAMetadata(MLACommonMetadata):
decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor,
torch.Tensor]] = None
decode_num_splits: Optional[torch.Tensor] = None


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):

def __init__(self, runner):
super().__init__(runner, cls=FlashMLAMetadata)

self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config)

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
m = super().build(num_reqs, num_actual_tokens, max_query_len,
common_prefix_len)

if m.num_decode_tokens is not None and m.num_decode_tokens > 0:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_metadata(
m.seq_lens[:m.num_decode_tokens],
self.num_q_heads,
1, # MQA for the decode path
)

return m


class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)

assert is_flashmla_supported(), \
"FlashMLA is not supported on this device"

unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl")

def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 FlashMLA not yet supported")

q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)

o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.block_table[:attn_metadata.num_decodes,
...],
cache_seqlens=attn_metadata.seq_lens[:attn_metadata.
num_decode_tokens],
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.
decode_tile_scheduler_metadata,
num_splits=attn_metadata.decode_num_splits,
softmax_scale=self.scale,
causal=True,
)

return self._v_up_proj_and_o_proj(o)