From 95a816c9078c5cc7cb08d354a069a15f5d18951c Mon Sep 17 00:00:00 2001 From: mbaret <55580676+mbaret@users.noreply.github.com> Date: Wed, 29 Apr 2020 17:13:16 +0100 Subject: [PATCH] [TFLITE] Match TFLite shape for SSD custom op (#5473) This patch ensures that the output shape from TVM's Detection_PostProcess is the same as TFLite's and expands the unit test to confirm this. Change-Id: If5db95741533f131241dfebbaa7708dbd528fe70 --- python/tvm/relay/frontend/tflite.py | 13 +++++++++---- tests/python/frontend/tflite/test_forward.py | 7 +++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index b9a165711913..66d0ff326ce0 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2257,6 +2257,7 @@ def convert_detection_postprocess(self, op): assert len(inputs) == 3, "inputs length should be 3" cls_pred = self.get_expr(inputs[1].tensor_idx) loc_prob = self.get_expr(inputs[0].tensor_idx) + batch_size = inputs[1].tensor.Shape(0) anchor_values = self.get_tensor_value(inputs[2]) anchor_boxes = len(anchor_values) anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type()) @@ -2284,7 +2285,7 @@ def convert_detection_postprocess(self, op): loc_prob = _op.concatenate( [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2 ) - loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4]) + loc_prob = _op.reshape(loc_prob, [batch_size, anchor_boxes*4]) # anchor coords are in yxhw format # need to convert to ltrb @@ -2327,10 +2328,14 @@ def convert_detection_postprocess(self, op): ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs) ret = _op.vision.get_valid_counts(ret, 0) valid_count = ret[0] + # keep only the top 'max_detections' rows + ret = _op.strided_slice(ret[1], + [0, 0, 0], + [batch_size, custom_options["max_detections"], anchor_boxes]) # the output needs some reshaping to match tflite - ret = _op.split(ret[1], 6, axis=2) - cls_ids = ret[0] - scores = ret[1] + ret = _op.split(ret, 6, axis=2) + cls_ids = _op.reshape(ret[0], [batch_size, -1]) + scores = _op.reshape(ret[1], [batch_size, -1]) boxes = _op.concatenate([ret[3], ret[2], ret[5], ret[4]], axis=2) ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4) return ret diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 7ff4c3135a91..bc3f32a2b0cd 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1731,7 +1731,14 @@ def test_detection_postprocess(): ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4) # check valid count is the same assert tvm_output[3] == tflite_output[3] + # check all the output shapes are the same + assert tvm_output[0].shape == tflite_output[0].shape + assert tvm_output[1].shape == tflite_output[1].shape + assert tvm_output[2].shape == tflite_output[2].shape valid_count = tvm_output[3][0] + # only check the valid detections are the same + # tvm has a different convention to tflite for invalid detections, it uses all -1s whereas + # tflite appears to put in nonsense data instead tvm_boxes = tvm_output[0][0][:valid_count] tvm_classes = tvm_output[1][0][:valid_count] tvm_scores = tvm_output[2][0][:valid_count]