44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Callable , Tuple
7+ from typing import Tuple
88
99import fire
1010import torch
1111import triton
12- from torch . _inductor . utils import do_bench_using_profiling
12+ from triton . testing import do_bench
1313
1414from torchao .prototype .mx_formats .kernels import (
1515 triton_to_mxfp8_dim1 ,
@@ -64,29 +64,35 @@ def to_mx_dim1_reference(x_hp, block_size):
6464 return data_d1 .t (), scale_d1
6565
6666
67- def benchmark_cuda_function_in_microseconds (func : Callable , * args , ** kwargs ) -> float :
68- """Thin wrapper around do_bench_using_profiling"""
69- no_args = lambda : func (* args , ** kwargs )
70- time = do_bench_using_profiling (no_args )
71- return time * 1e3
67+ def benchmark_cuda_function_in_microseconds (f , * args ):
68+ return do_bench (lambda : f (* args ), return_mode = "median" ) * 1e3
7269
7370
7471def run (
7572 M : int = 16384 ,
7673 K : int = 16384 ,
7774 BLOCK_SIZE : int = 32 ,
78- mode : str = "dim0 " ,
75+ mode : str = "dim0_floor " ,
7976):
8077 print (f"M { M } K { K } BLOCK_SIZE { BLOCK_SIZE } " )
8178 print (f"GPU: { torch .cuda .get_device_name (0 )} " )
8279 print (f"torch version: { torch .__version__ } " )
8380 print (f"triton version: { triton .__version__ } " )
8481 print (f"mode: { mode } " )
85- assert mode in ("dim0" , "dim1" , "dim0_dim1" , "dim0_mx" , "dim1_mx" , "dim1_mx_triton" )
82+ assert mode in (
83+ "dim0_floor" ,
84+ "dim1_floor" ,
85+ "dim0_dim1_floor" ,
86+ "dim0_mx_floor" ,
87+ "dim1_mx_floor" ,
88+ "dim1_mx_triton_floor" ,
89+ "dim1_mx_cuda_floor" ,
90+ "dim1_mx_cuda_rceil" ,
91+ )
8692
8793 x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" ) * 1000
8894
89- if mode == "dim0 " :
95+ if mode == "dim0_floor " :
9096 scale_dim0_reference_c = torch .compile (scale_dim0_reference )
9197 y_d0 , s_d0 = scale_dim0_reference_c (x , BLOCK_SIZE )
9298
@@ -103,7 +109,7 @@ def run(
103109 bytes_rw = sum (t .numel () for t in [x , y_d0 , s_d0 ]) * bytes_per_el_bf16
104110 bps = bytes_rw / (time_us / 1e6 )
105111
106- elif mode == "dim1 " :
112+ elif mode == "dim1_floor " :
107113 scale_dim1_reference_c = torch .compile (scale_dim1_reference )
108114 y_d1 , s_d1 = scale_dim1_reference_c (x , BLOCK_SIZE )
109115
@@ -120,7 +126,7 @@ def run(
120126 bytes_rw = sum (t .numel () for t in [x , y_d1 , s_d1 ]) * bytes_per_el_bf16
121127 bps = bytes_rw / (time_us / 1e6 )
122128
123- elif mode == "dim0_dim1 " :
129+ elif mode == "dim0_dim1_floor " :
124130 scale_dim0_dim1_reference_c = torch .compile (scale_dim0_dim1_reference )
125131 y_d0 , y_d1 , s_d0 , s_d1 = scale_dim0_dim1_reference_c (x , BLOCK_SIZE )
126132
@@ -141,7 +147,7 @@ def run(
141147 )
142148 bps = bytes_rw / (time_us / 1e6 )
143149
144- elif mode == "dim0_mx " :
150+ elif mode == "dim0_mx_floor " :
145151 to_mx_dim0_reference_c = torch .compile (to_mx_dim0_reference )
146152 y_d0 , s_d0 = to_mx_dim0_reference_c (x , BLOCK_SIZE )
147153
@@ -159,7 +165,7 @@ def run(
159165 bytes_w = (y_d0 .numel () + s_d0 .numel ()) * bytes_per_el_fp8
160166 bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
161167
162- elif mode == "dim1_mx " :
168+ elif mode == "dim1_mx_floor " :
163169 to_mx_dim1_reference_c = torch .compile (to_mx_dim1_reference )
164170 y_d1 , s_d1 = to_mx_dim1_reference_c (x , BLOCK_SIZE )
165171
@@ -177,7 +183,7 @@ def run(
177183 bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
178184 bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
179185
180- elif mode == "dim1_mx_triton " :
186+ elif mode == "dim1_mx_triton_floor " :
181187 y_d1 , s_d1 = triton_to_mxfp8_dim1 (x , inner_block_size = BLOCK_SIZE )
182188
183189 for _ in range (2 ):
@@ -194,6 +200,58 @@ def run(
194200 bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
195201 bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
196202
203+ elif mode == "dim1_mx_cuda_floor" :
204+ from torchao .prototype import mxfp8_cuda
205+
206+ _ , y_d1 , _ , s_d1 = mxfp8_cuda .quantize (
207+ x , rowwise = False , colwise = True , scaling_mode = "floor"
208+ )
209+
210+ for _ in range (2 ):
211+ __ = mxfp8_cuda .quantize (
212+ x , rowwise = False , colwise = True , scaling_mode = "floor"
213+ )
214+
215+ time_us = benchmark_cuda_function_in_microseconds (
216+ lambda x : mxfp8_cuda .quantize (
217+ x , rowwise = False , colwise = True , scaling_mode = "floor"
218+ ),
219+ x ,
220+ )
221+
222+ assert y_d1 .dtype == torch .float8_e4m3fn
223+ assert s_d1 .dtype == torch .float8_e8m0fnu
224+
225+ bytes_r = x .numel () * bytes_per_el_bf16
226+ bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
227+ bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
228+
229+ elif mode == "dim1_mx_cuda_rceil" :
230+ from torchao .prototype import mxfp8_cuda
231+
232+ _ , y_d1 , _ , s_d1 = mxfp8_cuda .quantize (
233+ x , rowwise = False , colwise = True , scaling_mode = "rceil"
234+ )
235+
236+ for _ in range (2 ):
237+ __ = mxfp8_cuda .quantize (
238+ x , rowwise = False , colwise = True , scaling_mode = "rceil"
239+ )
240+
241+ time_us = benchmark_cuda_function_in_microseconds (
242+ lambda x : mxfp8_cuda .quantize (
243+ x , rowwise = False , colwise = True , scaling_mode = "rceil"
244+ ),
245+ x ,
246+ )
247+
248+ assert y_d1 .dtype == torch .float8_e4m3fn
249+ assert s_d1 .dtype == torch .float8_e8m0fnu
250+
251+ bytes_r = x .numel () * bytes_per_el_bf16
252+ bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
253+ bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
254+
197255 else :
198256 raise AssertionError (f"unknown mode { mode } " )
199257
0 commit comments