Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][TFLite] fix detection_postprocess's non_max_suppression_attrs["force_suppress"] #12593

Merged
merged 2 commits into from
Sep 19, 2022

Conversation

czh978
Copy link
Contributor

@czh978 czh978 commented Aug 25, 2022

I found the value should be "True"

Copy link
Contributor

@leandron leandron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix @czh978.

Do you mind providing some context on why force_supress needs to be set this way? Can you give one example of something that wasn't working before this change, and works after the change?

@czh978
Copy link
Contributor Author

czh978 commented Aug 29, 2022

Thanks for the fix @czh978.

Do you mind providing some context on why force_supress needs to be set this way? Can you give one example of something that wasn't working before this change, and works after the change?

if "use_regular_nms" in custom_options:
if custom_options["use_regular_nms"]:
raise tvm.error.OpAttributeUnImplemented(
"use_regular_nms=True is not yet supported for operator {}.".format(
"TFLite_Detection_PostProcess"
)
)

The value of use_regular_nms is always Flase,becasue use_regular_nms=True is not yet supported in tvm。
https://github.com/tensorflow/tensorflow/blob/47c541330813f575057d7af90ef55985baca87eb/tensorflow/lite/kernels/detection_postprocess.cc#L882-L887
The result will suppress all detections regardless of class_id in tflite,when the value of use_regular_nms is always Flase.
Therefore,the value of force_suppress should be true.
tflite_graph_with_postprocess.zip
There is a model,you can try adjust the code in test_forward.py of tflite,where is in
def test_detection_postprocess():
"""Detection PostProcess"""
tf_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/object_detection/"
"ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
"ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb",
)
converter = tf.lite.TFLiteConverter.from_frozen_graph(
tf_model_file,
input_arrays=["raw_outputs/box_encodings", "raw_outputs/class_predictions"],
output_arrays=[
"TFLite_Detection_PostProcess",
"TFLite_Detection_PostProcess:1",
"TFLite_Detection_PostProcess:2",
"TFLite_Detection_PostProcess:3",
],
input_shapes={
"raw_outputs/box_encodings": (1, 1917, 4),
"raw_outputs/class_predictions": (1, 1917, 91),
},
)
converter.allow_custom_ops = True
converter.inference_type = tf.lite.constants.FLOAT
tflite_model = converter.convert()
np.random.seed(0)
box_encodings = np.random.uniform(size=(1, 1917, 4)).astype("float32")
class_predictions = np.random.uniform(size=(1, 1917, 91)).astype("float32")
tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions])

there is my example:
1661768456(1)
It will get a Mismatched elements.

Copy link
Contributor

@leandron leandron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any publicly available network that we can observe this behaviour and create a test case?

Copy link
Contributor

@leandron leandron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding a test case @czh978. I'm happy for this to be merged once CI is green.

The last thing we need is for the PR description to be improved with a description of your findings that as use_regular_nms is set as False in TVM, we need to have force_supress to be True, to avoid all detections to be suppressed.

Please expand that with the reasoning why this change is needed and point out that it reproduces on networks like ssd_mobilenet_v2 and other examples if you know any.

…rs["force_suppress"]

Since tvm only supports operators detection_postprocess use_regular_nms
is false, which will suppress boxes that exceed the threshold regardless
of the class when implementing NMS in tflite, in order for the results
of tvm and tflite to be consistent, we need to set force_suppress to
True.
…rs[force_suppress]

Added a test case that reproduces inconsistent results between tvm and tflite
When the force_suppress is false,it will get a good result if you set the force_suppress as true
@czh978
Copy link
Contributor Author

czh978 commented Sep 5, 2022

Thanks for adding a test case @czh978. I'm happy for this to be merged once CI is green.

The last thing we need is for the PR description to be improved with a description of your findings that as use_regular_nms is set as False in TVM, we need to have force_supress to be True, to avoid all detections to be suppressed.

Please expand that with the reasoning why this change is needed and point out that it reproduces on networks like ssd_mobilenet_v2 and other examples if you know any.

I've modified the commit message to meet the requirements of the new RFC, could you review it if you're available?Thanks.

@czh978
Copy link
Contributor Author

czh978 commented Sep 15, 2022

@leandron @AndrewZhaoLuo

@AndrewZhaoLuo AndrewZhaoLuo merged commit 60cf692 into apache:main Sep 19, 2022
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
…trs["force_suppress"] (apache#12593)

* [Frontend][TFLite]fix detection_postprocess's non_max_suppression_attrs["force_suppress"]

Since tvm only supports operators detection_postprocess use_regular_nms
is false, which will suppress boxes that exceed the threshold regardless
of the class when implementing NMS in tflite, in order for the results
of tvm and tflite to be consistent, we need to set force_suppress to
True.

* [Frontend][TFLite]fix detection_postprocess's non_max_suppression_attrs[force_suppress]

Added a test case that reproduces inconsistent results between tvm and tflite
When the force_suppress is false,it will get a good result if you set the force_suppress as true
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants