-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B #15423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B #15423
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
|
Does this change only affect the case of chunked prefill? |
yes pretty much |
0048072 to
4c672ff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice, and straightforward optimization
Are there any e2e speedup results?
|
@cyang49 @tlrmchlsmth the improvement should be something like: after: |
vllm/model_executor/models/bamba.py
Outdated
| # metadata, we simply just compute redundently and | ||
| # will be silently ignored inside the mamba kernels. | ||
| # if not needed. | ||
| chunk_indices, chunk_offsets = seq_idx_to_chunk_indices_offsets( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much slowdown does it introduce to compute this when not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1.5ms on my machine. This part can be vectorized and optimized further
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the metadata for making the determination can also be exposed to this level - @fabianlim is that right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the check you need here: https://github.com/cyang49/vllm/blob/9b3705b3e48eb55dcc57d7dba0e288fe108e81bd/vllm/model_executor/layers/mamba/mamba_mixer2.py#L417 ? If so, you already have attn_metadata here and you can check whether to compute chunk_indices/offsets.
Would be also nice, if we could abstract this logic away, so that we don't need to copy-paste the same code across all mamba2-related models. I like @tlrmchlsmth's suggestion to put seq_idx, chunk_indices/offsets into a dataclass, so that we can pass a single object instead of three. We could also convert the whole block in lines 319-342 to a method in a separate file, that all mamba2 models can import. We can create a separate mamba2_utils.py or put in the existing file, not sure which though, @tlrmchlsmth any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably putting in mamba_mixer2.py would be enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I see @tlrmchlsmth actually suggested the same thing as me :) . I think 1 is what we've been suggesting (metadata with helper method in mamba_mixer.py).
I'm not that familiar with chunked prefill changes in mamba2 kernels, but isn't the need to compute chunked_indices/chunked_offsets being determined by this condition here? And initial_states are being computed here, and you only need attn_metadata to make this determination.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes correct, but the point is that the chunked_indices and chunked_offsets only need to be computed once per model forward, but If I make the determination inside the layer, we need logic to detect if this was the first mixer layer in the model.. which can be involved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yury-tokpanov i just had a call with @tlrmchlsmth , eventually we will have a metadata in the mixer, but that will take quite a bit of work. So for now we will settle for an intemediate solution with metadata in the model forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metadata class and helper method definitions in the mixer, while the actual call is in the forward pass of the model? That sounds good to me, I was just suggesting to move the code that is repeated between the models to the mixer.
Metadata object should live in the model and it would simply encapsulate seq_idx, chunked_indices and chunked_offsets, so that we don't need to pass three things everywhere, and the call to compute that metadata object would be a one-liner in all the models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added Mamba2Metadata for this
9b3705b to
5dd6b23
Compare
b666bd4 to
ace2d4a
Compare
e1d12ba to
142d063
Compare
|
@tlrmchlsmth @yury-tokpanov I refactored the metadata logic and it affects other models using mamba2 mixer. Please suggest how I should test them. I searched the code and see 2 models, Zamba and Mamba2 |
Those should be the only ones. Could you run a gsm8k eval on |
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
also causes gpu-cpu sync This reverts commit d8df2b23b65fe57d109b69cc7c49e7f3c031b15a. Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
4ac99e5 to
4382192
Compare
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
gsm8k Evaluation Resultsibm-ai-platform/Bamba-9BMain (ce8d6b7) PR mistralai/Mamba-Codestral-7B-v0.1Main (ce8d6b7) PR Zyphra/Zamba2-2.7BMain (ce8d6b7) PR |
benchmark_servingTested on H100-80GB HBM3 ibm-ai-platform/Bamba-9BMain (ce8d6b7) PR Observed 20% overall throughput improvement. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR has some really nice cleanup in it now, thank you for that.
LGTM!
| from vllm.attention.backends.flash_attn import FlashAttentionMetadata | ||
| from vllm.attention.backends.placeholder_attn import ( | ||
| PlaceholderAttentionMetadata) | ||
| from vllm.attention.backends.xformers import XFormersMetadata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great job getting rid of these imports
…llm-project#15423) Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
…llm-project#15423) Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Signed-off-by: Yang Wang <elainewy@meta.com>
…llm-project#15423) Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
…llm-project#15423) Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
…llm-project#15423) Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>

Observed redundant computations for common metadata across mamba2 layers. This PR attempts to reduce them as much as possible. It will be done incrementally through several smaller commits.
Original: 1.5ms redundant computations per mamba2 layer

With this PR: single 1.5ms computation per model forward

cc @fabianlim @tlrmchlsmth @tdoublep @yury-tokpanov