@@ -942,6 +942,7 @@ def __init__(
942942 qk_head_dim : int ,
943943 v_head_dim : int ,
944944 kv_b_proj : ColumnParallelLinear ,
945+ q_pad_num_heads : Optional [int ] = None ,
945946 ) -> None :
946947 if kv_sharing_target_layer_name is not None :
947948 raise NotImplementedError ("KV sharing is not supported for MLA" )
@@ -959,6 +960,7 @@ def __init__(
959960 self .qk_head_dim = qk_head_dim
960961 self .v_head_dim = v_head_dim
961962 self .kv_b_proj = kv_b_proj
963+ self .q_pad_num_heads = q_pad_num_heads
962964
963965 if use_flashinfer_prefill ():
964966 logger .debug_once ("Using FlashInfer prefill for MLA" )
@@ -1134,7 +1136,7 @@ def _run_prefill_context_chunk_cudnn(self,
11341136 True , #Indicates actual_seq_lens are on GPU or CPU.
11351137 )
11361138
1137- def _v_up_proj (self , x ):
1139+ def _v_up_proj (self , x : torch . Tensor , out : torch . Tensor ):
11381140 # Convert from (B, N, L) to (N, B, L)
11391141 x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
11401142 if is_rocm_aiter_fp8bmm_enabled ():
@@ -1146,12 +1148,23 @@ def _v_up_proj(self, x):
11461148 transpose_bm = True )
11471149 # Convert from (B, N, V) to (B, N * V)
11481150 x = x .reshape (- 1 , self .num_heads * self .v_head_dim )
1151+ # Copy result
1152+ out .copy_ (x )
11491153 else :
1154+ # Convert from (B, N * V) to (N, B, V)
1155+ out = out .view (- 1 , self .num_heads , self .v_head_dim ).transpose (0 , 1 )
1156+
11501157 # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
1151- x = torch .bmm (x , self .W_UV )
1158+ torch .bmm (x , self .W_UV , out = out ) # Reuse "out" to make it "hot"
1159+
11521160 # Convert from (N, B, V) to (B, N * V)
1153- x = x .transpose (0 , 1 ).reshape (- 1 , self .num_heads * self .v_head_dim )
1154- return x
1161+ out_new = out .transpose (0 , 1 ).reshape (
1162+ - 1 , self .num_heads * self .v_head_dim )
1163+
1164+ # Adjust output buffer shape back to the original (B, N * V)
1165+ N , B , V = out .shape
1166+ out .resize_ ((B , N * V ))
1167+ out .copy_ (out_new ) # Copy result
11551168
11561169 def process_weights_after_loading (self , act_dtype : torch .dtype ):
11571170
@@ -1559,6 +1572,15 @@ def forward(
15591572 # Convert from (B, N, P) to (N, B, P)
15601573 decode_q_nope = decode_q_nope .transpose (0 , 1 )
15611574
1575+ # Pads the head_dim if necessary (for the underlying kernel)
1576+ if self .q_pad_num_heads is not None :
1577+ B , N , L = decode_q_pe .shape
1578+ decode_pe_padded = decode_q_pe .new_empty (
1579+ (B , self .q_pad_num_heads , L ))
1580+ decode_pe_padded .resize_ ((B , N , L ))
1581+ decode_pe_padded .copy_ (decode_q_pe )
1582+ decode_q_pe = decode_pe_padded
1583+
15621584 if is_rocm_aiter_fp8bmm_enabled ():
15631585 # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
15641586 decode_ql_nope = aiter_triton_fp8_bmm (decode_q_nope ,
@@ -1567,8 +1589,19 @@ def forward(
15671589 group_size = 128 ,
15681590 transpose_bm = True )
15691591 else :
1592+ # Pads the head_dim if necessary (for the underlying kernel)
1593+ N , B , P = decode_q_nope .shape
1594+ _ , _ , L = self .W_UK_T .shape
1595+ if self .q_pad_num_heads is not None :
1596+ decode_ql_nope = decode_q_nope .new_empty (
1597+ (self .q_pad_num_heads , B , L ))
1598+ decode_ql_nope .resize_ ((N , B , L ))
1599+
1600+ else :
1601+ decode_ql_nope = decode_q_nope .new_empty ((N , B , L ))
1602+
15701603 # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
1571- decode_ql_nope = torch .bmm (decode_q_nope , self .W_UK_T )
1604+ torch .bmm (decode_q_nope , self .W_UK_T , out = decode_ql_nope )
15721605 # Convert from (N, B, L) to (B, N, L)
15731606 decode_ql_nope = decode_ql_nope .transpose (0 , 1 )
15741607
@@ -1603,5 +1636,5 @@ def forward(
16031636 attn_out = cp_lse_ag_out_rs (attn_out , lse , get_dcp_group ())
16041637
16051638 # v_up projection
1606- output [: num_decode_tokens ] = self ._v_up_proj (attn_out )
1639+ self ._v_up_proj (attn_out , out = output [: num_decode_tokens ] )
16071640 return output_padded
0 commit comments