Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][TFLite] Add parser support for arg_min_max #4704

Closed
wants to merge 3 commits into from

Conversation

inadob
Copy link
Contributor

@inadob inadob commented Jan 14, 2020

  • this implementation supports only the case when the axis is a scalar
  • tflite 1.13 removes all dims of size 1, Relay doesn't do this
  • WARNING: every newer version of tflite > 1.13 needs keepdims=False
    in the corresponding Relay operation. If there is a way to check the tflite version that was used to
    convert the model passed to the parser, I will be happy to implement that.

inadob and others added 2 commits January 14, 2020 15:01
* this implementation supports only the case when the axis is a scalar
* tflite 1.13 removes all dims of size 1, Relay doesn't do this
* WARNING: every newer version of tflite > 1.13 needs keepdims=TRUE
@inadob
Copy link
Contributor Author

inadob commented Jan 14, 2020

can you please review this PR @anijain2305 @FrozenGene

Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments. Sorry for such a late review.

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
axis_tensor = input_tensors[1]
# we support the case when the axis is a scalar not a tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A check will be helpful here along with the comment

output_dtype = arg_min_max_options.OutputType()

# set keepdims to True since tflite 1.13 removes all dims of size 1
# WARNING: all other versions of tflite > 1.13 need keepdims=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a TFLite version check?

Copy link
Contributor Author

@inadob inadob Feb 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here is that we can't use the same TFLite version check as we do in the tests since I do not think there is a way to get such information from a model that you pass to the parser.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After doing some further investigation, it turned out that arg_min and arg_max are fundamentally broken in TFL 1.13 since the behaviour described in the in flatbuff file is different from what is actually calculated if we do the inference directly (as in all our parser tests).

You can see the changes between TFL 1.13 and TFL 1.14 for these ops here: tensorflow/tensorflow@1d0552e

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen tqchen added the status: need update need update based on feedbacks label Feb 6, 2020
@tqchen
Copy link
Member

tqchen commented Feb 26, 2020

ping @inadob @anijain2305 please followup

@tqchen
Copy link
Member

tqchen commented Jun 26, 2020

ping @inadob @anijain2305

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status: need review status: need update need update based on feedbacks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants