@@ -489,6 +489,9 @@ def forward_cuda(
489489 # stay the same and reused for all mamba layers in the same iteration
490490 attn_metadata : AttentionMetadata = forward_context .attn_metadata
491491
492+ assert self .cache_config is not None
493+ mamba_block_size = self .cache_config .mamba_block_size
494+ prefix_caching_enabled = self .cache_config .enable_prefix_caching
492495 if attn_metadata is not None :
493496 assert isinstance (attn_metadata , dict )
494497 attn_metadata = attn_metadata [self .prefix ]
@@ -573,6 +576,25 @@ def forward_cuda(
573576 dim = 0 ,
574577 )
575578
579+ if prefix_caching_enabled :
580+ # If prefix caching is enabled, retrieve the relevant variables
581+ # for prefill and decode
582+ last_state_idx_d , last_state_idx_p = torch .split (
583+ attn_metadata .last_state_idx , [num_decodes , num_prefills ],
584+ dim = 0 )
585+ current_last_idx_d , current_last_idx_p = torch .split (
586+ attn_metadata .current_last_idx , [num_decodes , num_prefills ],
587+ dim = 0 )
588+ # Prefill-only variables:
589+ current_first_idx_p = attn_metadata .current_first_idx_p
590+ context_lens_p = attn_metadata .context_lens_p
591+ last_computed_offset_p = attn_metadata .last_computed_offset_p
592+ else :
593+ last_state_idx_d , last_state_idx_p = None , None
594+ current_last_idx_d , current_last_idx_p = None , None
595+ current_first_idx_p = None
596+ context_lens_p = None
597+
576598 # Preallocate output tensor to avoid memcpy cost for merging prefill
577599 # and decode outputs
578600 preallocated_ssm_out = torch .empty (
@@ -592,8 +614,17 @@ def forward_cuda(
592614 # Process prefill requests
593615 if has_prefill :
594616 # 2. Convolution sequence transformation
595- # - "cache_indices" updates the conv_state cache in positions
596- # pointed to by "state_indices_tensor"
617+ # - It will read the initial states for every sequence,
618+ # that has "has_initial_states_p" == True,
619+ # from "cache_indices", using "state_indices_tensor_p".
620+ # - It updates the "conv_state" cache in positions pointed
621+ # to by "state_indices_tensor_p".
622+ # In particular, it will always write the state at the
623+ # sequence end.
624+ # In addition, "current_first_idx_p" and "current_last_idx_p"
625+ # are provided (which are pointers into
626+ # "state_indices_tensor_p"), it will write additional cache
627+ # states aligned at "block_size_to_align".
597628 x = hidden_states_B_C_p .transpose (
598629 0 , 1 ) # this is the form that causal-conv see
599630 hidden_states_B_C_p = causal_conv1d_fn (
@@ -604,6 +635,11 @@ def forward_cuda(
604635 conv_states = conv_state ,
605636 has_initial_state = has_initial_states_p ,
606637 cache_indices = state_indices_tensor_p ,
638+ current_first_idx = current_first_idx_p ,
639+ current_last_idx = current_last_idx_p ,
640+ initial_state_idx = last_state_idx_p ,
641+ context_lens = context_lens_p ,
642+ block_size_to_align = mamba_block_size ,
607643 metadata = attn_metadata ,
608644 query_start_loc = query_start_loc_p ).transpose (
609645 0 , 1 )[:num_prefill_tokens ]
@@ -614,9 +650,13 @@ def forward_cuda(
614650 # 3. State Space Model sequence transformation
615651 initial_states = None
616652 if (has_initial_states_p is not None and prep_initial_states ):
653+ kernel_ssm_indices = state_indices_tensor_p
654+ if prefix_caching_enabled :
655+ kernel_ssm_indices = state_indices_tensor_p .gather (
656+ 1 , last_state_idx_p .unsqueeze (1 )).squeeze (1 )
617657 initial_states = torch .where (
618658 has_initial_states_p [:, None , None , None ],
619- ssm_state [state_indices_tensor_p ], 0 )
659+ ssm_state [kernel_ssm_indices ], 0 )
620660
621661 # NOTE: final output is an in-place update of out tensor
622662 varlen_states = mamba_chunk_scan_combined_varlen (
@@ -638,26 +678,82 @@ def forward_cuda(
638678 cu_chunk_seqlens = cu_chunk_seqlen_p ,
639679 last_chunk_indices = last_chunk_indices_p ,
640680 initial_states = initial_states ,
681+ return_intermediate_states = prefix_caching_enabled ,
641682 dt_softplus = True ,
642683 dt_limit = (0.0 , float ("inf" )),
643684 out = preallocated_ssm_out_p .view (num_prefill_tokens , - 1 ,
644685 self .head_dim ),
645686 state_dtype = ssm_state .dtype )
646687
647- # update ssm states
648- # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
649- ssm_state [state_indices_tensor_p ] = varlen_states
688+ if prefix_caching_enabled :
689+ # Save states for sequences with more than just the final state:
690+ n_blocks_to_fill = current_last_idx_p - current_first_idx_p
691+ for seq_idx in (n_blocks_to_fill > 0 ).nonzero ().squeeze (1 ):
692+ cache_blocks_to_fill = state_indices_tensor_p [
693+ seq_idx , current_first_idx_p [seq_idx ]:
694+ current_first_idx_p [seq_idx ] +
695+ n_blocks_to_fill [seq_idx ]]
696+ # chunks = [0 1 2 3 4 5 6 ...]
697+ # First aligned chunk would typically be:
698+ # mamba_block_size = 1024, chunk_size = 256
699+ # 1024 // 256 - 1 --> chunks[3]
700+ # But when last chunk wasn't block aligned:
701+ # - last_computed_offset_p[seq_idx] // chunk_size
702+ # e.g. 1000 // 256 -> 3 completed --> store chunk[0]
703+ # e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
704+ # e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
705+ # e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
706+ chunk_stride = mamba_block_size // chunk_size
707+ first_aligned_chunk = \
708+ torch .concat ([torch .zeros (1 , \
709+ dtype = last_chunk_indices_p .dtype , \
710+ device = last_chunk_indices_p .device ), \
711+ last_chunk_indices_p + 1 ])[seq_idx ] \
712+ + chunk_stride - 1 \
713+ - last_computed_offset_p [seq_idx ] // chunk_size
714+ from_where = varlen_states [
715+ first_aligned_chunk :first_aligned_chunk +
716+ n_blocks_to_fill [seq_idx ] * chunk_stride :chunk_stride ]
717+ ssm_state [cache_blocks_to_fill ] = from_where
718+
719+ #For all seqs, store the last state (Note: might be partial):
720+ ssm_state [state_indices_tensor_p .gather (1 ,
721+ current_last_idx_p .unsqueeze (1 )).squeeze (1 )] = \
722+ varlen_states [last_chunk_indices_p ]
723+ else :
724+ # update ssm states
725+ # - varlen state is a (num_prefills, nheads, headdim, dstate)
726+ # tensor
727+ ssm_state [state_indices_tensor_p ] = varlen_states
650728
651729 # Process decode requests
652730 if has_decode :
731+ if prefix_caching_enabled :
732+ state_indices_tensor_d_input = \
733+ state_indices_tensor_d .gather (1 ,
734+ last_state_idx_d .unsqueeze (1 )).squeeze (1 )
735+ state_indices_tensor_d_output = \
736+ state_indices_tensor_d .gather (1 ,
737+ current_last_idx_d .unsqueeze (1 )).squeeze (1 )
738+ #Note:
739+ # for decode always: current_first_idx_d == current_last_idx_d
740+ # at block boundaries: current_first_idx_d > last_state_idx_d
741+ else :
742+ # Without caching, read and write in-place to the same blocks:
743+ state_indices_tensor_d_input = state_indices_tensor_d
744+ state_indices_tensor_d_output = state_indices_tensor_d
745+
653746 # 2. Convolution sequence transformation
654747 hidden_states_B_C_d = causal_conv1d_update (
655748 hidden_states_B_C_d ,
656749 conv_state ,
657750 conv_weights ,
658751 self .conv1d .bias ,
659752 self .activation ,
660- conv_state_indices = state_indices_tensor_d )
753+ conv_state_indices = state_indices_tensor_d ,
754+ current_last_idx = current_last_idx_d ,
755+ initial_state_idx = last_state_idx_d ,
756+ )
661757
662758 hidden_states_d , B_d , C_d = split_hidden_states_B_C_fn (
663759 hidden_states_B_C_d )
@@ -689,7 +785,8 @@ def forward_cuda(
689785 z = None ,
690786 dt_bias = dt_bias ,
691787 dt_softplus = True ,
692- state_batch_indices = state_indices_tensor_d ,
788+ state_batch_indices = state_indices_tensor_d_input ,
789+ dst_state_batch_indices = state_indices_tensor_d_output ,
693790 out = preallocated_ssm_out_d .view (num_decodes , - 1 ,
694791 self .head_dim ),
695792 )
0 commit comments