Skip to content

Commit

Permalink
The from_tflite() function should accept None as default value of inp…
Browse files Browse the repository at this point in the history
…ut_names and output_names. (#1967)

* The from_tflite() function should not change the value of None to an empty list for input_names and output_names.
* Change the way to validate a list is None or Emtpy.

Signed-off-by: Jay Zhang <jiz@microsoft.com>
  • Loading branch information
fatcat-z authored Jun 15, 2022
1 parent 9cea907 commit 89c4c5c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
24 changes: 20 additions & 4 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,27 @@ def test_tflite(self):

x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
model_proto, _ = tf2onnx.convert.from_tflite("tests/models/regression/tflite/test_api_model.tflite",
input_names=['input'], output_names=['output'],
input_names=["input"], output_names=["output"],
output_path=output_path)
output_names = [n.name for n in model_proto.graph.output]
oy = self.run_onnxruntime(output_path, {"input": x_val}, output_names)
self.assertTrue(output_names[0] == "output")
actual_output_names = [n.name for n in model_proto.graph.output]
oy = self.run_onnxruntime(output_path, {"input": x_val}, actual_output_names)

self.assertTrue(actual_output_names[0] == "output")
exp_result = tf.add(x_val, x_val)
self.assertAllClose(exp_result, oy[0], rtol=0.1, atol=0.1)

@check_tf_min_version("2.0")
def test_tflite_without_input_output_names(self):
output_path = os.path.join(self.test_data_directory, "model.onnx")

x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
model_proto, _ = tf2onnx.convert.from_tflite("tests/models/regression/tflite/test_api_model.tflite",
output_path=output_path)
actual_input_names = [n.name for n in model_proto.graph.input]
actual_output_names = [n.name for n in model_proto.graph.output]
oy = self.run_onnxruntime(output_path, {actual_input_names[0]: x_val}, output_names=None)

self.assertTrue(actual_output_names[0] == "output")
exp_result = tf.add(x_val, x_val)
self.assertAllClose(exp_result, oy[0], rtol=0.1, atol=0.1)

Expand Down
4 changes: 0 additions & 4 deletions tf2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,6 @@ def from_tflite(tflite_path, input_names=None, output_names=None, opset=None, cu
"""
if not tflite_path:
raise ValueError("tflite_path needs to be provided")
if not input_names:
input_names = []
if not output_names:
output_names = []

with tf.device("/cpu:0"):
model_proto, external_tensor_storage = _convert_common(
Expand Down
4 changes: 2 additions & 2 deletions tf2onnx/tflite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ def graphs_from_tflite(tflite_path, input_names=None, output_names=None):
if is_main_g:
# Override IO in main graph
utils.check_io(input_names, output_names, output_shapes.keys())
if input_names is not None:
if input_names:
g_inputs = input_names
if output_names is not None:
if output_names:
g_outputs = output_names
g = Graph(onnx_nodes, output_shapes, dtypes, input_names=g_inputs, output_names=g_outputs,
is_subgraph=not is_main_g, graph_name=graph_name)
Expand Down

0 comments on commit 89c4c5c

Please sign in to comment.