-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for tflite arg_min and arg_max #5992
Conversation
* this implementation supports only the case when the axis is a scalar * tflite 1.13 removes all dims of size 1, Relay doesn't do this * WARNING: every newer version of tflite > 1.13 needs keepdims=TRUE
keepdims set to False and added some checks Note the unit tests emmitted following warning: /workspace/src/te/schedule/bound.cc:119: not in feed graph consumer = compute(T_multiply_red_temp, 0x53f5050)
@d-smirnov can you get some other ppl to review as well? |
python/tvm/relay/frontend/tflite.py
Outdated
def _convert_arg_min_max(self, relay_op, op): | ||
"""Generic method to convert TFLite arg_min_max""" | ||
try: | ||
from tflite.Operator import Operator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this. not required
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is something like this?
def convert_arg_max(self, op):
"""Convert TFLite ARG_MAX"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized ARG_MAX operator is not supported yet.')
return self._convert_unary_elemwise(_op.arg_max, op)
Removed quantized argmin due to inablility to provide proper test case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks @d-smirnov @siju-samuel @MarisaKirisame This is merged! |
* [Relay][Frontend][TFLite] Add parser support for arg_min_max * this implementation supports only the case when the axis is a scalar * tflite 1.13 removes all dims of size 1, Relay doesn't do this * WARNING: every newer version of tflite > 1.13 needs keepdims=TRUE * Migrated to tflite 2.1.0 keepdims set to False and added some checks Note the unit tests emmitted following warning: /workspace/src/te/schedule/bound.cc:119: not in feed graph consumer = compute(T_multiply_red_temp, 0x53f5050) * linter * Removed quantized argmin Removed quantized argmin due to inablility to provide proper test case * added negative ranges * re-trigger CI Co-authored-by: Ina_Dobreva <Ina.Dobreva@arm.com>
* [Relay][Frontend][TFLite] Add parser support for arg_min_max * this implementation supports only the case when the axis is a scalar * tflite 1.13 removes all dims of size 1, Relay doesn't do this * WARNING: every newer version of tflite > 1.13 needs keepdims=TRUE * Migrated to tflite 2.1.0 keepdims set to False and added some checks Note the unit tests emmitted following warning: /workspace/src/te/schedule/bound.cc:119: not in feed graph consumer = compute(T_multiply_red_temp, 0x53f5050) * linter * Removed quantized argmin Removed quantized argmin due to inablility to provide proper test case * added negative ranges * re-trigger CI Co-authored-by: Ina_Dobreva <Ina.Dobreva@arm.com>
"Add support for tflite arg_min and arg_max"
This is an updated version of abandoned PR [Relay][Frontend][TFLite] Add parser support for arg_min_max #4704 rebased on top of current master with small updates related to tflite 2.1.0.
Please note unit tests for arg_max and arg_min emit following warning:
/workspace/src/te/schedule/bound.cc:119: not in feed graph consumer = compute(T_multiply_red_temp, 0x53f5050)