Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for aten::index operator #2432

Open
wants to merge 2 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions third_party/nvfuser/csrc/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1560,6 +1560,47 @@ class IrParser {
nullptr);
}
}
{
if (isOptionEnabled(EnableOption::GraphOp)) {
auto ptr_op = getOperatorForLiteral(
"aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getPWFormatValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
auto input = list_val.front();
list_val.pop_front();
auto index = list_val.front();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index here is a list of tensors. so we need to access the first and only item instead of the whole thing.

list_val.pop_front();
Val* out = index_select(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome. Looks like we are already parsing it as index_select.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I didn't make restrictions on the situation that cannot be converted...

input->as<TensorView>(), 0, index->as<TensorView>());
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
[](const Node* node) -> bool {
if (auto tensor_type =
node->inputs()[0]->type()->cast<TensorType>()) {
// index_select doesn't support 0-dim tensors
if (tensor_type->dim() == 0u) {
return false;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also return false when second input length is > 1.

for (const auto& val : node->inputs()) {
auto tensor_type = val->type()->cast<TensorType>();
if (tensor_type && is_zero_sized_tensor(tensor_type)) {
return false;
}
}
return true;
},
nullptr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should return a gatherOp type here.

}
}
{
if (isOptionEnabled(EnableOption::GraphOp)) {
auto ptr_op = getOperatorForLiteral(
Expand Down
20 changes: 20 additions & 0 deletions third_party/nvfuser/csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,26 @@ void initNvFuserPythonBindings(PyObject* module) {
py::arg("index"),
py::arg("dim"),
py::return_value_policy::reference);
nvf_ops.def(
"index",
[](nvfuser::FusionDefinition::Operators& self,
nvfuser::Tensor arg,
nvfuser::Tensor index) -> nvfuser::Tensor {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature doesn't look right. It should be a list of Tensors instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature isn't right. I don't know how to declare the list of Tensor type...

FUSER_PERF_SCOPE("Operators.index");
nvfuser::FusionDefinition* fd = self.fusion_definition;
nvfuser::Tensor output = fd->defineTensor(arg.dims);
fd->defineRecord(new nvfuser::IndexSelectOpRecord(
{
fd->recordingState(arg()),
fd->recordingState(index()),
},
{fd->recordingState(output())},
0));
return output;
},
py::arg("arg"),
py::arg("index"),
py::return_value_policy::reference);
nvf_ops.def(
"gather",
[](nvfuser::FusionDefinition::Operators& self,
Expand Down
1 change: 1 addition & 0 deletions third_party/nvfuser/csrc/type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ class NaiveTypePropagator {
case aten::transpose_copy:
case aten::unsqueeze_copy:
case aten::index_select:
case aten::index:
case aten::gather:
case aten::view_copy: {
auto out_type = node->input(0)->type()->cast<TensorType>();
Expand Down
20 changes: 20 additions & 0 deletions third_party/nvfuser/python_tests/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,26 @@ def fusion_func(fd: FusionDefinition) :
eager_out = torch.gather(inputs[0] + inputs[1], 0, inputs[2])
self.assertEqual(eager_out, nvf_out[0])

def test_index(self):
inputs = [
torch.randn(8, 16, device='cuda'),
torch.randn(8, 16, device='cuda'),
torch.randint(0, 8, (4, ), device="cuda").to(dtype=torch.long)
]

def fusion_func(fd: FusionDefinition) :
t0 = fd.from_pytorch(inputs[0])
t1 = fd.from_pytorch(inputs[1])
t2 = fd.from_pytorch(inputs[2])
t3 = fd.ops.add(t0, t1)
t4 = fd.ops.index(t3, t2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This really isn't the expected index signature?

fd.add_output(t4)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
tmp = inputs[0] + inputs[1]
eager_out = tmp[inputs[2]]
self.assertEqual(eager_out, nvf_out[0])

def test_index_select(self):
inputs = [
torch.randn(8, 16, device='cuda'),
Expand Down
16 changes: 16 additions & 0 deletions third_party/nvfuser/python_tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -4201,6 +4201,22 @@ def t(x: torch.Tensor, y: torch.Tensor, ind: torch.Tensor):
t_jit = torch.jit.script(t)
self._run_training_helper(t_jit, t, grad, x, y, ind)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_index_function(self):
def t(x: torch.Tensor, y: torch.Tensor, ind: torch.Tensor):
o = torch.mul(x, y)
o = o[ind]
return o

x = torch.randn([68, 128], dtype=torch.float, device="cuda")
y = torch.randn_like(x)
ind = torch.randint(0, 68, (130,), device="cuda").to(dtype=torch.int)

t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y, ind)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
Expand Down