diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 35c857a3d77f..e44653ff1ba9 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1039,8 +1039,8 @@ def _impl(inputs, attr, params): # otherwise its value is get from params try: axes = _get_list_param(params, inputs[1]) - except (IndexError, KeyError): - axes = None + except (IndexError, KeyError, AttributeError): + axes = _infer_value_simulated(inputs[1], params).asnumpy() return _op.transpose(inputs[0], axes=axes) return _impl diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 17db2f5cc9a8..e02532fa748b 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2114,6 +2114,22 @@ def _test_forward_transpose(ishape, axes=None): compare_tf_with_tvm(data, 'transpose_data:0', 'transpose:0') +def _test_forward_tranapose_axes_input(ishape, axes): + data = np.random.uniform(size=ishape).astype(np.float32) + axes_np = np.array(axes).astype(np.int32) + + with tf.Graph().as_default(): + in1 = tf.placeholder( + shape=data.shape, dtype=data.dtype, name="transpose_data") + + const1 = tf.constant(axes_np, dtype=tf.int32) + + # make axes an input to tf.transpose, but not an input to the graph, + # so it can be extracted with infer_value_simulated + axes = tf.reverse(const1, axis=[-1]) + tf.transpose(in1, axes) + + compare_tf_with_tvm([data], ['transpose_data:0'], 'transpose:0') def test_forward_transpose(): _test_forward_transpose((2, 3, 4), (1, 2, 0)) @@ -2122,6 +2138,8 @@ def test_forward_transpose(): _test_forward_transpose((2, 3, 4), (1, 2, 0)) _test_forward_transpose((2, 3, 4), (0, 1, 2)) _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2)) + _test_forward_tranapose_axes_input((2, 3, 4), (1, 2, 0)) + _test_forward_tranapose_axes_input((2, 3, 4, 5), (3, 0, 1, 2)) def test_forward_ceil():