Skip to content

Commit 1e78d33

Browse files
[mlir][arith][nfc] Adding examples to scaling_extf/truncf descriptions (#163980)
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
1 parent a99e32b commit 1e78d33

File tree

1 file changed

+60
-34
lines changed

1 file changed

+60
-34
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 60 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,37 +1229,50 @@ def Arith_ScalingExtFOp
12291229
let summary = "Upcasts input floats using provided scales values following "
12301230
"OCP MXFP Spec";
12311231
let description = [{
1232-
This operation upcasts input floating-point values using provided scale
1233-
values. It expects both scales and the input operand to be of the same shape,
1234-
making the operation elementwise. Scales are usually calculated per block
1235-
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1236-
1237-
If scales are calculated per block where blockSize != 1, then scales may
1238-
require broadcasting to make this operation elementwise. For example, let's
1239-
say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1240-
assuming quantization happens on the last axis, the input can be reshaped to
1241-
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1242-
per block on the last axis. Therefore, scales will be of shape
1243-
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1244-
shape as long as it is broadcast compatible with the input, e.g.,
1245-
`<1 x 1 x ... (dimN/blockSize) x 1>`.
1246-
1247-
In this example, before calling into `arith.scaling_extf`, scales must be
1248-
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1249-
that there could be multiple quantization axes. Internally,
1250-
`arith.scaling_extf` would perform the following:
1232+
This operation upcasts input floating-point values using provided scale
1233+
values. It expects both scales and the input operand to be of the same shape,
1234+
making the operation elementwise. Scales are usually calculated per block
1235+
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
12511236

1252-
```
1253-
resultTy = get_type(result)
1254-
scaleTy = get_type(scale)
1255-
inputTy = get_type(input)
1256-
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
1257-
scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
1258-
input.extf = arith.extf(input) : inputTy to resultTy
1259-
result = arith.mulf(scale.extf, input.extf)
1237+
If scales are calculated per block where blockSize != 1, then scales may
1238+
require broadcasting to make this operation elementwise. For example, let's
1239+
say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1240+
assuming quantization happens on the last axis, the input can be reshaped to
1241+
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1242+
per block on the last axis. Therefore, scales will be of shape
1243+
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1244+
shape as long as it is broadcast compatible with the input, e.g.,
1245+
`<1 x 1 x ... (dimN/blockSize) x 1>`.
1246+
1247+
In this example, before calling into `arith.scaling_extf`, scales must be
1248+
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1249+
that there could be multiple quantization axes. Internally,
1250+
`arith.scaling_extf` would perform the following:
1251+
1252+
```mlir
1253+
// Cast scale to result type.
1254+
%0 = arith.truncf %1 : f32 to f8E8M0FNU
1255+
%1 = arith.extf %0 : f8E8M0FNU to f16
1256+
1257+
// Cast input to result type.
1258+
%2 = arith.extf %3 : f4E2M1FN to f16
1259+
1260+
// Perform scaling
1261+
%3 = arith.mulf %2, %1 : f16
12601262
```
12611263
It propagates NaN values. Therefore, if either scale or the input element
12621264
contains NaN, then the output element value will also be a NaN.
1265+
1266+
Example:
1267+
1268+
```mlir
1269+
// Upcast from f4E2M1FN to f32.
1270+
%a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32
1271+
1272+
// Element-wise upcast with broadcast (blockSize = 32).
1273+
%f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
1274+
%h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16>
1275+
```
12631276
}];
12641277
let hasVerifier = 1;
12651278
let assemblyFormat =
@@ -1397,14 +1410,27 @@ def Arith_ScalingTruncFOp
13971410
that there could be multiple quantization axes. Internally,
13981411
`arith.scaling_truncf` would perform the following:
13991412

1413+
```mlir
1414+
// Cast scale to input type.
1415+
%0 = arith.truncf %1 : f32 to f8E8M0FNU
1416+
%1 = arith.extf %0 : f8E8M0FNU to f16
1417+
1418+
// Perform scaling.
1419+
%3 = arith.divf %2, %1 : f16
1420+
1421+
// Cast to result type.
1422+
%4 = arith.truncf %3 : f16 to f4E2M1FN
14001423
```
1401-
scaleTy = get_type(scale)
1402-
inputTy = get_type(input)
1403-
resultTy = get_type(result)
1404-
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
1405-
scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
1406-
result = arith.divf(input, scale.extf)
1407-
result.cast = arith.truncf(result, resultTy)
1424+
1425+
Example:
1426+
1427+
```mlir
1428+
// Downcast from f32 to f4E2M1FN.
1429+
%a = arith.scaling_truncf %b, %c : f32, f8E8M0FNU to f4E2M1FN
1430+
1431+
// Element-wise downcast with broadcast (blockSize = 32).
1432+
%f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
1433+
%h = arith.scaling_truncf %i, %f : vector<32xbf16>, vector<32xf8E8M0FNU> to vector<32xf4E2M1FN>
14081434
```
14091435
}];
14101436
let hasVerifier = 1;

0 commit comments

Comments
 (0)