From 0969181f9a051722dbce3887b910ac340be7d945 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 21 Apr 2020 03:57:13 -0700 Subject: [PATCH] Add ability to have multiple copies of same input to onnx_inputs. (#5389) --- python/tvm/relay/frontend/onnx.py | 3 +-- tests/python/frontend/onnx/test_forward.py | 11 ++++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 527a1ed2f07b..245b3853ae90 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -57,8 +57,7 @@ def __setitem__(self, item, value): if isinstance(item, int): self.input_dict[self.input_keys[item]] = value elif isinstance(item, str): - if item not in self.input_dict: - self.input_keys.append(item) + self.input_keys.append(item) self.input_dict[item] = value else: raise ValueError("Only integer and string indexed writes allowed.") diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 2c0849451a25..c06aa50538f4 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1366,16 +1366,16 @@ def test_binary_ops(): dtype = "float32" out_shape = in_shape - def verify_binary_ops(op, x, y, out_np, broadcast=None): + def verify_binary_ops(op, x, y, out_np, x_name='in1', y_name='in2', broadcast=None): if broadcast is None: - z = helper.make_node(op, ['in1', 'in2'], ['out']) + z = helper.make_node(op, [x_name, y_name], ['out']) else: - z = helper.make_node(op, ['in1', 'in2'], ['out'], broadcast=1) + z = helper.make_node(op, [x_name, y_name], ['out'], broadcast=1) graph = helper.make_graph([z], '_test', - inputs=[helper.make_tensor_value_info("in1", + inputs=[helper.make_tensor_value_info(x_name, TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("in2", + helper.make_tensor_value_info(y_name, TensorProto.FLOAT, list(in_shape))], outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) @@ -1393,6 +1393,7 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None): verify_binary_ops("Sub", x, z, x - z, broadcast=True) verify_binary_ops("Mul", x, y, x * y, broadcast=None) verify_binary_ops("Mul", x, z, x * z, broadcast=True) + verify_binary_ops("Mul", x, x, x * x, x_name='in1', y_name='in1', broadcast=None) verify_binary_ops("Div", x, y, x / y, broadcast=None) verify_binary_ops("Div", x, z, x / z, broadcast=True) verify_binary_ops("Sum", x, y, x + y, broadcast=None)