Skip to content

Conversation

@xyang16
Copy link
Contributor

@xyang16 xyang16 commented Nov 7, 2025

Purpose

Currently running the benchmark_rope.py script failed with the error (error pasted below). This PR fixes the benchmark script:

Error:

Traceback (most recent call last):
  File "/root/workspace/vllm-project/vllm/benchmarks/kernels/benchmark_rope.py", line 192, in <module>
    benchmark_rope_kernels_multi_lora(
  File "/root/workspace/vllm-project/vllm/benchmarks/kernels/benchmark_rope.py", line 96, in benchmark_rope_kernels_multi_lora
    r.forward(positions, q, k)
  File "/root/workspace/vllm-project/vllm/vllm/model_executor/custom_op.py", line 46, in forward
    return self._forward_method(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding/base.py", line 165, in forward_cuda
    ops.rotary_embedding(
  File "/root/workspace/vllm-project/vllm/vllm/_custom_ops.py", line 322, in rotary_embedding
    torch.ops._C.rotary_embedding(
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1255, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_device.py", line 103, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1255, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: query, key and positions must have the same batch_size and seq_len

Test Plan

python3 benchmarks/kernels/benchmark_rope.py

Test Result

Benchmark script is able to generate benchmark result after the change.

rope-perf-neox-style:
    batch_size  seq_len  num_heads      PyTorch  FlashInfer        vLLM
0          1.0     64.0       32.0   238.848001   22.175999   12.352000
1          1.0     64.0       48.0   237.103999   19.743999   13.696000
2          1.0    128.0       32.0   234.863997   20.959999   15.552000
3          1.0    128.0       48.0   235.328004   23.232000   18.048000
4          1.0    256.0       32.0   235.344000   26.432000   16.672000
5          1.0    256.0       48.0   235.136002   31.392001   20.800000
6          1.0    512.0       32.0   234.959997   37.184000   21.984000
7          1.0    512.0       48.0   233.424000   53.183999   25.184000
8          4.0     64.0       32.0   235.551998   20.256000   16.736001
9          4.0     64.0       48.0   235.487998   21.088000   18.112000
10         4.0    128.0       32.0   234.847993   24.320001   19.296000
11         4.0    128.0       48.0   234.528005   31.199999   25.184000
12         4.0    256.0       32.0   233.903997   33.280000   27.424000
13         4.0    256.0       48.0   235.679999   44.032000   35.071999
14         4.0    512.0       32.0   240.656003   55.712000   43.423999
15         4.0    512.0       48.0   245.503999   77.280000   59.567999
16        16.0     64.0       32.0   234.608002   29.696001   28.767999
17        16.0     64.0       48.0   235.023998   38.816001   35.071999
18        16.0    128.0       32.0   242.160000   50.207999   43.391999
19        16.0    128.0       48.0   245.951995   71.135998   60.959999
20        16.0    256.0       32.0   317.375988   90.623997   78.592002
21        16.0    256.0       48.0   437.824011  129.280001  109.327998
22        16.0    512.0       32.0   572.303981  168.479994  140.223995
23        16.0    512.0       48.0   821.743995  244.128004  200.800002
24        64.0     64.0       32.0   314.943999   92.352003   77.280000
25        64.0     64.0       48.0   440.064013  133.407995  108.239997
26        64.0    128.0       32.0   568.607986  169.504002  143.232003
27        64.0    128.0       48.0   819.648027  247.263998  202.335998
28        64.0    256.0       32.0  1069.247961  324.544013  265.248001
29        64.0    256.0       48.0  1547.855973  480.832011  384.992003
30        64.0    512.0       32.0  2040.031910  631.232023  514.464021
31        64.0    512.0       48.0  3005.200028  938.271999  755.136013

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the performance Performance-related issues label Nov 7, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully fixes the failing benchmark_rope.py script by refactoring it to use Triton's benchmarking utilities and removing the obsolete batched RoPE benchmark. The new implementation is much cleaner and more robust. My review includes a few suggestions to improve the script's usability and reproducibility by correctly handling command-line arguments for data type and seed, and by removing other arguments that are no longer used.

@xyang16 xyang16 force-pushed the bench branch 5 times, most recently from 0d74316 to b875c59 Compare November 7, 2025 21:02
Signed-off-by: Xin Yang <xyangx@amazon.com>
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 10, 2025
@ProExpertProg ProExpertProg merged commit 57201a6 into vllm-project:main Nov 11, 2025
23 checks passed
@xyang16 xyang16 deleted the bench branch November 12, 2025 19:14
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Nov 13, 2025
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants