diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ca05954227f88..991e3a8a00320 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2117,6 +2117,7 @@ def create_convert_map(self): "aten::to": self.to, "aten::squeeze": self.squeeze, "aten::unsqueeze": self.unsqueeze, + "aten::unsqueeze_": self.unsqueeze, "aten::cat": self.concatenate, "aten::slice": self.slice, "aten::split": self.split, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index f76c697a2c810..7cdd450448cac 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -447,8 +447,16 @@ class Unsqueeze1(Module): def forward(self, *args): return args[0].unsqueeze(2) + class Unsqueeze2(Module): + def forward(self, *args): + _ = args[0].unsqueeze_(2) + # Check whether operations after inplace unsqueeze works as expected + y = args[0].squeeze(2) + return torch.add(y, y) + input_data = torch.rand(input_shape).float() verify_model(Unsqueeze1().float().eval(), input_data=input_data) + verify_model(Unsqueeze2().float().eval(), input_data=input_data) @tvm.testing.uses_gpu