@@ -1038,6 +1038,175 @@ def to_mxfp8_dim1_kernel(
10381038 # TODO(future): mask this store
10391039 tl .store (col_scale_start_ptr + col_scale_indices , col_scale_e8m0 )
10401040
1041+ @triton .autotune (
1042+ configs = _get_mxfp8_dim1_kernel_autotune_configs (),
1043+ key = ["n_rows" , "n_cols" , "INNER_BLOCK_SIZE" ],
1044+ )
1045+ @triton .jit
1046+ def to_mxfp8_dim0_kernel (
1047+ x_ptr , # pointer to input tensor
1048+ output_ptr , # pointer to output tensor (row-normalized)
1049+ row_scale_ptr , # pointer to store row-wise maximum absolute values
1050+ n_rows , # number of rows in the tensor
1051+ n_cols , # number of columns in the tensor
1052+ ROW_TILE_SIZE : tl .constexpr ,
1053+ COL_TILE_SIZE : tl .constexpr ,
1054+ INNER_BLOCK_SIZE : tl .constexpr , # should be 32 for MX
1055+ ):
1056+ """
1057+ Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
1058+
1059+ This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization.
1060+ Instead of transposing and scaling across columns, this kernel scales across rows.
1061+ """
1062+
1063+ BLOCKS_PER_COL_TILE : tl .constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE
1064+
1065+ # Get program ID
1066+ pid_row = tl .program_id (0 )
1067+ pid_col = tl .program_id (1 )
1068+
1069+ # Calculate starting row and column for this tile
1070+ start_row = pid_row * ROW_TILE_SIZE
1071+ start_col = pid_col * COL_TILE_SIZE
1072+
1073+ # Create offsets for the block
1074+ row_offsets = tl .arange (0 , ROW_TILE_SIZE )
1075+ col_offsets = tl .arange (0 , COL_TILE_SIZE )
1076+
1077+ # Compute global row/col positions
1078+ rows = start_row + row_offsets [:, None ]
1079+ cols = start_col + col_offsets [None , :]
1080+
1081+ # Create masks for out-of-bounds accesses
1082+ row_mask = rows < n_rows
1083+ col_mask = cols < n_cols
1084+ mask = row_mask & col_mask
1085+
1086+ # Compute memory offsets for row-major layout (rows, cols)
1087+ row_major_offsets = (rows * n_cols + cols ).to (tl .int32 )
1088+
1089+ # Load the entire block in a single operation
1090+ # shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
1091+ x_block = tl .load (x_ptr + row_major_offsets , mask = mask )
1092+
1093+ # Reshape to inner tile size for rowwise scaling
1094+ # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1095+ x_block_r = x_block .reshape (
1096+ ROW_TILE_SIZE * BLOCKS_PER_COL_TILE , INNER_BLOCK_SIZE
1097+ )
1098+
1099+ # Calculate the absolute values of elements in the block
1100+ x_block_abs_r = tl .abs (x_block_r )
1101+
1102+ # Find the maximum absolute value for each row (across columns)
1103+ # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1104+ row_scale_r , row_scale_e8m0_r = _triton_calculate_scale (x_block_abs_r , axis = 1 )
1105+
1106+ # Divide each row by scale
1107+ # Broadcasting row_scale to match x_block's shape
1108+ # x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1109+ # row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
1110+ row_normalized_r = x_block_r / row_scale_r [:, None ]
1111+
1112+ # Reshape back to original tile size
1113+ row_normalized = tl .reshape (row_normalized_r , ROW_TILE_SIZE , COL_TILE_SIZE )
1114+
1115+ # Quantize to float8
1116+ row_normalized = row_normalized .to (tl .float8e4nv )
1117+
1118+ # Store the row-normalized result in row-major format
1119+ tl .store (output_ptr + row_major_offsets , row_normalized , mask = mask )
1120+
1121+ # reshape row_scale_e8m0_r for proper storage
1122+ # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1123+ row_scale_e8m0 = row_scale_e8m0_r .reshape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE )
1124+
1125+ row_scale_start_offsets = (
1126+ (pid_row * ROW_TILE_SIZE * (n_cols // COL_TILE_SIZE ))
1127+ * BLOCKS_PER_COL_TILE # number of blocks seen so far
1128+ + pid_col * BLOCKS_PER_COL_TILE # increment BLOCKS_PER_COL_TILE
1129+ )
1130+
1131+ row_scale_start_ptr = row_scale_ptr + row_scale_start_offsets
1132+
1133+ # calculate row_scale_indices
1134+ row_scale_indices = tl .arange (0 , ROW_TILE_SIZE * BLOCKS_PER_COL_TILE )
1135+
1136+ # How many values are in all the other rows for this col_pid, need to jump
1137+ # over them for every BLOCKS_PER_COL_TILE values
1138+ jump_vals_per_row = (n_cols - COL_TILE_SIZE ) // INNER_BLOCK_SIZE
1139+
1140+ # example transformation (specifics depend on tile sizes):
1141+ # [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13]
1142+ row_scale_indices = row_scale_indices + (
1143+ (row_scale_indices // BLOCKS_PER_COL_TILE ) * jump_vals_per_row
1144+ )
1145+
1146+ # Store the scales
1147+ tl .store (row_scale_start_ptr + row_scale_indices , row_scale_e8m0 )
1148+
1149+ @triton_op ("torchao::triton_to_mxfp8_dim0" , mutates_args = {})
1150+ def triton_to_mxfp8_dim0 (
1151+ x : torch .Tensor , inner_block_size : int = 32
1152+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1153+ """
1154+ Input:
1155+ * `x` - input tensor, in row major memory layout
1156+ * `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes
1157+
1158+ Output:
1159+ * `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
1160+ * `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
1161+ """
1162+ assert x .is_contiguous (), "`x` must be contiguous"
1163+ assert inner_block_size <= 32
1164+
1165+ # Get tensor shape
1166+ n_rows , n_cols = x .shape
1167+
1168+ # Masking of loads and stores is not well tested yet, so for now enforce
1169+ # shapes which do not need masking. Note that this condition depends on max values of
1170+ # ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above.
1171+ # TODO(future): implement and test masking and remove this restriction
1172+ max_row_tile_size = 128
1173+ max_col_tile_size = 128
1174+ assert n_rows % max_row_tile_size == 0 , "unsupported"
1175+ assert n_cols % max_col_tile_size == 0 , "unsupported"
1176+
1177+ # Create output tensors
1178+ output = torch .empty (
1179+ (n_rows , n_cols ), dtype = torch .float8_e4m3fn , device = x .device
1180+ )
1181+
1182+ # Create scale tensors for rowwise scaling
1183+ row_scale = torch .empty (
1184+ (n_rows , n_cols // inner_block_size , 1 ),
1185+ dtype = torch .uint8 ,
1186+ device = x .device ,
1187+ )
1188+
1189+ # Calculate grid dimensions based on tile size
1190+ grid = lambda META : (
1191+ triton .cdiv (n_rows , META ["ROW_TILE_SIZE" ]),
1192+ triton .cdiv (n_cols , META ["COL_TILE_SIZE" ]),
1193+ )
1194+
1195+ # Launch the kernel
1196+ wrap_triton (to_mxfp8_dim0_kernel )[grid ](
1197+ x_ptr = x ,
1198+ output_ptr = output ,
1199+ row_scale_ptr = row_scale ,
1200+ n_rows = n_rows ,
1201+ n_cols = n_cols ,
1202+ INNER_BLOCK_SIZE = inner_block_size ,
1203+ )
1204+
1205+ return (
1206+ output ,
1207+ row_scale .view (torch .float8_e8m0fnu ),
1208+ )
1209+
10411210 @triton_op ("torchao::triton_to_mxfp8_dim1" , mutates_args = {})
10421211 def triton_to_mxfp8_dim1 (
10431212 x : torch .Tensor , inner_block_size : int = 32
0 commit comments