From 687b65d54310ec1905a7ddc87c735f538748947b Mon Sep 17 00:00:00 2001 From: "yan.xu0210" Date: Fri, 22 Mar 2024 06:08:23 +0800 Subject: [PATCH] [onnx] add eps_outside_sqrt for l2_norm --- compiler/doc/byteir_mhlo_custom_call.md | 19 +++++----- .../src/Conversion/OFRewriteToCustomCall.cpp | 36 +++++++++++++++++++ .../src/Conversion/OFRewriteToCustomCall.td | 16 +++++++++ .../test/of_rewrite_to_custom_call.mlir | 18 ++++++++++ 4 files changed, 80 insertions(+), 9 deletions(-) diff --git a/compiler/doc/byteir_mhlo_custom_call.md b/compiler/doc/byteir_mhlo_custom_call.md index e0e80b608..0cc5887c9 100644 --- a/compiler/doc/byteir_mhlo_custom_call.md +++ b/compiler/doc/byteir_mhlo_custom_call.md @@ -2,15 +2,15 @@ ByteIR compiler introduces several coarse-grained ops to improve pattern-matching rewriting during compilation. -ByteIR implements in the way of re-using mhlo custom call op definition with a ByteIR prefix in `call_target_name`, +ByteIR implements in the way of re-using mhlo custom call op definition with a ByteIR prefix in `call_target_name`, instead of defining another new dialect. ByteIR implements this conversion in frontends, instead of puting it to ByteIR compiler. -## Rationales +## Rationales ### Need of coarse-grained ops -Introduction of coarse-grained ops can provide several benefits as follows, +Introduction of coarse-grained ops can provide several benefits as follows, * it simplifies pattern-matching processes during rewriting regardless of optimization or lowering; * it allows high-level information to be encoded with coase-grained ops, helping optimization; * it provides intuitive mapping from frontends to IR, helping debuggability; @@ -33,14 +33,14 @@ Implementing coarse-grained op conversion in frontends can provide several benef ## Addtional op definition -A coarse-grained op kind is defined through with a prefix. +A coarse-grained op kind is defined through with a prefix. ```call_target_name = "byteir.softmax" or "tf.DynamicPartition"``` If an op is generic across frontends, which happen mostly, it uses a `byteir` prefix. If an op is frontend-specific, it uses a frontend-specific prefix, such as `tf` or `pytorch`. -Further needed infomation for a given coarse-grained op are encoded in a dictionary attribute, called `byteir_attrs`, which includes all named attributes. +Further needed infomation for a given coarse-grained op are encoded in a dictionary attribute, called `byteir_attrs`, which includes all named attributes. **Op Attribute**: * ```byteir_attrs = {approximate = "none"}``` or ```byteir_attrs = {}``` if no attribute @@ -56,7 +56,7 @@ Further needed infomation for a given coarse-grained op are encoded in a diction - axis: I64ArrayAttr - eps_outside_sqrt: Optional\ - Results(1 or 3): - - output: Tensor + - output: Tensor - mean: Optional\ - inv_std_dev: Optional\ @@ -65,6 +65,7 @@ Further needed infomation for a given coarse-grained op are encoded in a diction - input: Tensor - Attrs - epsilon: F64Attr + - eps_outside_sqrt: F64Attr - axis: I64ArrayAttr - Results: - output: Tensor @@ -116,10 +117,10 @@ Further needed infomation for a given coarse-grained op are encoded in a diction - select_last_index: BoolAttr - Results: - output: Optional\ - - indices: IntTensor + - indices: IntTensor -### byteir.top_k +### byteir.top_k - Operands: - input: Tensor - Attrs @@ -130,7 +131,7 @@ Further needed infomation for a given coarse-grained op are encoded in a diction - output: Tensor - indices: IntTensor -### byteir.erf +### byteir.erf - Operands: - input: Tensor - Results: diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp index 7689ec822..26a2b7668 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp @@ -157,6 +157,40 @@ Value createL2NormWithoutEps(PatternRewriter &rewriter, Location loc, return customCallOp.getResults()[0]; } +Value createL2NormWithOutsideSqrtEps(PatternRewriter &rewriter, Location loc, + Value input, Value axes, Value epsValue) { + RankedTensorType inputType = + input.getType().dyn_cast_or_null(); + assert(inputType != nullptr && "L2Norm input type must be ranked"); + + ElementsAttr axis_attr = onnx_mlir::getElementAttributeFromONNXValue(axes); + int64_t axis = axis_attr.getValues()[0].getSExtValue(); + // canonicalize axis to be positive + if (axis < 0) { + axis = inputType.getRank() + axis; + } + ElementsAttr epsilon_attr = + onnx_mlir::getElementAttributeFromONNXValue(epsValue); + double epsilon = + (*epsilon_attr.getValues().begin()).convertToDouble(); + assert(0 < epsilon && epsilon < 1e-7 && "epsilon out of range for L2Norm"); + + std::string call_target_name = getL2NormNameWithPrefix(); + stablehlo::CustomCallOp customCallOp = + rewriter.create( + loc, llvm::ArrayRef{inputType}, llvm::ArrayRef{input}, + call_target_name, false, rewriter.getStringAttr(""), + stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL, + rewriter.getArrayAttr(llvm::ArrayRef{}), nullptr, + nullptr, rewriter.getArrayAttr(llvm::ArrayRef{})); + DictionaryAttrWrapper attrs(rewriter.getContext()); + attrs.setAttr("eps_outside_sqrt", rewriter.getF64FloatAttr(epsilon)); + attrs.setAttr("axis", rewriter.getI64ArrayAttr({axis})); + customCallOp->setAttr(BYTEIR_ATTRS, getCleanAttr(attrs)); + + return customCallOp.getResults()[0]; +} + //===----------------------------------------------------------------------===// // Quantize/Dequantize //===----------------------------------------------------------------------===// @@ -646,6 +680,8 @@ struct OFRewriteToCustomCallPass std::make_unique(context)); validOpSet[getL2NormName()].emplace_back( std::make_unique(context)); + validOpSet[getL2NormName()].emplace_back( + std::make_unique(context)); validOpSet[getQuantizeName()].emplace_back( std::make_unique(context)); validOpSet[getDequantizeName()].emplace_back( diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td index 927871fa3..aaecd3224 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td @@ -32,6 +32,9 @@ def HasOneUse : Constraint, "value has exactly one use"> def isScalarConstantTensor : Constraint, "Value is produced by a dense ONNXConstantOp and has size one">; +def isNoneValue : + Constraint, + "Value is None Value">; def SameTwoIntegerScalarConstantValues : Constraint()[0].getSExtValue() == onnx_mlir::getElementAttributeFromONNXValue($1).getValues()[0].getSExtValue()">, "Two integer scalar constant Value have the same integer scalar">; @@ -61,6 +64,19 @@ def RewriteL2NormPat2 : Pat< (NativeCodeCall<"createL2NormWithoutEps($_builder, $_loc, $0, $1)"> $input, $axes), [(isScalarConstantTensor:$axes), (TrueBoolAttr $keep_dims)]>; +def RewriteL2NormPat3 : Pat< + (ONNXDivOp + $input, + (ONNXExpandOp + (ONNXClipOp + (ONNXReduceL2Op $input, $axes, $keep_dims, $noop_with_empty_axes), $min, $max + ), + (ONNXConstantOp $_, $_, $_, $_, $_, $_, $_, $_) // should be the shape of $input + ) + ), + (NativeCodeCall<"createL2NormWithOutsideSqrtEps($_builder, $_loc, $0, $1, $2)"> $input, $axes, $min), + [(isScalarConstantTensor:$axes), (TrueBoolAttr $keep_dims), (isNoneValue:$max)]>; + //===----------------------------------------------------------------------===// // Quantize/Dequantize Pattern //===----------------------------------------------------------------------===// diff --git a/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir b/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir index 2cfd75d08..a8b82eca5 100644 --- a/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir +++ b/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir @@ -241,6 +241,24 @@ func.func @test_l2_norm_pat2(%1146: tensor<12x128xf32>) -> tensor<12x128xf32> { // ----- +func.func @test_l2_norm_pat3(%arg0: tensor<16x128xf32>) -> tensor<16x128xf32> { + %0 = onnx.Constant dense<1> : tensor<1xi64> + %1 = onnx.Constant dense<9.99999996E-13> : tensor + %2 = "onnx.ReduceL2"(%arg0, %0) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<16x128xf32>, tensor<1xi64>) -> tensor<16x1xf32> + %3 = "onnx.NoValue"() {value} : () -> none + %4 = "onnx.Clip"(%2, %1, %3) : (tensor<16x1xf32>, tensor, none) -> tensor<16x1xf32> + %5 = onnx.Constant dense<[16, 128]> : tensor<2xi64> + %6 = "onnx.Expand"(%4, %5) : (tensor<16x1xf32>, tensor<2xi64>) -> tensor<16x128xf32> + %7 = "onnx.Div"(%arg0, %6) : (tensor<16x128xf32>, tensor<16x128xf32>) -> tensor<16x128xf32> + return %7 : tensor<16x128xf32> +// CHECK-LABEL: @test_l2_norm_pat3 +// CHECK-SAME: (%arg0: tensor<16x128xf32>) -> tensor<16x128xf32> { +// CHECK: %0 = stablehlo.custom_call @byteir.l2_norm(%arg0) {byteir_attrs = {axis = [1], eps_outside_sqrt = 9.999999960041972E-13 : f64}} : (tensor<16x128xf32>) -> tensor<16x128xf32> +// CHECK: return %0 : tensor<16x128xf32> +} + +// ----- + func.func @test_quantize_per_tensor(%arg0: tensor<16x3x256x256xf32>) -> tensor<16x3x256x256xi8> { %291 = stablehlo.constant dense<0.0207054354> : tensor %292 = stablehlo.constant dense<0> : tensor