@@ -1229,37 +1229,50 @@ def Arith_ScalingExtFOp
12291229 let summary = "Upcasts input floats using provided scales values following "
12301230 "OCP MXFP Spec";
12311231 let description = [{
1232- This operation upcasts input floating-point values using provided scale
1233- values. It expects both scales and the input operand to be of the same shape,
1234- making the operation elementwise. Scales are usually calculated per block
1235- following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1236-
1237- If scales are calculated per block where blockSize != 1, then scales may
1238- require broadcasting to make this operation elementwise. For example, let's
1239- say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1240- assuming quantization happens on the last axis, the input can be reshaped to
1241- `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1242- per block on the last axis. Therefore, scales will be of shape
1243- `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1244- shape as long as it is broadcast compatible with the input, e.g.,
1245- `<1 x 1 x ... (dimN/blockSize) x 1>`.
1246-
1247- In this example, before calling into `arith.scaling_extf`, scales must be
1248- broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1249- that there could be multiple quantization axes. Internally,
1250- `arith.scaling_extf` would perform the following:
1232+ This operation upcasts input floating-point values using provided scale
1233+ values. It expects both scales and the input operand to be of the same shape,
1234+ making the operation elementwise. Scales are usually calculated per block
1235+ following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
12511236
1252- ```
1253- resultTy = get_type(result)
1254- scaleTy = get_type(scale)
1255- inputTy = get_type(input)
1256- scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
1257- scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
1258- input.extf = arith.extf(input) : inputTy to resultTy
1259- result = arith.mulf(scale.extf, input.extf)
1237+ If scales are calculated per block where blockSize != 1, then scales may
1238+ require broadcasting to make this operation elementwise. For example, let's
1239+ say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1240+ assuming quantization happens on the last axis, the input can be reshaped to
1241+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1242+ per block on the last axis. Therefore, scales will be of shape
1243+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1244+ shape as long as it is broadcast compatible with the input, e.g.,
1245+ `<1 x 1 x ... (dimN/blockSize) x 1>`.
1246+
1247+ In this example, before calling into `arith.scaling_extf`, scales must be
1248+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1249+ that there could be multiple quantization axes. Internally,
1250+ `arith.scaling_extf` would perform the following:
1251+
1252+ ```mlir
1253+ // Cast scale to result type.
1254+ %0 = arith.truncf %1 : f32 to f8E8M0FNU
1255+ %1 = arith.extf %0 : f8E8M0FNU to f16
1256+
1257+ // Cast input to result type.
1258+ %2 = arith.extf %3 : f4E2M1FN to f16
1259+
1260+ // Perform scaling
1261+ %3 = arith.mulf %2, %1 : f16
12601262 ```
12611263 It propagates NaN values. Therefore, if either scale or the input element
12621264 contains NaN, then the output element value will also be a NaN.
1265+
1266+ Example:
1267+
1268+ ```mlir
1269+ // Upcast from f4E2M1FN to f32.
1270+ %a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32
1271+
1272+ // Element-wise upcast with broadcast (blockSize = 32).
1273+ %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
1274+ %h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16>
1275+ ```
12631276 }];
12641277 let hasVerifier = 1;
12651278 let assemblyFormat =
@@ -1397,14 +1410,27 @@ def Arith_ScalingTruncFOp
13971410 that there could be multiple quantization axes. Internally,
13981411 `arith.scaling_truncf` would perform the following:
13991412
1413+ ```mlir
1414+ // Cast scale to input type.
1415+ %0 = arith.truncf %1 : f32 to f8E8M0FNU
1416+ %1 = arith.extf %0 : f8E8M0FNU to f16
1417+
1418+ // Perform scaling.
1419+ %3 = arith.divf %2, %1 : f16
1420+
1421+ // Cast to result type.
1422+ %4 = arith.truncf %3 : f16 to f4E2M1FN
14001423 ```
1401- scaleTy = get_type(scale)
1402- inputTy = get_type(input)
1403- resultTy = get_type(result)
1404- scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
1405- scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
1406- result = arith.divf(input, scale.extf)
1407- result.cast = arith.truncf(result, resultTy)
1424+
1425+ Example:
1426+
1427+ ```mlir
1428+ // Downcast from f32 to f4E2M1FN.
1429+ %a = arith.scaling_truncf %b, %c : f32, f8E8M0FNU to f4E2M1FN
1430+
1431+ // Element-wise downcast with broadcast (blockSize = 32).
1432+ %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
1433+ %h = arith.scaling_truncf %i, %f : vector<32xbf16>, vector<32xf8E8M0FNU> to vector<32xf4E2M1FN>
14081434 ```
14091435 }];
14101436 let hasVerifier = 1;
0 commit comments