diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 71554a8a5bab..aa8150fac71c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1193,6 +1193,12 @@ def _fill(self, node: fx.Node) -> relax.Var: value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + def _index_select(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = self.env[node.args[2]] + return self.block_builder.emit(relax.op.take(x, index, dim)) + def _new_ones(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) self_var = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0f1dc11787da..5f6c4c902401 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -389,7 +389,6 @@ def create_convert_map( "reshape.default": self._reshape, # tensor creation "_to_copy.default": self._to_copy, - "lift_fresh_copy.default": self._to_copy, "detach.default": self._detach, "detach_.default": self._detach, "arange.start": self._arange, @@ -398,6 +397,8 @@ def create_convert_map( "empty.memory_format": self._empty, "empty_like.default": self._empty_like, "fill.Scalar": self._fill, + "index_select.default": self._index_select, + "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, # other diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 022a7bffea80..99cde790d63e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -485,12 +485,6 @@ def _full(self, node: fx.Node) -> relax.Var: ) ) - def _index_select(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] - index = self.env[node.args[2]] - return self.block_builder.emit(relax.op.take(x, index, dim)) - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 64babdc43a5c..19b8f80a2390 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -467,5 +467,17 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_index_select(target, dev): + class IndexSelectModel(nn.Module): + def forward(self, x): + indices = torch.tensor([0, 2]) + return torch.index_select(x, 0, indices) + + raw_data = np.random.rand(3, 4).astype("float32") + torch_module = IndexSelectModel().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + if __name__ == "__main__": tvm.testing.main()