-
Notifications
You must be signed in to change notification settings - Fork 7
Add support for index_select OP (frontend implementation) #2184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ent out call index_select func as it is in backend PR
jjsjann123
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks about right.
There seems to be some unrelated changes under torch/csrc/jit/passes/, let's revert those.
| "aten::logit(Tensor self, float? eps=None) -> Tensor", | ||
| "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor", | ||
| "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor", | ||
| "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope you don't want it here. This is legacy fuser not nvfuser.
| case aten::log10: | ||
| case aten::frac: | ||
| case aten::lerp: | ||
| case aten::index_select: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to have index_select here to remove the guard?
| dim_value.has_value(), "dim in index_select is not valid"); | ||
| auto index = list_val.front(); | ||
| list_val.pop_front(); | ||
| Val* out = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: Can we just put an assert(false) here with a not implemented? I suspect this will give us a segfault that's going to be confusing if index_select is somehow accidentally enabled.
| auto arg3 = | ||
| fd.getFusionState(args_.at(1).index)->template as<Nvf::TensorView>(); | ||
|
|
||
| Nvf::Val* output = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto.
d033d72 to
c0228b7
Compare
test/test_jit_cuda_fuser.py
Outdated
| indies_tv = torch.randint(0, lookup_size, (num_elements,), dtype=torch.float, device="cuda").to(dtype=torch.int) | ||
| sbf = torch.rand(num_elements, feat_dim, dtype=torch.float, device="cuda") | ||
|
|
||
| def failure(x_kj, idx_kj, sbf): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be able to support runtime scalar input. e.g. https://github.com/pytorch/pytorch/blob/391b593ca262432ccba1939f7448275cfd4f62e6/test/test_jit_cuda_fuser.py#L1106-L1110
runtime_dim shouldn't be a failure if the scalar is passed properly.
def func(x_kj, idx_kj, sbf, dim: int):
sbf_res = torch.index_select(x_kj, dim, idx_kj) * sbf
sbf_res = sbf_res + 17
return sbf_res
And you would also need to run it like self._run_helper(t_jit, t, lookup_tv, indies_tv, sbf, 0)
jjsjann123
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, minor comments on proper test for runtime scalar, looks like the test is still not quite right, but code logic looks about right. I'm stamping it.
eaa1fbe to
eadcb1c
Compare
This PR is frontend implementation of index_select(lookup_tv, dim, index_tv) fusion:
This PR depends on index_select frontend PR #2183
OP def is here: https://pytorch.org/docs/stable/generated/torch.index_select.html