Skip to content

Commit

Permalink
[PYTORCH]expand bug fix (apache#5576)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored May 13, 2020
1 parent ca15a83 commit 8eb6584
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
17 changes: 13 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,27 +1245,36 @@ def _impl(inputs, input_types):
return _op.nn.dense(data0, data1_t)
return _impl


def _expand():
def _impl(inputs, input_types):
data_in = inputs[0]
if isinstance(data_in, _expr.Expr):
shape = _infer_shape(data_in)
shape = list(_infer_shape(data_in))

ndims = len(shape)
sizes = _infer_shape(inputs[1])
out = inputs[0]

out_dims = len(sizes)
if ndims < out_dims:
num_newaxis = out_dims - ndims
out = _op.expand_dims(out, axis=0, num_newaxis=num_newaxis)
shape = [1] * num_newaxis + shape

for i in range(ndims):
if sizes[i] in {-1, shape[i]}:
if sizes[i] == -1 or sizes[i] == shape[i]:
continue
data = list()
for temp in range(sizes[i]):
data.append(out)
call = _op.tensor.concatenate(data, i)

return call
out = _op.tensor.concatenate(data, i)

return out
return _impl


def _int():
def _impl(inputs, input_types):
if isinstance(inputs[0], _expr.Expr):
Expand Down
11 changes: 10 additions & 1 deletion tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,15 +902,24 @@ def forward(self, *args):

def test_forward_expand():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Expand1(Module):
def forward(self, *args):
return args[0].expand((3, -1, -1, -1))

input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(Expand1().float().eval(), input_data=input_data)

class Expand2(Module):
def forward(self, *args):
return args[0].expand((3, 3, 3, 1))

input_shape = [3, 1]
input_data = torch.rand(input_shape).float()
verify_model(Expand2().float().eval(), input_data=input_data)


def test_forward_pow():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down

0 comments on commit 8eb6584

Please sign in to comment.