diff --git a/third_party/nvfuser/csrc/parser.cpp b/third_party/nvfuser/csrc/parser.cpp index f375742077330..e4e22287e24cf 100644 --- a/third_party/nvfuser/csrc/parser.cpp +++ b/third_party/nvfuser/csrc/parser.cpp @@ -1560,6 +1560,47 @@ class IrParser { nullptr); } } + { + if (isOptionEnabled(EnableOption::GraphOp)) { + auto ptr_op = getOperatorForLiteral( + "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"); + REGISTER_PARSE_RULE( + ptr_op, + { + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getPWFormatValues( + MemoryFormat::Contiguous(), + value_map[node->inputs()[0]->unique()], + value_map[node->inputs()[1]->unique()]); + auto input = list_val.front(); + list_val.pop_front(); + auto index = list_val.front(); + list_val.pop_front(); + Val* out = index_select( + input->as(), 0, index->as()); + value_map.emplace( + node->output()->unique(), ValueHolder(out, format)); + }, + [](const Node* node) -> bool { + if (auto tensor_type = + node->inputs()[0]->type()->cast()) { + // index_select doesn't support 0-dim tensors + if (tensor_type->dim() == 0u) { + return false; + } + } + for (const auto& val : node->inputs()) { + auto tensor_type = val->type()->cast(); + if (tensor_type && is_zero_sized_tensor(tensor_type)) { + return false; + } + } + return true; + }, + nullptr); + } + } { if (isOptionEnabled(EnableOption::GraphOp)) { auto ptr_op = getOperatorForLiteral( diff --git a/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp b/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp index 1635648d25074..968ea992adf32 100644 --- a/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp +++ b/third_party/nvfuser/csrc/python_frontend/python_bindings.cpp @@ -1346,6 +1346,26 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("index"), py::arg("dim"), py::return_value_policy::reference); + nvf_ops.def( + "index", + [](nvfuser::FusionDefinition::Operators& self, + nvfuser::Tensor arg, + nvfuser::Tensor index) -> nvfuser::Tensor { + FUSER_PERF_SCOPE("Operators.index"); + nvfuser::FusionDefinition* fd = self.fusion_definition; + nvfuser::Tensor output = fd->defineTensor(arg.dims); + fd->defineRecord(new nvfuser::IndexSelectOpRecord( + { + fd->recordingState(arg()), + fd->recordingState(index()), + }, + {fd->recordingState(output())}, + 0)); + return output; + }, + py::arg("arg"), + py::arg("index"), + py::return_value_policy::reference); nvf_ops.def( "gather", [](nvfuser::FusionDefinition::Operators& self, diff --git a/third_party/nvfuser/csrc/type_inference.cpp b/third_party/nvfuser/csrc/type_inference.cpp index 44d3a02d8308e..29798381119d9 100644 --- a/third_party/nvfuser/csrc/type_inference.cpp +++ b/third_party/nvfuser/csrc/type_inference.cpp @@ -459,6 +459,7 @@ class NaiveTypePropagator { case aten::transpose_copy: case aten::unsqueeze_copy: case aten::index_select: + case aten::index: case aten::gather: case aten::view_copy: { auto out_type = node->input(0)->type()->cast(); diff --git a/third_party/nvfuser/python_tests/test_python_frontend.py b/third_party/nvfuser/python_tests/test_python_frontend.py index d5f8149554cf7..599afbdbbddc3 100644 --- a/third_party/nvfuser/python_tests/test_python_frontend.py +++ b/third_party/nvfuser/python_tests/test_python_frontend.py @@ -628,6 +628,26 @@ def fusion_func(fd: FusionDefinition) : eager_out = torch.gather(inputs[0] + inputs[1], 0, inputs[2]) self.assertEqual(eager_out, nvf_out[0]) + def test_index(self): + inputs = [ + torch.randn(8, 16, device='cuda'), + torch.randn(8, 16, device='cuda'), + torch.randint(0, 8, (4, ), device="cuda").to(dtype=torch.long) + ] + + def fusion_func(fd: FusionDefinition) : + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.from_pytorch(inputs[1]) + t2 = fd.from_pytorch(inputs[2]) + t3 = fd.ops.add(t0, t1) + t4 = fd.ops.index(t3, t2) + fd.add_output(t4) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + tmp = inputs[0] + inputs[1] + eager_out = tmp[inputs[2]] + self.assertEqual(eager_out, nvf_out[0]) + def test_index_select(self): inputs = [ torch.randn(8, 16, device='cuda'), diff --git a/third_party/nvfuser/python_tests/test_torchscript.py b/third_party/nvfuser/python_tests/test_torchscript.py index db3435c3f4cb1..477a6bc65bb47 100644 --- a/third_party/nvfuser/python_tests/test_torchscript.py +++ b/third_party/nvfuser/python_tests/test_torchscript.py @@ -4201,6 +4201,22 @@ def t(x: torch.Tensor, y: torch.Tensor, ind: torch.Tensor): t_jit = torch.jit.script(t) self._run_training_helper(t_jit, t, grad, x, y, ind) + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_index_function(self): + def t(x: torch.Tensor, y: torch.Tensor, ind: torch.Tensor): + o = torch.mul(x, y) + o = o[ind] + return o + + x = torch.randn([68, 128], dtype=torch.float, device="cuda") + y = torch.randn_like(x) + ind = torch.randint(0, 68, (130,), device="cuda").to(dtype=torch.int) + + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, y, ind) + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective")