@@ -388,10 +388,15 @@ def forward_cuda(
388388 # mamba2_metadata contains metadata necessary for the mamba2 triton
389389 # kernels to operate in continuous batching and in chunked prefill
390390 # modes; they are computed at top-level model forward since they
391- # are the same and reused for all mamba layers in the same iteration
391+ # stay the same and reused for all mamba layers in the same iteration
392392 attn_metadata : AttentionMetadata = get_forward_context ().attn_metadata
393393
394- seq_len , _ = hidden_states .shape
394+ num_prefills = attn_metadata .num_prefills # request count
395+ num_decodes = attn_metadata .num_decode_tokens # token count (=request)
396+ num_prefill_tokens = attn_metadata .num_prefill_tokens # token count
397+ has_prefill = num_prefills > 0
398+ has_decode = num_decodes > 0
399+
395400 groups_time_state_size = self .n_groups * self .ssm_state_size
396401
397402 # 1. Gated MLP's linear projection
@@ -406,44 +411,32 @@ def forward_cuda(
406411 dim = - 1 ,
407412 )
408413
409- # 2. Convolution sequence transformation
410414 conv_weights = self .conv1d .weight .view (self .conv1d .weight .size (0 ),
411415 self .conv1d .weight .size (2 ))
412416
413- if mamba2_metadata .has_prefill :
414- # |---------- N-1 iteration --------|
415- # |---------------- N iteration ---------------------|
416- # |- tokenA -|......................|-- newTokens ---|
417- # |---------- context_len ----------|
418- # |-------------------- seq_len ---------------------|
419- # |-- query_len ---|
420-
421- # - "cache_indices" updates the conv_state cache in positions
422- # pointed to by "mamba_cache_params.state_indices_tensor"
423- hidden_states_B_C = causal_conv1d_fn (
424- hidden_states_B_C .transpose (0 , 1 ),
425- conv_weights ,
426- self .conv1d .bias ,
427- activation = self .activation ,
428- conv_states = mamba_cache_params .conv_state ,
429- has_initial_state = mamba2_metadata .has_initial_states ,
430- cache_indices = mamba_cache_params .state_indices_tensor ,
431- query_start_loc = attn_metadata .query_start_loc ).transpose (
432- 0 , 1 )[:seq_len ]
433-
434- # TODO: Why is this needed?
435- hidden_states_B_C = hidden_states_B_C .contiguous ()
436- else :
437- hidden_states_B_C = causal_conv1d_update (
438- hidden_states_B_C ,
439- mamba_cache_params .conv_state ,
440- conv_weights ,
441- self .conv1d .bias ,
442- self .activation ,
443- conv_state_indices = mamba_cache_params .state_indices_tensor )
417+ # Separate prefill and decode by splitting varlen input
418+ # Split along token dimension
419+ hidden_states_B_C_p , hidden_states_B_C_d = torch .split (
420+ hidden_states_B_C ,
421+ [num_prefill_tokens , num_decodes ],
422+ dim = 0 ,
423+ )
424+ dt_p , dt_d = torch .split (
425+ dt ,
426+ [num_prefill_tokens , num_decodes ],
427+ dim = 0 ,
428+ )
429+ # Split along batch dimension
430+ state_indices_tensor_p , state_indices_tensor_d = torch .split (
431+ mamba_cache_params .state_indices_tensor ,
432+ [num_prefills , num_decodes ],
433+ dim = 0 ,
434+ )
435+ query_start_loc_p = (attn_metadata .query_start_loc [:num_prefills + 1 ]
436+ if has_prefill else None )
444437
445438 # - get hidden_states, B and C after depthwise convolution.
446- hidden_states , B , C = torch .split (
439+ split_hidden_states_B_C_fn = lambda hidden_states_B_C : torch .split (
447440 hidden_states_B_C ,
448441 [
449442 self .intermediate_size // self .tp_size ,
@@ -453,32 +446,56 @@ def forward_cuda(
453446 dim = - 1 ,
454447 )
455448
456- # 3. State Space Model sequence transformation
457- if mamba2_metadata .has_prefill :
449+ ssd_output_list = []
450+
451+ # Process prefill requests
452+ if has_prefill :
453+ # 2. Convolution sequence transformation
454+ # - "cache_indices" updates the conv_state cache in positions
455+ # pointed to by "mamba_cache_params.state_indices_tensor"
456+ hidden_states_B_C_p = causal_conv1d_fn (
457+ hidden_states_B_C_p .transpose (0 , 1 ),
458+ conv_weights ,
459+ self .conv1d .bias ,
460+ activation = self .activation ,
461+ conv_states = mamba_cache_params .conv_state ,
462+ has_initial_state = mamba2_metadata .has_initial_states ,
463+ cache_indices = state_indices_tensor_p ,
464+ query_start_loc = query_start_loc_p ).transpose (
465+ 0 , 1 )[:num_prefill_tokens ]
466+
467+ # TODO: Why is this needed?
468+ hidden_states_B_C_p = hidden_states_B_C_p .contiguous ()
469+ hidden_states_p , B_p , C_p = split_hidden_states_B_C_fn (
470+ hidden_states_B_C_p )
471+
472+ # 3. State Space Model sequence transformation
458473 initial_states = None
459474 if (mamba2_metadata .has_initial_states is not None
460475 and mamba2_metadata .prep_initial_states ):
461476 # making a copy of the states
462477 initial_states = torch .where (
463478 mamba2_metadata .has_initial_states [:, None , None , None ],
464- mamba_cache_params .ssm_state [
465- mamba_cache_params .state_indices_tensor ], 0 )
479+ mamba_cache_params .ssm_state [state_indices_tensor_p ], 0 )
466480
467481 scan_output , varlen_state = mamba_chunk_scan_combined (
468- hidden_states .view (1 , seq_len , self .num_heads // self .tp_size ,
469- self .head_dim ),
470- dt .unsqueeze (0 ),
482+ hidden_states_p .view (1 , num_prefill_tokens ,
483+ self .num_heads // self .tp_size ,
484+ self .head_dim ),
485+ dt_p .unsqueeze (0 ),
471486 self .A ,
472- B .view (1 , seq_len , self .n_groups // self .tp_size , - 1 ),
473- C .view (1 , seq_len , self .n_groups // self .tp_size , - 1 ),
487+ B_p .view (1 , num_prefill_tokens , self .n_groups // self .tp_size ,
488+ - 1 ),
489+ C_p .view (1 , num_prefill_tokens , self .n_groups // self .tp_size ,
490+ - 1 ),
474491 chunk_size = mamba2_metadata .chunk_size ,
475492 D = self .D ,
476493 z = None ,
477494 dt_bias = self .dt_bias ,
478495 seq_idx = mamba2_metadata .seq_idx ,
479496 chunk_indices = mamba2_metadata .chunk_indices ,
480497 chunk_offsets = mamba2_metadata .chunk_offsets ,
481- cu_seqlens = attn_metadata .query_start_loc ,
498+ cu_seqlens = attn_metadata .query_start_loc [: num_prefills + 1 ] ,
482499 initial_states = initial_states ,
483500 return_varlen_states = True ,
484501 return_final_states = False ,
@@ -487,52 +504,65 @@ def forward_cuda(
487504 )
488505
489506 # update ssm states
490- # - varlen state is a (batch, nheads, headdim, dstate) tensor
491- mamba_cache_params .ssm_state [
492- mamba_cache_params .state_indices_tensor ] = varlen_state
507+ # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
508+ mamba_cache_params .ssm_state [state_indices_tensor_p ] = varlen_state
493509
494510 # - reshape
495- hidden_states = scan_output .view (seq_len , - 1 )
496- else :
511+ ssd_output_list .append (scan_output .view (num_prefill_tokens , - 1 ))
497512
513+ # Process decode requests
514+ if has_decode :
515+ # 2. Convolution sequence transformation
516+ hidden_states_B_C_d = causal_conv1d_update (
517+ hidden_states_B_C_d ,
518+ mamba_cache_params .conv_state ,
519+ conv_weights ,
520+ self .conv1d .bias ,
521+ self .activation ,
522+ conv_state_indices = state_indices_tensor_d )
523+
524+ hidden_states_d , B_d , C_d = split_hidden_states_B_C_fn (
525+ hidden_states_B_C_d )
526+
527+ # 3. State Space Model sequence transformation
498528 n_groups = self .n_groups // self .tp_size
499- A = self .A [:, None , ...][:, :, None ].expand (
529+ A_d = self .A [:, None , ...][:, :, None ].expand (
500530 - 1 , self .head_dim , self .ssm_state_size ).to (dtype = torch .float32 )
501- dt = dt [:, :, None ].expand (- 1 , - 1 , self .head_dim )
531+ dt_d = dt_d [:, :, None ].expand (- 1 , - 1 , self .head_dim )
502532 dt_bias = self .dt_bias [:, None , ...].expand (- 1 , self .head_dim )
503- D = self .D [:, None , ...].expand (- 1 , self .head_dim )
504- B = B .view (- 1 , n_groups , B .shape [1 ] // n_groups )
505- C = C .view (- 1 , n_groups , C .shape [1 ] // n_groups )
506- hidden_states_reshaped = hidden_states .view (
533+ D_d = self .D [:, None , ...].expand (- 1 , self .head_dim )
534+ B_d = B_d .view (- 1 , n_groups , B_d .shape [1 ] // n_groups )
535+ C_d = C_d .view (- 1 , n_groups , C_d .shape [1 ] // n_groups )
536+ hidden_states_d = hidden_states_d .view (
507537 - 1 , self .num_heads // self .tp_size , self .head_dim )
508538
509- # - the hidden is reshaped into number of current batches
510- # - in this case there is no more prefill, so the batches gen
511- # 1 token at a time
512- # - thus hidden will be (bs, num_heads, head_dim)
539+ # - the hidden is reshaped into (bs, num_heads, head_dim)
513540 # - mamba_cache_params.ssm_state's slots will be selected
514- # using "mamba_cache_params.state_indices_tensor", just as
515- # above in the prefill case
541+ # using state_indices_tensor_d
516542
517- hidden_states = selective_state_update (
543+ hidden_states_d = selective_state_update (
518544 mamba_cache_params .ssm_state ,
519- hidden_states_reshaped ,
520- dt ,
521- A ,
522- B ,
523- C ,
524- D ,
545+ hidden_states_d ,
546+ dt_d ,
547+ A_d ,
548+ B_d ,
549+ C_d ,
550+ D_d ,
525551 z = None ,
526552 dt_bias = dt_bias ,
527553 dt_softplus = True ,
528- state_batch_indices = mamba_cache_params . state_indices_tensor ,
554+ state_batch_indices = state_indices_tensor_d ,
529555 )
530- hidden_states = hidden_states .view (
531- - 1 , (self .num_heads // self .tp_size ) * self .head_dim )
556+ ssd_output_list .append (
557+ hidden_states_d .view (- 1 , (self .num_heads // self .tp_size ) *
558+ self .head_dim ))
559+
560+ # Merge prefill and decode outputs before passing to gated MLP
561+ hidden_states = torch .vstack (ssd_output_list )
532562
533- # # 4. gated MLP
563+ # 4. gated MLP
534564 hidden_states = self .norm (hidden_states , gate )
535565
536- # # 5. Final linear projection
566+ # 5. Final linear projection
537567 out , _ = self .out_proj (hidden_states )
538568 return out
0 commit comments