Skip to content

Commit

Permalink
[onnx] add eps_outside_sqrt for l2_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
Connor-XY committed Mar 21, 2024
1 parent 30768c7 commit 687b65d
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 9 deletions.
19 changes: 10 additions & 9 deletions compiler/doc/byteir_mhlo_custom_call.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

ByteIR compiler introduces several coarse-grained ops to improve pattern-matching rewriting during compilation.

ByteIR implements in the way of re-using mhlo custom call op definition with a ByteIR prefix in `call_target_name`,
ByteIR implements in the way of re-using mhlo custom call op definition with a ByteIR prefix in `call_target_name`,
instead of defining another new dialect.

ByteIR implements this conversion in frontends, instead of puting it to ByteIR compiler.

## Rationales
## Rationales
### Need of coarse-grained ops

Introduction of coarse-grained ops can provide several benefits as follows,
Introduction of coarse-grained ops can provide several benefits as follows,
* it simplifies pattern-matching processes during rewriting regardless of optimization or lowering;
* it allows high-level information to be encoded with coase-grained ops, helping optimization;
* it provides intuitive mapping from frontends to IR, helping debuggability;
Expand All @@ -33,14 +33,14 @@ Implementing coarse-grained op conversion in frontends can provide several benef

## Addtional op definition

A coarse-grained op kind is defined through with a prefix.
A coarse-grained op kind is defined through with a prefix.

```call_target_name = "byteir.softmax" or "tf.DynamicPartition"```

If an op is generic across frontends, which happen mostly, it uses a `byteir` prefix.
If an op is frontend-specific, it uses a frontend-specific prefix, such as `tf` or `pytorch`.

Further needed infomation for a given coarse-grained op are encoded in a dictionary attribute, called `byteir_attrs`, which includes all named attributes.
Further needed infomation for a given coarse-grained op are encoded in a dictionary attribute, called `byteir_attrs`, which includes all named attributes.

**Op Attribute**:
* ```byteir_attrs = {approximate = "none"}``` or ```byteir_attrs = {}``` if no attribute
Expand All @@ -56,7 +56,7 @@ Further needed infomation for a given coarse-grained op are encoded in a diction
- axis: I64ArrayAttr
- eps_outside_sqrt: Optional\<BoolAttr>
- Results(1 or 3):
- output: Tensor
- output: Tensor
- mean: Optional\<Tensor>
- inv_std_dev: Optional\<Tensor>

Expand All @@ -65,6 +65,7 @@ Further needed infomation for a given coarse-grained op are encoded in a diction
- input: Tensor
- Attrs
- epsilon: F64Attr
- eps_outside_sqrt: F64Attr
- axis: I64ArrayAttr
- Results:
- output: Tensor
Expand Down Expand Up @@ -116,10 +117,10 @@ Further needed infomation for a given coarse-grained op are encoded in a diction
- select_last_index: BoolAttr
- Results:
- output: Optional\<Tensor>
- indices: IntTensor
- indices: IntTensor


### byteir.top_k
### byteir.top_k
- Operands:
- input: Tensor
- Attrs
Expand All @@ -130,7 +131,7 @@ Further needed infomation for a given coarse-grained op are encoded in a diction
- output: Tensor
- indices: IntTensor

### byteir.erf
### byteir.erf
- Operands:
- input: Tensor
- Results:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,40 @@ Value createL2NormWithoutEps(PatternRewriter &rewriter, Location loc,
return customCallOp.getResults()[0];
}

Value createL2NormWithOutsideSqrtEps(PatternRewriter &rewriter, Location loc,
Value input, Value axes, Value epsValue) {
RankedTensorType inputType =
input.getType().dyn_cast_or_null<RankedTensorType>();
assert(inputType != nullptr && "L2Norm input type must be ranked");

ElementsAttr axis_attr = onnx_mlir::getElementAttributeFromONNXValue(axes);
int64_t axis = axis_attr.getValues<APInt>()[0].getSExtValue();
// canonicalize axis to be positive
if (axis < 0) {
axis = inputType.getRank() + axis;
}
ElementsAttr epsilon_attr =
onnx_mlir::getElementAttributeFromONNXValue(epsValue);
double epsilon =
(*epsilon_attr.getValues<APFloat>().begin()).convertToDouble();
assert(0 < epsilon && epsilon < 1e-7 && "epsilon out of range for L2Norm");

std::string call_target_name = getL2NormNameWithPrefix();
stablehlo::CustomCallOp customCallOp =
rewriter.create<mlir::stablehlo::CustomCallOp>(
loc, llvm::ArrayRef<Type>{inputType}, llvm::ArrayRef<Value>{input},
call_target_name, false, rewriter.getStringAttr(""),
stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL,
rewriter.getArrayAttr(llvm::ArrayRef<mlir::Attribute>{}), nullptr,
nullptr, rewriter.getArrayAttr(llvm::ArrayRef<mlir::Attribute>{}));
DictionaryAttrWrapper attrs(rewriter.getContext());
attrs.setAttr("eps_outside_sqrt", rewriter.getF64FloatAttr(epsilon));
attrs.setAttr("axis", rewriter.getI64ArrayAttr({axis}));
customCallOp->setAttr(BYTEIR_ATTRS, getCleanAttr(attrs));

return customCallOp.getResults()[0];
}

//===----------------------------------------------------------------------===//
// Quantize/Dequantize
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -646,6 +680,8 @@ struct OFRewriteToCustomCallPass
std::make_unique<RewriteL2NormPat1>(context));
validOpSet[getL2NormName()].emplace_back(
std::make_unique<RewriteL2NormPat2>(context));
validOpSet[getL2NormName()].emplace_back(
std::make_unique<RewriteL2NormPat3>(context));
validOpSet[getQuantizeName()].emplace_back(
std::make_unique<RewriteQuantize>(context));
validOpSet[getDequantizeName()].emplace_back(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">, "value has exactly one use">
def isScalarConstantTensor :
Constraint<CPred<"onnx_mlir::isScalarConstantTensor($_self)">,
"Value is produced by a dense ONNXConstantOp and has size one">;
def isNoneValue :
Constraint<CPred<"onnx_mlir::isNoneValue($_self)">,
"Value is None Value">;
def SameTwoIntegerScalarConstantValues :
Constraint<CPred<"onnx_mlir::getElementAttributeFromONNXValue($0).getValues<APInt>()[0].getSExtValue() == onnx_mlir::getElementAttributeFromONNXValue($1).getValues<APInt>()[0].getSExtValue()">,
"Two integer scalar constant Value have the same integer scalar">;
Expand Down Expand Up @@ -61,6 +64,19 @@ def RewriteL2NormPat2 : Pat<
(NativeCodeCall<"createL2NormWithoutEps($_builder, $_loc, $0, $1)"> $input, $axes),
[(isScalarConstantTensor:$axes), (TrueBoolAttr $keep_dims)]>;

def RewriteL2NormPat3 : Pat<
(ONNXDivOp
$input,
(ONNXExpandOp
(ONNXClipOp
(ONNXReduceL2Op $input, $axes, $keep_dims, $noop_with_empty_axes), $min, $max
),
(ONNXConstantOp $_, $_, $_, $_, $_, $_, $_, $_) // should be the shape of $input
)
),
(NativeCodeCall<"createL2NormWithOutsideSqrtEps($_builder, $_loc, $0, $1, $2)"> $input, $axes, $min),
[(isScalarConstantTensor:$axes), (TrueBoolAttr $keep_dims), (isNoneValue:$max)]>;

//===----------------------------------------------------------------------===//
// Quantize/Dequantize Pattern
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,24 @@ func.func @test_l2_norm_pat2(%1146: tensor<12x128xf32>) -> tensor<12x128xf32> {

// -----

func.func @test_l2_norm_pat3(%arg0: tensor<16x128xf32>) -> tensor<16x128xf32> {
%0 = onnx.Constant dense<1> : tensor<1xi64>
%1 = onnx.Constant dense<9.99999996E-13> : tensor<f32>
%2 = "onnx.ReduceL2"(%arg0, %0) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<16x128xf32>, tensor<1xi64>) -> tensor<16x1xf32>
%3 = "onnx.NoValue"() {value} : () -> none
%4 = "onnx.Clip"(%2, %1, %3) : (tensor<16x1xf32>, tensor<f32>, none) -> tensor<16x1xf32>
%5 = onnx.Constant dense<[16, 128]> : tensor<2xi64>
%6 = "onnx.Expand"(%4, %5) : (tensor<16x1xf32>, tensor<2xi64>) -> tensor<16x128xf32>
%7 = "onnx.Div"(%arg0, %6) : (tensor<16x128xf32>, tensor<16x128xf32>) -> tensor<16x128xf32>
return %7 : tensor<16x128xf32>
// CHECK-LABEL: @test_l2_norm_pat3
// CHECK-SAME: (%arg0: tensor<16x128xf32>) -> tensor<16x128xf32> {
// CHECK: %0 = stablehlo.custom_call @byteir.l2_norm(%arg0) {byteir_attrs = {axis = [1], eps_outside_sqrt = 9.999999960041972E-13 : f64}} : (tensor<16x128xf32>) -> tensor<16x128xf32>
// CHECK: return %0 : tensor<16x128xf32>
}

// -----

func.func @test_quantize_per_tensor(%arg0: tensor<16x3x256x256xf32>) -> tensor<16x3x256x256xi8> {
%291 = stablehlo.constant dense<0.0207054354> : tensor<f32>
%292 = stablehlo.constant dense<0> : tensor<i8>
Expand Down

0 comments on commit 687b65d

Please sign in to comment.