From aaaea9560dabaddec19ed0a7ba76315d441183a6 Mon Sep 17 00:00:00 2001 From: Klas Magnusson Date: Wed, 12 Apr 2023 17:55:30 +0200 Subject: [PATCH] Fix None shape error when one input to ConcatV2 has a shape. (#2135) * Fix None shape error when one input to ConcatV2 has a shape. I tried to convert a network where one of the inputs to a concatenation along dimension -1 had shape None. The other input did however have a shape. The conversion failed because the code only looked att the shape of the first input to determine what positive axis value to use in the concatenation. If the order of the inputs had been reversed, the conversion would have worked. I have now changed the code to look at the shapes of both input nodes. With the new code, I can convert the network. I have also verified that the resulting onnx-file works. Signed-off-by: Klas Magnusson --------- Signed-off-by: Klas Magnusson --- tests/test_backend.py | 17 +++++++++++++++++ tf2onnx/onnx_opset/tensor.py | 9 +++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/test_backend.py b/tests/test_backend.py index 1060574c1..225f9461a 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1685,6 +1685,23 @@ def func(x1, x2, x3): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, "input3:0": x_val3}) + def test_concat_negative_axis_none_shape(self): + x_val = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32).reshape((2, 3)) + y_val = np.array([7.0, 8.0, 9.0, 10.0, 11.0, 12.0], dtype=np.float32).reshape((2, 3)) + s1_val = np.array([1, 1], dtype=np.int32) + s2_val = np.array([1, 1], dtype=np.int32) + def func(): + x = tf_placeholder(tf.float32, [2, 3], name=_TFINPUT) + y = tf_placeholder(tf.float32, [2, 3], name=_TFINPUT1) + s1 = tf_placeholder(tf.int32, [2], name="input3") + s2 = tf_placeholder(tf.int32, [2], name="input4") + s = tf.add(s1, s2) + x_with_none_shape = tf.slice(x, [0, 0], s) + t = tf.concat([x_with_none_shape, y], -1) + return tf.identity(t, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, "input3:0": s1_val, "input4:0": s2_val}, + as_session=True, premade_placeholders=True) + def test_concat_const_string(self): x_val1 = np.array([["Hello world", "abc"], ["def", "♦♥♠♣"]], dtype=str) const_val = np.array([["Hello there", "wxyz"], ["", "π"]], dtype=str) diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index ca4daf0d1..bb9b0cb52 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -297,8 +297,13 @@ def version_1(cls, ctx, node, **kwargs): ctx.remove_input(node, node.input[-1], len(node.input) - 1) if axis_val < 0: # onnxruntime does not support -1 axis, but TF supports. - input_shape = ctx.get_shape(node.input[0]) - utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0])) + input_shape = None + for node_input in node.input: + input_shape = ctx.get_shape(node_input) + if input_shape is not None: + break + utils.make_sure(input_shape is not None, + "the shapes of the following inputs are None: {}".format(', '.join(node.input))) axis_val = len(input_shape) + axis_val node.set_attr("axis", axis_val)