Skip to content

Conversation

@cyang49
Copy link
Contributor

@cyang49 cyang49 commented Mar 25, 2025

Observed redundant computations for common metadata across mamba2 layers. This PR attempts to reduce them as much as possible. It will be done incrementally through several smaller commits.

Original: 1.5ms redundant computations per mamba2 layer
image

With this PR: single 1.5ms computation per model forward
image

cc @fabianlim @tlrmchlsmth @tdoublep @yury-tokpanov

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@yury-tokpanov
Copy link
Contributor

Does this change only affect the case of chunked prefill?

@fabianlim
Copy link
Contributor

fabianlim commented Mar 25, 2025

Does this change only affect the case of chunked prefill?

yes pretty much

@cyang49 cyang49 force-pushed the pr_mamba_opt_reuse_across_layers branch from 0048072 to 4c672ff Compare March 25, 2025 14:00
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice, and straightforward optimization

Are there any e2e speedup results?

@fabianlim
Copy link
Contributor

@cyang49 @tlrmchlsmth the improvement should be something like:

before:
image

after:

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  183.77
Total input tokens:                      215196
Total generated tokens:                  193675
Request throughput (req/s):              5.44
Output token throughput (tok/s):         1053.87
Total Token throughput (tok/s):          2224.85
---------------Time to First Token----------------
Mean TTFT (ms):                          66875.64
Median TTFT (ms):                        56899.11
P99 TTFT (ms):                           166865.34
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          263.00
Median TPOT (ms):                        280.23
P99 TPOT (ms):                           402.55
---------------Inter-token Latency----------------
Mean ITL (ms):                           218.30
Median ITL (ms):                         370.27
P99 ITL (ms):                            438.36
==================================================

# metadata, we simply just compute redundently and
# will be silently ignored inside the mamba kernels.
# if not needed.
chunk_indices, chunk_offsets = seq_idx_to_chunk_indices_offsets(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much slowdown does it introduce to compute this when not needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.5ms on my machine. This part can be vectorized and optimized further

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the metadata for making the determination can also be exposed to this level - @fabianlim is that right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the check you need here: https://github.com/cyang49/vllm/blob/9b3705b3e48eb55dcc57d7dba0e288fe108e81bd/vllm/model_executor/layers/mamba/mamba_mixer2.py#L417 ? If so, you already have attn_metadata here and you can check whether to compute chunk_indices/offsets.

Would be also nice, if we could abstract this logic away, so that we don't need to copy-paste the same code across all mamba2-related models. I like @tlrmchlsmth's suggestion to put seq_idx, chunk_indices/offsets into a dataclass, so that we can pass a single object instead of three. We could also convert the whole block in lines 319-342 to a method in a separate file, that all mamba2 models can import. We can create a separate mamba2_utils.py or put in the existing file, not sure which though, @tlrmchlsmth any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably putting in mamba_mixer2.py would be enough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see @tlrmchlsmth actually suggested the same thing as me :) . I think 1 is what we've been suggesting (metadata with helper method in mamba_mixer.py).

I'm not that familiar with chunked prefill changes in mamba2 kernels, but isn't the need to compute chunked_indices/chunked_offsets being determined by this condition here? And initial_states are being computed here, and you only need attn_metadata to make this determination.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes correct, but the point is that the chunked_indices and chunked_offsets only need to be computed once per model forward, but If I make the determination inside the layer, we need logic to detect if this was the first mixer layer in the model.. which can be involved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yury-tokpanov i just had a call with @tlrmchlsmth , eventually we will have a metadata in the mixer, but that will take quite a bit of work. So for now we will settle for an intemediate solution with metadata in the model forward

Copy link
Contributor

@yury-tokpanov yury-tokpanov Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metadata class and helper method definitions in the mixer, while the actual call is in the forward pass of the model? That sounds good to me, I was just suggesting to move the code that is repeated between the models to the mixer.

Metadata object should live in the model and it would simply encapsulate seq_idx, chunked_indices and chunked_offsets, so that we don't need to pass three things everywhere, and the call to compute that metadata object would be a one-liner in all the models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added Mamba2Metadata for this

@cyang49 cyang49 force-pushed the pr_mamba_opt_reuse_across_layers branch from 9b3705b to 5dd6b23 Compare March 29, 2025 16:09
@cyang49 cyang49 force-pushed the pr_mamba_opt_reuse_across_layers branch 2 times, most recently from b666bd4 to ace2d4a Compare April 7, 2025 20:10
@cyang49 cyang49 force-pushed the pr_mamba_opt_reuse_across_layers branch from e1d12ba to 142d063 Compare April 9, 2025 21:52
@cyang49
Copy link
Contributor Author

cyang49 commented Apr 9, 2025

@tlrmchlsmth @yury-tokpanov I refactored the metadata logic and it affects other models using mamba2 mixer. Please suggest how I should test them. I searched the code and see 2 models, Zamba and Mamba2

@tlrmchlsmth
Copy link
Member

@tlrmchlsmth @yury-tokpanov I refactored the metadata logic and it affects other models using mamba2 mixer. Please suggest how I should test them. I searched the code and see 2 models, Zamba and Mamba2

Those should be the only ones. Could you run a gsm8k eval on Zyphra/Zamba2-2.7B and mistralai/Mamba-Codestral-7B-v0.1 and to make sure this doesn't affect performance?

cyang49 and others added 11 commits April 10, 2025 09:05
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
also causes gpu-cpu sync

This reverts commit d8df2b23b65fe57d109b69cc7c49e7f3c031b15a.

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
@cyang49 cyang49 force-pushed the pr_mamba_opt_reuse_across_layers branch from 4ac99e5 to 4382192 Compare April 10, 2025 14:13
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
@cyang49
Copy link
Contributor Author

cyang49 commented Apr 10, 2025

gsm8k Evaluation Results

ibm-ai-platform/Bamba-9B

lm_eval --model vllm     --model_args pretrained=ibm-ai-platform/Bamba-9B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k

Main (ce8d6b7)

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2487|±  |0.0119|
|     |       |strict-match    |     5|exact_match|↑  |0.3563|±  |0.0132|

PR

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2487|±  |0.0119|
|     |       |strict-match    |     5|exact_match|↑  |0.3563|±  |0.0132|

mistralai/Mamba-Codestral-7B-v0.1

lm_eval --model vllm     --model_args pretrained=mistralai/Mamba-Codestral-7B-v0.1,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k

Main (ce8d6b7)

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4761|±  |0.0138|
|     |       |strict-match    |     5|exact_match|↑  |0.4632|±  |0.0137|

PR

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4761|±  |0.0138|
|     |       |strict-match    |     5|exact_match|↑  |0.4632|±  |0.0137|

Zyphra/Zamba2-2.7B

lm_eval --model vllm     --model_args pretrained=Zyphra/Zamba2-2.7B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k

Main (ce8d6b7)

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5292|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.5436|±  |0.0137|

PR

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5292|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.5436|±  |0.0137|

@cyang49
Copy link
Contributor Author

cyang49 commented Apr 10, 2025

benchmark_serving

Tested on H100-80GB HBM3

ibm-ai-platform/Bamba-9B

# Server (default configs)
vllm serve ibm-ai-platform/Bamba-9B --port 9999
# Client
python benchmarks/benchmark_serving.py --model ibm-ai-platform/Bamba-9B  --dataset-name sharegpt     --dataset-path ShareGPT_V3/ShareGPT_V3_unfiltered_cleaned_split.json --ignore-eos --port 9999

Main (ce8d6b7)

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  234.74    
Total input tokens:                      215201    
Total generated tokens:                  198343    
Request throughput (req/s):              4.26      
Output token throughput (tok/s):         844.96    
Total Token throughput (tok/s):          1761.74   
---------------Time to First Token----------------
Mean TTFT (ms):                          88342.50  
Median TTFT (ms):                        78434.12  
P99 TTFT (ms):                           219941.13 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          325.54    
Median TPOT (ms):                        342.52    
P99 TPOT (ms):                           508.29    
---------------Inter-token Latency----------------
Mean ITL (ms):                           280.50    
Median ITL (ms):                         482.95    
P99 ITL (ms):                            561.78    
==================================================

PR

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  195.15    
Total input tokens:                      215201    
Total generated tokens:                  198343    
Request throughput (req/s):              5.12      
Output token throughput (tok/s):         1016.37   
Total Token throughput (tok/s):          2119.14   
---------------Time to First Token----------------
Mean TTFT (ms):                          73498.49  
Median TTFT (ms):                        66722.67  
P99 TTFT (ms):                           180412.11 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          268.64    
Median TPOT (ms):                        281.08    
P99 TPOT (ms):                           426.19    
---------------Inter-token Latency----------------
Mean ITL (ms):                           230.79    
Median ITL (ms):                         345.81    
P99 ITL (ms):                            459.38    
==================================================

Observed 20% overall throughput improvement.

@cyang49 cyang49 marked this pull request as ready for review April 10, 2025 16:32
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR has some really nice cleanup in it now, thank you for that.

LGTM!

Comment on lines -9 to -12
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.xformers import XFormersMetadata
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great job getting rid of these imports

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 10, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) April 10, 2025 16:42
@tlrmchlsmth tlrmchlsmth merged commit daefed0 into vllm-project:main Apr 10, 2025
63 checks passed
p88h pushed a commit to p88h/vllm that referenced this pull request Apr 10, 2025
…llm-project#15423)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
…llm-project#15423)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yang Wang <elainewy@meta.com>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
…llm-project#15423)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
…llm-project#15423)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…llm-project#15423)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants