Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove anonymous namespaces in MLIR transforms #768

Merged
merged 1 commit into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ cc_library(
cc_library(
name = "larq_compute_engine_prepare",
srcs = [
"transforms/common.h",
"transforms/generated_prepare_target_arm.inc",
"transforms/generated_prepare_target_other.inc",
"transforms/prepare_tf.cc",
Expand All @@ -310,6 +311,7 @@ cc_library(
cc_library(
name = "larq_compute_engine_optimize",
srcs = [
"transforms/common.h",
"transforms/generated_bitpack_activations.inc",
"transforms/generated_optimize_target_arm.inc",
"transforms/generated_optimize_target_other.inc",
Expand Down
8 changes: 3 additions & 5 deletions larq_compute_engine/mlir/transforms/bitpack_weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
namespace mlir {
namespace TFL {

namespace {

struct BitpackWeightsLCE
: public PassWrapper<BitpackWeightsLCE, OperationPass<mlir::func::FuncOp>> {
llvm::StringRef getArgument() const final {
Expand All @@ -30,18 +28,18 @@ bool IsConv2DFilter(TypedAttr filter) {
filter_type.getShape().size() == 4;
}

namespace bitpackweights {
#include "larq_compute_engine/mlir/transforms/generated_bitpack_weights.inc"
} // namespace bitpackweights

void BitpackWeightsLCE::runOnOperation() {
RewritePatternSet patterns(&getContext());
auto func = getOperation();

TFL::populateWithGenerated(patterns);
bitpackweights::populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}

} // namespace

// Creates an instance of the TensorFlow dialect BitpackWeights pass.
std::unique_ptr<OperationPass<mlir::func::FuncOp>>
CreateBitpackWeightsLCEPass() {
Expand Down
19 changes: 19 additions & 0 deletions larq_compute_engine/mlir/transforms/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"

namespace mlir {
namespace TFL {

inline bool IsConstantValue(Attribute values, float expected_value) {
if (!values.isa<DenseElementsAttr>()) return false;

for (auto value : values.cast<DenseElementsAttr>().getValues<float>()) {
if (value != expected_value) return false;
}
return true;
}

} // namespace TFL
} // namespace mlir
8 changes: 3 additions & 5 deletions larq_compute_engine/mlir/transforms/fuse_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
namespace mlir {
namespace TFL {

namespace {

bool NoBatchAndChannelPadding(Attribute paddings_attr) {
auto paddings = GetValidPadAttr(paddings_attr);
if (!paddings) return false;
Expand All @@ -33,7 +31,9 @@ bool IsSamePaddingPartial(Attribute paddings_attr, Value input, Value output,
output_shape[dimension], stride);
}

namespace fuse_padding {
#include "larq_compute_engine/mlir/transforms/generated_fuse_padding.inc"
}

// Prepare LCE operations in functions for subsequent legalization.
struct FusePadding
Expand All @@ -49,16 +49,14 @@ struct FusePadding
auto* ctx = &getContext();
RewritePatternSet patterns(ctx);
auto func = getOperation();
populateWithGenerated(patterns);
fuse_padding::populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<::mlir::TFL::TensorFlowLiteDialect>();
}
};

} // namespace

// Creates an instance of the TensorFlow dialect FusePadding pass.
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateFusePaddingPass() {
return std::make_unique<FusePadding>();
Expand Down
4 changes: 0 additions & 4 deletions larq_compute_engine/mlir/transforms/legalize_tflite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
namespace mlir {
namespace TFL {

namespace {

struct LegalizeLCE
: public PassWrapper<LegalizeLCE, OperationPass<mlir::func::FuncOp>> {
llvm::StringRef getArgument() const final { return "tfl-legalize-lce"; }
Expand Down Expand Up @@ -55,8 +53,6 @@ void LegalizeLCE::runOnOperation() {
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}

} // namespace

// Creates an instance of the LegalizeLCE pass.
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateLegalizeLCEPass() {
return std::make_unique<LegalizeLCE>();
Expand Down
8 changes: 3 additions & 5 deletions larq_compute_engine/mlir/transforms/op_removal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
namespace mlir {
namespace TFL {

namespace {

// Op removal of pass through ops to make following patterns easier and enable
// early constant folding
struct OpRemoval
Expand All @@ -21,18 +19,18 @@ struct OpRemoval
void runOnOperation() override;
};

namespace op_removal {
#include "larq_compute_engine/mlir/transforms/generated_op_removal.inc"
} // namespace op_removal

void OpRemoval::runOnOperation() {
RewritePatternSet patterns(&getContext());
auto func = getOperation();

TFL::populateWithGenerated(patterns);
op_removal::populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}

} // namespace

// Creates an instance of the TensorFlow dialect OpRemoval pass.
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateOpRemovalPass() {
return std::make_unique<OpRemoval>();
Expand Down
26 changes: 7 additions & 19 deletions larq_compute_engine/mlir/transforms/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "larq_compute_engine/core/bitpacking/bitpack.h"
#include "larq_compute_engine/mlir/ir/lce_ops.h"
#include "larq_compute_engine/mlir/transforms/common.h"
#include "larq_compute_engine/mlir/transforms/passes.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
Expand All @@ -16,8 +17,6 @@
namespace mlir {
namespace TFL {

namespace {

// Optimize LCE operations in functions.
struct OptimizeLCE
: public PassWrapper<OptimizeLCE, OperationPass<mlir::func::FuncOp>> {
Expand All @@ -38,15 +37,6 @@ struct OptimizeLCE
clEnumValN(LCETarget::XCORE, "xcore", "XCORE target"))};
};

bool IsConstantValue(Attribute values, float expected_value) {
if (!values.isa<DenseElementsAttr>()) return false;

for (auto value : values.cast<DenseElementsAttr>().getValues<float>()) {
if (value != expected_value) return false;
}
return true;
}

/**
* =================================================
* Computing thresholds for writing bitpacked output
Expand Down Expand Up @@ -254,15 +244,15 @@ DenseElementsAttr GetBitpackedOutputThresholds(
return DenseElementsAttr::get(type, thresholds);
}

namespace target_arm {
namespace optimize_target_arm {
#include "larq_compute_engine/mlir/transforms/generated_optimize_target_arm.inc"
}

namespace target_other {
namespace optimize_target_other {
#include "larq_compute_engine/mlir/transforms/generated_optimize_target_other.inc"
}

namespace bitpack_activations {
namespace optimize_bitpack_activations {
#include "larq_compute_engine/mlir/transforms/generated_bitpack_activations.inc"
}

Expand All @@ -271,17 +261,15 @@ void OptimizeLCE::runOnOperation() {
auto func = getOperation();

if (target_ == LCETarget::ARM) {
target_arm::populateWithGenerated(patterns);
optimize_target_arm::populateWithGenerated(patterns);
} else {
target_other::populateWithGenerated(patterns);
optimize_target_other::populateWithGenerated(patterns);
}
bitpack_activations::populateWithGenerated(patterns);
optimize_bitpack_activations::populateWithGenerated(patterns);

(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}

} // namespace

// Creates an instance of the TensorFlow dialect OptimizeLCE pass.
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateOptimizeLCEPass(
const LCETarget target) {
Expand Down
21 changes: 5 additions & 16 deletions larq_compute_engine/mlir/transforms/prepare_tf.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "larq_compute_engine/core/types.h"
#include "larq_compute_engine/mlir/ir/lce_ops.h"
#include "larq_compute_engine/mlir/transforms/common.h"
#include "larq_compute_engine/mlir/transforms/padding.h"
#include "larq_compute_engine/mlir/transforms/passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -14,8 +15,6 @@
namespace mlir {
namespace TFL {

namespace {

using compute_engine::core::bitpacking_bitwidth;

// Prepare LCE operations in functions for subsequent legalization.
Expand All @@ -39,14 +38,6 @@ struct PrepareLCE
clEnumValN(LCETarget::XCORE, "xcore", "XCORE target"))};
};

bool IsConstantValue(Attribute values, float expected_value) {
if (!values.isa<DenseElementsAttr>()) return false;

for (auto value : values.cast<DenseElementsAttr>().getValues<float>()) {
if (value != expected_value) return false;
}
return true;
}
DenseElementsAttr GetConstantVector(TypedAttr filter, float val) {
auto filter_type = filter.getType().cast<ShapedType>();
auto filter_shape = filter_type.getShape();
Expand Down Expand Up @@ -162,11 +153,11 @@ IntegerAttr GetNumChannels(Builder& b, Value output_val) {
return b.getI32IntegerAttr(shape_vector[shape_vector.size() - 1]);
}

namespace target_arm {
namespace prepare_target_arm {
#include "larq_compute_engine/mlir/transforms/generated_prepare_target_arm.inc"
}

namespace target_other {
namespace prepare_target_other {
#include "larq_compute_engine/mlir/transforms/generated_prepare_target_other.inc"
}

Expand All @@ -181,16 +172,14 @@ void PrepareLCE::runOnOperation() {
patterns.add<ConvertTFDilatedConvOp<TF::Conv2DOp>>(ctx);

if (target_ == LCETarget::ARM) {
target_arm::populateWithGenerated(patterns);
prepare_target_arm::populateWithGenerated(patterns);
} else {
target_other::populateWithGenerated(patterns);
prepare_target_other::populateWithGenerated(patterns);
}

(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}

} // namespace

// Creates an instance of the TensorFlow dialect PrepareLCE pass.
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreatePrepareLCEPass(
const LCETarget target) {
Expand Down
7 changes: 3 additions & 4 deletions larq_compute_engine/mlir/transforms/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ namespace TFL {
//===----------------------------------------------------------------------===//
// The actual Quantize Pass.
//
namespace {

// Applies quantization on the model in TFL dialect.
struct LCEQuantizePass
: public PassWrapper<LCEQuantizePass, OperationPass<mlir::func::FuncOp>> {
Expand All @@ -24,15 +22,16 @@ struct LCEQuantizePass
void runOnOperation() override;
};

namespace lce_quantize {
#include "larq_compute_engine/mlir/transforms/generated_quantize.inc"
}

void LCEQuantizePass::runOnOperation() {
RewritePatternSet patterns(&getContext());
auto func = getOperation();
TFL::populateWithGenerated(patterns);
lce_quantize::populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
} // namespace

// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateLCEQuantizePass() {
Expand Down
4 changes: 0 additions & 4 deletions larq_compute_engine/mlir/transforms/set_batch_size.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

namespace mlir {

namespace {

mlir::Type SetBatchSize(mlir::Type type) {
auto tensor_type = type.dyn_cast<mlir::TensorType>();
if (tensor_type && tensor_type.hasRank()) {
Expand Down Expand Up @@ -59,8 +57,6 @@ struct SetBatchSizePass
}
};

} // namespace

// Creates an instance of the ZeroPointCompatibility pass.
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateSetBatchSizePass() {
return std::make_unique<SetBatchSizePass>();
Expand Down
4 changes: 0 additions & 4 deletions larq_compute_engine/mlir/transforms/translate_tflite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ static llvm::StringRef ConvertPaddingAttr(tflite::Padding padding_type) {
namespace mlir {
namespace TFL {

namespace {

struct TranslateToLCE
: public PassWrapper<TranslateToLCE, OperationPass<mlir::func::FuncOp>> {
llvm::StringRef getArgument() const final { return "lce-translate-tfl"; }
Expand Down Expand Up @@ -90,8 +88,6 @@ void TranslateToLCE::runOnOperation() {
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}

} // namespace

// Creates an instance of the TranslateToLCE pass.
std::unique_ptr<OperationPass<mlir::func::FuncOp>> CreateTranslateToLCEPass() {
return std::make_unique<TranslateToLCE>();
Expand Down