Skip to content

Commit

Permalink
Fix None shape error when one input to ConcatV2 has a shape. (#2135)
Browse files Browse the repository at this point in the history
* Fix None shape error when one input to ConcatV2 has a shape.

I tried to convert a network where one of the inputs to a concatenation
along dimension -1 had shape None. The other input did however have a
shape. The conversion failed because the code only looked att the shape of
the first input to determine what positive axis value to use in the
concatenation. If the order of the inputs had been reversed, the conversion
would have worked.

I have now changed the code to look at the shapes of both input nodes. With
the new code, I can convert the network. I have also verified that the
resulting onnx-file works.

Signed-off-by: Klas Magnusson <klamag@raysearchlabs.com>
---------

Signed-off-by: Klas Magnusson <klamag@raysearchlabs.com>
  • Loading branch information
klasma authored Apr 12, 2023
1 parent 8f8d49a commit aaaea95
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
17 changes: 17 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,23 @@ def func(x1, x2, x3):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, "input3:0": x_val3})

def test_concat_negative_axis_none_shape(self):
x_val = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32).reshape((2, 3))
y_val = np.array([7.0, 8.0, 9.0, 10.0, 11.0, 12.0], dtype=np.float32).reshape((2, 3))
s1_val = np.array([1, 1], dtype=np.int32)
s2_val = np.array([1, 1], dtype=np.int32)
def func():
x = tf_placeholder(tf.float32, [2, 3], name=_TFINPUT)
y = tf_placeholder(tf.float32, [2, 3], name=_TFINPUT1)
s1 = tf_placeholder(tf.int32, [2], name="input3")
s2 = tf_placeholder(tf.int32, [2], name="input4")
s = tf.add(s1, s2)
x_with_none_shape = tf.slice(x, [0, 0], s)
t = tf.concat([x_with_none_shape, y], -1)
return tf.identity(t, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, "input3:0": s1_val, "input4:0": s2_val},
as_session=True, premade_placeholders=True)

def test_concat_const_string(self):
x_val1 = np.array([["Hello world", "abc"], ["def", "♦♥♠♣"]], dtype=str)
const_val = np.array([["Hello there", "wxyz"], ["", "π"]], dtype=str)
Expand Down
9 changes: 7 additions & 2 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,13 @@ def version_1(cls, ctx, node, **kwargs):
ctx.remove_input(node, node.input[-1], len(node.input) - 1)

if axis_val < 0: # onnxruntime does not support -1 axis, but TF supports.
input_shape = ctx.get_shape(node.input[0])
utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0]))
input_shape = None
for node_input in node.input:
input_shape = ctx.get_shape(node_input)
if input_shape is not None:
break
utils.make_sure(input_shape is not None,
"the shapes of the following inputs are None: {}".format(', '.join(node.input)))
axis_val = len(input_shape) + axis_val
node.set_attr("axis", axis_val)

Expand Down

0 comments on commit aaaea95

Please sign in to comment.