Skip to content

Commit

Permalink
[torch] Add narrow operator
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Feb 26, 2021
1 parent 09b0c8e commit fe931bc
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
14 changes: 14 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,19 @@ def slice(self, inputs, input_types):
data, begin=begin, end=end, strides=strides, slice_mode="end"
)

def narrow(self, inputs, input_types):
# Inputs are:
# 0 - the tensor to narrow
# 1 - the dimension along which to narrow
# 2 - the starting dimension
# 3 - the distance to the ending dimension
# Lets find the ending dimension
end = self.add(inputs[2:4], input_types[2:4])
stride = 1
slice_input = inputs[:3] + [end, stride]
slice_types = input_types + ["int32"]
return self.slice(slice_input, slice_types)

def split(self, inputs, input_types):
data = inputs[0]
split_size = int(inputs[1])
Expand Down Expand Up @@ -2214,6 +2227,7 @@ def create_convert_map(self):
"aten::unsqueeze_": self.unsqueeze,
"aten::cat": self.concatenate,
"aten::slice": self.slice,
"aten::narrow": self.narrow,
"aten::split": self.split,
"aten::split_with_sizes": self.split_with_sizes,
"aten::select": self.select,
Expand Down
26 changes: 26 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,31 @@ def forward(self, x):
verify_model(SliceWithStride2(), input_data=torch.randn(4, 4))


@tvm.testing.uses_gpu
def test_forward_narrow():
torch.set_grad_enabled(False)
input_shape = [3, 3]

class Narrow1(Module):
def forward(self, *args):
return torch.narrow(args[0], 0, 0, 2)

class Narrow2(Module):
def forward(self, *args):
return torch.narrow(args[0], 1, 1, 2)

class Narrow3(Module):
def forward(self, *args):
begin = torch.tensor(2) - torch.tensor(1)
length = torch.tensor(1) * torch.tensor(2)
return torch.narrow(args[0], 1, begin, length)

input_data = torch.rand(input_shape).float()
verify_model(Narrow1(), input_data=input_data)
verify_model(Narrow2(), input_data=input_data)
verify_model(Narrow3(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_mean():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -3749,6 +3774,7 @@ def test_fn(is_sorted, return_inverse, return_counts):
test_forward_avgpool3d()
test_forward_dropout()
test_forward_slice()
test_forward_narrow()
test_forward_mean()
test_forward_expand()
test_forward_pow()
Expand Down

0 comments on commit fe931bc

Please sign in to comment.