Skip to content

Commit b875c59

Browse files
committed
Fix rotary embedding benchmark script
Signed-off-by: Xin Yang <xyangx@amazon.com>
1 parent da786e3 commit b875c59

File tree

1 file changed

+64
-90
lines changed

1 file changed

+64
-90
lines changed
Lines changed: 64 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,76 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from itertools import accumulate
4+
import itertools
55

6-
import nvtx
76
import torch
87

9-
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope
10-
from vllm.platforms import current_platform
8+
from vllm.model_executor.layers.rotary_embedding import get_rope
9+
from vllm.triton_utils import triton
1110
from vllm.utils.argparse_utils import FlexibleArgumentParser
1211

12+
batch_size_range = [2**i for i in range(0, 8, 2)]
13+
seq_len_range = [2**i for i in range(6, 10, 1)]
14+
num_heads_range = [32, 48]
15+
configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range))
1316

14-
def benchmark_rope_kernels_multi_lora(
15-
is_neox_style: bool,
16-
batch_size: int,
17-
seq_len: int,
18-
num_heads: int,
19-
head_size: int,
20-
rotary_dim: int | None,
21-
dtype: torch.dtype,
22-
seed: int,
23-
device: str,
24-
max_position: int = 8192,
25-
base: float = 10000,
26-
) -> None:
27-
current_platform.seed_everything(seed)
28-
torch.set_default_device(device)
29-
if rotary_dim is None:
30-
rotary_dim = head_size
31-
# silulating serving 4 LoRAs
32-
scaling_factors = [1, 2, 4, 8]
33-
# batched RoPE can take multiple scaling factors
34-
batched_rope = get_rope(
35-
head_size,
36-
rotary_dim,
37-
max_position,
38-
base,
39-
is_neox_style,
40-
{"rope_type": "linear", "factor": tuple(scaling_factors)},
17+
18+
def get_benchmark(is_neox_style, head_size, rotary_dim, device):
19+
@triton.testing.perf_report(
20+
triton.testing.Benchmark(
21+
x_names=["batch_size", "seq_len", "num_heads"],
22+
x_vals=[list(_) for _ in configs],
23+
line_arg="provider",
24+
line_vals=["torch", "flashinfer", "vllm"],
25+
line_names=["PyTorch", "FlashInfer", "vLLM"],
26+
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
27+
ylabel="us",
28+
plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}",
29+
args={},
30+
)
4131
)
42-
# non-batched RoPE takes only one scaling factor, we create multiple
43-
# instances to simulate the same behavior
44-
non_batched_ropes: list[RotaryEmbedding] = []
45-
for scaling_factor in scaling_factors:
46-
non_batched_ropes.append(
47-
get_rope(
48-
head_size,
49-
rotary_dim,
50-
max_position,
51-
base,
52-
is_neox_style,
53-
{"rope_type": "linear", "factor": (scaling_factor,)},
54-
)
32+
def benchmark(batch_size, seq_len, num_heads, provider):
33+
dtype = torch.bfloat16
34+
max_position = 8192
35+
base = 10000
36+
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
37+
rope = rope.to(dtype=dtype, device=device)
38+
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
39+
40+
positions = torch.randint(0, max_position, (batch_size, seq_len), device=device)
41+
query = torch.randn(
42+
(batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device
5543
)
44+
key = torch.randn_like(query)
5645

57-
positions = torch.randint(0, max_position, (batch_size, seq_len))
58-
query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
59-
key = torch.randn_like(query)
46+
quantiles = [0.5, 0.2, 0.8]
6047

61-
# create query offsets for batched RoPE, we concat multiple kv cache
62-
# together and each query needs to find the right kv cache of its type
63-
offset_map = torch.tensor(
64-
list(
65-
accumulate(
66-
[0]
67-
+ [
68-
max_position * scaling_factor * 2
69-
for scaling_factor in scaling_factors[:-1]
70-
]
48+
if provider == "torch":
49+
ms, min_ms, max_ms = triton.testing.do_bench(
50+
lambda: rope.forward_native(positions, query.clone(), key.clone()),
51+
quantiles=quantiles,
7152
)
72-
)
73-
)
74-
query_types = torch.randint(
75-
0, len(scaling_factors), (batch_size, seq_len), device=device
76-
)
77-
# map query types to offsets
78-
query_offsets = offset_map[query_types]
79-
# the kernel takes flattened offsets
80-
flatten_offsets = query_offsets.flatten()
53+
elif provider == "flashinfer":
54+
ms, min_ms, max_ms = triton.testing.do_bench(
55+
lambda: torch.ops.vllm.flashinfer_rotary_embedding(
56+
positions,
57+
query.clone(),
58+
key.clone(),
59+
head_size,
60+
cos_sin_cache,
61+
is_neox_style,
62+
),
63+
quantiles=quantiles,
64+
)
65+
else:
66+
ms, min_ms, max_ms = triton.testing.do_bench(
67+
lambda: rope.forward_cuda(positions, query.clone(), key.clone()),
68+
quantiles=quantiles,
69+
)
70+
71+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
8172

82-
# batched queries of the same type together for non-batched RoPE
83-
queries = [query[query_types == i] for i in range(len(scaling_factors))]
84-
keys = [key[query_types == i] for i in range(len(scaling_factors))]
85-
packed_qkr = zip(queries, keys, non_batched_ropes)
86-
# synchronize before start timing
87-
torch.cuda.synchronize()
88-
with nvtx.annotate("non-batched", color="yellow"):
89-
for q, k, r in packed_qkr:
90-
r.forward(positions, q, k)
91-
torch.cuda.synchronize()
92-
with nvtx.annotate("batched", color="green"):
93-
batched_rope.forward(positions, query, key, flatten_offsets)
94-
torch.cuda.synchronize()
73+
return benchmark
9574

9675

9776
if __name__ == "__main__":
@@ -116,17 +95,12 @@ def benchmark_rope_kernels_multi_lora(
11695
parser.add_argument(
11796
"--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
11897
)
98+
parser.add_argument("--save-path", type=str, default="./configs/rope/")
11999
args = parser.parse_args()
120-
print(args)
121100

122-
benchmark_rope_kernels_multi_lora(
123-
is_neox_style=args.is_neox_style,
124-
batch_size=args.batch_size,
125-
seq_len=args.seq_len,
126-
num_heads=args.num_heads,
127-
head_size=args.head_size,
128-
rotary_dim=args.rotary_dim,
129-
dtype=getattr(torch, args.dtype),
130-
seed=args.seed,
131-
device=args.device,
101+
# Get the benchmark function
102+
benchmark = get_benchmark(
103+
args.is_neox_style, args.head_size, args.rotary_dim, args.device
132104
)
105+
# Run performance benchmark
106+
benchmark.run(print_data=True, save_path=args.save_path)

0 commit comments

Comments
 (0)