Skip to content

Commit

Permalink
Fix TFLite RESHAPE assert (apache#4320)
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov authored and Xingyu Zhou committed Nov 26, 2019
1 parent 7d041bd commit ba8c981
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def convert_reshape(self, op):

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"
assert input_tensors, "input tensors should not be empty"
input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx

Expand Down
21 changes: 21 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
except ImportError:
from tensorflow.contrib import lite as interpreter_wrapper

from tvm.contrib.download import download_testdata
import tvm.relay.testing.tf as tf_testing
from packaging import version as package_version

Expand Down Expand Up @@ -1137,6 +1138,25 @@ def test_forward_ssd_mobilenet_v1():
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)

#######################################################################
# MediaPipe
# -------------

def test_forward_mediapipe_hand_landmark():
"""Test MediaPipe 2D hand landmark TF Lite model."""
# MediaPipe 2D hand landmark TF
tflite_model_file = download_testdata(
"https://github.com/google/mediapipe/raw/master/mediapipe/models/hand_landmark.tflite",
"hand_landmark.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 256, 256, 3)).astype('float32')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input_1', num_output=2)
for i in range(2):
tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]),
rtol=1e-5, atol=1e-5)

#######################################################################
# Main
# ----
Expand Down Expand Up @@ -1192,6 +1212,7 @@ def test_forward_ssd_mobilenet_v1():
test_forward_inception_v3_net()
test_forward_inception_v4_net()
test_forward_ssd_mobilenet_v1()
test_forward_mediapipe_hand_landmark()

# End to End quantized
test_forward_qnn_inception_v1_net()
Expand Down

0 comments on commit ba8c981

Please sign in to comment.