11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4- from itertools import accumulate
4+ import itertools
55
6- import nvtx
76import torch
87
9- from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding , get_rope
10- from vllm .platforms import current_platform
8+ from vllm .model_executor .layers .rotary_embedding import get_rope
9+ from vllm .triton_utils import triton
1110from vllm .utils .argparse_utils import FlexibleArgumentParser
1211
12+ batch_size_range = [2 ** i for i in range (0 , 8 , 2 )]
13+ seq_len_range = [2 ** i for i in range (6 , 10 , 1 )]
14+ num_heads_range = [32 , 48 ]
15+ configs = list (itertools .product (batch_size_range , seq_len_range , num_heads_range ))
1316
14- def benchmark_rope_kernels_multi_lora (
15- is_neox_style : bool ,
16- batch_size : int ,
17- seq_len : int ,
18- num_heads : int ,
19- head_size : int ,
20- rotary_dim : int | None ,
21- dtype : torch .dtype ,
22- seed : int ,
23- device : str ,
24- max_position : int = 8192 ,
25- base : float = 10000 ,
26- ) -> None :
27- current_platform .seed_everything (seed )
28- torch .set_default_device (device )
29- if rotary_dim is None :
30- rotary_dim = head_size
31- # silulating serving 4 LoRAs
32- scaling_factors = [1 , 2 , 4 , 8 ]
33- # batched RoPE can take multiple scaling factors
34- batched_rope = get_rope (
35- head_size ,
36- rotary_dim ,
37- max_position ,
38- base ,
39- is_neox_style ,
40- {"rope_type" : "linear" , "factor" : tuple (scaling_factors )},
17+
18+ def get_benchmark (is_neox_style , head_size , rotary_dim , device ):
19+ @triton .testing .perf_report (
20+ triton .testing .Benchmark (
21+ x_names = ["batch_size" , "seq_len" , "num_heads" ],
22+ x_vals = [list (_ ) for _ in configs ],
23+ line_arg = "provider" ,
24+ line_vals = ["torch" , "flashinfer" , "vllm" ],
25+ line_names = ["PyTorch" , "FlashInfer" , "vLLM" ],
26+ styles = [("blue" , "-" ), ("green" , "-" ), ("red" , "-" )],
27+ ylabel = "us" ,
28+ plot_name = f"rope-perf{ '-neox-style' if is_neox_style else '' } " ,
29+ args = {},
30+ )
4131 )
42- # non-batched RoPE takes only one scaling factor, we create multiple
43- # instances to simulate the same behavior
44- non_batched_ropes : list [RotaryEmbedding ] = []
45- for scaling_factor in scaling_factors :
46- non_batched_ropes .append (
47- get_rope (
48- head_size ,
49- rotary_dim ,
50- max_position ,
51- base ,
52- is_neox_style ,
53- {"rope_type" : "linear" , "factor" : (scaling_factor ,)},
54- )
32+ def benchmark (batch_size , seq_len , num_heads , provider ):
33+ dtype = torch .bfloat16
34+ max_position = 8192
35+ base = 10000
36+ rope = get_rope (head_size , rotary_dim , max_position , base , is_neox_style )
37+ rope = rope .to (dtype = dtype , device = device )
38+ cos_sin_cache = rope .cos_sin_cache .to (dtype = torch .float , device = device )
39+
40+ positions = torch .randint (0 , max_position , (batch_size , seq_len ), device = device )
41+ query = torch .randn (
42+ (batch_size , seq_len , num_heads * head_size ), dtype = dtype , device = device
5543 )
44+ key = torch .randn_like (query )
5645
57- positions = torch .randint (0 , max_position , (batch_size , seq_len ))
58- query = torch .randn (batch_size , seq_len , num_heads * head_size , dtype = dtype )
59- key = torch .randn_like (query )
46+ quantiles = [0.5 , 0.2 , 0.8 ]
6047
61- # create query offsets for batched RoPE, we concat multiple kv cache
62- # together and each query needs to find the right kv cache of its type
63- offset_map = torch .tensor (
64- list (
65- accumulate (
66- [0 ]
67- + [
68- max_position * scaling_factor * 2
69- for scaling_factor in scaling_factors [:- 1 ]
70- ]
48+ if provider == "torch" :
49+ ms , min_ms , max_ms = triton .testing .do_bench (
50+ lambda : rope .forward_native (positions , query .clone (), key .clone ()),
51+ quantiles = quantiles ,
7152 )
72- )
73- )
74- query_types = torch .randint (
75- 0 , len (scaling_factors ), (batch_size , seq_len ), device = device
76- )
77- # map query types to offsets
78- query_offsets = offset_map [query_types ]
79- # the kernel takes flattened offsets
80- flatten_offsets = query_offsets .flatten ()
53+ elif provider == "flashinfer" :
54+ ms , min_ms , max_ms = triton .testing .do_bench (
55+ lambda : torch .ops .vllm .flashinfer_rotary_embedding (
56+ positions ,
57+ query .clone (),
58+ key .clone (),
59+ head_size ,
60+ cos_sin_cache ,
61+ is_neox_style ,
62+ ),
63+ quantiles = quantiles ,
64+ )
65+ else :
66+ ms , min_ms , max_ms = triton .testing .do_bench (
67+ lambda : rope .forward_cuda (positions , query .clone (), key .clone ()),
68+ quantiles = quantiles ,
69+ )
70+
71+ return 1000 * ms , 1000 * max_ms , 1000 * min_ms
8172
82- # batched queries of the same type together for non-batched RoPE
83- queries = [query [query_types == i ] for i in range (len (scaling_factors ))]
84- keys = [key [query_types == i ] for i in range (len (scaling_factors ))]
85- packed_qkr = zip (queries , keys , non_batched_ropes )
86- # synchronize before start timing
87- torch .cuda .synchronize ()
88- with nvtx .annotate ("non-batched" , color = "yellow" ):
89- for q , k , r in packed_qkr :
90- r .forward (positions , q , k )
91- torch .cuda .synchronize ()
92- with nvtx .annotate ("batched" , color = "green" ):
93- batched_rope .forward (positions , query , key , flatten_offsets )
94- torch .cuda .synchronize ()
73+ return benchmark
9574
9675
9776if __name__ == "__main__" :
@@ -116,17 +95,12 @@ def benchmark_rope_kernels_multi_lora(
11695 parser .add_argument (
11796 "--device" , type = str , choices = ["cuda:0" , "cuda:1" ], default = "cuda:0"
11897 )
98+ parser .add_argument ("--save-path" , type = str , default = "./configs/rope/" )
11999 args = parser .parse_args ()
120- print (args )
121100
122- benchmark_rope_kernels_multi_lora (
123- is_neox_style = args .is_neox_style ,
124- batch_size = args .batch_size ,
125- seq_len = args .seq_len ,
126- num_heads = args .num_heads ,
127- head_size = args .head_size ,
128- rotary_dim = args .rotary_dim ,
129- dtype = getattr (torch , args .dtype ),
130- seed = args .seed ,
131- device = args .device ,
101+ # Get the benchmark function
102+ benchmark = get_benchmark (
103+ args .is_neox_style , args .head_size , args .rotary_dim , args .device
132104 )
105+ # Run performance benchmark
106+ benchmark .run (print_data = True , save_path = args .save_path )
0 commit comments