Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SetBatchSize pass to MLIR converter #690

Merged
merged 3 commits into from
Dec 3, 2021
Merged

Conversation

Tombana
Copy link
Collaborator

@Tombana Tombana commented Dec 2, 2021

What do these changes do?

This adds a pass to the converter that sets the dynamic batchsize to 1 on the input layer.

When batch_size is not explicitly set to 1, it will remain a ? (wildcard) throughout all MLIR passes. Only at the final MLIR to Flatbuffer conversion stage, are all the ? simply converted to 1:
https://github.com/tensorflow/tensorflow/blob/v2.7.0/tensorflow/compiler/mlir/lite/flatbuffer_export.cc#L844

This causes broken Reshape and Concatenation ops: the MLIR passes will correctly infer the Reshape output shape to something like ?x?x2, and then the flatbuffer exporter simply overrides both with 1. Other artifacts

  • left-over StridedSlice ops in the graph (which should've been folded out)
  • invalid Reshape shapes: image
  • invalid concatenation shapes: image
  • I forgot the exact details but something like depthwise layers with dilation also got shape mismatches.

The graphdef converter code has a batch_size=1 fix in Python, but the saved model converter did not have something like this yet.

How Has This Been Tested?

MLIR filecheck tests have been added.

@Tombana Tombana requested a review from lgeiger December 2, 2021 12:52
@lgeiger lgeiger added the internal-improvement Internal Improvements and Maintenance label Dec 3, 2021
Copy link
Member

@lgeiger lgeiger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool, this is really useful!

I just have a few minor suggestions to improve readability, other than that this looks great!

Comment on lines 20 to 21
// CHECK: %arg0: tensor<1x6xf32>
// CHECK: %arg1: tensor<1x4xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also run shape inference (-tf-shape-inference) in this test and verify that the shapes are propagated correctly? Alternatively we can also keep it like and simplify the test cases further to only test the the actual change in the input tensor size.

Comment on lines 28 to 29
%0 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<!tf.resource<tensor<6xf32>>>) -> tensor<6xf32>
%1 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor<!tf.resource<tensor<4x6xf32>>>) -> tensor<4x6xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove tf.ReadVariableOp and tf.Identity ops to improve readability of the test case.

Tombana and others added 2 commits December 3, 2021 14:29
@Tombana Tombana enabled auto-merge (squash) December 3, 2021 14:29
@Tombana Tombana merged commit 8e29cde into main Dec 3, 2021
@Tombana Tombana deleted the converter_fix_batchsize branch December 3, 2021 18:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
internal-improvement Internal Improvements and Maintenance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants