3333from vllm .model_executor .weight_utils import (default_weight_loader ,
3434 hf_model_weights_iterator )
3535from vllm .sequence import SamplerOutput
36- from mamba_ssm .ops .selective_scan_interface import mamba_inner_fn , selective_scan_fn
36+ from mamba_ssm .ops .selective_scan_interface import selective_scan_fn
3737from mamba_ssm .ops .triton .selective_state_update import selective_state_update
3838from causal_conv1d import causal_conv1d_fn , causal_conv1d_update
3939
@@ -114,7 +114,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar
114114
115115 # 2. Convolution sequence transformation
116116 conv_weights = self .conv1d .weight .view (self .conv1d .weight .size (0 ), self .conv1d .weight .size (2 ))
117- if cache_params is not None and cache_params .seqlen_offset > 0 :
117+ if cache_params is not None and not cache_params .is_prompt :
118118 hidden_states = causal_conv1d_update (
119119 hidden_states .squeeze (- 1 ),
120120 cache_params .conv_state ,
@@ -154,7 +154,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar
154154 A = - torch .exp (self .A_log .float ())
155155 # 3.c perform the recurrence y ← SSM(A, B, C)(x)
156156 time_proj_bias = self .dt_proj .bias .float () if hasattr (self .dt_proj , "bias" ) else None
157- if cache_params is not None and cache_params .seqlen_offset > 0 :
157+ if cache_params is not None and not cache_params .is_prompt :
158158 scan_outputs = selective_state_update (
159159 cache_params .ssm_state ,
160160 hidden_states [..., 0 ],
@@ -187,50 +187,14 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar
187187 contextualized_states = self .out_proj (scan_outputs .transpose (1 , 2 ))
188188 return contextualized_states
189189
190- def forward (self , hidden_states : torch .Tensor , input_metadata : InputMetadata ):
191- if input_metadata .is_prompt :
192- batch_size = hidden_states .shape [0 ]
193- conv_cache = torch .zeros (
194- batch_size ,
195- self .config .mamba_expand * self .config .hidden_size ,
196- self .config .mamba_d_conv ,
197- device = hidden_states .device ,
198- dtype = hidden_states .dtype
199- )
200- ssm_cache = torch .zeros (
201- batch_size ,
202- self .config .mamba_expand * self .config .hidden_size ,
203- self .config .mamba_d_state ,
204- device = hidden_states .device ,
205- dtype = hidden_states .dtype
206- )
207- cache = MambaCacheParams (0 , conv_cache , ssm_cache )
208- else :
209- for mamba_cache_request in input_metadata .mamba_cache_batch :
210- # check if batch size of cache fits "n"
211- n = mamba_cache_request .request_info .n
212- if mamba_cache_request .layer_idx2mamba_cache [self .layer_idx ].conv_state .shape [0 ] < n :
213- expanded_dims_conv = (n , * mamba_cache_request .layer_idx2mamba_cache [self .layer_idx ].conv_state .shape [1 :])
214- conv_state = mamba_cache_request .layer_idx2mamba_cache [self .layer_idx ].conv_state .expand (* expanded_dims_conv )
215- expanded_dims_ssm = (n , * mamba_cache_request .layer_idx2mamba_cache [self .layer_idx ].ssm_state .shape [1 :])
216- ssm_state = mamba_cache_request .layer_idx2mamba_cache [self .layer_idx ].ssm_state .expand (* expanded_dims_ssm )
217- mamba_cache_request .layer_idx2mamba_cache [self .layer_idx ].conv_state = conv_state
218- mamba_cache_request .layer_idx2mamba_cache [self .layer_idx ].ssm_state = ssm_state
219-
220- # mamba requires concatenated cache
221- conv_state = torch .concat ([req .layer_idx2mamba_cache [self .layer_idx ].conv_state for req in input_metadata .mamba_cache_batch ], dim = 0 )
222- ssm_state = torch .concat ([req .layer_idx2mamba_cache [self .layer_idx ].ssm_state for req in input_metadata .mamba_cache_batch ], dim = 0 )
223- cache = MambaCacheParams (1 , conv_state , ssm_state )
190+ def forward (self , hidden_states : torch .Tensor , input_metadata : InputMetadata , conv_state : torch .Tensor , ssm_state : torch .Tensor ):
191+ cache = MambaCacheParams (
192+ input_metadata .is_prompt ,
193+ conv_state = conv_state [self .layer_idx ],
194+ ssm_state = ssm_state [self .layer_idx ]
195+ )
224196 hidden_states = self .mamba_forward (hidden_states , cache_params = cache )
225197
226- # split cache back to individual requests
227- sample_id = 0
228- for req_mamba_metadata in input_metadata .mamba_cache_batch :
229- n = 1 if input_metadata .is_prompt else req_mamba_metadata .request_info .n
230- req_mamba_metadata .layer_idx2mamba_cache [self .layer_idx ].conv_state = cache .conv_state [sample_id :sample_id + n ]
231- req_mamba_metadata .layer_idx2mamba_cache [self .layer_idx ].ssm_state = cache .ssm_state [sample_id :sample_id + n ]
232- sample_id += n
233-
234198 return hidden_states
235199
236200
@@ -352,6 +316,8 @@ def forward(self,
352316 hidden_states : torch .Tensor ,
353317 input_metadata : InputMetadata ,
354318 residual : Optional [torch .Tensor ],
319+ conv_state : torch .Tensor ,
320+ ssm_state : torch .Tensor ,
355321 ** kwargs ):
356322
357323 if residual is None :
@@ -360,7 +326,12 @@ def forward(self,
360326 else :
361327 hidden_states , residual = self .input_layernorm (hidden_states , residual )
362328
363- hidden_states = self .mamba (hidden_states , input_metadata )
329+ hidden_states = self .mamba (
330+ hidden_states ,
331+ input_metadata ,
332+ conv_state ,
333+ ssm_state
334+ )
364335 # Fully Connected
365336 hidden_states , residual = self .pre_moe_layernorm (
366337 hidden_states , residual )
@@ -433,7 +404,8 @@ def self_attention(self,
433404 positions : torch .Tensor ,
434405 hidden_states : torch .Tensor ,
435406 kv_cache : KVCache ,
436- input_metadata : InputMetadata ) -> torch .Tensor :
407+ input_metadata : InputMetadata ,
408+ ** kwargs ) -> torch .Tensor :
437409 qkv , _ = self .qkv_proj (hidden_states )
438410 q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
439411 # TODO - add embedding flag
@@ -450,7 +422,8 @@ def forward(
450422 hidden_states : torch .Tensor ,
451423 kv_cache : KVCache ,
452424 input_metadata : InputMetadata ,
453- residual : Optional [torch .Tensor ]):
425+ residual : Optional [torch .Tensor ],
426+ ** kwargs ):
454427 if residual is None :
455428 residual = hidden_states
456429 hidden_states = self .input_layernorm (hidden_states )
@@ -524,6 +497,8 @@ def forward(
524497 positions : torch .Tensor ,
525498 kv_caches : List [KVCache ],
526499 input_metadata : InputMetadata ,
500+ conv_state : torch .Tensor ,
501+ ssm_state : torch .Tensor
527502 ) -> torch .Tensor :
528503 hidden_states = self .embed_tokens (input_ids )
529504 residual = None
@@ -534,7 +509,10 @@ def forward(
534509 hidden_states = hidden_states ,
535510 kv_cache = kv_caches [i ],
536511 input_metadata = input_metadata ,
537- residual = residual )
512+ residual = residual ,
513+ conv_state = conv_state ,
514+ ssm_state = ssm_state
515+ )
538516 hidden_states , _ = self .final_layernorm (hidden_states , residual )
539517 return hidden_states
540518
@@ -593,9 +571,17 @@ def forward(
593571 positions : torch .Tensor ,
594572 kv_caches : List [KVCache ],
595573 input_metadata : InputMetadata ,
596- ) -> torch .Tensor :
597- hidden_states = self .model (input_ids , positions , kv_caches ,
598- input_metadata )
574+ conv_state : torch .Tensor ,
575+ ssm_state : torch .Tensor
576+ ):
577+ hidden_states = self .model (
578+ input_ids ,
579+ positions ,
580+ kv_caches ,
581+ input_metadata ,
582+ conv_state ,
583+ ssm_state
584+ )
599585 return hidden_states
600586
601587 def sample (
0 commit comments