diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFCanonicalizer.cpp b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFCanonicalizer.cpp index 87a8a88e9..53f5f0950 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFCanonicalizer.cpp +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFCanonicalizer.cpp @@ -48,7 +48,11 @@ struct OFCanonicalizerPass void runOnOperation() override { ModuleOp module = getOperation(); // Canonicalization - LogicalResult converged = applyPatternsAndFoldGreedily(module, patterns); + // TODO: revert this after bump llvm version after 2024-12-22 + GreedyRewriteConfig config; + config.useTopDownTraversal = true; + LogicalResult converged = + applyPatternsAndFoldGreedily(module, patterns, config); // Canonicalization is best-effort. Non-convergence is not a pass failure. if (testConvergence && failed(converged)) signalPassFailure(); diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp index a2fd8ab4a..554517aec 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp @@ -375,6 +375,26 @@ Value createResize(PatternRewriter &rewriter, Location loc, Value input, // LayerNorm //===----------------------------------------------------------------------===// +// NOTE: This method already exists in upstream llvm. Replace this once +// upgrading llvm. +TypedAttr getOneAttr(PatternRewriter &rewriter, Type type) { + if (llvm::isa(type)) + return rewriter.getFloatAttr(type, 1.0); + if (llvm::isa(type)) + return rewriter.getIndexAttr(1); + if (llvm::dyn_cast(type)) + return rewriter.getIntegerAttr( + type, APInt(llvm::cast(type).getWidth(), 1)); + if (llvm::isa(type)) { + auto vtType = llvm::cast(type); + auto element = getOneAttr(rewriter, vtType.getElementType()); + if (!element) + return {}; + return DenseElementsAttr::get(vtType, element); + } + return {}; +} + Value createSqueezedValue(PatternRewriter &rewriter, Location loc, Value input, int axis) { RankedTensorType inputType = @@ -518,6 +538,24 @@ Value createLayerNormWithoutLastAdd(PatternRewriter &rewriter, Location loc, return createLayerNorm(rewriter, loc, input, scale, B, axes, epsilon_attr); } +Value createLayerNormWithoutLastMulAdd(PatternRewriter &rewriter, Location loc, + Value input, Value axes, + Attribute epsilon_attr) { + auto inputType = llvm::cast(input.getType()); + auto axesValue = onnx_mlir::getElementAttributeFromONNXValue(axes) + .getValues()[0] + .getSExtValue(); + if (axesValue < 0) + axesValue += inputType.getRank(); + auto biasType = RankedTensorType::get({inputType.getShape()[axesValue]}, + inputType.getElementType()); + Attribute zero = rewriter.getZeroAttr(biasType); + Attribute one = getOneAttr(rewriter, biasType); + Value B = rewriter.create(loc, Attribute(), zero); + Value scale = rewriter.create(loc, Attribute(), one); + return createLayerNorm(rewriter, loc, input, scale, B, axes, epsilon_attr); +} + //===----------------------------------------------------------------------===// // GeLU //===----------------------------------------------------------------------===// @@ -783,6 +821,8 @@ struct OFRewriteToCustomCallPass std::make_unique(context)); validOpSet[getLayerNormName()].emplace_back( std::make_unique(context)); + validOpSet[getLayerNormName()].emplace_back( + std::make_unique(context)); validOpSet[getLayerNormName()].emplace_back( std::make_unique(context)); validOpSet[getOneHotName()].emplace_back( diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td index ce1328c29..477cac87b 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td @@ -28,6 +28,7 @@ def IsOneSizeElements : Constraint, "this BoolAttr should be true">; def SameTwoValuesOrAttrs : Constraint, "two values or attrs are actually the same">; def HasOneUse : Constraint, "value has exactly one use">; +def HasMultiUse : Constraint, "value has multi use">; def isScalarConstantTensor : Constraint, @@ -225,6 +226,32 @@ def RewriteLayerNormWithoutLastAdd : Pat< (SameTwoValuesOrAttrs $input_zero_mean_0, $input_zero_mean_1), ]>; +def RewriteLayerNormWithoutLastMulAdd : Pat< + (ONNXDivOp:$results + (ONNXSubOp:$input_zero_mean_0 + $input, + (ONNXReduceMeanOp $input, $axes_0, $_, $_) + ), + (ONNXSqrtOp + (ONNXAddOp + (ONNXReduceMeanOp + (ONNXMulOp + $input_zero_mean_1, + $input_zero_mean_1 + ), $axes_1, $_, $_ + ), + (ONNXConstantOp $_, $epsilon_attr, $_, $_, $_, $_, $_, $_) + ) + ) + ), + (NativeCodeCall<"createLayerNormWithoutLastMulAdd($_builder, $_loc, $0, $1, $2)"> $input, $axes_0, $epsilon_attr), + [(IsOneSizeElements $epsilon_attr), + (isScalarConstantTensor:$axes_0), (isScalarConstantTensor:$axes_1), + (SameTwoIntegerScalarConstantValues $axes_0, $axes_1), + (SameTwoValuesOrAttrs $input_zero_mean_0, $input_zero_mean_1), + (HasMultiUse $results__0), + ]>; + //===----------------------------------------------------------------------===// // GeLU Pattern //===----------------------------------------------------------------------===// diff --git a/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir b/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir index e05402885..e86968f36 100644 --- a/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir +++ b/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir @@ -128,6 +128,31 @@ func.func @test_layer_norm_without_last_add(%arg0: tensor<1x3xf32>) -> tensor<1x // ----- +func.func @test_layer_norm_without_last_muladd(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> { + %550 = "onnx.Constant"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %551 = "onnx.Constant"() {value = dense<9.99999974E-6> : tensor} : () -> tensor + %526 = "onnx.Constant"() {value = dense<[0.15, 0.2, 0.25]> : tensor<3xf32>} : () -> tensor<3xf32> + %c1 = onnx.Constant dense<-1> : tensor<1xi64> + %c2 = onnx.Constant dense<-1> : tensor<1xi64> + %963 = "onnx.ReduceMean"(%arg0, %c1) : (tensor<1x3xf32>, tensor<1xi64>) -> tensor<1x1xf32> + %964 = "onnx.Sub"(%arg0, %963) {onnx_node_name = "Sub_537"} : (tensor<1x3xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> + %965 = "onnx.Mul"(%964, %964) : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %966 = "onnx.ReduceMean"(%965, %c2) : (tensor<1x3xf32>, tensor<1xi64>) -> tensor<1x1xf32> + %967 = "onnx.Add"(%966, %551) {onnx_node_name = "Add_542"} : (tensor<1x1xf32>, tensor) -> tensor<1x1xf32> + %968 = "onnx.Sqrt"(%967) {onnx_node_name = "Sqrt_543"} : (tensor<1x1xf32>) -> tensor<1x1xf32> + %969 = "onnx.Div"(%964, %968) {onnx_node_name = "Div_544"} : (tensor<1x3xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> + %970 = "onnx.Mul"(%969, %526) {onnx_node_name = "Mul_545"} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + // onnx.Add is folded here + return %970 : tensor<1x3xf32> +// CHECK-LABEL: @test_layer_norm_without_last_muladd(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1.500000e-01, 2.000000e-01, 2.500000e-01]> : tensor<3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor<3xf32> +// CHECK-NEXT: %2 = stablehlo.custom_call @byteir.layer_norm(%arg0, [[VAR_0_]], [[VAR_1_]]) {byteir_attrs = {axis = [1], epsilon = 9.9999997473787516E-6 : f64}} : (tensor<1x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x3xf32> +// CHECK-NEXT: return %2 : tensor<1x3xf32> +} + +// ----- + func.func @test_layer_norm_squeeze(%arg0: tensor<2x4x3xf32>) -> tensor<2x4x3xf32> { %c1 = onnx.Constant dense<-1> : tensor<1xi64> %c2 = onnx.Constant dense<-1> : tensor<1xi64> diff --git a/frontends/onnx-frontend/third_party/patches/OnnxMlirElementwise.patch b/frontends/onnx-frontend/third_party/patches/OnnxMlirElementwise.patch index f3a946fb8..c1220b5da 100644 --- a/frontends/onnx-frontend/third_party/patches/OnnxMlirElementwise.patch +++ b/frontends/onnx-frontend/third_party/patches/OnnxMlirElementwise.patch @@ -1,5 +1,5 @@ diff --git a/src/Conversion/ONNXToStablehlo/Math/Elementwise.cpp b/src/Conversion/ONNXToStablehlo/Math/Elementwise.cpp -index b5b58f2b..8b32b1a5 100644 +index b5b58f2b..35a2ca28 100644 --- a/src/Conversion/ONNXToStablehlo/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToStablehlo/Math/Elementwise.cpp @@ -81,6 +81,11 @@ struct StablehloDialectOp { @@ -14,7 +14,36 @@ index b5b58f2b..8b32b1a5 100644 template <> struct StablehloDialectOp { using Op = stablehlo::PowOp; -@@ -444,6 +449,7 @@ void populateLoweringONNXElementwiseOpToStablehloPattern( +@@ -288,6 +293,28 @@ struct ONNXElementwiseUnaryOpLoweringToStablehlo + } + }; + ++// ONNXRoundOp(x) is implemented using Stablehlo round_nearest_even(x, 0) ++template <> ++struct ONNXElementwiseUnaryOpLoweringToStablehlo ++ : public ConversionPattern { ++ ONNXElementwiseUnaryOpLoweringToStablehlo(MLIRContext *ctx) ++ : ConversionPattern(ONNXRoundOp::getOperationName(), 1, ctx) {} ++ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ++ ConversionPatternRewriter &rewriter) const final { ++ Location loc = op->getLoc(); ++ ONNXRoundOpAdaptor adaptor(operands, op->getAttrDictionary()); ++ Value inp = adaptor.getX(); ++ ShapedType inpType = inp.getType().dyn_cast_or_null(); ++ if (inpType == nullptr) ++ return failure(); ++ Type resultType = *op->result_type_begin(); ++ Value resultOp = ++ rewriter.create(loc, resultType, inp); ++ rewriter.replaceOp(op, resultOp); ++ return success(); ++ } ++}; ++ + template <> + struct ONNXElementwiseUnaryOpLoweringToStablehlo + : public ConversionPattern { +@@ -444,10 +471,12 @@ void populateLoweringONNXElementwiseOpToStablehloPattern( ONNXElementwiseUnaryOpLoweringToStablehlo, ONNXElementwiseUnaryOpLoweringToStablehlo, ONNXElementwiseUnaryOpLoweringToStablehlo, @@ -22,3 +51,8 @@ index b5b58f2b..8b32b1a5 100644 ONNXElementwiseUnaryOpLoweringToStablehlo, ONNXElementwiseUnaryOpLoweringToStablehlo, ONNXElementwiseUnaryOpLoweringToStablehlo, + ONNXElementwiseUnaryOpLoweringToStablehlo, ++ ONNXElementwiseUnaryOpLoweringToStablehlo, + ONNXElementwiseUnaryOpLoweringToStablehlo, + ONNXElementwiseCompareBinaryOpLoweringToStablehlo, + ONNXElementwiseCompareBinaryOpLoweringToStablehlo, diff --git a/frontends/onnx-frontend/third_party/patches/OnnxMlirOnnxOpsTensorPad.patch b/frontends/onnx-frontend/third_party/patches/OnnxMlirOnnxOpsTensorPad.patch new file mode 100644 index 000000000..a47345d55 --- /dev/null +++ b/frontends/onnx-frontend/third_party/patches/OnnxMlirOnnxOpsTensorPad.patch @@ -0,0 +1,25 @@ +diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp +index b00edc4a..4454ee76 100644 +--- a/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp ++++ b/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp +@@ -45,9 +45,17 @@ LogicalResult ONNXPadOpShapeHelper::computeShape() { + // Calculate output dimension sizes. + for (uint64_t i = 0; i < dataRank; i++) { + // Get begin/end pads. +- SymbolIndexExpr padBegin(createIE->getIntFromArrayAsSymbol(padsOperand, i)); +- SymbolIndexExpr padEnd( +- createIE->getIntFromArrayAsSymbol(padsOperand, i + dataRank)); ++ auto padBeginIE = createIE->getIntFromArrayAsSymbol(padsOperand, i); ++ if (padBeginIE.isUndefined()) { ++ return failure(); ++ } ++ SymbolIndexExpr padBegin(padBeginIE); ++ auto padEndIE = ++ createIE->getIntFromArrayAsSymbol(padsOperand, i + dataRank); ++ if (padEndIE.isUndefined()) { ++ return failure(); ++ } ++ SymbolIndexExpr padEnd(padEndIE); + if (padBegin.isUndefined() || padEnd.isUndefined()) + return op->emitError("pad parameter could not be processed"); + // Get input dim.