Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 36 additions & 21 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,9 @@
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
kv_c: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: T,
) -> torch.Tensor:
raise NotImplementedError
Expand Down Expand Up @@ -442,30 +443,29 @@
is_decode = attn_metadata.decode_metadata is not None
is_prefill = attn_metadata.prefill_metadata is not None

if (is_decode and is_prefill):
raise NotImplementedError(
"chunked prefill is not supported for MLAImplBase")

# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions")

num_prefill_tokens: int = attn_metadata.num_prefill_tokens

if is_decode:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
decode_q_nope = self._q_proj_and_k_up_proj(
hidden_states_or_q_c[num_prefill_tokens:])
decode_q_pe = torch.matmul(hidden_states_or_q_c[num_prefill_tokens:], self.W_QR)\

Check failure on line 455 in vllm/attention/backends/mla/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/backends/mla/utils.py:455:81: E501 Line too long (93 > 80)
.view(-1, self.num_heads, self.qk_rope_head_dim)
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
k_pe)
else:
assert is_prefill
q = self.q_proj(hidden_states_or_q_c)[0]\
decode_q_pe, k_pe[num_prefill_tokens:] = \
self.rotary_emb(attn_metadata.input_positions[num_prefill_tokens:],
decode_q_pe, k_pe[num_prefill_tokens:])
if is_prefill:
prefill_q = self.q_proj(hidden_states_or_q_c[:num_prefill_tokens])[0]\

Check failure on line 461 in vllm/attention/backends/mla/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/backends/mla/utils.py:461:81: E501 Line too long (82 > 80)
.view(-1, self.num_heads, self.qk_head_dim)

# TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \
prefill_q[..., self.qk_nope_head_dim:], k_pe[:num_prefill_tokens] = \

Check failure on line 465 in vllm/attention/backends/mla/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/backends/mla/utils.py:465:81: E501 Line too long (81 > 80)
self.rotary_emb(
attn_metadata.input_positions,
q[..., self.qk_nope_head_dim:], k_pe)
attn_metadata.input_positions[:num_prefill_tokens],
prefill_q[..., self.qk_nope_head_dim:], k_pe[:num_prefill_tokens])

Check failure on line 468 in vllm/attention/backends/mla/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/backends/mla/utils.py:468:81: E501 Line too long (86 > 80)

# write the latent and rope to kv cache
if kv_cache.numel() > 0:
Expand All @@ -477,12 +477,25 @@
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
output = torch.empty(attn_metadata.num_prefill_tokens +
attn_metadata.num_decode_tokens,
self.o_proj.output_size,
device=hidden_states_or_q_c.device,
dtype=hidden_states_or_q_c.dtype)
# output shape: [2048, 16, 512]

if is_prefill:
# forward prefill output shape: [2048, 7168]
output[:num_prefill_tokens] = self._forward_prefill(
prefill_q, k_c_normed[:num_prefill_tokens].contiguous(),
k_pe[:num_prefill_tokens].contiguous(), kv_cache,
attn_metadata)

if attn_metadata.prefill_metadata is not None:
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata)
if is_decode:
output[num_prefill_tokens:] = self._forward_decode(
decode_q_nope, decode_q_pe, kv_cache, attn_metadata)

if attn_metadata.decode_metadata is not None:
return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata)
return output

# Optional common flash-attn based prefill
def _forward_prefill_flash(
Expand All @@ -492,6 +505,8 @@
k_pe: torch.Tensor,
seq_start_loc: torch.Tensor,
max_prefill_seq_len: int,
query_start_loc: torch.Tensor,
max_query_len: int,
) -> torch.Tensor:

kv_nope = self.kv_b_proj(k_c_normed)[0]\
Expand All @@ -510,9 +525,9 @@
q=q,
k=k,
v=v_padded,
cu_seqlens_q=seq_start_loc,
cu_seqlens_q=query_start_loc,
cu_seqlens_k=seq_start_loc,
max_seqlen_q=max_prefill_seq_len,
max_seqlen_q=max_query_len,
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
Expand Down
123 changes: 119 additions & 4 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import triton
import triton.language as tl

from vllm.multimodal import MultiModalPlaceholderMap

try:
Expand Down Expand Up @@ -647,6 +650,63 @@ def build(self, seq_lens: List[int], query_lens: List[int],
)


@triton.jit
def _gather_kv_cache(
# Pointers to inputs and output
seq_start_locs, # (batch_size + 1,)
block_tables, # (batch_size, max_blocks_per_seq)
block_table_stride,
kv_cache, # (num_blocks, block_size, head_size)
kv_out,
CACHE_PAGE_SIZE: tl.constexpr,
CACHE_ENTRY_SIZE: tl.constexpr,
CACHE_ENTRIES_PER_PAGE: tl.constexpr,
CACHE_PAGE_SIZE_POW_2: tl.constexpr,
CACHE_ENTRY_SIZE_POW_2: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""

# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)

seq_start_loc = tl.load(seq_start_locs + g_id)
seq_len = tl.load(seq_start_locs + g_id + 1) - seq_start_loc

pages_to_copy = tl.cdiv(seq_len, CACHE_ENTRIES_PER_PAGE)
kv_out = kv_out + seq_start_loc * CACHE_ENTRY_SIZE
block_table = block_tables + g_id * block_table_stride

cache_page_range = tl.arange(0, CACHE_PAGE_SIZE_POW_2)
cache_page_mask = cache_page_range < CACHE_PAGE_SIZE
for i in range(pages_to_copy - 1):
page = tl.load(block_table + i)
page_start = kv_cache + page * CACHE_PAGE_SIZE
page_data = tl.load(page_start + cache_page_range,
mask=cache_page_mask)
tl.store(kv_out + i * CACHE_PAGE_SIZE + cache_page_range,
page_data,
mask=cache_page_mask)

last_page_len = (seq_len + CACHE_ENTRIES_PER_PAGE -
1) % CACHE_ENTRIES_PER_PAGE + 1
last_page = tl.load(block_table + pages_to_copy - 1)
last_page_start = kv_cache + last_page * CACHE_PAGE_SIZE

cache_entry_range = tl.arange(0, CACHE_ENTRY_SIZE_POW_2)
cache_entry_mask = cache_entry_range < CACHE_ENTRY_SIZE
kv_out_page = kv_out + (pages_to_copy - 1) * CACHE_PAGE_SIZE
for i in range(last_page_len):
last_page_data = tl.load(last_page_start + \
i * CACHE_ENTRY_SIZE + cache_entry_range,
mask=cache_entry_mask)
tl.store(kv_out_page + i * CACHE_ENTRY_SIZE + cache_entry_range,
last_page_data,
mask=cache_entry_mask)


class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):

def __init__(
Expand Down Expand Up @@ -686,14 +746,69 @@ def __init__(
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
kv_c: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: TritonMLAMetadata,
) -> torch.Tensor:
assert isinstance(attn_metadata, TritonMLAMetadata)
return self._forward_prefill_flash(q, kv_c_normed, k_pe,
attn_metadata.seq_start_loc,
attn_metadata.max_prefill_seq_len)

prefill_meta = attn_metadata.prefill_metadata
assert prefill_meta is not None

if kv_c_and_k_pe_cache.numel() > 0 and \
prefill_meta.block_tables is not None and \
prefill_meta.block_tables.numel() > 0:
assert prefill_meta.seq_start_loc is not None
assert prefill_meta.max_query_len is not None

entries_total = prefill_meta.seq_start_loc[-1]
kv_c_k_pe_cache = torch.empty_strided(
(entries_total, kv_c_and_k_pe_cache.shape[-1]),
(kv_c_and_k_pe_cache.stride(1), 1),
dtype=kv_c_and_k_pe_cache.dtype,
device=kv_c_and_k_pe_cache.device,
)

assert kv_c_and_k_pe_cache.shape[-1] == 576
assert kv_c_and_k_pe_cache.shape[-2] == 16
_gather_kv_cache[(attn_metadata.num_prefills, )](
prefill_meta.seq_start_loc,
prefill_meta.block_tables,
prefill_meta.block_tables.stride(0),
kv_c_and_k_pe_cache,
kv_c_k_pe_cache,
CACHE_PAGE_SIZE=kv_c_and_k_pe_cache.stride(0),
CACHE_ENTRY_SIZE=kv_c_and_k_pe_cache.stride(1),
CACHE_ENTRIES_PER_PAGE=kv_c_and_k_pe_cache.shape[1],
CACHE_ENTRY_SIZE_POW_2=triton.next_power_of_2(
kv_c_and_k_pe_cache.stride(1)),
CACHE_PAGE_SIZE_POW_2=triton.next_power_of_2(
kv_c_and_k_pe_cache.stride(0)),
)

kv_c = kv_c_k_pe_cache[..., :self.kv_lora_rank].unsqueeze(
1).contiguous()
k_pe = kv_c_k_pe_cache[..., self.kv_lora_rank:].unsqueeze(
1).contiguous()

return self._forward_prefill_flash(
q,
kv_c,
k_pe,
seq_start_loc=prefill_meta.seq_start_loc,
max_prefill_seq_len=prefill_meta.max_prefill_seq_len,
query_start_loc=prefill_meta.query_start_loc,
max_query_len=prefill_meta.max_query_len)
else:
return self._forward_prefill_flash(
q,
kv_c,
k_pe,
seq_start_loc=prefill_meta.seq_start_loc,
max_prefill_seq_len=prefill_meta.max_prefill_seq_len,
query_start_loc=prefill_meta.seq_start_loc,
max_query_len=prefill_meta.max_prefill_seq_len)

def _forward_decode(
self,
Expand Down
10 changes: 0 additions & 10 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3264,16 +3264,6 @@ def __post_init__(self):

current_platform.check_and_update_config(self)

# If MLA is enabled, force disable chunked prefill and prefix caching
if self.model_config and self.model_config.use_mla:
logger.info("MLA is enabled; forcing chunked prefill and prefix "
"caching to be disabled.")
self.scheduler_config.enable_chunked_prefill = False
self.scheduler_config.chunked_prefill_enabled = False

if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False

if not self.instance_id:
self.instance_id = random_uuid()[:5]

Expand Down
Loading