Skip to content

Commit

Permalink
Add preliminary support for the xcore target.
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamHillier committed Jan 26, 2021
1 parent 95fd0b0 commit 5b5423e
Show file tree
Hide file tree
Showing 13 changed files with 189 additions and 47 deletions.
34 changes: 34 additions & 0 deletions larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ gentbl(
],
)

gentbl(
name = "prepare_lce_target_arm_inc_gen",
tbl_outs = [
("-gen-rewriters", "transforms/target_arm/generated_prepare.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/target_arm/prepare_patterns.td",
td_srcs = [
"ir/lce_ops.td",
"transforms/op_removal_patterns.td",
"transforms/prepare_patterns.td",
"@llvm-project//mlir:StdOpsTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
],
)

gentbl(
name = "optimize_lce_inc_gen",
tbl_outs = [
Expand All @@ -65,6 +82,21 @@ gentbl(
],
)

gentbl(
name = "optimize_lce_target_arm_inc_gen",
tbl_outs = [
("-gen-rewriters", "transforms/target_arm/generated_optimize.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/target_arm/optimize_patterns.td",
td_srcs = [
"ir/lce_ops.td",
"transforms/optimize_patterns.td",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
"@llvm-project//mlir:StdOpsTdFiles",
],
)

gentbl(
name = "bitpack_weights_lce_inc_gen",
tbl_outs = [
Expand Down Expand Up @@ -138,6 +170,7 @@ cc_library(
srcs = [
"transforms/generated_prepare.inc",
"transforms/prepare_tf.cc",
"transforms/target_arm/generated_prepare.inc",
],
hdrs = [
"transforms/passes.h",
Expand All @@ -159,6 +192,7 @@ cc_library(
srcs = [
"transforms/generated_optimize.inc",
"transforms/optimize.cc",
"transforms/target_arm/generated_optimize.inc",
],
hdrs = [
"transforms/passes.h",
Expand Down
5 changes: 5 additions & 0 deletions larq_compute_engine/mlir/python/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def convert_keras_model(
*, # Require remaining arguments to be keyword-only.
inference_input_type: tf.DType = tf.float32,
inference_output_type: tf.DType = tf.float32,
target: str = "arm",
experimental_default_int8_range: Optional[Tuple[float, float]] = None,
experimental_enable_bitpacked_activations: bool = False,
) -> bytes:
Expand All @@ -73,6 +74,7 @@ def convert_keras_model(
must be either `tf.float32` or `tf.int8`.
inference_output_type: Data type of the output layer. Defaults to `tf.float32`,
must be either `tf.float32` or `tf.int8`.
target: Target hardware platform. Must be "arm" or "xcore".
experimental_default_int8_range: Tuple of integers representing `(min, max)`
range values for all arrays without a specified range. Intended for
experimenting with quantization via "dummy quantization". (default None)
Expand All @@ -98,6 +100,8 @@ def convert_keras_model(
"Expected `inference_output_type` to be either `tf.float32` or `tf.int8`, "
f"got {inference_output_type}."
)
if target not in ("arm", "xcore"):
raise ValueError(f'Expected `target` to be "arm" or "xcore", but got {target}.')

if not tf.executing_eagerly():
raise RuntimeError(
Expand Down Expand Up @@ -147,6 +151,7 @@ def convert_keras_model(
[tensor.shape.as_list() for tensor in input_tensors],
[get_tensor_name(tensor) for tensor in output_tensors],
should_quantize,
target,
experimental_default_int8_range,
experimental_enable_bitpacked_activations,
)
Expand Down
15 changes: 15 additions & 0 deletions larq_compute_engine/mlir/python/converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_larq_zoo_models(self):
[[1, 224, 224, 3]],
["Identity"],
False,
"arm",
None,
False,
)
Expand All @@ -39,6 +40,20 @@ def test_wrong_arg(self):
with self.assertRaises(ValueError):
convert_keras_model("./model.h5")

def test_target_arg(self):
with context.eager_mode():
model = lqz.sota.QuickNet(weights=None)

# These should work
convert_keras_model(model, target="arm")
convert_keras_model(model, target="xcore")

# Anything else shouldn't
with self.assertRaises(
ValueError, msg='Expected `target` to be "arm" or "xcore"'
):
convert_keras_model(model, target="x86")


if __name__ == "__main__":
unittest.main()
13 changes: 11 additions & 2 deletions larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,23 @@ pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer(
const std::vector<string>& input_dtypes,
const std::vector<std::vector<int>>& input_shapes,
const std::vector<string>& output_arrays, const bool should_quantize,
const pybind11::object& default_ranges,
const std::string& target_str, const pybind11::object& default_ranges,
const bool experimental_enable_bitpacked_activations) {
GraphDef graphdef;
if (!tensorflow::LoadProtoFromBuffer(std::string(graphdef_bytes), &graphdef)
.ok()) {
throw std::runtime_error("Could not load GraphDef.");
}

LCETarget target;
if (target_str == "arm") {
target = LCETarget::ARM;
} else if (target_str == "xcore") {
target = LCETarget::XCORE;
} else {
throw std::runtime_error("Invalid target.");
}

GraphImportConfig specs;
specs.prune_unused_nodes = true;
specs.convert_legacy_fed_inputs = true;
Expand Down Expand Up @@ -88,7 +97,7 @@ pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer(
}
mlir::PassManager pm(&context);
tensorflow::AddTFToLCETFLConversionPasses(
quant_specs, &pm, experimental_enable_bitpacked_activations);
quant_specs, &pm, target, experimental_enable_bitpacked_activations);

// Convert back to outlined while format for export back to flatbuffer.
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
Expand Down
11 changes: 5 additions & 6 deletions larq_compute_engine/mlir/tf_tfl_passes.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "larq_compute_engine/mlir/tf_tfl_passes.h"

#include "larq_compute_engine/mlir/transforms/passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
Expand Down Expand Up @@ -43,8 +42,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,

void AddTFToLCETFLConversionPasses(
const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager,
bool experimental_enable_bitpacked_activations) {
mlir::OpPassManager* pass_manager, const LCETarget target,
const bool experimental_enable_bitpacked_activations) {
mlir::TF::StandardPipelineOptions standard_pipeline_options;
standard_pipeline_options.enable_inliner = false;
standard_pipeline_options.form_clusters = false;
Expand Down Expand Up @@ -126,7 +125,7 @@ void AddTFToLCETFLConversionPasses(
mlir::TF::CreateLayoutOptimizationPipeline(*pass_manager,
layout_optimization_options);
// Inject Larq Compute Engine Ops
pass_manager->addPass(mlir::TFL::CreatePrepareLCEPass());
pass_manager->addPass(mlir::TFL::CreatePrepareLCEPass(target));
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
// the TFLite dialect.
pass_manager->addPass(mlir::TFL::CreatePrepareTFPass(true));
Expand All @@ -144,10 +143,10 @@ void AddTFToLCETFLConversionPasses(
pass_manager->addPass(mlir::TF::CreateInitTextFileToImportPass());

pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass(true));
pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(false));
pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(target, false));
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(
experimental_enable_bitpacked_activations));
target, experimental_enable_bitpacked_activations));
pass_manager->addPass(mlir::TFL::CreateBitpackWeightsLCEPass());
// This pass operates on TensorFlow ops but is triggered after legalization
// so that it can target constants introduced once TensorFlow Identity ops
Expand Down
5 changes: 3 additions & 2 deletions larq_compute_engine/mlir/tf_tfl_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <functional>

#include "larq_compute_engine/mlir/transforms/passes.h"
#include "mlir/Pass/PassManager.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"

Expand All @@ -11,8 +12,8 @@ namespace tensorflow {
// Add the TF to TFLite passes into a pass_manager.
void AddTFToLCETFLConversionPasses(
const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager,
bool experimental_enable_bitpacked_activations = false);
mlir::OpPassManager* pass_manager, const LCETarget target = LCETarget::ARM,
const bool experimental_enable_bitpacked_activations = false);

} // namespace tensorflow

Expand Down
33 changes: 25 additions & 8 deletions larq_compute_engine/mlir/transforms/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "larq_compute_engine/core/bitpacking/bitpack.h"
#include "larq_compute_engine/mlir/ir/lce_ops.h"
#include "larq_compute_engine/mlir/transforms/passes.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
Expand All @@ -19,15 +20,20 @@ namespace {
// Optimize LCE operations in functions.
struct OptimizeLCE : public PassWrapper<OptimizeLCE, FunctionPass> {
public:
// The default value must be true so that we can run with the optimisation in
// the file-check tests.
explicit OptimizeLCE() : experimental_enable_bitpacked_activations_(true) {}
explicit OptimizeLCE(bool experimental_enable_bitpacked_activations)
: experimental_enable_bitpacked_activations_(
// The default values must be ARM and true so that we can run with all
// patterns in the file-check tests.
explicit OptimizeLCE()
: target_(LCETarget::ARM),
experimental_enable_bitpacked_activations_(true) {}
explicit OptimizeLCE(LCETarget target,
bool experimental_enable_bitpacked_activations)
: target_(target),
experimental_enable_bitpacked_activations_(
experimental_enable_bitpacked_activations) {}
void runOnFunction() override;

private:
LCETarget target_;
bool experimental_enable_bitpacked_activations_;
};

Expand All @@ -40,7 +46,13 @@ bool IsConstantValue(Attribute values, float expected_value) {
return true;
}

namespace target_arm {
#include "larq_compute_engine/mlir/transforms/target_arm/generated_optimize.inc"
}

namespace target_other {
#include "larq_compute_engine/mlir/transforms/generated_optimize.inc"
}

/**
* =================================================
Expand Down Expand Up @@ -301,7 +313,11 @@ void OptimizeLCE::runOnFunction() {
auto* ctx = &getContext();
auto func = getFunction();

TFL::populateWithGenerated(ctx, patterns);
if (target_ == LCETarget::ARM) {
target_arm::populateWithGenerated(ctx, patterns);
} else {
target_other::populateWithGenerated(ctx, patterns);
}
if (experimental_enable_bitpacked_activations_) {
patterns.insert<SetBitpackedActivations>(ctx);
}
Expand All @@ -312,9 +328,10 @@ void OptimizeLCE::runOnFunction() {

// Creates an instance of the TensorFlow dialect OptimizeLCE pass.
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizeLCEPass(
bool experimental_enable_bitpacked_activations) {
const LCETarget target,
const bool experimental_enable_bitpacked_activations) {
return std::make_unique<OptimizeLCE>(
experimental_enable_bitpacked_activations);
target, experimental_enable_bitpacked_activations);
}

static PassRegistration<OptimizeLCE> pass(
Expand Down
22 changes: 3 additions & 19 deletions larq_compute_engine/mlir/transforms/optimize_patterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def F32ElementsAttr : ElementsAttrBase<
// Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;

class ConstantValue<string val> : AttrConstraint<CPred<"IsConstantValue($_self, " # val # ")">>;

// TODO: Check shapes before fusing
multiclass FuseAddOrSubWithBConv2D<dag binaryOp> {
def : Pat<(binaryOp
Expand Down Expand Up @@ -83,13 +85,11 @@ multiclass FuseMulOrDivWithBConv2D<dag binaryOp> {
$stride_width),
[(HasOneUse $conv_output)], (addBenefit 100)>;
}

foreach binaryOp = [TFL_DivOp, TFL_MulOp] in
defm : FuseMulOrDivWithBConv2D<binaryOp>;

class ConstantValue<string val> : AttrConstraint<CPred<"IsConstantValue($_self, " # val # ")">>;

// Fuse activation function into BConv2D
// Fuse an activation function into the BConv2D.
multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
def : Pat<(ActFnOp
(LQ_Bconv2dOp:$conv_output
Expand Down Expand Up @@ -152,23 +152,7 @@ multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
$stride_width),
[(HasOneUse $conv_output)], (addBenefit 100)>;
}

foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu1Op, TFL_AF_Relu1],
[TFL_Relu6Op, TFL_AF_Relu6]] in
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;

def : Pat<(LQ_QuantizeOp (TFL_MaxPool2DOp: $pool_output $input,
$padding,
$stride_w,
$stride_h,
$filter_width,
$filter_height,
$fused_activation_function)),
(LQ_BMaxPool2dOp (LQ_QuantizeOp $input),
$padding,
$stride_w,
$stride_h,
$filter_width,
$filter_height),
[(HasOneUse $pool_output)]>;
6 changes: 4 additions & 2 deletions larq_compute_engine/mlir/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@

#include "mlir/Pass/Pass.h"

enum LCETarget { ARM = 0, XCORE = 1 };

namespace mlir {
namespace TFL {

// Creates an instance of the TensorFlow dialect OpRemoval pass.
std::unique_ptr<OperationPass<FuncOp>> CreateOpRemovalPass();

// Creates an instance of the TensorFlow dialect PrepareLCE pass.
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareLCEPass();
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareLCEPass(LCETarget target);

// Creates an instance of the TensorFlow dialect OptimizeLCE pass.
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizeLCEPass(
bool experimental_enable_bitpacked_activations);
LCETarget target, bool experimental_enable_bitpacked_activations);

// Creates an instance of the TensorFlow dialect BitpackWeightsLCE pass.
std::unique_ptr<OperationPass<FuncOp>> CreateBitpackWeightsLCEPass();
Expand Down
Loading

0 comments on commit 5b5423e

Please sign in to comment.