Skip to content

Commit 8bf1e11

Browse files
committed
[Performance] Remove input pads in cutlass_mla and optimize v_proj output reshape
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
1 parent 8db2939 commit 8bf1e11

File tree

2 files changed

+55
-20
lines changed

2 files changed

+55
-20
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,7 @@ def __init__(
942942
qk_head_dim: int,
943943
v_head_dim: int,
944944
kv_b_proj: ColumnParallelLinear,
945+
q_pad_num_heads: Optional[int] = None,
945946
) -> None:
946947
if kv_sharing_target_layer_name is not None:
947948
raise NotImplementedError("KV sharing is not supported for MLA")
@@ -959,6 +960,7 @@ def __init__(
959960
self.qk_head_dim = qk_head_dim
960961
self.v_head_dim = v_head_dim
961962
self.kv_b_proj = kv_b_proj
963+
self.q_pad_num_heads = q_pad_num_heads
962964

963965
if use_flashinfer_prefill():
964966
logger.debug_once("Using FlashInfer prefill for MLA")
@@ -1134,7 +1136,7 @@ def _run_prefill_context_chunk_cudnn(self,
11341136
True, #Indicates actual_seq_lens are on GPU or CPU.
11351137
)
11361138

1137-
def _v_up_proj(self, x):
1139+
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
11381140
# Convert from (B, N, L) to (N, B, L)
11391141
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
11401142
if is_rocm_aiter_fp8bmm_enabled():
@@ -1146,12 +1148,23 @@ def _v_up_proj(self, x):
11461148
transpose_bm=True)
11471149
# Convert from (B, N, V) to (B, N * V)
11481150
x = x.reshape(-1, self.num_heads * self.v_head_dim)
1151+
# Copy result
1152+
out.copy_(x)
11491153
else:
1154+
# Convert from (B, N * V) to (N, B, V)
1155+
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
1156+
11501157
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
1151-
x = torch.bmm(x, self.W_UV)
1158+
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
1159+
11521160
# Convert from (N, B, V) to (B, N * V)
1153-
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
1154-
return x
1161+
out_new = out.transpose(0, 1).reshape(
1162+
-1, self.num_heads * self.v_head_dim)
1163+
1164+
# Adjust output buffer shape back to the original (B, N * V)
1165+
N, B, V = out.shape
1166+
out.resize_((B, N * V))
1167+
out.copy_(out_new) # Copy result
11551168

11561169
def process_weights_after_loading(self, act_dtype: torch.dtype):
11571170

@@ -1559,6 +1572,15 @@ def forward(
15591572
# Convert from (B, N, P) to (N, B, P)
15601573
decode_q_nope = decode_q_nope.transpose(0, 1)
15611574

1575+
# Pads the head_dim if necessary (for the underlying kernel)
1576+
if self.q_pad_num_heads is not None:
1577+
B, N, L = decode_q_pe.shape
1578+
decode_pe_padded = decode_q_pe.new_empty(
1579+
(B, self.q_pad_num_heads, L))
1580+
decode_pe_padded.resize_((B, N, L))
1581+
decode_pe_padded.copy_(decode_q_pe)
1582+
decode_q_pe = decode_pe_padded
1583+
15621584
if is_rocm_aiter_fp8bmm_enabled():
15631585
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
15641586
decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope,
@@ -1567,8 +1589,19 @@ def forward(
15671589
group_size=128,
15681590
transpose_bm=True)
15691591
else:
1592+
# Pads the head_dim if necessary (for the underlying kernel)
1593+
N, B, P = decode_q_nope.shape
1594+
_, _, L = self.W_UK_T.shape
1595+
if self.q_pad_num_heads is not None:
1596+
decode_ql_nope = decode_q_nope.new_empty(
1597+
(self.q_pad_num_heads, B, L))
1598+
decode_ql_nope.resize_((N, B, L))
1599+
1600+
else:
1601+
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
1602+
15701603
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
1571-
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
1604+
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
15721605
# Convert from (N, B, L) to (B, N, L)
15731606
decode_ql_nope = decode_ql_nope.transpose(0, 1)
15741607

@@ -1603,5 +1636,5 @@ def forward(
16031636
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
16041637

16051638
# v_up projection
1606-
output[:num_decode_tokens] = self._v_up_proj(attn_out)
1639+
self._v_up_proj(attn_out, out=output[:num_decode_tokens])
16071640
return output_padded

vllm/v1/attention/backends/mla/cutlass_mla.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def ensure_size(self, attn_metadata: MLACommonMetadata,
7474

7575
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
7676

77+
MAX_HEADS = 128
78+
7779

7880
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
7981
can_return_lse_for_decode: bool = True
@@ -92,10 +94,18 @@ def __init__(
9294
kv_sharing_target_layer_name: Optional[str],
9395
# MLA Specific Arguments
9496
**mla_args) -> None:
95-
super().__init__(num_heads, head_size, scale, num_kv_heads,
96-
alibi_slopes, sliding_window, kv_cache_dtype,
97-
logits_soft_cap, attn_type,
98-
kv_sharing_target_layer_name, **mla_args)
97+
super().__init__(num_heads,
98+
head_size,
99+
scale,
100+
num_kv_heads,
101+
alibi_slopes,
102+
sliding_window,
103+
kv_cache_dtype,
104+
logits_soft_cap,
105+
attn_type,
106+
kv_sharing_target_layer_name,
107+
q_pad_num_heads=MAX_HEADS,
108+
**mla_args)
99109

100110
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
101111
if any(unsupported_features):
@@ -157,14 +167,6 @@ def _sm100_cutlass_mla_decode(
157167

158168
MAX_HEADS = 128
159169
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
160-
if H < MAX_HEADS:
161-
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
162-
q_nope_padded[:, :H] = q_nope
163-
q_nope = q_nope_padded
164-
165-
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
166-
q_pe_padded[:, :H] = q_pe
167-
q_pe = q_pe_padded
168170

169171
assert len(page_table.shape) == 2
170172
B_block_table, block_num = page_table.shape
@@ -206,9 +208,9 @@ def _sm100_cutlass_mla_decode(
206208
)
207209

208210
if H < MAX_HEADS:
211+
# Extract the subsets of the outputs
212+
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
209213
out = out[:, :H]
210-
if self.need_to_return_lse_for_decode:
211-
lse = lse[:, :H].contiguous()
212214

213215
return out, lse
214216

0 commit comments

Comments
 (0)