77"""
88Triton kernels for scaling high precision tensors to float8.
99"""
10+ from enum import Enum
1011
1112import torch
12-
1313import triton
1414import triton .language as tl
1515
3131}
3232
3333
34+ class KernelAlgorithm (Enum ):
35+ """Enum for FP8 conversion strategy."""
36+
37+ # use atomic max to compute global amax between blocks
38+ ATOMIC_MAX = "atomic_max"
39+
40+ # reduce shared buffer containing local block amaxes to find global amax
41+ REDUCTION = "reduction"
42+
43+
44+ kernel_configs = [
45+ triton .Config ({"BLOCK_SIZE" : 128 }, num_warps = 1 ),
46+ triton .Config ({"BLOCK_SIZE" : 256 }, num_warps = 2 ),
47+ triton .Config ({"BLOCK_SIZE" : 512 }, num_warps = 4 ),
48+ ]
49+
50+
51+ # --- atomic max version of kernel ---
52+ @triton .autotune (configs = kernel_configs , key = ["input_size" ])
53+ @triton .jit
54+ def _block_amax_atomic (
55+ input_ptr ,
56+ amax_ptr ,
57+ num_elements ,
58+ input_dtype : tl .constexpr ,
59+ BLOCK_SIZE : tl .constexpr ,
60+ EPS : tl .constexpr ,
61+ ):
62+ # compute local amax for each block
63+ block_id = tl .program_id (axis = 0 )
64+ block_start = block_id * BLOCK_SIZE
65+ block_offs = block_start + tl .arange (0 , BLOCK_SIZE )
66+ block_mask = block_offs < num_elements
67+ vals = tl .load (input_ptr + block_offs , mask = block_mask ).to (input_dtype )
68+ block_amax = tl .max (tl .abs (vals ))
69+ tl .atomic_max (amax_ptr , block_amax )
70+
71+
72+ @triton .jit
73+ def _fp8_scale_atomic (
74+ amax_ptr ,
75+ scale_out_ptr ,
76+ fp8_dtype_max ,
77+ EPS : tl .constexpr ,
78+ ):
79+ # load previously computed global amax
80+ global_amax = tl .load (amax_ptr )
81+
82+ # compute scale, must be fp32
83+ scale = (fp8_dtype_max / tl .clamp (global_amax , min = EPS , max = float ("inf" ))).to (
84+ tl .float32
85+ )
86+
87+ # store scale for use in Float8Tensor constructor
88+ scale_off = tl .arange (0 , 1 )
89+ tl .store (scale_out_ptr + scale_off , scale )
90+
91+
92+ @triton .autotune (configs = kernel_configs , key = ["input_size" ])
3493@triton .jit
35- def _block_amax (
94+ def _to_fp8_atomic (
95+ input_ptr ,
96+ scale_ptr ,
97+ amax_ptr ,
98+ out_ptr ,
99+ num_elements ,
100+ fp8_dtype_min ,
101+ fp8_dtype_max ,
102+ input_dtype : tl .constexpr ,
103+ output_dtype : tl .constexpr ,
104+ BLOCK_SIZE : tl .constexpr ,
105+ EPS : tl .constexpr ,
106+ ):
107+ block_id = tl .program_id (axis = 0 )
108+
109+ # load scale
110+ scale = tl .load (scale_ptr )
111+
112+ # load block of input tensor
113+ block_start = block_id * BLOCK_SIZE
114+ block_offs = block_start + tl .arange (0 , BLOCK_SIZE )
115+ mask = block_offs < num_elements
116+ vals = tl .load (input_ptr + block_offs , mask = mask ).to (input_dtype )
117+
118+ # perform conversion
119+ vals = vals * scale
120+ fp8_vals = tl .clamp (vals , min = fp8_dtype_min , max = fp8_dtype_max ).to (output_dtype )
121+ tl .store (out_ptr + block_offs , fp8_vals , mask = mask )
122+
123+
124+ # --- reduction version of kernel ---
125+ @triton .jit
126+ def _block_amax_reduction (
36127 input_ptr ,
37128 block_amaxes_ptr ,
38129 num_elements ,
@@ -46,12 +137,12 @@ def _block_amax(
46137 block_offs = block_start + tl .arange (0 , BLOCK_SIZE )
47138 block_mask = block_offs < num_elements
48139 vals = tl .load (input_ptr + block_offs , mask = block_mask ).to (input_dtype )
49- block_amax = tl .max (tl .abs (vals ), axis = 0 )
140+ block_amax = tl .max (tl .abs (vals ))
50141 tl .store (block_amaxes_ptr + block_id , block_amax )
51142
52143
53144@triton .jit
54- def _fp8_scale (
145+ def _fp8_scale_reduction (
55146 block_amaxes_ptr ,
56147 scale_out_ptr ,
57148 num_elements ,
@@ -75,7 +166,7 @@ def _fp8_scale(
75166
76167
77168@triton .jit
78- def _to_fp8 (
169+ def _to_fp8_reduction (
79170 input_ptr ,
80171 scale_ptr ,
81172 out_ptr ,
@@ -108,12 +199,10 @@ def triton_hp_tensor_to_float8_dynamic(
108199 fp8_dtype : torch .dtype ,
109200 linear_mm_config : LinearMMConfig ,
110201 gemm_input_role : GemmInputRole = GemmInputRole .INPUT ,
202+ algo : KernelAlgorithm = KernelAlgorithm .ATOMIC_MAX ,
111203) -> Float8Tensor :
112-
113204 assert hp_tensor .is_contiguous (), "tensor must be contiguous"
114205
115- BLOCK_SIZE = 8 # TODO(danielvegamyhre): tune this for perf
116-
117206 num_elements = hp_tensor .numel ()
118207 orig_shape = hp_tensor .shape
119208 flattened_input = hp_tensor .flatten ()
@@ -126,47 +215,86 @@ def triton_hp_tensor_to_float8_dynamic(
126215
127216 # allocate memory for computed scale, local block maxes, and output fp8 tensor
128217 scale_out = torch .empty ((1 ,), dtype = torch .float32 , device = hp_tensor .device )
129- block_amaxes = torch .zeros (
130- (num_elements // BLOCK_SIZE ,), dtype = torch .float32 , device = hp_tensor .device
131- )
218+
132219 fp8_output = torch .empty_like (
133220 flattened_input , dtype = fp8_dtype , device = hp_tensor .device
134221 )
135222
136- # compute local amax for each block
137223 grid = lambda meta : (triton .cdiv (num_elements , meta ["BLOCK_SIZE" ]),)
138- _block_amax [grid ](
139- flattened_input ,
140- block_amaxes ,
141- num_elements ,
142- input_dtype = tl_input_dtype ,
143- BLOCK_SIZE = BLOCK_SIZE ,
144- EPS = EPS ,
145- )
146224
147- # calculate global amax across all blocks and use it to compute scale
148- _fp8_scale [(1 , 1 , 1 )](
149- block_amaxes ,
150- scale_out ,
151- num_elements ,
152- fp8_dtype_max ,
153- BLOCK_SIZE = BLOCK_SIZE ,
154- EPS = EPS ,
155- )
225+ if algo == KernelAlgorithm .ATOMIC_MAX :
226+ global_amax = torch .zeros ((1 ,), dtype = torch .float32 , device = hp_tensor .device )
227+ # compute global amax to be used for scaling
228+ _block_amax_atomic [grid ](
229+ flattened_input ,
230+ global_amax ,
231+ num_elements ,
232+ input_dtype = tl_input_dtype ,
233+ EPS = EPS ,
234+ )
156235
157- # perform conversion
158- _to_fp8 [grid ](
159- flattened_input ,
160- scale_out ,
161- fp8_output ,
162- num_elements ,
163- fp8_dtype_min ,
164- fp8_dtype_max ,
165- input_dtype = tl_input_dtype ,
166- output_dtype = tl_output_dtype ,
167- BLOCK_SIZE = BLOCK_SIZE ,
168- EPS = EPS ,
169- )
236+ # compute scale for fp8 conversion
237+ _fp8_scale_atomic [1 , 1 , 1 ](
238+ global_amax ,
239+ scale_out ,
240+ fp8_dtype_max ,
241+ EPS = EPS ,
242+ )
243+
244+ # perform conversion and store scale for use in Float8Tensor
245+ _to_fp8_atomic [grid ](
246+ flattened_input ,
247+ scale_out ,
248+ global_amax ,
249+ fp8_output ,
250+ num_elements ,
251+ fp8_dtype_min ,
252+ fp8_dtype_max ,
253+ input_dtype = tl_input_dtype ,
254+ output_dtype = tl_output_dtype ,
255+ EPS = EPS ,
256+ )
257+ elif algo == KernelAlgorithm .REDUCTION :
258+ max_block_size = 512
259+ BLOCK_SIZE = min (max_block_size , num_elements )
260+ block_amaxes = torch .zeros (
261+ (num_elements // BLOCK_SIZE ,), dtype = torch .float32 , device = hp_tensor .device
262+ )
263+ # compute local amax for each block
264+ _block_amax_reduction [grid ](
265+ flattened_input ,
266+ block_amaxes ,
267+ num_elements ,
268+ input_dtype = tl_input_dtype ,
269+ BLOCK_SIZE = BLOCK_SIZE ,
270+ EPS = EPS ,
271+ )
272+
273+ # calculate global amax across all blocks and use it to compute scale
274+ _fp8_scale_reduction [(1 , 1 , 1 )](
275+ block_amaxes ,
276+ scale_out ,
277+ num_elements ,
278+ fp8_dtype_max ,
279+ BLOCK_SIZE = BLOCK_SIZE ,
280+ EPS = EPS ,
281+ )
282+
283+ # perform conversion
284+ _to_fp8_reduction [grid ](
285+ flattened_input ,
286+ scale_out ,
287+ fp8_output ,
288+ num_elements ,
289+ fp8_dtype_min ,
290+ fp8_dtype_max ,
291+ input_dtype = tl_input_dtype ,
292+ output_dtype = tl_output_dtype ,
293+ BLOCK_SIZE = BLOCK_SIZE ,
294+ EPS = EPS ,
295+ )
296+ else :
297+ raise ValueError (f"Unsupported kernel algorithm: { algo } " )
170298
171299 return Float8Tensor (
172300 fp8_output .reshape (orig_shape ),
0 commit comments