-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SetBatchSize pass to MLIR converter (#690)
* Add SetBatchSize pass to MLIR converter Co-authored-by: Lukas Geiger <lukas.geiger94@gmail.com>
- Loading branch information
Showing
5 changed files
with
142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// RUN: lce-tf-opt %s -mlir-setbatchsize -verify-diagnostics | FileCheck %s | ||
|
||
// CHECK-LABEL: @simple | ||
func @simple(%arg0: tensor<?x6xf32>, %arg1: tensor<2x6xf32>) -> (tensor<?x6xf32>) { | ||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?x6xf32>, tensor<2x6xf32>) -> tensor<?x6xf32> | ||
return %0 : tensor<?x6xf32> | ||
// CHECK: %arg0: tensor<1x6xf32> | ||
// Check that the 'batch' size of the second argument is *not* changed to 1 | ||
// CHECK: %arg1: tensor<2x6xf32> | ||
} | ||
|
||
// This is an IR dump from the following simple 2-input model | ||
// This is to ensure that the pass does not destroy the extra function attributes that are present | ||
// img1 = tf.keras.layers.Input(shape=(4,)) | ||
// img2 = tf.keras.layers.Input(shape=(6,)) | ||
// x = tf.keras.layers.Dense(6)(img1) + img2 | ||
// return tf.keras.Model([img1, img2], x) | ||
// Both inputs have a dynamic batch size | ||
|
||
// CHECK-LABEL: @dual_input_model | ||
func @dual_input_model(%arg0: tensor<?x6xf32> {tf_saved_model.index_path = ["input_2"]}, %arg1: tensor<?x4xf32> {tf_saved_model.index_path = ["input_1"]}, %arg2: tensor<!tf.resource<tensor<6xf32>>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor<!tf.resource<tensor<4x6xf32>>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor<?x6xf32> {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { | ||
%0 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<!tf.resource<tensor<6xf32>>>) -> tensor<6xf32> | ||
%1 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor<!tf.resource<tensor<4x6xf32>>>) -> tensor<4x6xf32> | ||
%2 = "tf.MatMul"(%arg1, %1) {device = "", transpose_a = false, transpose_b = false} : (tensor<?x4xf32>, tensor<4x6xf32>) -> tensor<?x6xf32> | ||
%3 = "tf.BiasAdd"(%2, %0) {data_format = "NHWC", device = ""} : (tensor<?x6xf32>, tensor<6xf32>) -> tensor<?x6xf32> | ||
%4 = "tf.AddV2"(%3, %arg0) {device = ""} : (tensor<?x6xf32>, tensor<?x6xf32>) -> tensor<?x6xf32> | ||
%5 = "tf.Identity"(%4) {device = ""} : (tensor<?x6xf32>) -> tensor<?x6xf32> | ||
%6 = "tf.Identity"(%5) {device = ""} : (tensor<?x6xf32>) -> tensor<?x6xf32> | ||
return %6 : tensor<?x6xf32> | ||
// CHECK: %arg0: tensor<1x6xf32> {tf_saved_model.index_path = ["input_2"]} | ||
// CHECK: %arg1: tensor<1x4xf32> {tf_saved_model.index_path = ["input_1"]} | ||
// The resource objects and attributes should be unchanged | ||
// CHECK: %arg2: tensor<!tf.resource<tensor<6xf32>>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor<!tf.resource<tensor<4x6xf32>>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor<?x6xf32> {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { | ||
} | ||
|
||
// This is the same model, but one of the two inputs has been given a fixed batch size in Python | ||
|
||
// CHECK-LABEL: @dual_input_one_fixed_size | ||
func @dual_input_one_fixed_size(%arg0: tensor<?x6xf32> {tf_saved_model.index_path = ["input_2"]}, %arg1: tensor<1x4xf32> {tf_saved_model.index_path = ["input_1"]}, %arg2: tensor<!tf.resource<tensor<6xf32>>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor<!tf.resource<tensor<4x6xf32>>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor<?x6xf32> {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { | ||
%0 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<!tf.resource<tensor<6xf32>>>) -> tensor<6xf32> | ||
%1 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor<!tf.resource<tensor<4x6xf32>>>) -> tensor<4x6xf32> | ||
%2 = "tf.MatMul"(%arg1, %1) {device = "", transpose_a = false, transpose_b = false} : (tensor<1x4xf32>, tensor<4x6xf32>) -> tensor<1x6xf32> | ||
%3 = "tf.BiasAdd"(%2, %0) {data_format = "NHWC", device = ""} : (tensor<1x6xf32>, tensor<6xf32>) -> tensor<1x6xf32> | ||
%4 = "tf.AddV2"(%3, %arg0) {device = ""} : (tensor<1x6xf32>, tensor<?x6xf32>) -> tensor<?x6xf32> | ||
%5 = "tf.Identity"(%4) {device = ""} : (tensor<?x6xf32>) -> tensor<?x6xf32> | ||
%6 = "tf.Identity"(%5) {device = ""} : (tensor<?x6xf32>) -> tensor<?x6xf32> | ||
return %6 : tensor<?x6xf32> | ||
// CHECK: %arg0: tensor<1x6xf32> {tf_saved_model.index_path = ["input_2"]} | ||
// CHECK: %arg1: tensor<1x4xf32> {tf_saved_model.index_path = ["input_1"]} | ||
// CHECK: %arg2: tensor<!tf.resource<tensor<6xf32>>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor<!tf.resource<tensor<4x6xf32>>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor<?x6xf32> {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#include "mlir/Pass/Pass.h" | ||
|
||
// This pass will set the batch dimension of all inputs of the outermost | ||
// function to 1, leaving it to shape inference to do the rest. | ||
|
||
namespace mlir { | ||
|
||
namespace { | ||
|
||
mlir::Type SetBatchSize(mlir::Type type) { | ||
auto tensor_type = type.dyn_cast<mlir::TensorType>(); | ||
if (tensor_type && tensor_type.hasRank()) { | ||
auto shape = tensor_type.getShape(); | ||
if (shape.size() > 0 && shape[0] == ShapedType::kDynamicSize) { | ||
// Create a new shape but set the first dimension to 1 | ||
llvm::SmallVector<int64_t, 4> shape_new(shape.begin(), shape.end()); | ||
shape_new[0] = 1; | ||
|
||
return tensor_type.clone(shape_new); | ||
} | ||
} | ||
return nullptr; | ||
} | ||
|
||
struct SetBatchSizePass : public PassWrapper<SetBatchSizePass, FunctionPass> { | ||
void runOnFunction() override { | ||
FuncOp func = getFunction(); | ||
|
||
// We have to edit both the function signature (mlir::Type) *and* the | ||
// function arguments (mlir::Value) | ||
|
||
// mlir::FunctionType is a TableGen-autogenerated MLIR type | ||
mlir::FunctionType signature = func.getType(); | ||
|
||
// Create a mutable copy of the input types, since getInputs returns an | ||
// immutable llvm::ArrayRef<mlir::Type> | ||
std::vector<mlir::Type> signature_inputs(signature.getInputs()); | ||
|
||
for (auto& input_type : signature_inputs) { | ||
auto new_type = SetBatchSize(input_type); | ||
if (new_type) input_type = new_type; | ||
} | ||
|
||
auto signature_new = mlir::FunctionType::get( | ||
signature.getContext(), signature_inputs, signature.getResults()); | ||
// Set the new signature | ||
func.typeAttr(mlir::TypeAttr::get(signature_new)); | ||
|
||
// Now apply the same change to the mlir::Value objects | ||
for (mlir::BlockArgument arg : func.getArguments()) { | ||
// mlir::BlockArgument is a sublcass of mlir::Value | ||
auto new_type = SetBatchSize(arg.getType()); | ||
if (new_type) arg.setType(new_type); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
// Creates an instance of the ZeroPointCompatibility pass. | ||
std::unique_ptr<OperationPass<FuncOp>> CreateSetBatchSizePass() { | ||
return std::make_unique<SetBatchSizePass>(); | ||
} | ||
|
||
static PassRegistration<SetBatchSizePass> pass("mlir-setbatchsize", | ||
"Set batch size to 1"); | ||
|
||
} // namespace mlir |