1414from torch import nn
1515from transformers .configuration_utils import PretrainedConfig
1616
17+ from vllm import envs
1718from vllm .attention import Attention , AttentionMetadata
18- from vllm .config import CacheConfig , VllmConfig
19+ from vllm .config import CacheConfig , VllmConfig , get_current_vllm_config
1920from vllm .distributed .communication_op import tensor_model_parallel_all_reduce
2021from vllm .distributed .parallel_state import (
2122 get_pp_group , get_tensor_model_parallel_rank ,
3334 ReplicatedLinear ,
3435 RowParallelLinear )
3536from vllm .model_executor .layers .logits_processor import LogitsProcessor
37+ from vllm .model_executor .layers .mamba .abstract import MambaBase
38+ from vllm .model_executor .layers .mamba .mamba_utils import (
39+ MambaStateShapeCalculator )
3640from vllm .model_executor .layers .quantization .base_config import (
3741 QuantizationConfig )
3842from vllm .model_executor .layers .vocab_parallel_embedding import (
4145from vllm .model_executor .models .utils import maybe_prefix
4246from vllm .model_executor .sampling_metadata import SamplingMetadata
4347from vllm .sequence import IntermediateTensors
48+ from vllm .v1 .attention .backends .linear_attn import LinearAttentionMetadata
4449
45- from .interfaces import HasInnerState , IsHybrid , SupportsV0Only
50+ from .interfaces import HasInnerState , IsHybrid
4651from .minimax_cache import MinimaxCacheManager , MinimaxCacheParams
4752from .utils import PPMissingLayer , is_pp_missing_parameter , make_layers
4853
@@ -327,7 +332,17 @@ def jit_linear_forward_prefix(q: torch.Tensor,
327332 return rearrange (output .squeeze (0 ), "h n d -> n (h d)" )
328333
329334
330- class MiniMaxText01LinearAttention (nn .Module ):
335+ class MiniMaxText01LinearAttention (nn .Module , MambaBase ):
336+
337+ @property
338+ def mamba_type (self ) -> str :
339+ return "linear_attention"
340+
341+ def get_state_shape (self ) -> tuple [tuple [int , ...], tuple [int , ...]]:
342+ return MambaStateShapeCalculator .linear_attention_state_shape (
343+ num_heads = self .num_heads ,
344+ tp_size = self .tp_size ,
345+ head_dim = self .head_dim )
331346
332347 def __init__ (
333348 self ,
@@ -359,6 +374,7 @@ def __init__(
359374 self .tp_heads = self .total_num_heads // self .tp_size
360375 self .qkv_size = self .num_heads * self .head_dim
361376 self .tp_hidden = self .head_dim * self .tp_heads
377+ self .prefix = prefix
362378
363379 self .qkv_proj = ColumnParallelLinear (
364380 hidden_size ,
@@ -397,6 +413,12 @@ def __init__(
397413 self .tp_heads :(self .tp_rank + 1 ) *
398414 self .tp_heads ].contiguous ()
399415
416+ if envs .VLLM_USE_V1 :
417+ compilation_config = get_current_vllm_config ().compilation_config
418+ if prefix in compilation_config .static_forward_context :
419+ raise ValueError (f"Duplicate layer name: { prefix } " )
420+ compilation_config .static_forward_context [prefix ] = self
421+
400422 @staticmethod
401423 def weight_direct_load (param : torch .Tensor ,
402424 loaded_weight : torch .Tensor ) -> None :
@@ -434,13 +456,14 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
434456 break
435457 if _prefill_idx >= len (state_indices_tensor ):
436458 break
437- _start = attn_metadata .query_start_loc [_prefill_idx ]
438- _end = attn_metadata .query_start_loc [_prefill_idx + 1 ]
439- slot_id = state_indices_tensor [_prefill_idx ]
459+ # prefills are packed at end of batch in V1
460+ offset = attn_metadata .num_decode_tokens if envs .VLLM_USE_V1 else 0
461+ _start = attn_metadata .query_start_loc [offset + _prefill_idx ]
462+ _end = attn_metadata .query_start_loc [offset + _prefill_idx + 1 ]
463+ slot_id = state_indices_tensor [offset + _prefill_idx ]
440464 qs = q [_start :_end ].transpose (0 , 1 ).contiguous ()
441465 ks = k [_start :_end ].transpose (0 , 1 ).contiguous ()
442466 vs = v [_start :_end ].transpose (0 , 1 ).contiguous ()
443- slot_id = state_indices_tensor [_prefill_idx ]
444467 slice_layer_cache = kv_cache [slot_id , ...]
445468
446469 out_slice = MiniMaxText01LinearKernel .jit_linear_forward_prefix (
@@ -453,9 +476,13 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
453476 layer_idx = self .layer_idx )
454477 hidden .append (out_slice .contiguous ())
455478 if attn_metadata .num_decode_tokens > 0 :
456- hidden .append (
457- self ._decode_infer (q , k , v , kv_cache , state_indices_tensor ,
458- attn_metadata ))
479+ hidden_decode = self ._decode_infer (q , k , v , kv_cache ,
480+ state_indices_tensor ,
481+ attn_metadata )
482+ if envs .VLLM_USE_V1 :
483+ hidden .insert (0 , hidden_decode )
484+ else :
485+ hidden .append (hidden_decode )
459486
460487 if not hidden :
461488 return torch .empty ((0 , q .size (- 1 )), device = q .device , dtype = q .dtype )
@@ -465,11 +492,17 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
465492
466493 def _decode_infer (self , q , k , v , kv_cache , state_indices_tensor ,
467494 attn_metadata ):
468- q = q [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
469- k = k [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
470- v = v [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
471- slot_id = state_indices_tensor [getattr (attn_metadata , "num_prefills" , 0
472- ):]
495+ if not envs .VLLM_USE_V1 :
496+ q = q [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
497+ k = k [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
498+ v = v [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
499+ num_prefills = getattr (attn_metadata , "num_prefills" , 0 )
500+ slot_id = state_indices_tensor [num_prefills :]
501+ else :
502+ q = q [:attn_metadata .num_decode_tokens ].unsqueeze (2 ).contiguous ()
503+ k = k [:attn_metadata .num_decode_tokens ].unsqueeze (2 ).contiguous ()
504+ v = v [:attn_metadata .num_decode_tokens ].unsqueeze (2 ).contiguous ()
505+ slot_id = state_indices_tensor [:attn_metadata .num_decodes ]
473506 hidden = linear_decode_forward_triton (q , k , v , kv_cache , self .tp_slope ,
474507 slot_id , 32 )
475508 return hidden
@@ -483,17 +516,49 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
483516 q , k , v = torch .split (qkvact , [self .head_dim ] * 3 , dim = - 1 )
484517 forward_context = get_forward_context ()
485518 attn_metadata = forward_context .attn_metadata
486- kv_cache = kv_caches .minimax_cache
487- state_indices_tensor = kv_caches .state_indices_tensor
519+ if envs .VLLM_USE_V1 :
520+ if attn_metadata is not None :
521+ assert isinstance (attn_metadata , dict )
522+ attn_metadata = attn_metadata [self .prefix ]
523+ assert isinstance (attn_metadata , LinearAttentionMetadata )
524+ kv_cache = self .kv_cache [forward_context .virtual_engine ][0 ]
525+ state_indices_tensor = attn_metadata .state_indices_tensor
526+
527+ num_prefills = getattr (attn_metadata , "num_prefills" , 0 )
528+ if num_prefills > 0 :
529+ num_decode_tokens = getattr (attn_metadata ,
530+ "num_decode_tokens" , 0 )
531+ for prefill_idx in range (num_prefills ):
532+ q_start = attn_metadata .query_start_loc [
533+ num_decode_tokens + prefill_idx ]
534+ q_end = attn_metadata .query_start_loc [num_decode_tokens
535+ + prefill_idx +
536+ 1 ]
537+ query_len = q_end - q_start
538+ context_len = attn_metadata .seq_lens [
539+ num_decode_tokens + prefill_idx ] - query_len
540+ if context_len == 0 :
541+ block_to_clear = state_indices_tensor [
542+ num_decode_tokens + prefill_idx ]
543+ kv_cache [block_to_clear , ...] = 0
544+ else :
545+ kv_cache = kv_caches .minimax_cache
546+ state_indices_tensor = kv_caches .state_indices_tensor
488547
489548 decode_only = getattr (attn_metadata , "num_prefills" , 0 ) == 0
490- if not decode_only :
491- hidden = self . _prefill_and_mix_infer ( q , k , v , kv_cache ,
492- state_indices_tensor ,
493- attn_metadata )
549+ if attn_metadata is None :
550+ hidden = torch . empty (( q . shape [ 0 ], q . shape [ 1 ] * q . shape [ 2 ]) ,
551+ device = q . device ,
552+ dtype = q . dtype )
494553 else :
495- hidden = self ._decode_infer (q , k , v , kv_cache ,
496- state_indices_tensor , attn_metadata )
554+ if not decode_only :
555+ hidden = self ._prefill_and_mix_infer (q , k , v , kv_cache ,
556+ state_indices_tensor ,
557+ attn_metadata )
558+ else :
559+ hidden = self ._decode_infer (q , k , v , kv_cache ,
560+ state_indices_tensor ,
561+ attn_metadata )
497562
498563 hidden = self .norm ._forward (hidden )
499564 gate , _ = self .output_gate (hidden_states )
@@ -541,6 +606,7 @@ def __init__(
541606 self .scaling = self .head_dim ** - 0.5
542607 self .rope_theta = rope_theta
543608 self .sliding_window = sliding_window
609+ self .prefix = prefix
544610
545611 self .qkv_proj = QKVParallelLinear (
546612 hidden_size ,
@@ -575,7 +641,12 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
575641 attn_metadata = forward_context .attn_metadata
576642 qkv , _ = self .qkv_proj (hidden_states )
577643 q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
578- q , k = attn_metadata .rotary_emb (positions , q , k )
644+ if envs .VLLM_USE_V1 :
645+ if attn_metadata is not None :
646+ q , k = attn_metadata [f"{ self .prefix } .attn" ].rotary_emb (
647+ positions , q , k )
648+ else :
649+ q , k = attn_metadata .rotary_emb (positions , q , k )
579650 attn_output = self .attn (q , k , v )
580651 output , _ = self .o_proj (attn_output )
581652 return output
@@ -595,6 +666,7 @@ def __init__(
595666 ) -> None :
596667 self ._ilayer = layer_id
597668 self ._irank = get_tensor_model_parallel_rank ()
669+ self .prefix = prefix
598670 super ().__init__ ()
599671
600672 self .hidden_size = config .hidden_size
@@ -876,8 +948,9 @@ def layer_fn(prefix):
876948 self ._dtype = _dummy .dtype
877949 del _dummy
878950
879- self .minimax_cache = MinimaxCacheManager (dtype = torch .float32 ,
880- cache_shape = self .cache_shape )
951+ if not envs .VLLM_USE_V1 :
952+ self .minimax_cache = MinimaxCacheManager (
953+ dtype = torch .float32 , cache_shape = self .cache_shape )
881954
882955 rope_theta = getattr (config , "rope_theta" , 10000 )
883956 head_dim = getattr (config , "head_dim" , None )
@@ -944,23 +1017,27 @@ def forward(self,
9441017 ** kwargs ) -> Union [torch .Tensor , IntermediateTensors ]:
9451018 forward_context = get_forward_context ()
9461019 attn_metadata = forward_context .attn_metadata
947- if attn_metadata is None :
1020+ if not envs . VLLM_USE_V1 and attn_metadata is None :
9481021 return None
9491022 if "request_ids_to_seq_ids" not in kwargs :
9501023 kwargs ["request_ids_to_seq_ids" ] = {}
9511024 if "finished_requests_ids" not in kwargs :
9521025 kwargs ["finished_requests_ids" ] = []
9531026
954- (
955- minimax_cache_tensors ,
956- state_indices_tensor ,
957- ) = self .minimax_cache .current_run_tensors (** kwargs )
958- if getattr (attn_metadata , "num_prefills" , 0 ) > 0 :
959- self ._clear_prefill_cache (attn_metadata , minimax_cache_tensors ,
960- ** kwargs )
1027+ if not envs .VLLM_USE_V1 :
1028+ (
1029+ minimax_cache_tensors ,
1030+ state_indices_tensor ,
1031+ ) = self .minimax_cache .current_run_tensors (** kwargs )
1032+ if getattr (attn_metadata , "num_prefills" , 0 ) > 0 :
1033+ self ._clear_prefill_cache (attn_metadata , minimax_cache_tensors ,
1034+ ** kwargs )
1035+
1036+ minimax_cache_params = MinimaxCacheParams (minimax_cache_tensors ,
1037+ state_indices_tensor )
1038+ else :
1039+ minimax_cache_params = None
9611040
962- minimax_cache_params = MinimaxCacheParams (minimax_cache_tensors ,
963- state_indices_tensor )
9641041 if get_pp_group ().is_first_rank :
9651042 if inputs_embeds is None :
9661043 hidden_states = self .embed_scale * self .embed_tokens (input_ids )
@@ -973,11 +1050,22 @@ def forward(self,
9731050 residual = intermediate_tensors ["residual" ]
9741051
9751052 minimax_cache_index = 0
976- attn_metadata . rotary_emb = self . rotary_emb
1053+
9771054 for i in range (self .start_layer , self .end_layer ):
9781055 layer = self .layers [i ]
1056+ if attn_metadata is not None :
1057+ # TODO (tdoublep): this whole thing with the rotary_emb is
1058+ # weird. we shouldn't be passing it via attn_metadata imo.
1059+ if envs .VLLM_USE_V1 :
1060+ if isinstance (layer .self_attn , MiniMaxText01Attention ):
1061+ attn_metadata [layer .prefix +
1062+ ".attn" ].rotary_emb = self .rotary_emb
1063+ else :
1064+ attn_metadata .rotary_emb = self .rotary_emb
1065+
9791066 _caches = None
980- if isinstance (layer .self_attn , MiniMaxText01LinearAttention ):
1067+ if not envs .VLLM_USE_V1 and isinstance (
1068+ layer .self_attn , MiniMaxText01LinearAttention ):
9811069 current_state_layer = minimax_cache_index
9821070 _caches = minimax_cache_params .at_layer_idx (
9831071 current_state_layer )
@@ -1002,8 +1090,7 @@ def forward(self,
10021090 return hidden_states
10031091
10041092
1005- class MiniMaxText01ForCausalLM (nn .Module , HasInnerState , IsHybrid ,
1006- SupportsV0Only ):
1093+ class MiniMaxText01ForCausalLM (nn .Module , HasInnerState , IsHybrid ):
10071094
10081095 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ) -> None :
10091096
@@ -1321,3 +1408,28 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,
13211408
13221409 load_basic_weight (name , loaded_weight , self )
13231410 return loaded_params
1411+
1412+ @classmethod
1413+ def get_mamba_state_shape_from_config (
1414+ cls ,
1415+ vllm_config : "VllmConfig" ,
1416+ use_v1 : bool = True ,
1417+ ) -> tuple [tuple [int , ...], ...]:
1418+ """Calculate shape for MiniMaxText01LinearAttention cache.
1419+
1420+ Args:
1421+ vllm_config: vLLM config
1422+ use_v1: Get shapes for V1 (or V0)
1423+
1424+ Returns:
1425+ Tuple containing:
1426+ - state_shape: Shape of the cache
1427+ """
1428+ parallel_config = vllm_config .parallel_config
1429+ hf_config = vllm_config .model_config .hf_config
1430+
1431+ return MambaStateShapeCalculator .linear_attention_state_shape (
1432+ num_heads = hf_config .num_attention_heads ,
1433+ tp_size = parallel_config .tensor_parallel_size ,
1434+ head_dim = hf_config .head_dim ,
1435+ )
0 commit comments