diff --git a/larq_compute_engine/mlir/BUILD b/larq_compute_engine/mlir/BUILD index d16e0ff9..d0c5c83f 100644 --- a/larq_compute_engine/mlir/BUILD +++ b/larq_compute_engine/mlir/BUILD @@ -287,6 +287,7 @@ cc_library( cc_library( name = "larq_compute_engine_prepare", srcs = [ + "transforms/common.h", "transforms/generated_prepare_target_arm.inc", "transforms/generated_prepare_target_other.inc", "transforms/prepare_tf.cc", @@ -310,6 +311,7 @@ cc_library( cc_library( name = "larq_compute_engine_optimize", srcs = [ + "transforms/common.h", "transforms/generated_bitpack_activations.inc", "transforms/generated_optimize_target_arm.inc", "transforms/generated_optimize_target_other.inc", diff --git a/larq_compute_engine/mlir/transforms/bitpack_weights.cc b/larq_compute_engine/mlir/transforms/bitpack_weights.cc index c4169438..8983bbe5 100644 --- a/larq_compute_engine/mlir/transforms/bitpack_weights.cc +++ b/larq_compute_engine/mlir/transforms/bitpack_weights.cc @@ -10,8 +10,6 @@ namespace mlir { namespace TFL { -namespace { - struct BitpackWeightsLCE : public PassWrapper> { llvm::StringRef getArgument() const final { @@ -30,18 +28,18 @@ bool IsConv2DFilter(TypedAttr filter) { filter_type.getShape().size() == 4; } +namespace bitpackweights { #include "larq_compute_engine/mlir/transforms/generated_bitpack_weights.inc" +} // namespace bitpackweights void BitpackWeightsLCE::runOnOperation() { RewritePatternSet patterns(&getContext()); auto func = getOperation(); - TFL::populateWithGenerated(patterns); + bitpackweights::populateWithGenerated(patterns); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } -} // namespace - // Creates an instance of the TensorFlow dialect BitpackWeights pass. std::unique_ptr> CreateBitpackWeightsLCEPass() { diff --git a/larq_compute_engine/mlir/transforms/common.h b/larq_compute_engine/mlir/transforms/common.h new file mode 100644 index 00000000..d192d9cb --- /dev/null +++ b/larq_compute_engine/mlir/transforms/common.h @@ -0,0 +1,19 @@ +#pragma once + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" + +namespace mlir { +namespace TFL { + +inline bool IsConstantValue(Attribute values, float expected_value) { + if (!values.isa()) return false; + + for (auto value : values.cast().getValues()) { + if (value != expected_value) return false; + } + return true; +} + +} // namespace TFL +} // namespace mlir diff --git a/larq_compute_engine/mlir/transforms/fuse_padding.cc b/larq_compute_engine/mlir/transforms/fuse_padding.cc index 15e1af1f..c6fced9f 100644 --- a/larq_compute_engine/mlir/transforms/fuse_padding.cc +++ b/larq_compute_engine/mlir/transforms/fuse_padding.cc @@ -6,8 +6,6 @@ namespace mlir { namespace TFL { -namespace { - bool NoBatchAndChannelPadding(Attribute paddings_attr) { auto paddings = GetValidPadAttr(paddings_attr); if (!paddings) return false; @@ -33,7 +31,9 @@ bool IsSamePaddingPartial(Attribute paddings_attr, Value input, Value output, output_shape[dimension], stride); } +namespace fuse_padding { #include "larq_compute_engine/mlir/transforms/generated_fuse_padding.inc" +} // Prepare LCE operations in functions for subsequent legalization. struct FusePadding @@ -49,7 +49,7 @@ struct FusePadding auto* ctx = &getContext(); RewritePatternSet patterns(ctx); auto func = getOperation(); - populateWithGenerated(patterns); + fuse_padding::populateWithGenerated(patterns); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } void getDependentDialects(DialectRegistry& registry) const override { @@ -57,8 +57,6 @@ struct FusePadding } }; -} // namespace - // Creates an instance of the TensorFlow dialect FusePadding pass. std::unique_ptr> CreateFusePaddingPass() { return std::make_unique(); diff --git a/larq_compute_engine/mlir/transforms/legalize_tflite.cc b/larq_compute_engine/mlir/transforms/legalize_tflite.cc index 82ac5c84..d5722d3d 100644 --- a/larq_compute_engine/mlir/transforms/legalize_tflite.cc +++ b/larq_compute_engine/mlir/transforms/legalize_tflite.cc @@ -7,8 +7,6 @@ namespace mlir { namespace TFL { -namespace { - struct LegalizeLCE : public PassWrapper> { llvm::StringRef getArgument() const final { return "tfl-legalize-lce"; } @@ -55,8 +53,6 @@ void LegalizeLCE::runOnOperation() { (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } -} // namespace - // Creates an instance of the LegalizeLCE pass. std::unique_ptr> CreateLegalizeLCEPass() { return std::make_unique(); diff --git a/larq_compute_engine/mlir/transforms/op_removal.cc b/larq_compute_engine/mlir/transforms/op_removal.cc index bc90e7c2..93eb7438 100644 --- a/larq_compute_engine/mlir/transforms/op_removal.cc +++ b/larq_compute_engine/mlir/transforms/op_removal.cc @@ -8,8 +8,6 @@ namespace mlir { namespace TFL { -namespace { - // Op removal of pass through ops to make following patterns easier and enable // early constant folding struct OpRemoval @@ -21,18 +19,18 @@ struct OpRemoval void runOnOperation() override; }; +namespace op_removal { #include "larq_compute_engine/mlir/transforms/generated_op_removal.inc" +} // namespace op_removal void OpRemoval::runOnOperation() { RewritePatternSet patterns(&getContext()); auto func = getOperation(); - TFL::populateWithGenerated(patterns); + op_removal::populateWithGenerated(patterns); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } -} // namespace - // Creates an instance of the TensorFlow dialect OpRemoval pass. std::unique_ptr> CreateOpRemovalPass() { return std::make_unique(); diff --git a/larq_compute_engine/mlir/transforms/optimize.cc b/larq_compute_engine/mlir/transforms/optimize.cc index fa7cbda0..8b43a790 100644 --- a/larq_compute_engine/mlir/transforms/optimize.cc +++ b/larq_compute_engine/mlir/transforms/optimize.cc @@ -2,6 +2,7 @@ #include "larq_compute_engine/core/bitpacking/bitpack.h" #include "larq_compute_engine/mlir/ir/lce_ops.h" +#include "larq_compute_engine/mlir/transforms/common.h" #include "larq_compute_engine/mlir/transforms/passes.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" @@ -16,8 +17,6 @@ namespace mlir { namespace TFL { -namespace { - // Optimize LCE operations in functions. struct OptimizeLCE : public PassWrapper> { @@ -38,15 +37,6 @@ struct OptimizeLCE clEnumValN(LCETarget::XCORE, "xcore", "XCORE target"))}; }; -bool IsConstantValue(Attribute values, float expected_value) { - if (!values.isa()) return false; - - for (auto value : values.cast().getValues()) { - if (value != expected_value) return false; - } - return true; -} - /** * ================================================= * Computing thresholds for writing bitpacked output @@ -254,15 +244,15 @@ DenseElementsAttr GetBitpackedOutputThresholds( return DenseElementsAttr::get(type, thresholds); } -namespace target_arm { +namespace optimize_target_arm { #include "larq_compute_engine/mlir/transforms/generated_optimize_target_arm.inc" } -namespace target_other { +namespace optimize_target_other { #include "larq_compute_engine/mlir/transforms/generated_optimize_target_other.inc" } -namespace bitpack_activations { +namespace optimize_bitpack_activations { #include "larq_compute_engine/mlir/transforms/generated_bitpack_activations.inc" } @@ -271,17 +261,15 @@ void OptimizeLCE::runOnOperation() { auto func = getOperation(); if (target_ == LCETarget::ARM) { - target_arm::populateWithGenerated(patterns); + optimize_target_arm::populateWithGenerated(patterns); } else { - target_other::populateWithGenerated(patterns); + optimize_target_other::populateWithGenerated(patterns); } - bitpack_activations::populateWithGenerated(patterns); + optimize_bitpack_activations::populateWithGenerated(patterns); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } -} // namespace - // Creates an instance of the TensorFlow dialect OptimizeLCE pass. std::unique_ptr> CreateOptimizeLCEPass( const LCETarget target) { diff --git a/larq_compute_engine/mlir/transforms/prepare_tf.cc b/larq_compute_engine/mlir/transforms/prepare_tf.cc index 38645589..111fc6b2 100644 --- a/larq_compute_engine/mlir/transforms/prepare_tf.cc +++ b/larq_compute_engine/mlir/transforms/prepare_tf.cc @@ -1,5 +1,6 @@ #include "larq_compute_engine/core/types.h" #include "larq_compute_engine/mlir/ir/lce_ops.h" +#include "larq_compute_engine/mlir/transforms/common.h" #include "larq_compute_engine/mlir/transforms/padding.h" #include "larq_compute_engine/mlir/transforms/passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -14,8 +15,6 @@ namespace mlir { namespace TFL { -namespace { - using compute_engine::core::bitpacking_bitwidth; // Prepare LCE operations in functions for subsequent legalization. @@ -39,14 +38,6 @@ struct PrepareLCE clEnumValN(LCETarget::XCORE, "xcore", "XCORE target"))}; }; -bool IsConstantValue(Attribute values, float expected_value) { - if (!values.isa()) return false; - - for (auto value : values.cast().getValues()) { - if (value != expected_value) return false; - } - return true; -} DenseElementsAttr GetConstantVector(TypedAttr filter, float val) { auto filter_type = filter.getType().cast(); auto filter_shape = filter_type.getShape(); @@ -162,11 +153,11 @@ IntegerAttr GetNumChannels(Builder& b, Value output_val) { return b.getI32IntegerAttr(shape_vector[shape_vector.size() - 1]); } -namespace target_arm { +namespace prepare_target_arm { #include "larq_compute_engine/mlir/transforms/generated_prepare_target_arm.inc" } -namespace target_other { +namespace prepare_target_other { #include "larq_compute_engine/mlir/transforms/generated_prepare_target_other.inc" } @@ -181,16 +172,14 @@ void PrepareLCE::runOnOperation() { patterns.add>(ctx); if (target_ == LCETarget::ARM) { - target_arm::populateWithGenerated(patterns); + prepare_target_arm::populateWithGenerated(patterns); } else { - target_other::populateWithGenerated(patterns); + prepare_target_other::populateWithGenerated(patterns); } (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } -} // namespace - // Creates an instance of the TensorFlow dialect PrepareLCE pass. std::unique_ptr> CreatePrepareLCEPass( const LCETarget target) { diff --git a/larq_compute_engine/mlir/transforms/quantize.cc b/larq_compute_engine/mlir/transforms/quantize.cc index 767459f1..bc970b74 100644 --- a/larq_compute_engine/mlir/transforms/quantize.cc +++ b/larq_compute_engine/mlir/transforms/quantize.cc @@ -12,8 +12,6 @@ namespace TFL { //===----------------------------------------------------------------------===// // The actual Quantize Pass. // -namespace { - // Applies quantization on the model in TFL dialect. struct LCEQuantizePass : public PassWrapper> { @@ -24,15 +22,16 @@ struct LCEQuantizePass void runOnOperation() override; }; +namespace lce_quantize { #include "larq_compute_engine/mlir/transforms/generated_quantize.inc" +} void LCEQuantizePass::runOnOperation() { RewritePatternSet patterns(&getContext()); auto func = getOperation(); - TFL::populateWithGenerated(patterns); + lce_quantize::populateWithGenerated(patterns); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } -} // namespace // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass. std::unique_ptr> CreateLCEQuantizePass() { diff --git a/larq_compute_engine/mlir/transforms/set_batch_size.cc b/larq_compute_engine/mlir/transforms/set_batch_size.cc index 353370d6..7bb510c0 100644 --- a/larq_compute_engine/mlir/transforms/set_batch_size.cc +++ b/larq_compute_engine/mlir/transforms/set_batch_size.cc @@ -6,8 +6,6 @@ namespace mlir { -namespace { - mlir::Type SetBatchSize(mlir::Type type) { auto tensor_type = type.dyn_cast(); if (tensor_type && tensor_type.hasRank()) { @@ -59,8 +57,6 @@ struct SetBatchSizePass } }; -} // namespace - // Creates an instance of the ZeroPointCompatibility pass. std::unique_ptr> CreateSetBatchSizePass() { return std::make_unique(); diff --git a/larq_compute_engine/mlir/transforms/translate_tflite.cc b/larq_compute_engine/mlir/transforms/translate_tflite.cc index 3f2363a1..b96ccb97 100644 --- a/larq_compute_engine/mlir/transforms/translate_tflite.cc +++ b/larq_compute_engine/mlir/transforms/translate_tflite.cc @@ -22,8 +22,6 @@ static llvm::StringRef ConvertPaddingAttr(tflite::Padding padding_type) { namespace mlir { namespace TFL { -namespace { - struct TranslateToLCE : public PassWrapper> { llvm::StringRef getArgument() const final { return "lce-translate-tfl"; } @@ -90,8 +88,6 @@ void TranslateToLCE::runOnOperation() { (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } -} // namespace - // Creates an instance of the TranslateToLCE pass. std::unique_ptr> CreateTranslateToLCEPass() { return std::make_unique();