-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for aten::index operator #2432
base: devel
Are you sure you want to change the base?
Conversation
What's the status of this PR? CC @jjsjann123 |
Do we actually support index with backend?!?! I remember we went over the idea on GNN meeting and it's the actual codegen part where we are lacking. (frontend list of tensor as inputs is also needed of course.) asking @naoyam for codegen support. |
Generally no, we don't support |
That sounds good. I think the example included is also indicating a single tensor indexing. Looks like we are using python API here. Are we doing this via dynamo/primtorch or TorchScript? We can plumb it through python API, but frankly speaking I'm less comfortable having a "limited support" thing there, which makes it hard to get a robust runtime. So for |
We would likely need to do this anyway as supporting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even with the suggested changes, we still need to somehow wire Tensor[]
in our integration.
We can do the proper thing and plumb it through, which might take some time.
Alternatively, we can fake it by converting list of tensors to a single tensor and work with that... How urgent do we need this working? @ftxj
list_val.pop_front(); | ||
auto index = list_val.front(); | ||
list_val.pop_front(); | ||
Val* out = index_select( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome. Looks like we are already parsing it as index_select
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I didn't make restrictions on the situation that cannot be converted...
value_map[node->inputs()[1]->unique()]); | ||
auto input = list_val.front(); | ||
list_val.pop_front(); | ||
auto index = list_val.front(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index
here is a list of tensors. so we need to access the first and only item instead of the whole thing.
} | ||
return true; | ||
}, | ||
nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should return a gatherOp type here.
if (tensor_type->dim() == 0u) { | ||
return false; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should also return false when second input length is > 1.
"index", | ||
[](nvfuser::FusionDefinition::Operators& self, | ||
nvfuser::Tensor arg, | ||
nvfuser::Tensor index) -> nvfuser::Tensor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This signature doesn't look right. It should be a list of Tensors instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This signature isn't right. I don't know how to declare the list of Tensor type...
t1 = fd.from_pytorch(inputs[1]) | ||
t2 = fd.from_pytorch(inputs[2]) | ||
t3 = fd.ops.add(t0, t1) | ||
t4 = fd.ops.index(t3, t2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This really isn't the expected index signature?
Support |
There are many Tensor indexing API in PyTorch. We already support the
index_select
operator, but we doesn't supportindex
operator, which is supported byInductor + Triton
and widely used in GNN. So we need to add this operator.Although we can share most of logic with index_select, there are some difficulties. The
index
signature isaten::index.Tensor(Tensor self, Tensor?[] indices)
which useList
data type. But nvfuser doesn't support it. So the TorchScript test will raise an error until we supportList[T]
.