Skip to content

Conversation

@chanh
Copy link
Contributor

@chanh chanh commented Apr 10, 2025

This code inside apply_penalties does advanced indexing on a tensor which triggers nonzero which requires a CPU sync currently with PyTorch.

With torch.cuda.set_sync_debug_mode("warn") PyTorch framework confirms this:

/home/coder/vllm/venv/lib/python3.10/site-packages/torch/cuda/__init__.py:1067: UserWarning: Synchronization debug mode is a prototype feature and does not yet detect all synchronizing operations (Triggered internally at /pytorch/torch/csrc/cuda/Module.cpp:915.)
  torch._C._cuda_set_sync_debug_mode(debug_mode)
/home/coder/vllm/vllm/model_executor/layers/utils.py:52: UserWarning: called a synchronizing CUDA operation (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:152.)
  logits[logits > 0] /= torch.where(prompt_mask | output_mask,
/home/coder/vllm/vllm/model_executor/layers/utils.py:54: UserWarning: called a synchronizing CUDA operation (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:152.)
  logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
/home/coder/vllm/vllm/v1/worker/gpu_model_runner.py:1153: UserWarning: called a synchronizing CUDA operation (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:152.)
  valid_sampled_token_ids = sampled_token_ids.tolist()

This seems to be a known issue and was encountered here: pytorch/pytorch#12461

nonzero that is called in this conversion has a legitimate synchronization - it is necessary to pass the information from the device about how many non-zero elements were found in the boolean index tensor, as this information would be later required on the cpu, to resize the index tensor, and to configure launch parameters/kernel arguments for subsequent kernels. I'm not sure this sync can be avoided, because if mask comes as a result of an operation on the GPU, CPU has no way of getting the number of nonzeros in the mask, which is objectively needed.

By refactoring the code to avoid the indexing, we can remove the sync and allow much more of the sampling phase CPU work to overlap with the forward pass on the GPU, providing an 8% speedup to decoding for smaller models.

Before:
Screenshot 2025-04-10 at 11 50 24 AM

============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  103.22    
Total input tokens:                      100000    
Total generated tokens:                  10000     
Request throughput (req/s):              0.97      
Output token throughput (tok/s):         96.88     
Total Token throughput (tok/s):          1065.73   
---------------Time to First Token----------------
Mean TTFT (ms):                          37.21     
Median TTFT (ms):                        32.09     
P99 TTFT (ms):                           71.54     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.74      
Median TPOT (ms):                        6.67      
P99 TPOT (ms):                           7.20      
---------------Inter-token Latency----------------
Mean ITL (ms):                           6.74      
Median ITL (ms):                         6.69      
P99 ITL (ms):                            7.93      
==================================================

After:
Screenshot 2025-04-10 at 11 50 38 AM

============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  103.17    
Total input tokens:                      100000    
Total generated tokens:                  10000     
Request throughput (req/s):              0.97      
Output token throughput (tok/s):         96.93     
Total Token throughput (tok/s):          1066.19   
---------------Time to First Token----------------
Mean TTFT (ms):                          35.62     
Median TTFT (ms):                        30.71     
P99 TTFT (ms):                           60.89     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.18      
Median TPOT (ms):                        6.11      
P99 TPOT (ms):                           6.50      
---------------Inter-token Latency----------------
Mean ITL (ms):                           6.18      
Median ITL (ms):                         6.12      
P99 ITL (ms):                            7.43      
==================================================

Benchmark:

VLLM_FLASH_ATTN_VERSION=3 VLLM_USE_V1=1 vllm serve Qwen/Qwen2.5-1.5B-Instruct --enable-prefix-caching --dtype float16 --disable-log-requests -O3

vllm bench serve \
        --model Qwen/Qwen2.5-1.5B-Instruct \
        --request-rate 1 \
        --num-prompts 100 \
        --random-input-len 1000 \
        --random-output-len 100 \
        --tokenizer Qwen/Qwen2.5-1.5B-Instruct \
        --ignore-eos

Chanh Nguyen added 2 commits April 10, 2025 21:17
Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com>
Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@chanh chanh marked this pull request as ready for review April 11, 2025 04:12
@chanh chanh changed the title Speed up decode by remove synchronizing operation in sampler [Core] Speed up decode by remove synchronizing operation in sampler Apr 18, 2025
@WoosukKwon WoosukKwon self-assigned this Apr 21, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chanh Sorry for the late review. This is really great! Nice optimization!

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 21, 2025
@WoosukKwon WoosukKwon enabled auto-merge (squash) April 21, 2025 16:28
@WoosukKwon WoosukKwon merged commit 299ebb6 into vllm-project:main Apr 21, 2025
61 checks passed
frieda-huang pushed a commit to frieda-huang/vllm that referenced this pull request Apr 23, 2025
…llm-project#16436)

Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com>
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
Signed-off-by: Frieda (Jingying) Huang <jingyingfhuang@gmail.com>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
…llm-project#16436)

Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com>
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
…llm-project#16436)

Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com>
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
…llm-project#16436)

Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com>
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…llm-project#16436)

Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com>
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants