From 59104e7d10fa101ef68f11ad2bdfef831da0f4b8 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 18 Aug 2020 23:06:33 +0900 Subject: [PATCH] [Torch] Support index_select (#6295) * support index select * minor fix Co-authored-by: masa --- python/tvm/relay/frontend/pytorch.py | 1 + tests/python/frontend/pytorch/test_forward.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a1cabcd5ae22..235cec0f096d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2056,6 +2056,7 @@ def _get_convert_map(prelude): "aten::len" : _list_len(prelude), "aten::type_as" : _type_as(), "aten::gather" : _gather(), + "aten::index_select" : _select(), } return convert_map diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 88203f560641..2302f0fdb74a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -973,6 +973,7 @@ def forward(self, *args): verify_model(View2().float().eval(), input_data=input_data) verify_model(View3().float().eval(), input_data=input_data) + def test_forward_select(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -981,9 +982,26 @@ class Select1(Module): def forward(self, *args): return args[0].select(1, 1) + class IndexedSelect(Module): + def __init__(self, inp, dim): + super().__init__() + self.inp = inp + self.dim = dim + if torch.cuda.is_available(): + self.inp = self.inp.cuda() + + def forward(self, index): + return torch.index_select(self.inp, self.dim, index) + input_data = torch.rand(input_shape).float() verify_model(Select1().float().eval(), input_data=input_data) + x = torch.randn(3, 4) + indices = torch.tensor([0, 2]) + verify_model(IndexedSelect(x, 0).eval(), input_data=indices) + verify_model(IndexedSelect(x, 1).eval(), input_data=indices) + + def test_forward_clone(): torch.set_grad_enabled(False) input_shape = [10]