From cddf5cc862329d237a036be4b1c21ba338329447 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Thu, 25 Feb 2021 19:39:07 -0800 Subject: [PATCH] [torch] Add narrow operator --- python/tvm/relay/frontend/pytorch.py | 14 ++++++++++ tests/python/frontend/pytorch/test_forward.py | 26 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index fdebd2f50e68..a471639da623 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -473,6 +473,19 @@ def slice(self, inputs, input_types): data, begin=begin, end=end, strides=strides, slice_mode="end" ) + def narrow(self, inputs, input_types): + # Inputs are: + # 0 - the tensor to narrow + # 1 - the dimension along which to narrow + # 2 - the starting dimension + # 3 - the distance to the ending dimension + # Lets find the ending dimension + end = self.add(inputs[2:4], input_types[2:4]) + stride = 1 + slice_input = inputs[:3] + [end, stride] + slice_types = input_types + ["int32"] + return self.slice(slice_input, slice_types) + def split(self, inputs, input_types): data = inputs[0] split_size = int(inputs[1]) @@ -2222,6 +2235,7 @@ def create_convert_map(self): "aten::unsqueeze_": self.unsqueeze, "aten::cat": self.concatenate, "aten::slice": self.slice, + "aten::narrow": self.narrow, "aten::split": self.split, "aten::split_with_sizes": self.split_with_sizes, "aten::select": self.select, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 90604751d4f1..aeecfbc5b23e 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1505,6 +1505,31 @@ def forward(self, x): verify_model(SliceWithStride2(), input_data=torch.randn(4, 4)) +@tvm.testing.uses_gpu +def test_forward_narrow(): + torch.set_grad_enabled(False) + input_shape = [3, 3] + + class Narrow1(Module): + def forward(self, *args): + return torch.narrow(args[0], 0, 0, 2) + + class Narrow2(Module): + def forward(self, *args): + return torch.narrow(args[0], 1, 1, 2) + + class Narrow3(Module): + def forward(self, *args): + begin = torch.tensor(2) - torch.tensor(1) + length = torch.tensor(1) * torch.tensor(2) + return torch.narrow(args[0], 1, begin, length) + + input_data = torch.rand(input_shape).float() + verify_model(Narrow1(), input_data=input_data) + verify_model(Narrow2(), input_data=input_data) + verify_model(Narrow3(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_mean(): torch.set_grad_enabled(False) @@ -3758,6 +3783,7 @@ def test_fn(is_sorted, return_inverse, return_counts): test_forward_avgpool3d() test_forward_dropout() test_forward_slice() + test_forward_narrow() test_forward_mean() test_forward_expand() test_forward_pow()