Skip to content

Commit

Permalink
[Torch] Support index_select (apache#6295)
Browse files Browse the repository at this point in the history
* support index select

* minor fix

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
2 people authored and trevor-m committed Sep 3, 2020
1 parent 5d875ca commit 59104e7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2056,6 +2056,7 @@ def _get_convert_map(prelude):
"aten::len" : _list_len(prelude),
"aten::type_as" : _type_as(),
"aten::gather" : _gather(),
"aten::index_select" : _select(),
}
return convert_map

Expand Down
18 changes: 18 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ def forward(self, *args):
verify_model(View2().float().eval(), input_data=input_data)
verify_model(View3().float().eval(), input_data=input_data)


def test_forward_select():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand All @@ -981,9 +982,26 @@ class Select1(Module):
def forward(self, *args):
return args[0].select(1, 1)

class IndexedSelect(Module):
def __init__(self, inp, dim):
super().__init__()
self.inp = inp
self.dim = dim
if torch.cuda.is_available():
self.inp = self.inp.cuda()

def forward(self, index):
return torch.index_select(self.inp, self.dim, index)

input_data = torch.rand(input_shape).float()
verify_model(Select1().float().eval(), input_data=input_data)

x = torch.randn(3, 4)
indices = torch.tensor([0, 2])
verify_model(IndexedSelect(x, 0).eval(), input_data=indices)
verify_model(IndexedSelect(x, 1).eval(), input_data=indices)


def test_forward_clone():
torch.set_grad_enabled(False)
input_shape = [10]
Expand Down

0 comments on commit 59104e7

Please sign in to comment.