-
-
Couldn't load subscription status.
- Fork 10.9k
[v1] Support mamba2 #19327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[v1] Support mamba2 #19327
Conversation
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @heheda12345, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello! gemini-code-assist here to provide a summary of this pull request. This PR introduces the initial support for Mamba2 models within vLLM's v1 architecture. The core idea is to adapt the v1 KV cache management and attention backend mechanisms to handle the state-space model (SSM) states used by Mamba2, rather than the traditional KV states used by transformers. Key changes include reusing the KVCacheManager with a new SingleTypeKVCacheManager specifically for Mamba states, wrapping Mamba-specific input preparation into a new attention backend, and adjusting the batch processing order for v1's persistent batching preference for decode requests.
Highlights
- Mamba2 Support in v1: Adds the foundational support for running Mamba2 models using the vLLM v1 architecture.
- KV Cache Management Adaptation: Reuses the existing v1
KVCacheManagerbut introduces a newMambaManager(aSingleTypeKVCacheManager) to handle the unique state requirements of Mamba2 models, replacing the separateMambaCacheManagerused in v0. - New Attention Backend: Implements a
Mamba2AttentionBackendto encapsulate the Mamba2-specific logic for preparing input metadata and handling the continuous batching process. - Batch Reordering for v1: Modifies the batch processing order within the new Mamba2 backend to prioritize decode requests before prefill requests, aligning with v1's persistent batching strategy.
- Known Limitations: Notes that prefix caching and speculative decoding are not yet supported for Mamba2 in this initial v1 implementation. Support is currently limited to Mamba2 only, not Mamba1 or Minimax.
Changelog
Click here to see the changelog
- vllm/model_executor/layers/mamba/mamba_mixer2.py
- Added imports for v1 environment variables (
envs) and configuration (get_current_vllm_config) (lines 9, 11). - Imported
Mamba2AttentionMetadatafrom the new v1 backend (line 32). - Added
prefixandchunk_sizeparameters to theMambaMixer2constructor for v1 integration (lines 247-248). - Stored
conv_kernel_sizeas an instance attribute (line 281). - Added v1-specific logic in
__init__to register the layer in the compilation context and initialize a placeholderkv_cache(lines 420-434). - Modified
forward_cudato retrieve metadata and KV cache states based on the v1 environment flag (envs.VLLM_USE_V1) (lines 451-481). - Added a special case in
forward_cudafor v1 profile runs whenattn_metadatais None (lines 514-522). - Adjusted the splitting order of prefill and decode tokens/metadata in
forward_cudabased on the v1 environment flag (lines 533-572). - Updated
forward_cudato use the retrievedconv_state,ssm_state,state_indices_tensor,has_initial_states_p,prep_initial_states,chunk_size,seq_idx_p,chunk_indices_p,chunk_offsets_p, andquery_start_loc_pinstead of directly accessingmamba_cache_paramsormamba2_metadata(lines 598, 614-615, 627, 631-634, 644, 654, 680). - Changed the order of appending prefill/decode outputs to
ssd_output_listfor v1 to put decode first (lines 693-701). - Added
get_state_shapemethod to compute the shapes of the convolution and temporal SSM states for KV cache allocation (lines 716-742).
- Added imports for v1 environment variables (
- vllm/model_executor/models/mamba2.py
- Added import for v1 environment variables (
envs) (line 11). - Removed
SupportsV0Onlyinterface fromMamba2ForCausalLMclass definition (line 200). - Added
prefixparameter toMamba2DecoderLayerconstructor (line 48). - Passed the layer
prefixand modelchunk_sizeto theMambaMixer2constructor (lines 65-66). - Passed the layer
prefixto theMamba2DecoderLayerconstructor withinmake_layers(line 115). - Modified the
forwardmethod ofMamba2ForCausalLMto conditionally initialize and useMambaCacheManagerbased on the v1 environment flag (envs.VLLM_USE_V1), makingmamba_cache_paramsoptional (lines 255-268). - Made
mamba_cache_paramsoptional in theMamba2DecoderLayerforward call (line 165).
- Added import for v1 environment variables (
- vllm/v1/attention/backends/mamba_attn.py
- Created a new file to define the Mamba2-specific attention backend for v1.
- Includes
get_mamba2_chunk_sizehelper function (lines 21-27). - Defines
Mamba2AttentionMetadataBuilderwhich handles batch reordering (decode before prefill) and buildsMamba2AttentionMetadata(lines 30-166). - Defines
Mamba2AttentionBackendwhich provides the builder class (lines 169-173). - Defines the
Mamba2AttentionMetadatadataclass to hold Mamba2-specific metadata for v1, including prefill/decode counts, token counts, query start locations, sequence lengths, initial state flags, chunking info, and state indices (lines 176-192).
- vllm/v1/core/single_type_kv_cache_manager.py
- Imported
MambaSpec(line 11). - Added
_null_blockattribute toSingleTypeKVCacheManager(line 55). - Added
MambaManagerclass, aSingleTypeKVCacheManagersubclass for Mamba (lines 396-433). - Implemented
find_longest_cache_hitinMambaManagerto always return empty lists, indicating no prefix caching support (lines 398-416). - Implemented
remove_skipped_blocksinMambaManageras a no-op, assuming one block per request (lines 418-422). - Implemented
get_num_common_prefix_blocksinMambaManagerto always return 0 (lines 424-426). - Overrode
allocate_new_blocksinMambaManagerto assert that only one block is allocated per request (lines 428-433). - Added
MambaSpec: MambaManagerto thespec_manager_map(line 439).
- Imported
- vllm/v1/kv_cache_interface.py
- Imported
prodfrommath(line 6). - Added
MambaSpecdataclass inheriting fromKVCacheSpec(lines 158-178). - The
MambaSpecdefines the shapes and dtype of Mamba states, calculates total elements and page size, and specifies memory usage (lines 160-178).
- Imported
- vllm/v1/worker/gpu_model_runner.py
- Added imports for
MambaMixer2,get_dtype_size,Mamba2AttentionBackend, andMambaSpec(lines 31, 41, 43, 47). - Modified
initialize_attn_backendto check forMambaSpecand useMamba2AttentionBackendaccordingly (lines 2025-2060). - Modified
_reshape_kv_cache_tensorsto handleMambaSpecby iterating through the defined shapes and viewing the raw tensor data into the correct state tensors (lines 2177-2192). - Modified
get_kv_cache_specto iterate throughMambaMixer2layers if present (lines 2303-2322). - Added checks in
get_kv_cache_specto raiseNotImplementedErrorfor Mamba if speculative decoding, cuda graph (unless enforce_eager), or prefix caching are enabled (lines 2306-2314). - For Mamba layers,
get_kv_cache_speccreates aMambaSpecusing the state shapes from the layer, the KV cache dtype, and sets the block size tomax_model_lento ensure one block per request (lines 2315-2322).
- Added imports for
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces initial support for Mamba2 in the v1 architecture, which is a significant and welcome addition. The changes are generally well-structured, with clear distinctions for v1-specific logic. The integration of Mamba states into the v1 KV cache and attention metadata framework using MambaSpec, MambaManager, and Mamba2AttentionMetadataBuilder appears sound.
A key design choice for Mamba v1 is the use of a single "block" (with block_size=max_model_len) per sequence for Mamba states (convolutional and SSM states). This simplifies the initial implementation by avoiding complex paging logic for these states, treating them more like monolithic allocations per sequence.
The PR also correctly adapts the prefill/decode ordering to match v1's preference (decode before prefill).
I've identified a couple of areas for potential improvement or clarification, detailed in the comments. The known limitations (no prefix caching, spec decode, full CUDA graph compatibility for Mamba) are well-documented in the PR description and code TODOs.
Overall, this is a solid foundation for Mamba2 support in v1. Good work!
Summary of Findings
- Clarity of
chunk_sizedefault and assertion: InMambaMixer2.__init__, thechunk_sizedefaults to -1 but an assertion later requires it to be non -1 ifVLLM_USE_V1is true. This could be confusing or lead to errors if not handled carefully by callers. - Code Duplication in Batch Reordering: The
reorder_batchmethod inMamba2AttentionMetadataBuilderis noted as being copied fromFlashInferMetadataBuilder. This presents a maintainability concern and should be refactored to a common utility or base class method. - Commented-out Debug Code: There are commented-out
printstatements invllm/model_executor/layers/mamba/mamba_mixer2.py(lines 581-592) which should ideally be removed before merging. (Severity: low, not commented inline due to settings)
Merge Readiness
This PR lays a strong foundation for Mamba2 support in v1. The identified issues are mostly of medium severity, focusing on potential confusion in parameter handling and code maintainability. Addressing these points would enhance the clarity and robustness of the implementation.
Given the current state and the nature of the feedback, I would recommend addressing the chunk_size default/assertion concern and considering the refactoring for duplicated code before merging. The PR description also notes that model accuracy evaluation and Coda graph support are pending, which are important for the overall feature completeness but might be out of scope for this initial PR.
As an AI reviewer, I am not authorized to approve pull requests. Please ensure that other human reviewers inspect and approve these changes before merging.
| use_rms_norm: bool = True, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| prefix: str = "", | ||
| chunk_size: int = -1, # the chunk size used by v1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value for chunk_size is -1, with a comment indicating it's "the chunk size used by v1". However, later in the __init__ method (line 430), there's an assertion assert chunk_size != -1, "chunk_size must be set for v1" specifically within the if envs.VLLM_USE_V1: block.
This seems a bit contradictory. If chunk_size must be set to a valid value (not -1) for v1, what is the intended behavior if envs.VLLM_USE_V1 is true and chunk_size is left as its default -1? Will this always lead to an assertion error during initialization in a v1 context?
Perhaps the default should be None, and the assertion could check for chunk_size is None or chunk_size == -1 to ensure it's explicitly set if v1 is active? Or is the expectation that callers will always provide a valid chunk_size when VLLM_USE_V1 is true?
| # NOTE (Chen): Copied from FlashInferMetadataBuilder. Should be | ||
| # refactored later to avoid code duplication. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment here acknowledges that the reorder_batch logic is copied from FlashInferMetadataBuilder and should be refactored. This is good to note.
To improve maintainability and reduce redundancy, could we consider creating a utility function or a base class method for this batch reordering logic if it's indeed common across multiple attention metadata builders? This would help avoid potential inconsistencies if this logic needs to be updated in the future.
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
|
|
||
| def reorder_batch(self, input_batch: "InputBatch", | ||
| scheduler_output: "SchedulerOutput") -> bool: | ||
| # NOTE (Chen): Copied from FlashInferMetadataBuilder. Should be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: FlashInferMetadataBuilder actually copied this from MLACommonMetadataBuilder so we should probably add that to the comment too
| if not self.vllm_config.model_config.enforce_eager: | ||
| raise NotImplementedError( | ||
| "Mamba with cuda graph is not supported yet.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we need to wrap MambaMixer2 in a custom op and then add it to splitting_ops here
Lines 4169 to 4173 in c1c7dbb
| if not self.splitting_ops: | |
| self.splitting_ops = [] if self.full_cuda_graph else [ | |
| "vllm.unified_attention", | |
| "vllm.unified_attention_with_output", | |
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes but my concern is more about the performance. It seems that we shouldn't put too many logic into the non-cuda-graph region.
https://github.com/vllm-project/vllm/blame/c1c7dbbeeb6d4f0155d25b673f2063bfb14b37b9/vllm/attention/layer.py#L218-L219
Do you prefer to have a naive cuda graph support with performance problem or just leave it as a future work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, that's going to be a problem.
I think we should get partial CUDA graphs working for Mamba models first and then iterate on that base to reduce overheads. I'm OK landing this PR without CUDA graph support as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to move the code to a custom op in a separate PR to make the edit history more easy-to-follow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had one comment on the chunk_size where I think we can simplify things, but otherwise LGTM!
| def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int: | ||
| from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 | ||
| layers = get_layers_from_vllm_config(vllm_config, MambaMixer2) | ||
| chunk_sizes = set(layer.chunk_size for layer in layers.values()) | ||
| assert len( | ||
| chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size" | ||
| return chunk_sizes.pop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it will greatly simplify things if we get the chunk_size from the model config.
I've noticed a lot of places where we have to handle it, but chunk_size should just be a constant value, see here: https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1/blob/main/config.json#L7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But seems that chunk size is saved with different keys for different models. For example, this model uses mamba_chunk_size. Is there a unified key that I can use?
vllm/vllm/model_executor/models/falcon_h1.py
Line 456 in d70bc7c
| chunk_size=self.config.mamba_chunk_size, |
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tlrmchlsmth I've updated the PR except the chunk_size problem and cuda graph problem. Help wanted on simplifying chunk_size. And I've updated the ssm test to test v1 implementation.
| if not self.vllm_config.model_config.enforce_eager: | ||
| raise NotImplementedError( | ||
| "Mamba with cuda graph is not supported yet.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to move the code to a custom op in a separate PR to make the edit history more easy-to-follow.
Sounds good to me, let's land it and then we can iterate |
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
|
Thanks for your hard work! Can you update the V1 guide with the latest status? |
|
I think it is still work in progress as mamba1 is not supported yet. |
|
Hi @heheda12345 I see that we have updated spec_decode under v1, I don't know if spec_decode in v1 will work with Mamba normally or it will need extra update |
|
For prefix caching, the major concern is that we should only cache the mamba state of some specific tokens as the mamba state per token is too large. The exact strategy is still under discussion but I think it is solveable. |
Essential Elements of an Effective PR Description Checklist
Purpose
This PR adds the initial support for mamba2 in v1. Difference with v0:
Known limitations:
v1 mamba support RFC: #17140
Test Plan
HF_ALLOW_CODE_EVAL=1 lm_eval --model vllm \ --model_args pretrained=mistralai/Mamba-Codestral-7B-v0.1,enforce_eager=True,enable_prefix_caching=False \ --tasks humaneval \ --device cuda:0 \ --batch_size auto \ --confirm_run_unsafe_codeTest Result