Skip to content

Commit

Permalink
Add test_forward_ssd_mobilenet_v1 to tflite/test_forward (#3350)
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov authored and kevinthesun committed Jun 14, 2019
1 parent 8a89177 commit 59d8ba8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
9 changes: 3 additions & 6 deletions python/tvm/relay/testing/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,10 @@ def get_workload_official(model_url, model_sub_path):
model_sub_path:
Sub path in extracted tar for the ftozen protobuf file.
temp_dir: TempDirectory
The temporary directory object to download the content.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for mobilenet.
model_path: str
Full path to saved model file
"""

Expand Down Expand Up @@ -200,7 +197,7 @@ def get_workload(model_path, model_sub_path=None):
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for mobilenet.
graph_def is the tensorflow workload.
"""

Expand Down
19 changes: 19 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,24 @@ def test_forward_inception_v4_net():
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)

#######################################################################
# SSD Mobilenet
# -------------

def test_forward_ssd_mobilenet_v1():
"""Test the SSD Mobilenet V1 TF Lite model."""
# SSD MobilenetV1
tflite_model_file = tf_testing.get_workload_official(
"https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28_nopp.tgz",
"ssd_mobilenet_v1_coco_2018_01_28_nopp.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)

#######################################################################
# Main
# ----
Expand All @@ -623,3 +641,4 @@ def test_forward_inception_v4_net():
test_forward_mobilenet_v2()
test_forward_inception_v3_net()
test_forward_inception_v4_net()
test_forward_ssd_mobilenet_v1()

0 comments on commit 59d8ba8

Please sign in to comment.