-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Frontend][TFLite] Add parser support for arg_min_max #4704
Conversation
* this implementation supports only the case when the axis is a scalar * tflite 1.13 removes all dims of size 1, Relay doesn't do this * WARNING: every newer version of tflite > 1.13 needs keepdims=TRUE
can you please review this PR @anijain2305 @FrozenGene |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments. Sorry for such a late review.
input_tensor = input_tensors[0] | ||
in_expr = self.get_expr(input_tensor.tensor_idx) | ||
axis_tensor = input_tensors[1] | ||
# we support the case when the axis is a scalar not a tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A check will be helpful here along with the comment
output_dtype = arg_min_max_options.OutputType() | ||
|
||
# set keepdims to True since tflite 1.13 removes all dims of size 1 | ||
# WARNING: all other versions of tflite > 1.13 need keepdims=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a TFLite version check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue here is that we can't use the same TFLite version check as we do in the tests since I do not think there is a way to get such information from a model that you pass to the parser.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After doing some further investigation, it turned out that arg_min
and arg_max
are fundamentally broken in TFL 1.13 since the behaviour described in the in flatbuff file is different from what is actually calculated if we do the inference directly (as in all our parser tests).
You can see the changes between TFL 1.13 and TFL 1.14 for these ops here: tensorflow/tensorflow@1d0552e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ping @inadob @anijain2305 please followup |
ping @inadob @anijain2305 |
in the corresponding Relay operation. If there is a way to check the tflite version that was used to
convert the model passed to the parser, I will be happy to implement that.