Skip to content

Commit

Permalink
Fix some MLIR build errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Tombana committed Jun 15, 2024
1 parent 68e7076 commit f940562
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 28 deletions.
1 change: 1 addition & 0 deletions larq_compute_engine/mlir/ir/lce_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Bytecode/BytecodeOpInterface.h"

// clang-format off
#include "larq_compute_engine/mlir/ir/lce_dialect.h.inc"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,6 @@ class WriteBitpackedActivationsPat<ConstantStrAttr padding_type, string pad_valu
padding_type,
$stride_height,
$stride_width),
[(HasOneUse $output)], (addBenefit 10)>;
[(HasOneUse $output)]>;
def : WriteBitpackedActivationsPat<TFL_PAD_Valid, "0">;
def : WriteBitpackedActivationsPat<TFL_PAD_Same, "1">;
12 changes: 4 additions & 8 deletions larq_compute_engine/mlir/transforms/fuse_padding.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def : Pat<(TFL_Conv2DOp:$conv_output
[(HasOneUse $pad_output),
(NoBatchAndChannelPadding $paddings),
(SamePaddingHeight $paddings, $input, $conv_output, $stride_h),
(SamePaddingWidth $paddings, $input, $conv_output, $stride_w)],
(addBenefit 100)>;
(SamePaddingWidth $paddings, $input, $conv_output, $stride_w)]>;


// PadV2 > Conv2D
Expand Down Expand Up @@ -74,8 +73,7 @@ def : Pat<(TFL_Conv2DOp:$conv_output
(ConstFloatValueIs<"0.0"> $pad_values),
(NoBatchAndChannelPadding $paddings),
(SamePaddingHeight $paddings, $input, $conv_output, $stride_h),
(SamePaddingWidth $paddings, $input, $conv_output, $stride_w)],
(addBenefit 100)>;
(SamePaddingWidth $paddings, $input, $conv_output, $stride_w)]>;

// Pad > DepthwiseConv2D
def : Pat<(TFL_DepthwiseConv2DOp:$conv_output
Expand Down Expand Up @@ -104,8 +102,7 @@ def : Pat<(TFL_DepthwiseConv2DOp:$conv_output
[(HasOneUse $pad_output),
(NoBatchAndChannelPadding $paddings),
(SamePaddingHeight $paddings, $input, $conv_output, $stride_h),
(SamePaddingWidth $paddings, $input, $conv_output, $stride_w)],
(addBenefit 100)>;
(SamePaddingWidth $paddings, $input, $conv_output, $stride_w)]>;

// PadV2 > DepthwiseConv2D
def : Pat<(TFL_DepthwiseConv2DOp:$conv_output
Expand Down Expand Up @@ -136,5 +133,4 @@ def : Pat<(TFL_DepthwiseConv2DOp:$conv_output
(ConstFloatValueIs<"0.0"> $pad_values),
(NoBatchAndChannelPadding $paddings),
(SamePaddingHeight $paddings, $input, $conv_output, $stride_h),
(SamePaddingWidth $paddings, $input, $conv_output, $stride_w)],
(addBenefit 100)>;
(SamePaddingWidth $paddings, $input, $conv_output, $stride_w)]>;
1 change: 0 additions & 1 deletion larq_compute_engine/mlir/transforms/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "larq_compute_engine/mlir/ir/lce_ops.h"
#include "larq_compute_engine/mlir/transforms/common.h"
#include "larq_compute_engine/mlir/transforms/passes.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down
17 changes: 7 additions & 10 deletions larq_compute_engine/mlir/transforms/optimize_patterns_common.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,21 @@ def : Pat<(LQ_QuantizeOp
$input,
(Arith_ConstantOp ConstantValue<"0.0f">))),
(LQ_QuantizeOp $input),
[(HasOneUse $ge_op)],
(addBenefit 150)>;
[(HasOneUse $ge_op)]>;

def : Pat<(LQ_QuantizeOp
(TFL_GreaterEqualOp:$ge_op
$input,
$threshold)),
(LQ_QuantizeOp
(TFL_SubOp $input, $threshold, TFL_AF_None)),
[(HasOneUse $ge_op)],
(addBenefit 100)>;
[(HasOneUse $ge_op)]>;

def : Pat<(LQ_QuantizeOp
(TFL_LessEqualOp:$ge_op $lhs, $rhs)),
(LQ_QuantizeOp
(TFL_GreaterEqualOp $rhs, $lhs)),
[(HasOneUse $ge_op)],
(addBenefit 100)>;
[(HasOneUse $ge_op)]>;

// TODO: Check shapes before fusing
multiclass FuseAddOrSubWithBConv2D<Op binaryOp> {
Expand Down Expand Up @@ -70,7 +67,7 @@ multiclass FuseAddOrSubWithBConv2D<Op binaryOp> {
$padding,
$stride_height,
$stride_width),
[(HasOneUse $output)], (addBenefit 100)>;
[(HasOneUse $output)]>;
}
foreach binaryOp = [TFL_AddOp, TFL_SubOp] in
defm : FuseAddOrSubWithBConv2D<binaryOp>;
Expand Down Expand Up @@ -109,7 +106,7 @@ multiclass FuseMulOrDivWithBConv2D<Op binaryOp> {
$padding,
$stride_height,
$stride_width),
[(HasOneUse $conv_output)], (addBenefit 100)>;
[(HasOneUse $conv_output)]>;
}
foreach binaryOp = [TFL_DivOp, TFL_MulOp] in
defm : FuseMulOrDivWithBConv2D<binaryOp>;
Expand Down Expand Up @@ -146,7 +143,7 @@ multiclass FuseActFnIntoConvOpPat<Op ActFnOp, ConstantStrAttr ActFnAttr> {
$padding,
$stride_height,
$stride_width),
[(HasOneUse $conv_output)], (addBenefit 100)>;
[(HasOneUse $conv_output)]>;
def : Pat<(ActFnOp
(LQ_Bconv2dOp:$conv_output
$input,
Expand Down Expand Up @@ -176,7 +173,7 @@ multiclass FuseActFnIntoConvOpPat<Op ActFnOp, ConstantStrAttr ActFnAttr> {
$padding,
$stride_height,
$stride_width),
[(HasOneUse $conv_output)], (addBenefit 100)>;
[(HasOneUse $conv_output)]>;
}
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu1Op, TFL_AF_Relu1],
Expand Down
14 changes: 6 additions & 8 deletions larq_compute_engine/mlir/transforms/prepare_patterns_common.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ multiclass QuantDequantPatterns<Op SelectOp> {
$select_op,
$select_op,
/*use 32bit*/ConstBoolAttrFalse)))),
[], (addBenefit 100)>;
[]>;
def : Pat<(SelectOp:$select_op
$cond,
(Arith_ConstantOp ConstantValue<"-1.0f">),
Expand All @@ -51,17 +51,17 @@ multiclass QuantDequantPatterns<Op SelectOp> {
$select_op,
$select_op,
/*use 32bit*/ConstBoolAttrFalse)))),
[], (addBenefit 100)>;
[]>;
}
foreach SelectOp = [TF_SelectOp, TF_SelectV2Op]<Op> in
defm : QuantDequantPatterns<SelectOp>;

// A fallback for the old version of `ste_sign` that uses a specific `tf.sign`
// based implementation of `larq.math.sign`.
def : Pat<(TF_SignOp (TF_AddV2Op (TF_SignOp $arg), $c)),
(LQ_DequantizeOp (LQ_QuantizeOp $arg)), [], (addBenefit 100)>;
(LQ_DequantizeOp (LQ_QuantizeOp $arg)), []>;
def : Pat<(TF_SignOp (TF_AddV2Op $c, (TF_SignOp $arg))),
(LQ_DequantizeOp (LQ_QuantizeOp $arg)), [], (addBenefit 100)>;
(LQ_DequantizeOp (LQ_QuantizeOp $arg)), []>;

// Copied from legalize_patterns.td
class I32VectorElementsAttr<int len> : ElementsAttrBase<
Expand Down Expand Up @@ -123,8 +123,7 @@ class PrepareBConvPadValue0Pat<ConstantStrAttr padding_type> :
ExtractI32At<1>:$strides,
ExtractI32At<2>:$strides),
[(BinaryFilter $filter),
(ValidFilterShape $dequantized_input, $filter_op)],
(addBenefit 90)>;
(ValidFilterShape $dequantized_input, $filter_op)]>;
def : PrepareBConvPadValue0Pat<TFL_PAD_Valid>;

def ConstFloatValueIsOne : Constraint<
Expand Down Expand Up @@ -166,5 +165,4 @@ def : Pat<(TF_Conv2DOp:$output
[(BinaryFilter $filter),
(ConstFloatValueIsOne $pad_values),
(SamePadding $paddings, $input, $output, $strides),
(ValidFilterShape $dequantized_input, $filter_op)],
(addBenefit 90)>;
(ValidFilterShape $dequantized_input, $filter_op)]>;

0 comments on commit f940562

Please sign in to comment.