Skip to content

Commit ceebbbe

Browse files
s3wozbohnstingltdoublep
authored andcommitted
[V1] [Hybrid] Mamba2 Automatic Prefix Caching (vllm-project#25752)
Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com> Signed-off-by: Thomas Ortner <boh@zurich.ibm.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Thomas Ortner <boh@zurich.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent ccc16ad commit ceebbbe

File tree

18 files changed

+917
-147
lines changed

18 files changed

+917
-147
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 413 additions & 1 deletion
Large diffs are not rendered by default.

vllm/config/cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ class CacheConfig:
9292
mamba_page_size_padded: Optional[int] = None
9393
""" Optional override for mamba page size; used by hybrid mamba/attention
9494
models to ensure exact alignment with attention page size."""
95-
95+
mamba_block_size: Optional[int] = None
96+
"""Size of a contiguous cache block in number of tokens for mamba cache."""
9697
mamba_cache_dtype: MambaDType = "auto"
9798
"""The data type to use for the Mamba cache (both the conv as well as the
9899
ssm state). If set to 'auto', the data type will be inferred from the model

vllm/engine/arg_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1563,7 +1563,12 @@ def _set_default_args(self, usage_context: UsageContext,
15631563
self.enable_prefix_caching = False
15641564

15651565
if self.enable_prefix_caching is None:
1566-
self.enable_prefix_caching = True
1566+
# Disable prefix caching default for hybrid models
1567+
# since the feature is still experimental.
1568+
if model_config.is_hybrid:
1569+
self.enable_prefix_caching = False
1570+
else:
1571+
self.enable_prefix_caching = True
15671572
else:
15681573

15691574
pooling_type = model_config.pooler_config.pooling_type

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 105 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,9 @@ def forward_cuda(
489489
# stay the same and reused for all mamba layers in the same iteration
490490
attn_metadata: AttentionMetadata = forward_context.attn_metadata
491491

492+
assert self.cache_config is not None
493+
mamba_block_size = self.cache_config.mamba_block_size
494+
prefix_caching_enabled = self.cache_config.enable_prefix_caching
492495
if attn_metadata is not None:
493496
assert isinstance(attn_metadata, dict)
494497
attn_metadata = attn_metadata[self.prefix]
@@ -573,6 +576,25 @@ def forward_cuda(
573576
dim=0,
574577
)
575578

579+
if prefix_caching_enabled:
580+
# If prefix caching is enabled, retrieve the relevant variables
581+
# for prefill and decode
582+
last_state_idx_d, last_state_idx_p = torch.split(
583+
attn_metadata.last_state_idx, [num_decodes, num_prefills],
584+
dim=0)
585+
current_last_idx_d, current_last_idx_p = torch.split(
586+
attn_metadata.current_last_idx, [num_decodes, num_prefills],
587+
dim=0)
588+
# Prefill-only variables:
589+
current_first_idx_p = attn_metadata.current_first_idx_p
590+
context_lens_p = attn_metadata.context_lens_p
591+
last_computed_offset_p = attn_metadata.last_computed_offset_p
592+
else:
593+
last_state_idx_d, last_state_idx_p = None, None
594+
current_last_idx_d, current_last_idx_p = None, None
595+
current_first_idx_p = None
596+
context_lens_p = None
597+
576598
# Preallocate output tensor to avoid memcpy cost for merging prefill
577599
# and decode outputs
578600
preallocated_ssm_out = torch.empty(
@@ -592,8 +614,17 @@ def forward_cuda(
592614
# Process prefill requests
593615
if has_prefill:
594616
# 2. Convolution sequence transformation
595-
# - "cache_indices" updates the conv_state cache in positions
596-
# pointed to by "state_indices_tensor"
617+
# - It will read the initial states for every sequence,
618+
# that has "has_initial_states_p" == True,
619+
# from "cache_indices", using "state_indices_tensor_p".
620+
# - It updates the "conv_state" cache in positions pointed
621+
# to by "state_indices_tensor_p".
622+
# In particular, it will always write the state at the
623+
# sequence end.
624+
# In addition, "current_first_idx_p" and "current_last_idx_p"
625+
# are provided (which are pointers into
626+
# "state_indices_tensor_p"), it will write additional cache
627+
# states aligned at "block_size_to_align".
597628
x = hidden_states_B_C_p.transpose(
598629
0, 1) # this is the form that causal-conv see
599630
hidden_states_B_C_p = causal_conv1d_fn(
@@ -604,6 +635,11 @@ def forward_cuda(
604635
conv_states=conv_state,
605636
has_initial_state=has_initial_states_p,
606637
cache_indices=state_indices_tensor_p,
638+
current_first_idx=current_first_idx_p,
639+
current_last_idx=current_last_idx_p,
640+
initial_state_idx=last_state_idx_p,
641+
context_lens=context_lens_p,
642+
block_size_to_align=mamba_block_size,
607643
metadata=attn_metadata,
608644
query_start_loc=query_start_loc_p).transpose(
609645
0, 1)[:num_prefill_tokens]
@@ -614,9 +650,13 @@ def forward_cuda(
614650
# 3. State Space Model sequence transformation
615651
initial_states = None
616652
if (has_initial_states_p is not None and prep_initial_states):
653+
kernel_ssm_indices = state_indices_tensor_p
654+
if prefix_caching_enabled:
655+
kernel_ssm_indices = state_indices_tensor_p.gather(
656+
1, last_state_idx_p.unsqueeze(1)).squeeze(1)
617657
initial_states = torch.where(
618658
has_initial_states_p[:, None, None, None],
619-
ssm_state[state_indices_tensor_p], 0)
659+
ssm_state[kernel_ssm_indices], 0)
620660

621661
# NOTE: final output is an in-place update of out tensor
622662
varlen_states = mamba_chunk_scan_combined_varlen(
@@ -638,26 +678,82 @@ def forward_cuda(
638678
cu_chunk_seqlens=cu_chunk_seqlen_p,
639679
last_chunk_indices=last_chunk_indices_p,
640680
initial_states=initial_states,
681+
return_intermediate_states=prefix_caching_enabled,
641682
dt_softplus=True,
642683
dt_limit=(0.0, float("inf")),
643684
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
644685
self.head_dim),
645686
state_dtype=ssm_state.dtype)
646687

647-
# update ssm states
648-
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
649-
ssm_state[state_indices_tensor_p] = varlen_states
688+
if prefix_caching_enabled:
689+
# Save states for sequences with more than just the final state:
690+
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
691+
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
692+
cache_blocks_to_fill = state_indices_tensor_p[
693+
seq_idx, current_first_idx_p[seq_idx]:
694+
current_first_idx_p[seq_idx] +
695+
n_blocks_to_fill[seq_idx]]
696+
# chunks = [0 1 2 3 4 5 6 ...]
697+
# First aligned chunk would typically be:
698+
# mamba_block_size = 1024, chunk_size = 256
699+
# 1024 // 256 - 1 --> chunks[3]
700+
# But when last chunk wasn't block aligned:
701+
# - last_computed_offset_p[seq_idx] // chunk_size
702+
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
703+
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
704+
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
705+
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
706+
chunk_stride = mamba_block_size // chunk_size
707+
first_aligned_chunk = \
708+
torch.concat([torch.zeros(1, \
709+
dtype=last_chunk_indices_p.dtype, \
710+
device=last_chunk_indices_p.device), \
711+
last_chunk_indices_p + 1])[seq_idx] \
712+
+ chunk_stride - 1 \
713+
- last_computed_offset_p[seq_idx] // chunk_size
714+
from_where = varlen_states[
715+
first_aligned_chunk:first_aligned_chunk +
716+
n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride]
717+
ssm_state[cache_blocks_to_fill] = from_where
718+
719+
#For all seqs, store the last state (Note: might be partial):
720+
ssm_state[state_indices_tensor_p.gather(1,
721+
current_last_idx_p.unsqueeze(1)).squeeze(1)] = \
722+
varlen_states[last_chunk_indices_p]
723+
else:
724+
# update ssm states
725+
# - varlen state is a (num_prefills, nheads, headdim, dstate)
726+
# tensor
727+
ssm_state[state_indices_tensor_p] = varlen_states
650728

651729
# Process decode requests
652730
if has_decode:
731+
if prefix_caching_enabled:
732+
state_indices_tensor_d_input = \
733+
state_indices_tensor_d.gather(1,
734+
last_state_idx_d.unsqueeze(1)).squeeze(1)
735+
state_indices_tensor_d_output = \
736+
state_indices_tensor_d.gather(1,
737+
current_last_idx_d.unsqueeze(1)).squeeze(1)
738+
#Note:
739+
# for decode always: current_first_idx_d == current_last_idx_d
740+
# at block boundaries: current_first_idx_d > last_state_idx_d
741+
else:
742+
# Without caching, read and write in-place to the same blocks:
743+
state_indices_tensor_d_input = state_indices_tensor_d
744+
state_indices_tensor_d_output = state_indices_tensor_d
745+
653746
# 2. Convolution sequence transformation
654747
hidden_states_B_C_d = causal_conv1d_update(
655748
hidden_states_B_C_d,
656749
conv_state,
657750
conv_weights,
658751
self.conv1d.bias,
659752
self.activation,
660-
conv_state_indices=state_indices_tensor_d)
753+
conv_state_indices=state_indices_tensor_d,
754+
current_last_idx=current_last_idx_d,
755+
initial_state_idx=last_state_idx_d,
756+
)
661757

662758
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
663759
hidden_states_B_C_d)
@@ -689,7 +785,8 @@ def forward_cuda(
689785
z=None,
690786
dt_bias=dt_bias,
691787
dt_softplus=True,
692-
state_batch_indices=state_indices_tensor_d,
788+
state_batch_indices=state_indices_tensor_d_input,
789+
dst_state_batch_indices=state_indices_tensor_d_output,
693790
out=preallocated_ssm_out_d.view(num_decodes, -1,
694791
self.head_dim),
695792
)

0 commit comments

Comments
 (0)