Skip to content

Commit

Permalink
Add SetBatchSize pass to MLIR converter (#690)
Browse files Browse the repository at this point in the history
* Add SetBatchSize pass to MLIR converter

Co-authored-by: Lukas Geiger <lukas.geiger94@gmail.com>
  • Loading branch information
Tombana and lgeiger authored Dec 3, 2021
1 parent 06d6ece commit 8e29cde
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 0 deletions.
15 changes: 15 additions & 0 deletions larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,20 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "set_batch_size",
srcs = [
"transforms/set_batch_size.cc",
],
hdrs = [
"transforms/passes.h",
],
deps = [
"@llvm-project//mlir:StandardOps",
],
alwayslink = 1,
)

cc_library(
name = "lce_tfl_passes",
srcs = ["tf_tfl_passes.cc"],
Expand All @@ -292,6 +306,7 @@ cc_library(
":larq_compute_engine_optimize",
":larq_compute_engine_prepare",
":larq_compute_engine_quantize",
":set_batch_size",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
Expand Down
51 changes: 51 additions & 0 deletions larq_compute_engine/mlir/tests/set_batch_size.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: lce-tf-opt %s -mlir-setbatchsize -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @simple
func @simple(%arg0: tensor<?x6xf32>, %arg1: tensor<2x6xf32>) -> (tensor<?x6xf32>) {
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?x6xf32>, tensor<2x6xf32>) -> tensor<?x6xf32>
return %0 : tensor<?x6xf32>
// CHECK: %arg0: tensor<1x6xf32>
// Check that the 'batch' size of the second argument is *not* changed to 1
// CHECK: %arg1: tensor<2x6xf32>
}

// This is an IR dump from the following simple 2-input model
// This is to ensure that the pass does not destroy the extra function attributes that are present
// img1 = tf.keras.layers.Input(shape=(4,))
// img2 = tf.keras.layers.Input(shape=(6,))
// x = tf.keras.layers.Dense(6)(img1) + img2
// return tf.keras.Model([img1, img2], x)
// Both inputs have a dynamic batch size

// CHECK-LABEL: @dual_input_model
func @dual_input_model(%arg0: tensor<?x6xf32> {tf_saved_model.index_path = ["input_2"]}, %arg1: tensor<?x4xf32> {tf_saved_model.index_path = ["input_1"]}, %arg2: tensor<!tf.resource<tensor<6xf32>>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor<!tf.resource<tensor<4x6xf32>>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor<?x6xf32> {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<!tf.resource<tensor<6xf32>>>) -> tensor<6xf32>
%1 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor<!tf.resource<tensor<4x6xf32>>>) -> tensor<4x6xf32>
%2 = "tf.MatMul"(%arg1, %1) {device = "", transpose_a = false, transpose_b = false} : (tensor<?x4xf32>, tensor<4x6xf32>) -> tensor<?x6xf32>
%3 = "tf.BiasAdd"(%2, %0) {data_format = "NHWC", device = ""} : (tensor<?x6xf32>, tensor<6xf32>) -> tensor<?x6xf32>
%4 = "tf.AddV2"(%3, %arg0) {device = ""} : (tensor<?x6xf32>, tensor<?x6xf32>) -> tensor<?x6xf32>
%5 = "tf.Identity"(%4) {device = ""} : (tensor<?x6xf32>) -> tensor<?x6xf32>
%6 = "tf.Identity"(%5) {device = ""} : (tensor<?x6xf32>) -> tensor<?x6xf32>
return %6 : tensor<?x6xf32>
// CHECK: %arg0: tensor<1x6xf32> {tf_saved_model.index_path = ["input_2"]}
// CHECK: %arg1: tensor<1x4xf32> {tf_saved_model.index_path = ["input_1"]}
// The resource objects and attributes should be unchanged
// CHECK: %arg2: tensor<!tf.resource<tensor<6xf32>>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor<!tf.resource<tensor<4x6xf32>>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor<?x6xf32> {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
}

// This is the same model, but one of the two inputs has been given a fixed batch size in Python

// CHECK-LABEL: @dual_input_one_fixed_size
func @dual_input_one_fixed_size(%arg0: tensor<?x6xf32> {tf_saved_model.index_path = ["input_2"]}, %arg1: tensor<1x4xf32> {tf_saved_model.index_path = ["input_1"]}, %arg2: tensor<!tf.resource<tensor<6xf32>>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor<!tf.resource<tensor<4x6xf32>>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor<?x6xf32> {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<!tf.resource<tensor<6xf32>>>) -> tensor<6xf32>
%1 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor<!tf.resource<tensor<4x6xf32>>>) -> tensor<4x6xf32>
%2 = "tf.MatMul"(%arg1, %1) {device = "", transpose_a = false, transpose_b = false} : (tensor<1x4xf32>, tensor<4x6xf32>) -> tensor<1x6xf32>
%3 = "tf.BiasAdd"(%2, %0) {data_format = "NHWC", device = ""} : (tensor<1x6xf32>, tensor<6xf32>) -> tensor<1x6xf32>
%4 = "tf.AddV2"(%3, %arg0) {device = ""} : (tensor<1x6xf32>, tensor<?x6xf32>) -> tensor<?x6xf32>
%5 = "tf.Identity"(%4) {device = ""} : (tensor<?x6xf32>) -> tensor<?x6xf32>
%6 = "tf.Identity"(%5) {device = ""} : (tensor<?x6xf32>) -> tensor<?x6xf32>
return %6 : tensor<?x6xf32>
// CHECK: %arg0: tensor<1x6xf32> {tf_saved_model.index_path = ["input_2"]}
// CHECK: %arg1: tensor<1x4xf32> {tf_saved_model.index_path = ["input_1"]}
// CHECK: %arg2: tensor<!tf.resource<tensor<6xf32>>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor<!tf.resource<tensor<4x6xf32>>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor<?x6xf32> {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
}
4 changes: 4 additions & 0 deletions larq_compute_engine/mlir/tf_tfl_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ void AddTFToLCETFLConversionPasses(
pass_manager->addPass(
mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());

// Set the batch size of the function input to 1 and let shape inference
// propagate this in the next pass.
pass_manager->addPass(mlir::CreateSetBatchSizePass());

// Add a shape inference pass to optimize away the unnecessary casts.
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());

Expand Down
4 changes: 4 additions & 0 deletions larq_compute_engine/mlir/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLCEQuantizePass();
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeLCEPass();

} // namespace TFL

// Creates an instance of the TensorFlow dialect SetBatchSize pass
std::unique_ptr<OperationPass<FuncOp>> CreateSetBatchSizePass();

} // namespace mlir

#endif // LARQ_COMPUTE_ENGINE_MLIR_PASSES_H_
68 changes: 68 additions & 0 deletions larq_compute_engine/mlir/transforms/set_batch_size.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include "mlir/Pass/Pass.h"

// This pass will set the batch dimension of all inputs of the outermost
// function to 1, leaving it to shape inference to do the rest.

namespace mlir {

namespace {

mlir::Type SetBatchSize(mlir::Type type) {
auto tensor_type = type.dyn_cast<mlir::TensorType>();
if (tensor_type && tensor_type.hasRank()) {
auto shape = tensor_type.getShape();
if (shape.size() > 0 && shape[0] == ShapedType::kDynamicSize) {
// Create a new shape but set the first dimension to 1
llvm::SmallVector<int64_t, 4> shape_new(shape.begin(), shape.end());
shape_new[0] = 1;

return tensor_type.clone(shape_new);
}
}
return nullptr;
}

struct SetBatchSizePass : public PassWrapper<SetBatchSizePass, FunctionPass> {
void runOnFunction() override {
FuncOp func = getFunction();

// We have to edit both the function signature (mlir::Type) *and* the
// function arguments (mlir::Value)

// mlir::FunctionType is a TableGen-autogenerated MLIR type
mlir::FunctionType signature = func.getType();

// Create a mutable copy of the input types, since getInputs returns an
// immutable llvm::ArrayRef<mlir::Type>
std::vector<mlir::Type> signature_inputs(signature.getInputs());

for (auto& input_type : signature_inputs) {
auto new_type = SetBatchSize(input_type);
if (new_type) input_type = new_type;
}

auto signature_new = mlir::FunctionType::get(
signature.getContext(), signature_inputs, signature.getResults());
// Set the new signature
func.typeAttr(mlir::TypeAttr::get(signature_new));

// Now apply the same change to the mlir::Value objects
for (mlir::BlockArgument arg : func.getArguments()) {
// mlir::BlockArgument is a sublcass of mlir::Value
auto new_type = SetBatchSize(arg.getType());
if (new_type) arg.setType(new_type);
}
}
};

} // namespace

// Creates an instance of the ZeroPointCompatibility pass.
std::unique_ptr<OperationPass<FuncOp>> CreateSetBatchSizePass() {
return std::make_unique<SetBatchSizePass>();
}

static PassRegistration<SetBatchSizePass> pass("mlir-setbatchsize",
"Set batch size to 1");

} // namespace mlir

0 comments on commit 8e29cde

Please sign in to comment.