@@ -541,7 +541,6 @@ def forward_cuda(
541541 # NOTE: V0 put prefill before decode, v1 puts decode before prefill
542542 # Separate prefill and decode by splitting varlen input
543543 # Split along token dimension
544- # NOTE: V0 put prefill before decode, v1 puts decode before prefill
545544 if envs .VLLM_USE_V1 :
546545 hidden_states_B_C_d , hidden_states_B_C_p = torch .split (
547546 hidden_states_B_C [:num_actual_tokens ],
@@ -583,7 +582,28 @@ def forward_cuda(
583582 1 ]
584583 if has_prefill else None )
585584
586- ssd_output_list = []
585+ # Preallocate output tensor to avoid memcpy cost for merging prefill
586+ # and decode outputs
587+ preallocated_ssm_out = torch .empty (
588+ [
589+ num_prefill_tokens + num_decodes ,
590+ (self .num_heads // self .tp_size ) * self .head_dim
591+ ],
592+ dtype = hidden_states .dtype ,
593+ device = hidden_states .device ,
594+ )
595+ if envs .VLLM_USE_V1 :
596+ preallocated_ssm_out_d , preallocated_ssm_out_p = torch .split (
597+ preallocated_ssm_out ,
598+ [num_decodes , num_prefill_tokens ],
599+ dim = 0 ,
600+ )
601+ else :
602+ preallocated_ssm_out_p , preallocated_ssm_out_d = torch .split (
603+ preallocated_ssm_out ,
604+ [num_prefill_tokens , num_decodes ],
605+ dim = 0 ,
606+ )
587607
588608 # Process prefill requests
589609 if has_prefill :
@@ -623,7 +643,8 @@ def forward_cuda(
623643 has_initial_states_p [:num_prefills , None , None , None ],
624644 ssm_state [state_indices_tensor_p ], 0 )
625645
626- scan_output , varlen_state = mamba_chunk_scan_combined (
646+ # NOTE: final output is an in-place update of out tensor
647+ varlen_state = mamba_chunk_scan_combined (
627648 hidden_states_p .view (1 , num_prefill_tokens ,
628649 self .num_heads // self .tp_size ,
629650 self .head_dim ),
@@ -646,15 +667,14 @@ def forward_cuda(
646667 return_final_states = False ,
647668 dt_softplus = True ,
648669 dt_limit = (0.0 , float ("inf" )),
670+ out = preallocated_ssm_out_p .view (1 , num_prefill_tokens , - 1 ,
671+ self .head_dim ),
649672 )
650673
651674 # update ssm states
652675 # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
653676 ssm_state [state_indices_tensor_p ] = varlen_state
654677
655- # - reshape
656- ssd_output_list .append (scan_output .view (num_prefill_tokens , - 1 ))
657-
658678 # Process decode requests
659679 if has_decode :
660680 # 2. Convolution sequence transformation
@@ -684,8 +704,8 @@ def forward_cuda(
684704 # - the hidden is reshaped into (bs, num_heads, head_dim)
685705 # - mamba_cache_params.ssm_state's slots will be selected
686706 # using state_indices_tensor_d
687-
688- hidden_states_d = selective_state_update (
707+ # NOTE: final output is an in-place update of out tensor
708+ selective_state_update (
689709 ssm_state ,
690710 hidden_states_d ,
691711 dt_d ,
@@ -697,26 +717,16 @@ def forward_cuda(
697717 dt_bias = dt_bias ,
698718 dt_softplus = True ,
699719 state_batch_indices = state_indices_tensor_d ,
720+ out = preallocated_ssm_out_d .view (num_decodes , - 1 ,
721+ self .head_dim ),
700722 )
701723
702- if envs .VLLM_USE_V1 :
703- ssd_output_list .insert (
704- 0 ,
705- hidden_states_d .view (- 1 , (self .num_heads // self .tp_size ) *
706- self .head_dim ))
707- else :
708- ssd_output_list .append (
709- hidden_states_d .view (- 1 , (self .num_heads // self .tp_size ) *
710- self .head_dim ))
711-
712- # Merge prefill and decode outputs before passing to gated MLP
713- hidden_states = torch .vstack (ssd_output_list )
714-
715724 # 4. gated MLP
716725 # GatedRMSNorm internally applying SiLU to the gate
717726 # SiLU is applied internally before normalization, unlike standard
718727 # norm usage
719- hidden_states = self .norm (hidden_states , gate [:num_actual_tokens ])
728+ hidden_states = self .norm (preallocated_ssm_out ,
729+ gate [:num_actual_tokens ])
720730
721731 # 5. Final linear projection
722732 output [:num_actual_tokens ], _ = self .out_proj (hidden_states )
0 commit comments