Skip to content

Conversation

@cyang49
Copy link
Contributor

@cyang49 cyang49 commented Jul 16, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR uses preallocated output tensor for SSM output both from decode and prefill paths, instead of allocating individual tensors and then concatenating with torch.vstack. We observed that the original approach causes unnecessary D2D copy.

image

Test Plan

  • Testing with benchmark_serving.py and observe the throughput change. Ideally a slight improvement should be observed
  • Testing with lm_eval to make sure output is still correct

Test Result

Experiments were done on single H100-80GB.

benchmark_serving.py

# server
vllm serve ibm-ai-platform/Bamba-9B-v2 --port 9998
# client
python benchmarks/benchmark_serving.py --model ibm-ai-platform/Bamba-9B-v2 --backend vllm  --dataset-name sharegpt     --dataset-path /net/storage149/mnt/md0/ccyang/github.com/ShareGPT_V3/ShareGPT_V3_unfiltered_cleaned_split.json --ignore-eos --port 9998

Before (#1c3198b)

============ Serving Benchmark Result ============
Successful requests:                     983       
Benchmark duration (s):                  44.69     
Total input tokens:                      209731    
Total generated tokens:                  195084    
Request throughput (req/s):              22.00     
Output token throughput (tok/s):         4365.18   
Total Token throughput (tok/s):          9058.10 

After

============ Serving Benchmark Result ============
Successful requests:                     983       
Benchmark duration (s):                  44.01     
Total input tokens:                      209731    
Total generated tokens:                  195084    
Request throughput (req/s):              22.34     
Output token throughput (tok/s):         4432.88   
Total Token throughput (tok/s):          9198.58   

No performance degradation.

lm_eval

# Command
lm_eval --model vllm  --model_args pretrained=ibm-ai-platform/Bamba-9B-v2,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k

Before (#1c3198b)

vllm (pretrained=ibm-ai-platform/Bamba-9B-v2,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4162|±  |0.0136|
|     |       |strict-match    |     5|exact_match|↑  |0.4132|±  |0.0136|

After

vllm (pretrained=ibm-ai-platform/Bamba-9B-v2,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4162|±  |0.0136|
|     |       |strict-match    |     5|exact_match|↑  |0.4132|±  |0.0136|

(Optional) Documentation Update

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by pre-allocating the SSM output tensor, which avoids an unnecessary device-to-device copy. The approach is sound and the changes are well-contained. I've identified one critical issue related to tensor sharding that would cause an assertion failure when using tensor parallelism. Addressing this should make the implementation robust.

@cyang49 cyang49 marked this pull request as ready for review July 16, 2025 20:19
@cyang49 cyang49 changed the title [Model] preallocate SSM output tensor to avoid d2d copy overhead [Model] Mamba2 preallocate SSM output tensor to avoid d2d copy overhead Jul 16, 2025
@DarkLight1337
Copy link
Member

cc @tlrmchlsmth @tdoublep

@mergify
Copy link

mergify bot commented Jul 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @cyang49.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 21, 2025
@cyang49 cyang49 force-pushed the pr_mamba2_vstack branch from f9ab16e to 5f73b79 Compare July 21, 2025 14:51
@mergify mergify bot removed the needs-rebase label Jul 21, 2025
@cyang49 cyang49 force-pushed the pr_mamba2_vstack branch 4 times, most recently from 875c81f to 3873218 Compare July 23, 2025 15:09
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a reasonable optimization.

My main comment is that this leaves the interface to the mamba_ssm functions more complicated than they were before. Now they support both in-place updating and out-of-place allocation of the outputs. And we need to handle those two cases in a few different places.

Could we change it to always be in-place instead?

@cyang49
Copy link
Contributor Author

cyang49 commented Jul 30, 2025

This looks like a reasonable optimization.

My main comment is that this leaves the interface to the mamba_ssm functions more complicated than they were before. Now they support both in-place updating and out-of-place allocation of the outputs. And we need to handle those two cases in a few different places.

Could we change it to always be in-place instead?

I think I kept the original logic as a fall back, but you're right, we can remove them. I will push a simplified version if it is safe to remove.

@cyang49
Copy link
Contributor Author

cyang49 commented Jul 30, 2025

@tlrmchlsmth
There are two other uses in plamo2.py and phi4flash.py
If I make the kernel only support in-place update, they will need to be changed too.

  • plamo2 has similar logic as mamba_mixer2, so it should work after applying similar changes
  • phi4flash looks quite different, though.

@cyang49
Copy link
Contributor Author

cyang49 commented Jul 31, 2025

I tried to run both plamo2 and phi4flash on main (not the PR branch) and they both failed to run.
I think for now we should keep the out-of-place allocation for compatibility, because I can't check the correctness if we keep only the in-place update path.

@cyang49 cyang49 force-pushed the pr_mamba2_vstack branch from 3873218 to b165a18 Compare July 31, 2025 16:50
@cyang49 cyang49 requested a review from WoosukKwon as a code owner July 31, 2025 16:50
@cyang49
Copy link
Contributor Author

cyang49 commented Jul 31, 2025

Fixed models that calls the affected kernels

plamo2

lm_eval --model vllm  --model_args pretrained=pfnet/plamo
-2.1-2b-cpt,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,max_model_len=8192 --batch_size auto --trust_remote_code  --cache_re
quests true --tasks gsm8k
vllm (pretrained=pfnet/plamo-2.1-2b-cpt,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,max_model_len=8192,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5982|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.5951|±  |0.0135|

phi4flash

VLLM_ATTENTION_BACKEND=DIFFERENTIAL_FLASH_ATTN lm_eval --model vllm  --model_args pretrained=microsoft/Phi-4-mini-flash-reasoning,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,enable_prefix_caching=False,enable_chunked_prefill=False,max_model_len=8192 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k
vllm (pretrained=microsoft/Phi-4-mini-flash-reasoning,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,enable_prefix_caching=False,enable_chunked_prefill=False,max_model_len=8192,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5239|±  |0.0138|
|     |       |strict-match    |     5|exact_match|↑  |0.4837|±  |0.0138|

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 31, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) July 31, 2025 19:34
auto-merge was automatically disabled August 1, 2025 18:13

Head branch was pushed to by a user without write access

cyang49 added 5 commits August 1, 2025 21:13
Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
@vllm-bot vllm-bot merged commit b690e34 into vllm-project:main Aug 2, 2025
39 of 45 checks passed
@cyang49 cyang49 deleted the pr_mamba2_vstack branch August 4, 2025 11:53
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
…ad (vllm-project#21075)

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID,
preallocated_ssm_out=None):
out=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think out needs to be a required argument now, because it is not allocated within the function anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Will address this in an upcoming PR

jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…ad (vllm-project#21075)

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
…ad (vllm-project#21075)

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Noam Gat <noamgat@gmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…ad (vllm-project#21075)

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…ad (vllm-project#21075)

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
…ad (vllm-project#21075)

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
…ad (vllm-project#21075)

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants