Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for aten::index operator #2432

Open
wants to merge 2 commits into
base: devel
Choose a base branch
from
Open

Conversation

ftxj
Copy link

@ftxj ftxj commented Feb 8, 2023

There are many Tensor indexing API in PyTorch. We already support the index_select operator, but we doesn't support index operator, which is supported by Inductor + Triton and widely used in GNN. So we need to add this operator.

Although we can share most of logic with index_select, there are some difficulties. The index signature is aten::index.Tensor(Tensor self, Tensor?[] indices) which use List data type. But nvfuser doesn't support it. So the TorchScript test will raise an error until we support List[T].

  • Support List[T] type in nvfuser.

@csarofeen
Copy link
Owner

What's the status of this PR? CC @jjsjann123

@jjsjann123
Copy link
Collaborator

Do we actually support index with backend?!?! I remember we went over the idea on GNN meeting and it's the actual codegen part where we are lacking. (frontend list of tensor as inputs is also needed of course.)

asking @naoyam for codegen support.

@naoyam
Copy link
Collaborator

naoyam commented Feb 27, 2023

Generally no, we don't support aten::index. However, I think what @ftxj is trying to do is to add a limited partial support of aten::index only when it can be converted to index_select, which is supported with limitations. It seems that's common in GNNs, so I guess it probably makes sense.

@jjsjann123
Copy link
Collaborator

That sounds good. I think the example included is also indicating a single tensor indexing.

Looks like we are using python API here. Are we doing this via dynamo/primtorch or TorchScript? We can plumb it through python API, but frankly speaking I'm less comfortable having a "limited support" thing there, which makes it hard to get a robust runtime.

So for dynamo/primtorch, we can probably hack it via a decomposition to convert compatible index to index_select. Need to be careful with not going into an infinite recursion with the dispatch.
For TorchScript, I think we can hack it in parser.cpp. We can reject index op when len(indices) > 1 and parser index as index_select.

@naoyam
Copy link
Collaborator

naoyam commented Feb 27, 2023

We would likely need to do this anyway as supporting aten::index completely would be quite challenging. I'm also working on limited support of slice, which is also part of aten::index.

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even with the suggested changes, we still need to somehow wire Tensor[] in our integration.

We can do the proper thing and plumb it through, which might take some time.
Alternatively, we can fake it by converting list of tensors to a single tensor and work with that... How urgent do we need this working? @ftxj

list_val.pop_front();
auto index = list_val.front();
list_val.pop_front();
Val* out = index_select(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome. Looks like we are already parsing it as index_select.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I didn't make restrictions on the situation that cannot be converted...

value_map[node->inputs()[1]->unique()]);
auto input = list_val.front();
list_val.pop_front();
auto index = list_val.front();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index here is a list of tensors. so we need to access the first and only item instead of the whole thing.

}
return true;
},
nullptr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should return a gatherOp type here.

if (tensor_type->dim() == 0u) {
return false;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also return false when second input length is > 1.

"index",
[](nvfuser::FusionDefinition::Operators& self,
nvfuser::Tensor arg,
nvfuser::Tensor index) -> nvfuser::Tensor {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature doesn't look right. It should be a list of Tensors instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature isn't right. I don't know how to declare the list of Tensor type...

t1 = fd.from_pytorch(inputs[1])
t2 = fd.from_pytorch(inputs[2])
t3 = fd.ops.add(t0, t1)
t4 = fd.ops.index(t3, t2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This really isn't the expected index signature?

@ftxj
Copy link
Author

ftxj commented Feb 28, 2023

Even with the suggested changes, we still need to somehow wire Tensor[] in our integration.

We can do the proper thing and plumb it through, which might take some time. Alternatively, we can fake it by converting list of tensors to a single tensor and work with that... How urgent do we need this working? @ftxj

Support aten::index is not urgent at present, I can replace the aten::index using aten::index_select in the python level.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants