44import torch
55
66from vllm .model_executor .layers .quantization .utils .fp8_utils import (
7- w8a8_block_fp8_matmul ,
7+ apply_w8a8_block_fp8_linear ,
8+ )
9+ from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
10+ CUTLASS_BLOCK_FP8_SUPPORTED ,
811)
912from vllm .platforms import current_platform
1013from vllm .triton_utils import triton as vllm_triton
2932]
3033
3134
32- def build_w8a8_block_fp8_runner (M , N , K , block_size , device ):
35+ def build_w8a8_block_fp8_runner (M , N , K , block_size , device , use_cutlass ):
3336 """Build runner function for w8a8 block fp8 matmul."""
3437 factor_for_scale = 1e-2
3538
3639 fp8_info = torch .finfo (torch .float8_e4m3fn )
3740 fp8_max , fp8_min = fp8_info .max , fp8_info .min
3841
3942 # Create random FP8 tensors
40- A_fp32 = (torch .rand (M , K , dtype = torch .float32 , device = device ) - 0.5 ) * 2 * fp8_max
41- A = A_fp32 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
43+ A_ref = (torch .rand (M , K , dtype = torch .bfloat16 , device = device ) - 0.5 ) * 2 * fp8_max
4244
43- B_fp32 = (torch .rand (N , K , dtype = torch .float32 , device = device ) - 0.5 ) * 2 * fp8_max
44- B = B_fp32 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
45+ B_ref = (torch .rand (N , K , dtype = torch .bfloat16 , device = device ) - 0.5 ) * 2 * fp8_max
46+ B = B_ref .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
4547
4648 # Create scales
4749 block_n , block_k = block_size [0 ], block_size [1 ]
4850 n_tiles = (N + block_n - 1 ) // block_n
4951 k_tiles = (K + block_k - 1 ) // block_k
5052
51- As = torch .rand (M , k_tiles , dtype = torch .float32 , device = device ) * factor_for_scale
5253 Bs = (
5354 torch .rand (n_tiles , k_tiles , dtype = torch .float32 , device = device )
5455 * factor_for_scale
5556 )
5657
58+ # SM90 CUTLASS requires row-major format for scales
59+ if use_cutlass and current_platform .is_device_capability (90 ):
60+ Bs = Bs .T .contiguous ()
61+
5762 def run ():
58- return w8a8_block_fp8_matmul (A , B , As , Bs , block_size , torch .bfloat16 )
63+ if use_cutlass :
64+ return apply_w8a8_block_fp8_linear (
65+ A_ref , B , block_size , Bs , cutlass_block_fp8_supported = True
66+ )
67+ else :
68+ return apply_w8a8_block_fp8_linear (
69+ A_ref , B , block_size , Bs , cutlass_block_fp8_supported = False
70+ )
5971
6072 return run
6173
6274
75+ # Determine available providers
76+ available_providers = ["torch-bf16" , "w8a8-block-fp8-triton" ]
77+ plot_title = "BF16 vs W8A8 Block FP8 GEMMs"
78+
79+ if CUTLASS_BLOCK_FP8_SUPPORTED :
80+ available_providers .append ("w8a8-block-fp8-cutlass" )
81+
82+
6383@vllm_triton .testing .perf_report (
6484 vllm_triton .testing .Benchmark (
6585 x_names = ["batch_size" ],
6686 x_vals = [1 , 16 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 , 8192 , 16384 ],
6787 x_log = False ,
6888 line_arg = "provider" ,
69- line_vals = [ "torch-bf16" , "w8a8-block-fp8" ] ,
70- line_names = [ "torch-bf16" , "w8a8-block-fp8" ] ,
89+ line_vals = available_providers ,
90+ line_names = available_providers ,
7191 ylabel = "TFLOP/s (larger is better)" ,
7292 plot_name = "BF16 vs W8A8 Block FP8 GEMMs" ,
7393 args = {},
@@ -85,11 +105,22 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)):
85105 ms , min_ms , max_ms = vllm_triton .testing .do_bench_cudagraph (
86106 lambda : torch .nn .functional .linear (a , b ), quantiles = quantiles
87107 )
88- else : # w8a8-block-fp8
89- run_w8a8 = build_w8a8_block_fp8_runner (M , N , K , block_size , device )
108+ elif provider == "w8a8-block-fp8-triton" :
109+ run_w8a8_triton = build_w8a8_block_fp8_runner (
110+ M , N , K , block_size , device , use_cutlass = False
111+ )
112+ ms , min_ms , max_ms = vllm_triton .testing .do_bench_cudagraph (
113+ lambda : run_w8a8_triton (), quantiles = quantiles
114+ )
115+ elif provider == "w8a8-block-fp8-cutlass" :
116+ run_w8a8_cutlass = build_w8a8_block_fp8_runner (
117+ M , N , K , block_size , device , use_cutlass = True
118+ )
90119 ms , min_ms , max_ms = vllm_triton .testing .do_bench_cudagraph (
91- lambda : run_w8a8 (), quantiles = quantiles
120+ lambda : run_w8a8_cutlass (), quantiles = quantiles
92121 )
122+ else :
123+ raise ValueError (f"Unknown provider: { provider } " )
93124
94125 to_tflops = lambda t_ms : (2 * M * N * K ) * 1e-12 / (t_ms * 1e-3 )
95126 return to_tflops (ms ), to_tflops (max_ms ), to_tflops (min_ms )
0 commit comments