Skip to content

Commit

Permalink
[onnx] lowering round, fix pad op, support more layernorm pattern (#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
YellowHCH authored Dec 31, 2024
1 parent 5447a6d commit 0e39d32
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ struct OFCanonicalizerPass
void runOnOperation() override {
ModuleOp module = getOperation();
// Canonicalization
LogicalResult converged = applyPatternsAndFoldGreedily(module, patterns);
// TODO: revert this after bump llvm version after 2024-12-22
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
LogicalResult converged =
applyPatternsAndFoldGreedily(module, patterns, config);
// Canonicalization is best-effort. Non-convergence is not a pass failure.
if (testConvergence && failed(converged))
signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,26 @@ Value createResize(PatternRewriter &rewriter, Location loc, Value input,
// LayerNorm
//===----------------------------------------------------------------------===//

// NOTE: This method already exists in upstream llvm. Replace this once
// upgrading llvm.
TypedAttr getOneAttr(PatternRewriter &rewriter, Type type) {
if (llvm::isa<FloatType>(type))
return rewriter.getFloatAttr(type, 1.0);
if (llvm::isa<IndexType>(type))
return rewriter.getIndexAttr(1);
if (llvm::dyn_cast<IntegerType>(type))
return rewriter.getIntegerAttr(
type, APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
if (llvm::isa<RankedTensorType, VectorType>(type)) {
auto vtType = llvm::cast<ShapedType>(type);
auto element = getOneAttr(rewriter, vtType.getElementType());
if (!element)
return {};
return DenseElementsAttr::get(vtType, element);
}
return {};
}

Value createSqueezedValue(PatternRewriter &rewriter, Location loc, Value input,
int axis) {
RankedTensorType inputType =
Expand Down Expand Up @@ -518,6 +538,24 @@ Value createLayerNormWithoutLastAdd(PatternRewriter &rewriter, Location loc,
return createLayerNorm(rewriter, loc, input, scale, B, axes, epsilon_attr);
}

Value createLayerNormWithoutLastMulAdd(PatternRewriter &rewriter, Location loc,
Value input, Value axes,
Attribute epsilon_attr) {
auto inputType = llvm::cast<ShapedType>(input.getType());
auto axesValue = onnx_mlir::getElementAttributeFromONNXValue(axes)
.getValues<APInt>()[0]
.getSExtValue();
if (axesValue < 0)
axesValue += inputType.getRank();
auto biasType = RankedTensorType::get({inputType.getShape()[axesValue]},
inputType.getElementType());
Attribute zero = rewriter.getZeroAttr(biasType);
Attribute one = getOneAttr(rewriter, biasType);
Value B = rewriter.create<ONNXConstantOp>(loc, Attribute(), zero);
Value scale = rewriter.create<ONNXConstantOp>(loc, Attribute(), one);
return createLayerNorm(rewriter, loc, input, scale, B, axes, epsilon_attr);
}

//===----------------------------------------------------------------------===//
// GeLU
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -783,6 +821,8 @@ struct OFRewriteToCustomCallPass
std::make_unique<RewriteLayerNormWithNoneEps>(context));
validOpSet[getLayerNormName()].emplace_back(
std::make_unique<RewriteLayerNormWithoutLastAdd>(context));
validOpSet[getLayerNormName()].emplace_back(
std::make_unique<RewriteLayerNormWithoutLastMulAdd>(context));
validOpSet[getLayerNormName()].emplace_back(
std::make_unique<RewriteInstanceNorm>(context));
validOpSet[getOneHotName()].emplace_back(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def IsOneSizeElements : Constraint<And<[
def TrueBoolAttr : Constraint<CPred<"$0.getValue() == true">, "this BoolAttr should be true">;
def SameTwoValuesOrAttrs : Constraint<CPred<"$0 == $1">, "two values or attrs are actually the same">;
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">, "value has exactly one use">;
def HasMultiUse : Constraint<CPred<"!$0.hasOneUse()">, "value has multi use">;

def isScalarConstantTensor :
Constraint<CPred<"onnx_mlir::isScalarConstantTensor($_self)">,
Expand Down Expand Up @@ -225,6 +226,32 @@ def RewriteLayerNormWithoutLastAdd : Pat<
(SameTwoValuesOrAttrs $input_zero_mean_0, $input_zero_mean_1),
]>;

def RewriteLayerNormWithoutLastMulAdd : Pat<
(ONNXDivOp:$results
(ONNXSubOp:$input_zero_mean_0
$input,
(ONNXReduceMeanOp $input, $axes_0, $_, $_)
),
(ONNXSqrtOp
(ONNXAddOp
(ONNXReduceMeanOp
(ONNXMulOp
$input_zero_mean_1,
$input_zero_mean_1
), $axes_1, $_, $_
),
(ONNXConstantOp $_, $epsilon_attr, $_, $_, $_, $_, $_, $_)
)
)
),
(NativeCodeCall<"createLayerNormWithoutLastMulAdd($_builder, $_loc, $0, $1, $2)"> $input, $axes_0, $epsilon_attr),
[(IsOneSizeElements $epsilon_attr),
(isScalarConstantTensor:$axes_0), (isScalarConstantTensor:$axes_1),
(SameTwoIntegerScalarConstantValues $axes_0, $axes_1),
(SameTwoValuesOrAttrs $input_zero_mean_0, $input_zero_mean_1),
(HasMultiUse $results__0),
]>;

//===----------------------------------------------------------------------===//
// GeLU Pattern
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,31 @@ func.func @test_layer_norm_without_last_add(%arg0: tensor<1x3xf32>) -> tensor<1x

// -----

func.func @test_layer_norm_without_last_muladd(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> {
%550 = "onnx.Constant"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
%551 = "onnx.Constant"() {value = dense<9.99999974E-6> : tensor<f32>} : () -> tensor<f32>
%526 = "onnx.Constant"() {value = dense<[0.15, 0.2, 0.25]> : tensor<3xf32>} : () -> tensor<3xf32>
%c1 = onnx.Constant dense<-1> : tensor<1xi64>
%c2 = onnx.Constant dense<-1> : tensor<1xi64>
%963 = "onnx.ReduceMean"(%arg0, %c1) : (tensor<1x3xf32>, tensor<1xi64>) -> tensor<1x1xf32>
%964 = "onnx.Sub"(%arg0, %963) {onnx_node_name = "Sub_537"} : (tensor<1x3xf32>, tensor<1x1xf32>) -> tensor<1x3xf32>
%965 = "onnx.Mul"(%964, %964) : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32>
%966 = "onnx.ReduceMean"(%965, %c2) : (tensor<1x3xf32>, tensor<1xi64>) -> tensor<1x1xf32>
%967 = "onnx.Add"(%966, %551) {onnx_node_name = "Add_542"} : (tensor<1x1xf32>, tensor<f32>) -> tensor<1x1xf32>
%968 = "onnx.Sqrt"(%967) {onnx_node_name = "Sqrt_543"} : (tensor<1x1xf32>) -> tensor<1x1xf32>
%969 = "onnx.Div"(%964, %968) {onnx_node_name = "Div_544"} : (tensor<1x3xf32>, tensor<1x1xf32>) -> tensor<1x3xf32>
%970 = "onnx.Mul"(%969, %526) {onnx_node_name = "Mul_545"} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
// onnx.Add is folded here
return %970 : tensor<1x3xf32>
// CHECK-LABEL: @test_layer_norm_without_last_muladd(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1.500000e-01, 2.000000e-01, 2.500000e-01]> : tensor<3xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor<3xf32>
// CHECK-NEXT: %2 = stablehlo.custom_call @byteir.layer_norm(%arg0, [[VAR_0_]], [[VAR_1_]]) {byteir_attrs = {axis = [1], epsilon = 9.9999997473787516E-6 : f64}} : (tensor<1x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
// CHECK-NEXT: return %2 : tensor<1x3xf32>
}

// -----

func.func @test_layer_norm_squeeze(%arg0: tensor<2x4x3xf32>) -> tensor<2x4x3xf32> {
%c1 = onnx.Constant dense<-1> : tensor<1xi64>
%c2 = onnx.Constant dense<-1> : tensor<1xi64>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/src/Conversion/ONNXToStablehlo/Math/Elementwise.cpp b/src/Conversion/ONNXToStablehlo/Math/Elementwise.cpp
index b5b58f2b..8b32b1a5 100644
index b5b58f2b..35a2ca28 100644
--- a/src/Conversion/ONNXToStablehlo/Math/Elementwise.cpp
+++ b/src/Conversion/ONNXToStablehlo/Math/Elementwise.cpp
@@ -81,6 +81,11 @@ struct StablehloDialectOp<ONNXNegOp> {
Expand All @@ -14,11 +14,45 @@ index b5b58f2b..8b32b1a5 100644
template <>
struct StablehloDialectOp<ONNXPowOp> {
using Op = stablehlo::PowOp;
@@ -444,6 +449,7 @@ void populateLoweringONNXElementwiseOpToStablehloPattern(
@@ -288,6 +293,28 @@ struct ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXLeakyReluOp>
}
};

+// ONNXRoundOp(x) is implemented using Stablehlo round_nearest_even(x, 0)
+template <>
+struct ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXRoundOp>
+ : public ConversionPattern {
+ ONNXElementwiseUnaryOpLoweringToStablehlo(MLIRContext *ctx)
+ : ConversionPattern(ONNXRoundOp::getOperationName(), 1, ctx) {}
+ LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ Location loc = op->getLoc();
+ ONNXRoundOpAdaptor adaptor(operands, op->getAttrDictionary());
+ Value inp = adaptor.getX();
+ ShapedType inpType = inp.getType().dyn_cast_or_null<ShapedType>();
+ if (inpType == nullptr)
+ return failure();
+ Type resultType = *op->result_type_begin();
+ Value resultOp =
+ rewriter.create<stablehlo::RoundNearestEvenOp>(loc, resultType, inp);
+ rewriter.replaceOp(op, resultOp);
+ return success();
+ }
+};
+
template <>
struct ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXCastOp>
: public ConversionPattern {
@@ -444,10 +471,12 @@ void populateLoweringONNXElementwiseOpToStablehloPattern(
ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXLeakyReluOp>,
ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXLogOp>,
ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXNegOp>,
+ ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXNotOp>,
ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXSigmoidOp>,
ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXSinOp>,
ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXSqrtOp>,
ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXReluOp>,
+ ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXRoundOp>,
ONNXElementwiseUnaryOpLoweringToStablehlo<ONNXTanhOp>,
ONNXElementwiseCompareBinaryOpLoweringToStablehlo<ONNXEqualOp>,
ONNXElementwiseCompareBinaryOpLoweringToStablehlo<ONNXGreaterOp>,
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp
index b00edc4a..4454ee76 100644
--- a/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp
+++ b/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp
@@ -45,9 +45,17 @@ LogicalResult ONNXPadOpShapeHelper::computeShape() {
// Calculate output dimension sizes.
for (uint64_t i = 0; i < dataRank; i++) {
// Get begin/end pads.
- SymbolIndexExpr padBegin(createIE->getIntFromArrayAsSymbol(padsOperand, i));
- SymbolIndexExpr padEnd(
- createIE->getIntFromArrayAsSymbol(padsOperand, i + dataRank));
+ auto padBeginIE = createIE->getIntFromArrayAsSymbol(padsOperand, i);
+ if (padBeginIE.isUndefined()) {
+ return failure();
+ }
+ SymbolIndexExpr padBegin(padBeginIE);
+ auto padEndIE =
+ createIE->getIntFromArrayAsSymbol(padsOperand, i + dataRank);
+ if (padEndIE.isUndefined()) {
+ return failure();
+ }
+ SymbolIndexExpr padEnd(padEndIE);
if (padBegin.isUndefined() || padEnd.isUndefined())
return op->emitError("pad parameter could not be processed");
// Get input dim.

0 comments on commit 0e39d32

Please sign in to comment.