From f940562622fbfac3d4b114656d0f7aa4d08f4f54 Mon Sep 17 00:00:00 2001 From: Tom Bannink Date: Sat, 15 Jun 2024 18:02:41 +0200 Subject: [PATCH] Fix some MLIR build errors --- larq_compute_engine/mlir/ir/lce_ops.h | 1 + .../transforms/bitpack_activations_patterns.td | 2 +- .../mlir/transforms/fuse_padding.td | 12 ++++-------- larq_compute_engine/mlir/transforms/optimize.cc | 1 - .../mlir/transforms/optimize_patterns_common.td | 17 +++++++---------- .../mlir/transforms/prepare_patterns_common.td | 14 ++++++-------- 6 files changed, 19 insertions(+), 28 deletions(-) diff --git a/larq_compute_engine/mlir/ir/lce_ops.h b/larq_compute_engine/mlir/ir/lce_ops.h index f19dd81b..26656e1b 100644 --- a/larq_compute_engine/mlir/ir/lce_ops.h +++ b/larq_compute_engine/mlir/ir/lce_ops.h @@ -3,6 +3,7 @@ #include "mlir/Dialect/Quant/QuantTypes.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" // clang-format off #include "larq_compute_engine/mlir/ir/lce_dialect.h.inc" diff --git a/larq_compute_engine/mlir/transforms/bitpack_activations_patterns.td b/larq_compute_engine/mlir/transforms/bitpack_activations_patterns.td index c8dda3c2..f6beab39 100644 --- a/larq_compute_engine/mlir/transforms/bitpack_activations_patterns.td +++ b/larq_compute_engine/mlir/transforms/bitpack_activations_patterns.td @@ -55,6 +55,6 @@ class WriteBitpackedActivationsPat; + [(HasOneUse $output)]>; def : WriteBitpackedActivationsPat; def : WriteBitpackedActivationsPat; diff --git a/larq_compute_engine/mlir/transforms/fuse_padding.td b/larq_compute_engine/mlir/transforms/fuse_padding.td index 0aab22ae..57d7be0a 100644 --- a/larq_compute_engine/mlir/transforms/fuse_padding.td +++ b/larq_compute_engine/mlir/transforms/fuse_padding.td @@ -43,8 +43,7 @@ def : Pat<(TFL_Conv2DOp:$conv_output [(HasOneUse $pad_output), (NoBatchAndChannelPadding $paddings), (SamePaddingHeight $paddings, $input, $conv_output, $stride_h), - (SamePaddingWidth $paddings, $input, $conv_output, $stride_w)], - (addBenefit 100)>; + (SamePaddingWidth $paddings, $input, $conv_output, $stride_w)]>; // PadV2 > Conv2D @@ -74,8 +73,7 @@ def : Pat<(TFL_Conv2DOp:$conv_output (ConstFloatValueIs<"0.0"> $pad_values), (NoBatchAndChannelPadding $paddings), (SamePaddingHeight $paddings, $input, $conv_output, $stride_h), - (SamePaddingWidth $paddings, $input, $conv_output, $stride_w)], - (addBenefit 100)>; + (SamePaddingWidth $paddings, $input, $conv_output, $stride_w)]>; // Pad > DepthwiseConv2D def : Pat<(TFL_DepthwiseConv2DOp:$conv_output @@ -104,8 +102,7 @@ def : Pat<(TFL_DepthwiseConv2DOp:$conv_output [(HasOneUse $pad_output), (NoBatchAndChannelPadding $paddings), (SamePaddingHeight $paddings, $input, $conv_output, $stride_h), - (SamePaddingWidth $paddings, $input, $conv_output, $stride_w)], - (addBenefit 100)>; + (SamePaddingWidth $paddings, $input, $conv_output, $stride_w)]>; // PadV2 > DepthwiseConv2D def : Pat<(TFL_DepthwiseConv2DOp:$conv_output @@ -136,5 +133,4 @@ def : Pat<(TFL_DepthwiseConv2DOp:$conv_output (ConstFloatValueIs<"0.0"> $pad_values), (NoBatchAndChannelPadding $paddings), (SamePaddingHeight $paddings, $input, $conv_output, $stride_h), - (SamePaddingWidth $paddings, $input, $conv_output, $stride_w)], - (addBenefit 100)>; + (SamePaddingWidth $paddings, $input, $conv_output, $stride_w)]>; diff --git a/larq_compute_engine/mlir/transforms/optimize.cc b/larq_compute_engine/mlir/transforms/optimize.cc index 8b43a790..9646a3d0 100644 --- a/larq_compute_engine/mlir/transforms/optimize.cc +++ b/larq_compute_engine/mlir/transforms/optimize.cc @@ -4,7 +4,6 @@ #include "larq_compute_engine/mlir/ir/lce_ops.h" #include "larq_compute_engine/mlir/transforms/common.h" #include "larq_compute_engine/mlir/transforms/passes.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "mlir/Dialect/Func/IR/FuncOps.h" diff --git a/larq_compute_engine/mlir/transforms/optimize_patterns_common.td b/larq_compute_engine/mlir/transforms/optimize_patterns_common.td index 27c8de45..cdf6487e 100644 --- a/larq_compute_engine/mlir/transforms/optimize_patterns_common.td +++ b/larq_compute_engine/mlir/transforms/optimize_patterns_common.td @@ -18,8 +18,7 @@ def : Pat<(LQ_QuantizeOp $input, (Arith_ConstantOp ConstantValue<"0.0f">))), (LQ_QuantizeOp $input), - [(HasOneUse $ge_op)], - (addBenefit 150)>; + [(HasOneUse $ge_op)]>; def : Pat<(LQ_QuantizeOp (TFL_GreaterEqualOp:$ge_op @@ -27,15 +26,13 @@ def : Pat<(LQ_QuantizeOp $threshold)), (LQ_QuantizeOp (TFL_SubOp $input, $threshold, TFL_AF_None)), - [(HasOneUse $ge_op)], - (addBenefit 100)>; + [(HasOneUse $ge_op)]>; def : Pat<(LQ_QuantizeOp (TFL_LessEqualOp:$ge_op $lhs, $rhs)), (LQ_QuantizeOp (TFL_GreaterEqualOp $rhs, $lhs)), - [(HasOneUse $ge_op)], - (addBenefit 100)>; + [(HasOneUse $ge_op)]>; // TODO: Check shapes before fusing multiclass FuseAddOrSubWithBConv2D { @@ -70,7 +67,7 @@ multiclass FuseAddOrSubWithBConv2D { $padding, $stride_height, $stride_width), - [(HasOneUse $output)], (addBenefit 100)>; + [(HasOneUse $output)]>; } foreach binaryOp = [TFL_AddOp, TFL_SubOp] in defm : FuseAddOrSubWithBConv2D; @@ -109,7 +106,7 @@ multiclass FuseMulOrDivWithBConv2D { $padding, $stride_height, $stride_width), - [(HasOneUse $conv_output)], (addBenefit 100)>; + [(HasOneUse $conv_output)]>; } foreach binaryOp = [TFL_DivOp, TFL_MulOp] in defm : FuseMulOrDivWithBConv2D; @@ -146,7 +143,7 @@ multiclass FuseActFnIntoConvOpPat { $padding, $stride_height, $stride_width), - [(HasOneUse $conv_output)], (addBenefit 100)>; + [(HasOneUse $conv_output)]>; def : Pat<(ActFnOp (LQ_Bconv2dOp:$conv_output $input, @@ -176,7 +173,7 @@ multiclass FuseActFnIntoConvOpPat { $padding, $stride_height, $stride_width), - [(HasOneUse $conv_output)], (addBenefit 100)>; + [(HasOneUse $conv_output)]>; } foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], [TFL_Relu1Op, TFL_AF_Relu1], diff --git a/larq_compute_engine/mlir/transforms/prepare_patterns_common.td b/larq_compute_engine/mlir/transforms/prepare_patterns_common.td index 430c0e49..3eed4cf6 100644 --- a/larq_compute_engine/mlir/transforms/prepare_patterns_common.td +++ b/larq_compute_engine/mlir/transforms/prepare_patterns_common.td @@ -37,7 +37,7 @@ multiclass QuantDequantPatterns { $select_op, $select_op, /*use 32bit*/ConstBoolAttrFalse)))), - [], (addBenefit 100)>; + []>; def : Pat<(SelectOp:$select_op $cond, (Arith_ConstantOp ConstantValue<"-1.0f">), @@ -51,7 +51,7 @@ multiclass QuantDequantPatterns { $select_op, $select_op, /*use 32bit*/ConstBoolAttrFalse)))), - [], (addBenefit 100)>; + []>; } foreach SelectOp = [TF_SelectOp, TF_SelectV2Op] in defm : QuantDequantPatterns; @@ -59,9 +59,9 @@ foreach SelectOp = [TF_SelectOp, TF_SelectV2Op] in // A fallback for the old version of `ste_sign` that uses a specific `tf.sign` // based implementation of `larq.math.sign`. def : Pat<(TF_SignOp (TF_AddV2Op (TF_SignOp $arg), $c)), - (LQ_DequantizeOp (LQ_QuantizeOp $arg)), [], (addBenefit 100)>; + (LQ_DequantizeOp (LQ_QuantizeOp $arg)), []>; def : Pat<(TF_SignOp (TF_AddV2Op $c, (TF_SignOp $arg))), - (LQ_DequantizeOp (LQ_QuantizeOp $arg)), [], (addBenefit 100)>; + (LQ_DequantizeOp (LQ_QuantizeOp $arg)), []>; // Copied from legalize_patterns.td class I32VectorElementsAttr : ElementsAttrBase< @@ -123,8 +123,7 @@ class PrepareBConvPadValue0Pat : ExtractI32At<1>:$strides, ExtractI32At<2>:$strides), [(BinaryFilter $filter), - (ValidFilterShape $dequantized_input, $filter_op)], - (addBenefit 90)>; + (ValidFilterShape $dequantized_input, $filter_op)]>; def : PrepareBConvPadValue0Pat; def ConstFloatValueIsOne : Constraint< @@ -166,5 +165,4 @@ def : Pat<(TF_Conv2DOp:$output [(BinaryFilter $filter), (ConstFloatValueIsOne $pad_values), (SamePadding $paddings, $input, $output, $strides), - (ValidFilterShape $dequantized_input, $filter_op)], - (addBenefit 90)>; + (ValidFilterShape $dequantized_input, $filter_op)]>;