4747 for stages in num_stages
4848]
4949
50- from torch .library import triton_op , wrap_triton
5150
52-
53- @triton_op ("torchao::triton_fp8_row_major_jagged_rowwise_scales" , mutates_args = {})
54- def triton_fp8_row_major_jagged_rowwise_scales (
51+ @torch .library .custom_op (
52+ "torchao::triton_fp8_per_group_rowwise_scales" , mutates_args = {}
53+ )
54+ def triton_fp8_per_group_rowwise_scales (
5555 hp_tensor : torch .Tensor ,
5656 offsets : torch .Tensor ,
5757 output_dtype : torch .dtype = torch .float8_e4m3fn ,
@@ -95,7 +95,7 @@ def triton_fp8_row_major_jagged_rowwise_scales(
9595 triton .cdiv (m , meta ["BLOCK_SIZE" ]),
9696 offsets .numel (),
9797 )
98- wrap_triton ( _triton_fp8_row_major_jagged_rowwise_scales ) [grid ](
98+ _triton_fp8_per_group_rowwise_scales_kernel [grid ](
9999 hp_tensor ,
100100 offsets ,
101101 output_buffer ,
@@ -117,6 +117,24 @@ def triton_fp8_row_major_jagged_rowwise_scales(
117117 return output_buffer , scales_buffer
118118
119119
120+ @triton_fp8_per_group_rowwise_scales .register_fake
121+ def _fake_triton_fp8_per_group_rowwise_scales_kernel (
122+ hp_tensor : torch .Tensor ,
123+ offsets : torch .Tensor ,
124+ output_dtype : torch .dtype = torch .float8_e4m3fn ,
125+ round_scales_to_power_of_2 : bool = False ,
126+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
127+ assert hp_tensor .ndim == 2 , "input tensor must be 2D"
128+ m , k = hp_tensor .shape
129+ n_groups = offsets .numel ()
130+ output = torch .empty_like (hp_tensor , dtype = output_dtype ).as_strided (
131+ (m , k ), # shape
132+ (k , 1 ), # stride
133+ )
134+ scales = torch .empty ((m * n_groups ), dtype = torch .float32 , device = hp_tensor .device )
135+ return output , scales
136+
137+
120138# This kernel is used on grad_output.t() which has shape (K, M),
121139# before the calculation `grad_B = grad_output_t @ input`.
122140# However, in this code, we use the conventional dim names (M, K)
@@ -125,7 +143,7 @@ def triton_fp8_row_major_jagged_rowwise_scales(
125143# to recompile on `token` dim (K, in this case) changes.
126144@triton .autotune (configs = kernel_configs_2D , key = ["M" ])
127145@triton .jit
128- def _triton_fp8_row_major_jagged_rowwise_scales (
146+ def _triton_fp8_per_group_rowwise_scales_kernel (
129147 input_ptr ,
130148 offsets_ptr ,
131149 out_ptr ,
@@ -215,8 +233,10 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
215233 tl .store (out_ptr + out_offs , fp8_data , mask = block_mask )
216234
217235
218- @triton_op ("torchao::triton_fp8_col_major_jagged_colwise_scales" , mutates_args = {})
219- def triton_fp8_col_major_jagged_colwise_scales (
236+ @torch .library .custom_op (
237+ "torchao::triton_fp8_per_group_colwise_scales" , mutates_args = {}
238+ )
239+ def triton_fp8_per_group_colwise_scales (
220240 hp_tensor : torch .Tensor ,
221241 offsets : torch .Tensor ,
222242 output_dtype : torch .dtype = torch .float8_e4m3fn ,
@@ -263,7 +283,7 @@ def triton_fp8_col_major_jagged_colwise_scales(
263283 triton .cdiv (n , meta ["BLOCK_SIZE" ]),
264284 offsets .numel (),
265285 )
266- wrap_triton ( _triton_fp8_col_major_jagged_colwise_scales ) [grid ](
286+ _triton_fp8_per_group_colwise_scales_kernel [grid ](
267287 hp_tensor ,
268288 offsets ,
269289 output_buffer ,
@@ -285,13 +305,33 @@ def triton_fp8_col_major_jagged_colwise_scales(
285305 return output_buffer , scales_buffer
286306
287307
308+ @triton_fp8_per_group_colwise_scales .register_fake
309+ def _fake_triton_fp8_per_group_colwise_scales (
310+ hp_tensor : torch .Tensor ,
311+ offsets : torch .Tensor ,
312+ output_dtype : torch .dtype = torch .float8_e4m3fn ,
313+ round_scales_to_power_of_2 : bool = False ,
314+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
315+ assert hp_tensor .ndim == 2 , "input tensor must be 2D"
316+ k , n = hp_tensor .shape
317+ n_groups = offsets .numel ()
318+ output_buffer = torch .empty_like (
319+ hp_tensor , dtype = output_dtype , device = hp_tensor .device
320+ ).as_strided (hp_tensor .size (), (1 , k ))
321+
322+ scales_buffer = torch .empty (
323+ (n * n_groups ), dtype = torch .float32 , device = hp_tensor .device
324+ )
325+ return output_buffer , scales_buffer
326+
327+
288328# This kernel is used on `input` which has shape (M, K),
289329# before the calculation `grad_B = grad_output_t @ input`.
290330# The tokens per expert will vary per iteration, so don't want
291331# to recompile on `token` dim (M) changes.
292332@triton .autotune (configs = kernel_configs_2D , key = ["K" ])
293333@triton .jit
294- def _triton_fp8_col_major_jagged_colwise_scales (
334+ def _triton_fp8_per_group_colwise_scales_kernel (
295335 input_ptr ,
296336 offsets_ptr ,
297337 out_ptr ,
0 commit comments