@@ -829,6 +829,8 @@ def _(uint8_data):
829829    import  triton .language  as  tl 
830830    from  torch .library  import  triton_op , wrap_triton 
831831
832+     print ("importing triton ops" )
833+ 
832834    @triton .jit  
833835    def  _triton_calculate_scale (x , axis ):
834836        # There is no good support for accessing globals from a jit'ed triton 
@@ -891,13 +893,13 @@ def _get_mxfp8_dim1_kernel_autotune_configs():
891893
892894    @triton .autotune ( 
893895        configs = _get_mxfp8_dim1_kernel_autotune_configs (), 
894-         key = ["n_rows"  ,  " n_cols""INNER_BLOCK_SIZE" ], 
896+         key = ["n_cols" , "INNER_BLOCK_SIZE" ], 
895897    ) 
896898    @triton .jit  
897899    def  to_mxfp8_dim1_kernel (
898900        x_ptr ,  # pointer to input tensor 
899901        output_col_major_ptr ,  # pointer to column-major output tensor (column-normalized) 
900-         col_scale_ptr ,  # pointer to store column-wise maximum absolute values  
902+         col_scale_ptr ,  # pointer to store scales  
901903        n_rows ,  # number of rows in the tensor 
902904        n_cols ,  # number of columns in the tensor 
903905        ROW_TILE_SIZE : tl .constexpr ,
@@ -1038,6 +1040,175 @@ def to_mxfp8_dim1_kernel(
10381040        # TODO(future): mask this store 
10391041        tl .store (col_scale_start_ptr  +  col_scale_indices , col_scale_e8m0 )
10401042
1043+     @triton .autotune ( 
1044+         configs = _get_mxfp8_dim1_kernel_autotune_configs (), 
1045+         key = ["n_rows" , "n_cols" , "INNER_BLOCK_SIZE" ], 
1046+     ) 
1047+     @triton .jit  
1048+     def  to_mxfp8_dim0_kernel (
1049+         x_ptr ,
1050+         output_ptr ,
1051+         row_scale_ptr ,
1052+         n_rows ,
1053+         n_cols ,
1054+         ROW_TILE_SIZE : tl .constexpr ,
1055+         COL_TILE_SIZE : tl .constexpr ,
1056+         INNER_BLOCK_SIZE : tl .constexpr ,  # should be 32 for MX 
1057+     ):
1058+         """ 
1059+         Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity). 
1060+ 
1061+         This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization. 
1062+         Instead of transposing and scaling across columns, this kernel scales across rows. 
1063+         """ 
1064+ 
1065+         BLOCKS_PER_COL_TILE : tl .constexpr  =  COL_TILE_SIZE  //  INNER_BLOCK_SIZE 
1066+ 
1067+         # Get program ID 
1068+         pid_row  =  tl .program_id (0 )
1069+         pid_col  =  tl .program_id (1 )
1070+ 
1071+         # Calculate starting row and column for this tile 
1072+         start_row  =  pid_row  *  ROW_TILE_SIZE 
1073+         start_col  =  pid_col  *  COL_TILE_SIZE 
1074+ 
1075+         # Create offsets for the block 
1076+         row_offsets  =  tl .arange (0 , ROW_TILE_SIZE )
1077+         col_offsets  =  tl .arange (0 , COL_TILE_SIZE )
1078+ 
1079+         # Compute global row/col positions 
1080+         rows  =  start_row  +  row_offsets [:, None ]
1081+         cols  =  start_col  +  col_offsets [None , :]
1082+ 
1083+         # Create masks for out-of-bounds accesses 
1084+         row_mask  =  rows  <  n_rows 
1085+         col_mask  =  cols  <  n_cols 
1086+         mask  =  row_mask  &  col_mask 
1087+ 
1088+         # Compute memory offsets for row-major layout (rows, cols) 
1089+         row_major_offsets  =  (rows  *  n_cols  +  cols ).to (tl .int32 )
1090+ 
1091+         # Load the entire block in a single operation 
1092+         # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) 
1093+         x_block  =  tl .load (x_ptr  +  row_major_offsets , mask = mask )
1094+ 
1095+         # Reshape to inner tile size for rowwise scaling 
1096+         # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE) 
1097+         x_block_r  =  x_block .reshape (
1098+             ROW_TILE_SIZE  *  BLOCKS_PER_COL_TILE , INNER_BLOCK_SIZE 
1099+         )
1100+ 
1101+         # Calculate the absolute values of elements in the block 
1102+         x_block_abs_r  =  tl .abs (x_block_r )
1103+ 
1104+         # Find the maximum absolute value for each row (across columns) 
1105+         # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) 
1106+         row_scale_r , row_scale_e8m0_r  =  _triton_calculate_scale (x_block_abs_r , axis = 1 )
1107+ 
1108+         # Divide each row by scale 
1109+         # Broadcasting row_scale to match x_block's shape 
1110+         # x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE) 
1111+         # row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1) 
1112+         row_normalized_r  =  x_block_r  /  row_scale_r [:, None ]
1113+ 
1114+         # Reshape back to original tile size 
1115+         row_normalized  =  tl .reshape (row_normalized_r , ROW_TILE_SIZE , COL_TILE_SIZE )
1116+ 
1117+         # Quantize to float8 
1118+         row_normalized  =  row_normalized .to (tl .float8e4nv )
1119+ 
1120+         # Store the row-normalized result in row-major format 
1121+         tl .store (output_ptr  +  row_major_offsets , row_normalized , mask = mask )
1122+ 
1123+         # reshape row_scale_e8m0_r for proper storage 
1124+         # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) 
1125+         row_scale_e8m0  =  row_scale_e8m0_r .reshape (ROW_TILE_SIZE  *  BLOCKS_PER_COL_TILE )
1126+ 
1127+         row_scale_start_offsets  =  (
1128+             (pid_row  *  ROW_TILE_SIZE  *  (n_cols  //  COL_TILE_SIZE ))
1129+             *  BLOCKS_PER_COL_TILE   # number of blocks seen so far 
1130+             +  pid_col  *  BLOCKS_PER_COL_TILE   # increment BLOCKS_PER_COL_TILE 
1131+         )
1132+ 
1133+         row_scale_start_ptr  =  row_scale_ptr  +  row_scale_start_offsets 
1134+ 
1135+         # calculate row_scale_indices 
1136+         row_scale_indices  =  tl .arange (0 , ROW_TILE_SIZE  *  BLOCKS_PER_COL_TILE )
1137+ 
1138+         # How many values are in all the other rows for this col_pid, need to jump 
1139+         # over them for every BLOCKS_PER_COL_TILE values 
1140+         jump_vals_per_row  =  (n_cols  -  COL_TILE_SIZE ) //  INNER_BLOCK_SIZE 
1141+ 
1142+         # example transformation (specifics depend on tile sizes): 
1143+         # [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13] 
1144+         row_scale_indices  =  row_scale_indices  +  (
1145+             (row_scale_indices  //  BLOCKS_PER_COL_TILE ) *  jump_vals_per_row 
1146+         )
1147+ 
1148+         # Store the scales 
1149+         tl .store (row_scale_start_ptr  +  row_scale_indices , row_scale_e8m0 )
1150+ 
1151+     @triton_op ("torchao::triton_to_mxfp8_dim0" , mutates_args = {}) 
1152+     def  triton_to_mxfp8_dim0 (
1153+         x : torch .Tensor , inner_block_size : int  =  32 
1154+     ) ->  Tuple [torch .Tensor , torch .Tensor ]:
1155+         """ 
1156+         Input: 
1157+         * `x` - input tensor, in row major memory layout 
1158+         * `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes 
1159+ 
1160+         Output: 
1161+         * `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise) 
1162+         * `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0 
1163+         """ 
1164+         assert  x .is_contiguous (), "`x` must be contiguous" 
1165+         assert  inner_block_size  <=  32 
1166+ 
1167+         # Get tensor shape 
1168+         n_rows , n_cols  =  x .shape 
1169+ 
1170+         # Masking of loads and stores is not well tested yet, so for now enforce 
1171+         # shapes which do not need masking. Note that this condition depends on max values of 
1172+         # ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above. 
1173+         # TODO(future): implement and test masking and remove this restriction 
1174+         max_row_tile_size  =  128 
1175+         max_col_tile_size  =  128 
1176+         assert  n_rows  %  max_row_tile_size  ==  0 , "unsupported" 
1177+         assert  n_cols  %  max_col_tile_size  ==  0 , "unsupported" 
1178+ 
1179+         # Create output tensors 
1180+         output  =  torch .empty (
1181+             (n_rows , n_cols ), dtype = torch .float8_e4m3fn , device = x .device 
1182+         )
1183+ 
1184+         # Create scale tensors for rowwise scaling 
1185+         row_scale  =  torch .empty (
1186+             (n_rows , n_cols  //  inner_block_size , 1 ),
1187+             dtype = torch .uint8 ,
1188+             device = x .device ,
1189+         )
1190+ 
1191+         # Calculate grid dimensions based on tile size 
1192+         grid  =  lambda  META : (
1193+             triton .cdiv (n_rows , META ["ROW_TILE_SIZE" ]),
1194+             triton .cdiv (n_cols , META ["COL_TILE_SIZE" ]),
1195+         )
1196+ 
1197+         # Launch the kernel 
1198+         wrap_triton (to_mxfp8_dim0_kernel )[grid ](
1199+             x_ptr = x ,
1200+             output_ptr = output ,
1201+             row_scale_ptr = row_scale ,
1202+             n_rows = n_rows ,
1203+             n_cols = n_cols ,
1204+             INNER_BLOCK_SIZE = inner_block_size ,
1205+         )
1206+ 
1207+         return  (
1208+             output ,
1209+             row_scale .view (torch .float8_e8m0fnu ),
1210+         )
1211+ 
10411212    @triton_op ("torchao::triton_to_mxfp8_dim1" , mutates_args = {}) 
10421213    def  triton_to_mxfp8_dim1 (
10431214        x : torch .Tensor , inner_block_size : int  =  32 
@@ -1459,6 +1630,12 @@ def _(scale_tensor):
14591630        return  scale_tensor .new_empty ((padded_rows , padded_cols ))
14601631else :
14611632
1633+     def  triton_to_mxfp8_dim0 (
1634+         x : torch .Tensor ,
1635+         inner_block_size = 32 ,
1636+     ) ->  Tuple [torch .Tensor , torch .Tensor ]:
1637+         raise  AssertionError ("needs torch version 2.8+ and triton" )
1638+ 
14621639    def  triton_to_mxfp8_dim1 (
14631640        x , inner_block_size = 32 
14641641    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
0 commit comments