-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SetBatchSize pass to MLIR converter #690
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool, this is really useful!
I just have a few minor suggestions to improve readability, other than that this looks great!
// CHECK: %arg0: tensor<1x6xf32> | ||
// CHECK: %arg1: tensor<1x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also run shape inference (-tf-shape-inference
) in this test and verify that the shapes are propagated correctly? Alternatively we can also keep it like and simplify the test cases further to only test the the actual change in the input tensor size.
%0 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<!tf.resource<tensor<6xf32>>>) -> tensor<6xf32> | ||
%1 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor<!tf.resource<tensor<4x6xf32>>>) -> tensor<4x6xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove tf.ReadVariableOp
and tf.Identity
ops to improve readability of the test case.
Co-authored-by: Lukas Geiger <lukas.geiger94@gmail.com>
What do these changes do?
This adds a pass to the converter that sets the dynamic batchsize to 1 on the input layer.
When
batch_size
is not explicitly set to 1, it will remain a?
(wildcard) throughout all MLIR passes. Only at the final MLIR to Flatbuffer conversion stage, are all the?
simply converted to1
:https://github.com/tensorflow/tensorflow/blob/v2.7.0/tensorflow/compiler/mlir/lite/flatbuffer_export.cc#L844
This causes broken Reshape and Concatenation ops: the MLIR passes will correctly infer the Reshape output shape to something like
?x?x2
, and then the flatbuffer exporter simply overrides both with1
. Other artifactsStridedSlice
ops in the graph (which should've been folded out)The graphdef converter code has a
batch_size=1
fix in Python, but the saved model converter did not have something like this yet.How Has This Been Tested?
MLIR filecheck tests have been added.