Skip to content

Commit

Permalink
Fix tensor array in pytorch frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun committed Jun 11, 2020
1 parent 70c245b commit 0644343
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,12 @@ def tensor_array_concat(lst, axis):
assert axis == 0, "Tensor array concat supported only for axis 0"
tensor_array, shape = _convert_to_tensor_array(lst, prelude)
concat_shape = (Any(),) + shape[1:]
static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
static_tensor_array_ops.define_tensor_get_data(concat_shape)

concat = prelude.get_var_static('tensor_array_concat', "float32", shape)
concatenated = concat(tensor_array)
get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)

static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape)
static_tensor_array_ops.register()
get_tensor = prelude.get_var_static('tensor_get_data', "float32", concat_shape)
return get_tensor(concatenated)

def _impl(inputs, input_types):
Expand Down Expand Up @@ -1610,14 +1610,14 @@ def _impl(inputs, input_types):
def _tensor_array_stack(prelude):
def _impl(inputs, input_types):
tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude)

stacked_shape = (Any(),) + shape
stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
stacked = stack(tensor_array)

stacked_shape = (Any(),) + shape
static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
static_tensor_array_ops.define_tensor_get_data(stacked_shape)
# passing stacked_shape below gives "'Prelude' object has no attribute" error
get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)
static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape)
static_tensor_array_ops.register()
get_tensor = prelude.get_var_static('tensor_get_data', "float32", stacked_shape)
return get_tensor(stacked)
return _impl

Expand Down

0 comments on commit 0644343

Please sign in to comment.