diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 2decc2180e48..7c1d34f3fd2c 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -626,8 +626,14 @@ def _impl(inputs, attr, params): def _slice(): def _impl(inputs, attr, params): - begin = _get_list_param(params, inputs[1]) - size = _get_list_param(params, inputs[2]) + try: + begin = _get_list_param(params, inputs[1]) + except (IndexError, KeyError, AttributeError): + begin = _infer_value(inputs[1], params).asnumpy().tolist()[0] + try: + size = _get_list_param(params, inputs[2]) + except (IndexError, KeyError, AttributeError): + size = _infer_value(inputs[2], params).asnumpy().tolist()[0] data_shape = attr['_input_shapes'][inputs[0]] data_dim = len(data_shape) end = size diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index db19ed4e851d..4ec8abdfb336 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2188,6 +2188,20 @@ def test_forward_transpose(): _test_forward_tranapose_axes_input((2, 3, 4, 5), (3, 0, 1, 2)) +def _test_forward_slice_operation_input(input_value, begin_value, size_value): + input_data = np.array(input_value, dtype=np.float32) + with tf.Graph().as_default(): + input_tensor = tf.placeholder( + shape=input_data.shape, dtype=input_data.dtype, name="input") + begin_tensor = tf.expand_dims(begin_value, axis=0) + size_tensor = tf.expand_dims(size_value, axis=0) + slice_tensor = tf.slice(input_tensor, begin_tensor, size_tensor, name='slice_output') + compare_tf_with_tvm([input_data], ['input:0'], 'slice_output:0') + + +def test_forward_slice(): + _test_forward_slice_operation_input([1, 1], 0, 2) + def test_forward_ceil(): ishape = (1, 3, 10, 10) inp_array = np.random.uniform(size=ishape).astype(np.float32) @@ -2760,8 +2774,8 @@ def test_forward_add_n(): # Main # ---- if __name__ == '__main__': - # Transforms + test_forward_slice() test_forward_transpose() test_forward_reshape() test_forward_depthtospace()