@@ -51,31 +51,33 @@ def calculate_diff(
5151):
5252 """Calculate the difference between Inductor and CUDA implementations."""
5353 device = torch .device ("cuda" )
54- x = torch .rand ((batch_size * hidden_size , 4096 ), dtype = dtype , device = device )
54+ x = torch .randn ((batch_size , hidden_size ), dtype = dtype , device = device )
5555
5656 quant_fp8 = QuantFP8 (False , group_shape , column_major_scales = False )
5757
5858 torch_out , torch_scale = bench_compile (quant_fp8 .forward_native )(x )
5959 torch_eager_out , torch_eager_scale = quant_fp8 .forward_native (x )
6060 cuda_out , cuda_scale = quant_fp8 .forward_cuda (x )
6161
62- out_allclose = lambda o1 , o2 : torch .allclose (
63- o1 .to (torch .float32 ),
64- o2 .to (torch .float32 ),
65- rtol = 1e-3 ,
66- atol = 1e-5 ,
67- )
68- scale_allclose = lambda s1 , s2 : torch .allclose (s1 , s2 , rtol = 1e-3 , atol = 1e-5 )
69-
70- if (
71- out_allclose (cuda_out , torch_out )
72- and scale_allclose (cuda_scale , torch_scale )
73- and out_allclose (cuda_out , torch_eager_out )
74- and scale_allclose (cuda_scale , torch_eager_scale )
75- ):
62+ try :
63+ torch .testing .assert_close (
64+ cuda_out .to (torch .float32 ),
65+ torch_out .to (torch .float32 ),
66+ rtol = 1e-3 ,
67+ atol = 1e-5 ,
68+ )
69+ torch .testing .assert_close (cuda_scale , torch_scale , rtol = 1e-3 , atol = 1e-5 )
70+ torch .testing .assert_close (
71+ cuda_out .to (torch .float32 ),
72+ torch_eager_out .to (torch .float32 ),
73+ rtol = 1e-3 ,
74+ atol = 1e-5 ,
75+ )
76+ torch .testing .assert_close (cuda_scale , torch_eager_scale , rtol = 1e-3 , atol = 1e-5 )
7677 print ("✅ All implementations match" )
77- else :
78+ except AssertionError as e :
7879 print ("❌ Implementations differ" )
80+ print (e )
7981
8082
8183configs = []
@@ -91,7 +93,7 @@ def benchmark_quantization(
9193):
9294 device = torch .device ("cuda" )
9395
94- x = torch .randn (batch_size * hidden_size , 4096 , device = device , dtype = dtype )
96+ x = torch .randn (batch_size , hidden_size , device = device , dtype = dtype )
9597
9698 quantiles = [0.5 , 0.2 , 0.8 ]
9799 quant_fp8 = QuantFP8 (False , group_shape , column_major_scales = col_major )
@@ -157,21 +159,21 @@ def geo_speedup(group: pd.DataFrame) -> pd.Series:
157159 )
158160 parser .add_argument ("-c" , "--check" , action = "store_true" )
159161 parser .add_argument (
160- "--dtype" , type = str , choices = ["half" , "bfloat16" , "float" ], default = "half "
162+ "--dtype" , type = str , choices = ["half" , "bfloat16" , "float" ], default = "bfloat16 "
161163 )
162164 parser .add_argument (
163165 "--hidden-sizes" ,
164166 type = int ,
165167 nargs = "+" ,
166- default = None ,
167- help = "Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096) " ,
168+ default = [ 896 , 1024 , 2048 , 4096 , 7168 ] ,
169+ help = "Hidden sizes to benchmark" ,
168170 )
169171 parser .add_argument (
170172 "--batch-sizes" ,
171173 type = int ,
172174 nargs = "+" ,
173- default = None ,
174- help = "Batch sizes to benchmark (default: 1,16,32,64,128) " ,
175+ default = [ 1 , 16 , 128 , 512 , 1024 ] ,
176+ help = "Batch sizes to benchmark" ,
175177 )
176178 parser .add_argument (
177179 "--group-sizes" ,
@@ -192,8 +194,8 @@ def geo_speedup(group: pd.DataFrame) -> pd.Series:
192194
193195 dtype = STR_DTYPE_TO_TORCH_DTYPE [args .dtype ]
194196
195- hidden_sizes = args .hidden_sizes or [ 1 , 16 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 ]
196- batch_sizes = args .batch_sizes or [ 1 , 16 , 32 , 64 , 128 ]
197+ hidden_sizes = args .hidden_sizes
198+ batch_sizes = args .batch_sizes
197199
198200 if args .group_sizes is not None :
199201 group_shapes = []
0 commit comments