From a8a675d86eb8cf44b56f23b72d6561be0cd62500 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Tue, 9 Jun 2020 17:40:06 -0700 Subject: [PATCH] Fix tensor array in pytorch frontend --- python/tvm/relay/frontend/pytorch.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index e74d58efeaec8..d5da24e8447cf 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -211,12 +211,12 @@ def tensor_array_concat(lst, axis): assert axis == 0, "Tensor array concat supported only for axis 0" tensor_array, shape = _convert_to_tensor_array(lst, prelude) concat_shape = (Any(),) + shape[1:] - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) - static_tensor_array_ops.define_tensor_get_data(concat_shape) - concat = prelude.get_var_static('tensor_array_concat', "float32", shape) concatenated = concat(tensor_array) - get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) + + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape) + static_tensor_array_ops.register() + get_tensor = prelude.get_var_static('tensor_get_data', "float32", concat_shape) return get_tensor(concatenated) def _impl(inputs, input_types): @@ -1610,14 +1610,14 @@ def _impl(inputs, input_types): def _tensor_array_stack(prelude): def _impl(inputs, input_types): tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude) + + stacked_shape = (Any(),) + shape stack = prelude.get_var_static('tensor_array_stack', "float32", shape) stacked = stack(tensor_array) - stacked_shape = (Any(),) + shape - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) - static_tensor_array_ops.define_tensor_get_data(stacked_shape) - # passing stacked_shape below gives "'Prelude' object has no attribute" error - get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape) + static_tensor_array_ops.register() + get_tensor = prelude.get_var_static('tensor_get_data', "float32", stacked_shape) return get_tensor(stacked) return _impl