Skip to content

Conversation

@azazhu
Copy link
Collaborator

@azazhu azazhu commented Nov 15, 2022

This PR is backend implementation of index_select(lookup_tv, dim, index_tv) fusion:

  • support arbitrarily rank
  • support arbitrarily dim value
  • index_tv can become consumer;
  • lookup_tv cannot become consumer.

OP def is here: https://pytorch.org/docs/stable/generated/torch.index_select.html

TODO: move index calculation from codegen.cpp to lower_index.cpp

azazhu and others added 3 commits November 15, 2022 09:08
…port arbitrarily rank; 2) support arbitrarily dim value; 3) index_tv can become consumer; 4) lookup_tv cannot become consumer. TODO: move index calculation from codegen.cpp to lower_index.cpp
2) mv test_index_select.cpp to torch/csrc/jit/codegen/cuda/test/
3) refine code
@naoyam
Copy link
Collaborator

naoyam commented Nov 15, 2022

@zasdfgbnm added support of select in #2179. I think it'd be relatively straightforward to build index_select on top of his work. Please take a look at the PR, especially around here: https://github.com/csarofeen/pytorch/pull/2179/files#diff-3c08060d5a1993c9599d5cbb2d7a805258ae1b4e166682b497c4f5fd376ba5d3R199-R205

@azazhu azazhu force-pushed the fw/idx_sel_backend branch 2 times, most recently from e67cb84 to 4b2fea4 Compare November 21, 2022 08:13
@azazhu azazhu force-pushed the fw/idx_sel_backend branch from 4b2fea4 to ddd9742 Compare November 21, 2022 09:07
@azazhu
Copy link
Collaborator Author

azazhu commented Nov 21, 2022

@zasdfgbnm added support of select in #2179. I think it'd be relatively straightforward to build index_select on top of his work. Please take a look at the PR, especially around here: https://github.com/csarofeen/pytorch/pull/2179/files#diff-3c08060d5a1993c9599d5cbb2d7a805258ae1b4e166682b497c4f5fd376ba5d3R199-R205

Referred to Xiang's impl and moved lower_index from codegen.cpp to lower_index.cpp

@naoyam
Copy link
Collaborator

naoyam commented Nov 21, 2022

@azazhu Can you please also add a validation check to validateIr so that all of the lookup TVs are ensured to be fusion inputs? Something like:

for (auto expr: fusion->exprs()) {
   if (expr->isA<SelectOp>() || expr->isA<IndexSelectOp>()) {
     TORCH_CHECK(expr->input(0)->isFusionInput(), "Lookup input must be a fusion input: ", expr->toString());
   }
}

@azazhu
Copy link
Collaborator Author

azazhu commented Nov 22, 2022

@azazhu Can you please also add a validation check to validateIr so that all of the lookup TVs are ensured to be fusion inputs? Something like:

for (auto expr: fusion->exprs()) {
   if (expr->isA<SelectOp>() || expr->isA<IndexSelectOp>()) {
     TORCH_CHECK(expr->input(0)->isFusionInput(), "Lookup input must be a fusion input: ", expr->toString());
   }
}

Done. see https://github.com/csarofeen/pytorch/pull/2183/files#diff-e9719a5d2b00c30404cd4d11ebaf37dea60b9447ee893aaddb9a9ef6fd9aaebbR1313-R1323

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just left one more minor comment.

Thanks for adding this new operator!

@naoyam naoyam merged commit 3edc643 into csarofeen:devel Nov 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants