diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3af1051344ef..d95a9122ad59 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1245,27 +1245,36 @@ def _impl(inputs, input_types): return _op.nn.dense(data0, data1_t) return _impl + def _expand(): def _impl(inputs, input_types): data_in = inputs[0] if isinstance(data_in, _expr.Expr): - shape = _infer_shape(data_in) + shape = list(_infer_shape(data_in)) ndims = len(shape) sizes = _infer_shape(inputs[1]) out = inputs[0] + out_dims = len(sizes) + if ndims < out_dims: + num_newaxis = out_dims - ndims + out = _op.expand_dims(out, axis=0, num_newaxis=num_newaxis) + shape = [1] * num_newaxis + shape + for i in range(ndims): - if sizes[i] in {-1, shape[i]}: + if sizes[i] == -1 or sizes[i] == shape[i]: continue data = list() for temp in range(sizes[i]): data.append(out) - call = _op.tensor.concatenate(data, i) - return call + out = _op.tensor.concatenate(data, i) + + return out return _impl + def _int(): def _impl(inputs, input_types): if isinstance(inputs[0], _expr.Expr): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e1c276bb95f6..82a027f45c34 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -902,15 +902,24 @@ def forward(self, *args): def test_forward_expand(): torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] class Expand1(Module): def forward(self, *args): return args[0].expand((3, -1, -1, -1)) + input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() verify_model(Expand1().float().eval(), input_data=input_data) + class Expand2(Module): + def forward(self, *args): + return args[0].expand((3, 3, 3, 1)) + + input_shape = [3, 1] + input_data = torch.rand(input_shape).float() + verify_model(Expand2().float().eval(), input_data=input_data) + + def test_forward_pow(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10]