11import argparse
2+
23import logging
34
45from typing import Any , Callable , List , Optional
78import torch ._inductor .config as inductor_config
89import triton
910
11+ from torch ._inductor .kernel .mm import scaling_pairs , ScalingType
12+
1013from tritonbench .operators .fp8_gemm .persistent import blackwell_persistent_tma
1114from tritonbench .utils .env_utils import get_nvidia_gpu_model , is_cuda
1215
4649def parse_args (args ):
4750 parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
4851 parser .add_argument ("--llama" , action = "store_true" )
49- parser .add_argument ("--scaling_rowwise " , action = "store_true " )
52+ parser .add_argument ("--scaling-pair " , type = str , default = "TensorWise,TensorWise " )
5053 parser .add_argument ("--m" , type = int )
5154 parser .add_argument ("--k" , type = int )
5255 parser .add_argument ("--n" , type = int )
@@ -65,6 +68,58 @@ def get_fp8_dtype():
6568 return torch .float8_e4m3fnuz
6669
6770
71+ def get_scaling_recipe (scaling_recipe : str ) -> int :
72+ if scaling_recipe == "TensorWise" :
73+ return ScalingType .TensorWise
74+ elif scaling_recipe == "RowWise" :
75+ return ScalingType .RowWise
76+ else :
77+ raise ValueError (f"Invalid scaling recipe: { scaling_recipe } " )
78+
79+
80+ def get_scale (
81+ x : torch .Tensor ,
82+ scaling_recipe : ScalingType ,
83+ transpose : bool = False ,
84+ custom_scale : float = None ,
85+ ) -> (torch .Tensor , torch .Tensor ):
86+ def _get_scale_per_tensor (
87+ x : torch .Tensor , custom_scale : float = None
88+ ) -> (torch .Tensor , torch .Tensor ):
89+ # For tensor-wise scaling, kernel requires a float32 scale tensor
90+ if custom_scale :
91+ return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
92+ scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
93+ x *= scale
94+ return x , scale .to (torch .float32 )
95+
96+ def _get_scale_per_row (
97+ x : torch .Tensor , transpose : bool = False
98+ ) -> (torch .Tensor , torch .Tensor ):
99+ if transpose : # scale_b.shape should be [1, N]
100+ scale = (
101+ torch .finfo (torch .float8_e4m3fn ).max
102+ / x .abs ().max (dim = 0 , keepdim = True ).values
103+ )
104+ else : # scale_a.shape should be [M, 1]
105+ scale = (
106+ torch .finfo (torch .float8_e4m3fn ).max
107+ / x .abs ().max (dim = 1 , keepdim = True ).values
108+ )
109+ x = x .mul (scale )
110+ return x , scale .to (
111+ torch .float32
112+ ) # For row-wise scaling, kernel requires a float32 scale tensor
113+
114+ match scaling_recipe :
115+ case ScalingType .TensorWise :
116+ return _get_scale_per_tensor (x , custom_scale = custom_scale )
117+ case ScalingType .RowWise :
118+ return _get_scale_per_row (x , transpose = transpose )
119+ case _:
120+ raise AssertionError (f"Unsupported scaling type { scaling_recipe } " )
121+
122+
68123class Operator (BenchmarkOperator ):
69124 DEFAULT_METRICS = ["tflops" , "gbps" , "latency" ]
70125 DEFAULT_PRECISION = "fp8"
@@ -78,53 +133,39 @@ def __init__(
78133
79134 self .fp8_dtype = get_fp8_dtype ()
80135
136+ scaling_recipe_a , scaling_recipe_b = self .extra_args .scaling_pair .split ("," )
137+ if (scaling_recipe_a , scaling_recipe_b ) not in [
138+ (a .name , b .name ) for a , b in scaling_pairs
139+ ]:
140+ raise ValueError (
141+ f"Invalid scaling pair: { scaling_recipe_a } , { scaling_recipe_b } . See torch/_inductor/kernel/mm.py::scaling_pairs for valid pairs."
142+ )
143+ self .scaling_recipe_a = get_scaling_recipe (scaling_recipe_a )
144+ self .scaling_recipe_b = get_scaling_recipe (scaling_recipe_b )
145+
81146 def _get_dtype (self ):
82- if self .extra_args .scaling_rowwise :
83- return torch .bfloat16
84- else :
147+ if (
148+ self .scaling_recipe_a == ScalingType .TensorWise
149+ and self .scaling_recipe_b == ScalingType .TensorWise
150+ ):
85151 return torch .float16
152+ return torch .bfloat16
86153
87154 def get_input_iter (self ):
88- def _get_scale_per_tensor (
89- x : torch .Tensor , custom_scale : float = None
90- ) -> torch .Tensor :
91- # For tensor-wise scaling, kernel requires a float32 scale tensor
92- if custom_scale :
93- return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
94- scale = torch .finfo (self .fp8_dtype ).max / x .abs ().max ()
95- return scale .to (torch .float32 )
96-
97- def _get_scale_per_row (
98- x : torch .Tensor , transpose : bool = False
99- ) -> torch .Tensor :
100- if transpose : # scale_b.shape should be [1, N]
101- scale = (
102- torch .finfo (self .fp8_dtype ).max
103- / x .abs ().max (dim = 0 , keepdim = True ).values
104- )
105- else : # scale_a.shape should be [M, 1]
106- scale = (
107- torch .finfo (self .fp8_dtype ).max
108- / x .abs ().max (dim = 1 , keepdim = True ).values
109- )
110- return scale .to (
111- torch .float32
112- ) # For row-wise scaling, kernel requires a float32 scale tensor
113-
114155 def args (m , n , k ):
115156 a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
116157 b = torch .randn (n , k , device = self .device ).to (self ._get_dtype ())
117158
118- if self . extra_args . scaling_rowwise :
119- scale_a = _get_scale_per_row ( a )
120- scale_b = _get_scale_per_row ( b )
121- else :
122- scale_a = _get_scale_per_tensor (
123- a , custom_scale = self . extra_args . per_tensor_scale_a
124- )
125- scale_b = _get_scale_per_tensor (
126- b , custom_scale = self .extra_args .per_tensor_scale_b
127- )
159+ a , scale_a = get_scale (
160+ a ,
161+ self . scaling_recipe_a ,
162+ custom_scale = self . extra_args . per_tensor_scale_a ,
163+ )
164+ b , scale_b = get_scale (
165+ b ,
166+ self . scaling_recipe_b ,
167+ custom_scale = self .extra_args .per_tensor_scale_b ,
168+ )
128169
129170 # Kernels expect dtype=float8_e4m3fn(uz)
130171 a = a .to (self .fp8_dtype )
@@ -198,13 +239,21 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
198239
199240 @register_benchmark (enabled = True )
200241 def blackwell_persistent_tma_fp8_gemm (self , a , b , scale_a , scale_b ):
242+ if self .scaling_recipe_a == self .scaling_recipe_b == ScalingType .TensorWise :
243+ scaling_recipe_int = 0
244+ elif self .scaling_recipe_a == self .scaling_recipe_b == ScalingType .RowWise :
245+ scaling_recipe_int = 1
246+ else :
247+ raise ValueError (
248+ f"Invalid scaling pair: { self .scaling_recipe_a } , { self .scaling_recipe_b } for blackwell_persistent_tma_fp8_gemm."
249+ )
201250 return lambda : blackwell_persistent_tma (
202251 a ,
203252 b ,
204253 scale_a ,
205254 scale_b ,
206255 self ._get_dtype (),
207- self . extra_args . scaling_rowwise ,
256+ scaling_recipe_int ,
208257 )
209258
210259 @register_benchmark (enabled = True )
0 commit comments