1515from  einops  import  rearrange 
1616from  torch  import  nn 
1717
18- from  vllm  import  envs 
1918from  vllm .attention  import  AttentionMetadata 
2019from  vllm .config  import  CacheConfig , ModelConfig , get_current_vllm_config 
2120from  vllm .distributed .communication_op  import  tensor_model_parallel_all_reduce 
4241import  torch 
4342import  torch .distributed 
4443
45- from  vllm .model_executor .models .minimax_cache  import  MinimaxCacheParams 
46- 
4744
4845class  MiniMaxText01RMSNormTP (CustomOp ):
4946    name  =  "MiniMaxText01RMSNormTP" 
@@ -225,11 +222,10 @@ def __init__(
225222                                        self .tp_heads :(self .tp_rank  +  1 ) * 
226223                                        self .tp_heads ].contiguous ()
227224
228-         if  envs .VLLM_USE_V1 :
229-             compilation_config  =  get_current_vllm_config ().compilation_config 
230-             if  prefix  in  compilation_config .static_forward_context :
231-                 raise  ValueError (f"Duplicate layer name: { prefix }  " )
232-             compilation_config .static_forward_context [prefix ] =  self 
225+         compilation_config  =  get_current_vllm_config ().compilation_config 
226+         if  prefix  in  compilation_config .static_forward_context :
227+             raise  ValueError (f"Duplicate layer name: { prefix }  " )
228+         compilation_config .static_forward_context [prefix ] =  self 
233229
234230    @staticmethod  
235231    def  weight_direct_load (param : torch .Tensor ,
@@ -268,8 +264,7 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
268264                break 
269265            if  _prefill_idx  >=  len (state_indices_tensor ):
270266                break 
271-             # prefills are packed at end of batch in V1 
272-             offset  =  attn_metadata .num_decode_tokens  if  envs .VLLM_USE_V1  else  0 
267+             offset  =  attn_metadata .num_decode_tokens 
273268            _start  =  attn_metadata .query_start_loc [offset  +  _prefill_idx ]
274269            _end  =  attn_metadata .query_start_loc [offset  +  _prefill_idx  +  1 ]
275270            slot_id  =  state_indices_tensor [offset  +  _prefill_idx ]
@@ -291,10 +286,7 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
291286            hidden_decode  =  self ._decode_infer (q , k , v , kv_cache ,
292287                                               state_indices_tensor ,
293288                                               attn_metadata )
294-             if  envs .VLLM_USE_V1 :
295-                 hidden .insert (0 , hidden_decode )
296-             else :
297-                 hidden .append (hidden_decode )
289+             hidden .insert (0 , hidden_decode )
298290
299291        if  not  hidden :
300292            return  torch .empty ((0 , q .size (- 1 )), device = q .device , dtype = q .dtype )
@@ -304,40 +296,28 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
304296
305297    def  _decode_infer (self , q , k , v , kv_cache , state_indices_tensor ,
306298                      attn_metadata ):
307-         if  not  envs .VLLM_USE_V1 :
308-             q  =  q [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
309-             k  =  k [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
310-             v  =  v [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
311-             num_prefills  =  getattr (attn_metadata , "num_prefills" , 0 )
312-             slot_id  =  state_indices_tensor [num_prefills :]
313-         else :
314-             q  =  q [:attn_metadata .num_decode_tokens ].unsqueeze (2 ).contiguous ()
315-             k  =  k [:attn_metadata .num_decode_tokens ].unsqueeze (2 ).contiguous ()
316-             v  =  v [:attn_metadata .num_decode_tokens ].unsqueeze (2 ).contiguous ()
317-             slot_id  =  state_indices_tensor [:attn_metadata .num_decodes ]
299+         q  =  q [:attn_metadata .num_decode_tokens ].unsqueeze (2 ).contiguous ()
300+         k  =  k [:attn_metadata .num_decode_tokens ].unsqueeze (2 ).contiguous ()
301+         v  =  v [:attn_metadata .num_decode_tokens ].unsqueeze (2 ).contiguous ()
302+         slot_id  =  state_indices_tensor [:attn_metadata .num_decodes ]
318303        hidden  =  linear_decode_forward_triton (q , k , v , kv_cache , self .tp_slope ,
319304                                              slot_id , 32 )
320305        return  hidden 
321306
322307    def  forward (self , hidden_states : torch .Tensor , output : torch .Tensor ,
323-                 positions : torch .Tensor ,
324-                 kv_caches : MinimaxCacheParams ) ->  None :
325-         if  not  envs .VLLM_USE_V1 :
326-             self ._forward (hidden_states , output , positions , kv_caches )
327-         else :
328-             torch .ops .vllm .linear_attention (
329-                 hidden_states ,
330-                 output ,
331-                 positions ,
332-                 self .prefix ,
333-             )
308+                 positions : torch .Tensor ) ->  None :
309+         torch .ops .vllm .linear_attention (
310+             hidden_states ,
311+             output ,
312+             positions ,
313+             self .prefix ,
314+         )
334315
335316    def  _forward (self , hidden_states : torch .Tensor , output : torch .Tensor ,
336-                  positions : torch .Tensor ,
337-                  kv_caches : Optional [MinimaxCacheParams ]) ->  None :
317+                  positions : torch .Tensor ) ->  None :
338318        forward_context  =  get_forward_context ()
339319        attn_metadata : AttentionMetadata  =  forward_context .attn_metadata 
340-         if  envs . VLLM_USE_V1   and   attn_metadata  is  not   None :
320+         if  attn_metadata  is  not   None :
341321            assert  isinstance (attn_metadata , dict )
342322            attn_metadata  =  attn_metadata [self .prefix ]
343323            assert  isinstance (attn_metadata , LinearAttentionMetadata )
@@ -351,32 +331,26 @@ def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
351331        qkvact  =  torch .nn .functional .silu (qkv32 )
352332        qkvact  =  qkvact .view ((qkv .shape [0 ], self .tp_heads , - 1 ))
353333        q , k , v  =  torch .split (qkvact , [self .head_dim ] *  3 , dim = - 1 )
354-         if  envs .VLLM_USE_V1 :
355-             if  attn_metadata  is  not   None :
356-                 kv_cache  =  self .kv_cache [forward_context .virtual_engine ][0 ]
357-                 state_indices_tensor  =  attn_metadata .state_indices_tensor 
358- 
359-                 num_prefills  =  getattr (attn_metadata , "num_prefills" , 0 )
360-                 if  num_prefills  >  0 :
361-                     num_decode_tokens  =  getattr (attn_metadata ,
362-                                                 "num_decode_tokens" , 0 )
363-                     for  prefill_idx  in  range (num_prefills ):
364-                         q_start  =  attn_metadata .query_start_loc [
365-                             num_decode_tokens  +  prefill_idx ]
366-                         q_end  =  attn_metadata .query_start_loc [num_decode_tokens 
367-                                                               +  prefill_idx  + 
368-                                                               1 ]
369-                         query_len  =  q_end  -  q_start 
370-                         context_len  =  attn_metadata .seq_lens [
371-                             num_decode_tokens  +  prefill_idx ] -  query_len 
372-                         if  context_len  ==  0 :
373-                             block_to_clear  =  state_indices_tensor [
374-                                 num_decode_tokens  +  prefill_idx ]
375-                             kv_cache [block_to_clear , ...] =  0 
376-         else :
377-             assert  kv_caches  is  not   None 
378-             kv_cache  =  kv_caches .minimax_cache 
379-             state_indices_tensor  =  kv_caches .state_indices_tensor 
334+         if  attn_metadata  is  not   None :
335+             kv_cache  =  self .kv_cache [forward_context .virtual_engine ][0 ]
336+             state_indices_tensor  =  attn_metadata .state_indices_tensor 
337+ 
338+             num_prefills  =  getattr (attn_metadata , "num_prefills" , 0 )
339+             if  num_prefills  >  0 :
340+                 num_decode_tokens  =  getattr (attn_metadata , "num_decode_tokens" ,
341+                                             0 )
342+                 for  prefill_idx  in  range (num_prefills ):
343+                     q_start  =  attn_metadata .query_start_loc [num_decode_tokens  + 
344+                                                             prefill_idx ]
345+                     q_end  =  attn_metadata .query_start_loc [num_decode_tokens  + 
346+                                                           prefill_idx  +  1 ]
347+                     query_len  =  q_end  -  q_start 
348+                     context_len  =  attn_metadata .seq_lens [
349+                         num_decode_tokens  +  prefill_idx ] -  query_len 
350+                     if  context_len  ==  0 :
351+                         block_to_clear  =  state_indices_tensor [num_decode_tokens 
352+                                                               +  prefill_idx ]
353+                         kv_cache [block_to_clear , ...] =  0 
380354
381355        decode_only  =  getattr (attn_metadata , "num_prefills" , 0 ) ==  0 
382356        if  attn_metadata  is  None :
@@ -410,8 +384,7 @@ def linear_attention(
410384    self  =  forward_context .no_compile_layers [layer_name ]
411385    self ._forward (hidden_states = hidden_states ,
412386                  output = output ,
413-                   positions = positions ,
414-                   kv_caches = None )
387+                   positions = positions )
415388
416389
417390def  linear_attention_fake (
0 commit comments