@@ -48,17 +48,19 @@ def _vllm_layout_trans_kernel(
4848    ):
4949        batch_idx  =  tl .program_id (0 )
5050        block_idx  =  tl .program_id (1 )
51-         batch_token_indexes  =  tl .load (b_seq_lens_loc  +  batch_idx  + 
52-                                       tl .arange (0 , 2 ))
53-         batch_token_start , batch_token_end  =  tl .split (batch_token_indexes )
54-         seq_len  =  batch_token_end  -  batch_token_start 
5551
5652        batch_query_indexes  =  tl .load (b_query_lens_loc  +  batch_idx  + 
5753                                      tl .arange (0 , 2 ))
5854        batch_query_start , batch_query_end  =  tl .split (batch_query_indexes )
5955        query_len  =  batch_query_end  -  batch_query_start 
6056        if  query_len  <=  1 :
6157            return 
58+ 
59+         batch_token_indexes  =  tl .load (b_seq_lens_loc  +  batch_idx  + 
60+                                       tl .arange (0 , 2 ))
61+         batch_token_start , batch_token_end  =  tl .split (batch_token_indexes )
62+         seq_len  =  batch_token_end  -  batch_token_start 
63+ 
6264        if  block_idx  *  BLOCK_SIZE  <  seq_len :
6365            block_mask  =  (block_idx  *  BLOCK_SIZE  + 
6466                          tl .arange (0 , BLOCK_SIZE )[:, None ]) <  seq_len 
@@ -269,12 +271,13 @@ def build(self, common_prefix_len: int,
269271        max_query_len  =  common_attn_metadata .max_query_len 
270272
271273        max_seq_len  =  int (self .runner .seq_lens_np [:num_reqs ].max ())
272-         total_tokens  =  int (self .runner .seq_lens_np [:num_reqs ].sum ())
273274        query_start_loc  =  common_attn_metadata .query_start_loc 
274275        seq_lens  =  common_attn_metadata .seq_lens 
275276        block_table  =  self .block_table 
276277        block_table_tensor  =  block_table .get_device_tensor ()[:num_reqs ]
277- 
278+         query_lens  =  query_start_loc [1 :] -  query_start_loc [:- 1 ]
279+         masked_seq_lens  =  torch .where (query_lens  >  1 , seq_lens ,
280+                                       torch .zeros_like (seq_lens ))
278281        block_table .slot_mapping [:num_actual_tokens ].copy_ (
279282            block_table .slot_mapping_cpu [:num_actual_tokens ],
280283            non_blocking = True )
@@ -284,10 +287,10 @@ def build(self, common_prefix_len: int,
284287
285288        slot_mapping  =  block_table .slot_mapping [:num_actual_tokens ]
286289
287-         cu_seq_lens  =  torch .zeros (seq_lens .shape [0 ] +  1 ,
290+         cu_seq_lens  =  torch .zeros (masked_seq_lens .shape [0 ] +  1 ,
288291                                  dtype = torch .int32 ,
289292                                  device = "cuda" )
290-         torch .cumsum (seq_lens ,
293+         torch .cumsum (masked_seq_lens ,
291294                     dim = 0 ,
292295                     dtype = cu_seq_lens .dtype ,
293296                     out = cu_seq_lens [1 :])
@@ -356,14 +359,14 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
356359            dtype = torch .uint8 ,
357360            device = self .runner .device ,
358361        )
359- 
362+          masked_total_tokens   =   cu_seq_lens [ - 1 ]. item () 
360363        k_buffer  =  torch .empty (
361-             (total_tokens , self .num_heads_kv , self .headdim ),
364+             (masked_total_tokens , self .num_heads_kv , self .headdim ),
362365            dtype = self .runner .dtype ,
363366            device = self .runner .device ,
364367        )
365368        v_buffer  =  torch .empty (
366-             (total_tokens , self .num_heads_kv , self .headdim ),
369+             (masked_total_tokens , self .num_heads_kv , self .headdim ),
367370            dtype = self .runner .dtype ,
368371            device = self .runner .device ,
369372        )
0 commit comments