Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT]Support adv indexing using list. #37848

Closed
wants to merge 1 commit into from

Conversation

ailzhang
Copy link
Contributor

@ailzhang ailzhang commented May 5, 2020

We used to only support indexing through

  • numbers like x[0, 1]
  • tuple like x[(0, 1)]
  • tensor like x[torch.tensor([0, 1])]

This PR adds support for indexing through list which is equivalent to tensor.

  • x[[0, 1, 5]]
  • x[[0, 1], [0, 1]]
  • x[[[0, 1], [0, 1]], [[0, 1], [0, 1]]]

Note for x[[0, 1, 5]] we had a bug in AST conversion code so we used to treat it like x[0, 1, 5] which means it might accidentally run and produce wrong result(fixes #37286 fixes #18616), now that it's fixed we probably want to mark it as BC breaking.

@ailzhang ailzhang requested a review from apaszke as a code owner May 5, 2020 18:36
@ailzhang ailzhang changed the title [JIT]Fix adv indexing using list. [JIT]Support adv indexing using list. May 5, 2020
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label May 5, 2020
@ailzhang ailzhang added the module: bc-breaking Related to a BC-breaking change label May 5, 2020
@ailzhang ailzhang requested review from eellison and suo May 5, 2020 18:52
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@eellison eellison removed their request for review May 5, 2020 20:05
Copy link
Member

@suo suo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!! Left some comments in line about some utilities you might find helpful

test/test_jit.py Outdated Show resolved Hide resolved
@@ -571,7 +571,9 @@ def build_ExtSlice(ctx, base, extslice):
base = build_expr(ctx, expr.value)
sub_type = type(expr.slice)
if sub_type is ast.Index:
if isinstance(expr.slice.value, ast.Tuple) or isinstance(expr.slice.value, ast.List):
if isinstance(expr.slice.value, ast.Tuple):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are making a change to the python frontend, you probably need to make one to the string frontend as well. self.checkScript will make sure the behavior is the same between Python, TS+python, TS+no python.

Copy link
Contributor Author

@ailzhang ailzhang May 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw I switched to use self.checkScript in the test and it passed. Is string frontend the one that takes in function strings and parse it? If so this change might be shared by both since it's done at python ast level I think?

torch/csrc/jit/frontend/ir_emitter.cpp Outdated Show resolved Hide resolved
@ailzhang ailzhang force-pushed the fix_jit_adv_indexing_list branch from 3aa7050 to adcc03d Compare May 5, 2020 22:46
@ailzhang ailzhang requested a review from suo May 5, 2020 22:50
@ailzhang ailzhang force-pushed the fix_jit_adv_indexing_list branch from adcc03d to 9307068 Compare May 5, 2020 22:51
@ailzhang ailzhang force-pushed the fix_jit_adv_indexing_list branch from 9307068 to a3ff0fd Compare May 5, 2020 22:53
@dr-ci
Copy link

dr-ci bot commented May 5, 2020

💊 Build failures summary and remediations

As of commit a3ff0fd (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-CircleCI failure(s)

🕵️ 1 new failure recognized by patterns

The following build failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_macos_10_13_py3_test (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

May 05 16:57:40 [E request_callback_impl.cpp:99] Received error while processing request type 15: size mismatch, m1: [3 x 3], m2: [6 x 6] at ../aten/src/TH/generic/THTensorMath.cpp:41
May 05 16:57:37   test_debug_info (__main__.DistAutogradTestWithSpawn) ... skip (0.005s) 
May 05 16:57:38   test_dist_autograd_profiling (__main__.DistAutogradTestWithSpawn) ... ok (1.110s) 
May 05 16:57:39   test_embedding_bag_with_no_grad_tensors (__main__.DistAutogradTestWithSpawn) ... [W pybind_utils.h:712] Warning: Using sparse tensors in TorchScript is experimental. Many optimization pathways have not been thoroughly tested with sparse tensors. Please include the fact that the network is running sparse tensors in any bug reports submitted. (function operator()) 
May 05 16:57:39 [W pybind_utils.h:712] Warning: Using sparse tensors in TorchScript is experimental. Many optimization pathways have not been thoroughly tested with sparse tensors. Please include the fact that the network is running sparse tensors in any bug reports submitted. (function operator()) 
May 05 16:57:39 [W pybind_utils.h:712] Warning: Using sparse tensors in TorchScript is experimental. Many optimization pathways have not been thoroughly tested with sparse tensors. Please include the fact that the network is running sparse tensors in any bug reports submitted. (function operator()) 
May 05 16:57:39 [W pybind_utils.h:712] Warning: Using sparse tensors in TorchScript is experimental. Many optimization pathways have not been thoroughly tested with sparse tensors. Please include the fact that the network is running sparse tensors in any bug reports submitted. (function operator()) 
May 05 16:57:39 ok (1.178s) 
May 05 16:57:40   test_error_in_context (__main__.DistAutogradTestWithSpawn) ... [E request_callback_impl.cpp:99] Received error while processing request type 15: size mismatch, m1: [3 x 3], m2: [6 x 6] at ../aten/src/TH/generic/THTensorMath.cpp:41 
May 05 16:57:40 [E request_callback_impl.cpp:99] Received error while processing request type 15: size mismatch, m1: [3 x 3], m2: [6 x 6] at ../aten/src/TH/generic/THTensorMath.cpp:41 
May 05 16:57:40 [E request_callback_impl.cpp:99] Received error while processing request type 15: size mismatch, m1: [3 x 3], m2: [6 x 6] at ../aten/src/TH/generic/THTensorMath.cpp:41 
May 05 16:57:40 [E request_callback_impl.cpp:99] Received error while processing request type 15: size mismatch, m1: [3 x 3], m2: [6 x 6] at ../aten/src/TH/generic/THTensorMath.cpp:41 
May 05 16:57:40 ok (1.157s) 
May 05 16:57:41   test_grad_copy_sparse_indices_extra_ref (__main__.DistAutogradTestWithSpawn) ... /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py:1890: UserWarning: Argument order of nn.functional.embedding_bag was changed. Usage `embedding_bag(weight, input, ...)` is deprecated, and should now be `embedding_bag(input, weight, ...)`. 
May 05 16:57:41   warnings.warn("Argument order of nn.functional.embedding_bag was changed. " 
May 05 16:57:41 /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py:1890: UserWarning: Argument order of nn.functional.embedding_bag was changed. Usage `embedding_bag(weight, input, ...)` is deprecated, and should now be `embedding_bag(input, weight, ...)`. 
May 05 16:57:41   warnings.warn("Argument order of nn.functional.embedding_bag was changed. " 
May 05 16:57:41 /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py:1890: UserWarning: Argument order of nn.functional.embedding_bag was changed. Usage `embedding_bag(weight, input, ...)` is deprecated, and should now be `embedding_bag(input, weight, ...)`. 
May 05 16:57:41   warnings.warn("Argument order of nn.functional.embedding_bag was changed. " 
May 05 16:57:41 /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py:1890: UserWarning: Argument order of nn.functional.embedding_bag was changed. Usage `embedding_bag(weight, input, ...)` is deprecated, and should now be `embedding_bag(input, weight, ...)`. 
May 05 16:57:41   warnings.warn("Argument order of nn.functional.embedding_bag was changed. " 
May 05 16:57:41 [W pybind_utils.h:712] Warning: Using sparse tensors in TorchScript is experimental. Many optimization pathways have not been thoroughly tested with sparse tensors. Please include the fact that the network is running sparse tensors in any bug reports submitted. (function operator()) 

2 failures confirmed as flaky and can be ignored:

  • pytorch_xla_linux_bionic_py3_6_clang9_build
  • pytorch_linux_xenial_py3_6_gcc5_4_build

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 5 times.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@suo suo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome :) thanks!

if (subscript_expr.kind() == TK_NONE) {
type_hint = NoneType::get();
Value* index;
if (subscript_expr.kind() == TK_LIST_LITERAL) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question, maybe as a followup to this PR: seems like we should be able to handle any type of list, not just a list literal? e.g., the follow should work:

def f(x):
    ls = [0]
    ls.append(1)
    ls.append(2)
    return x[ls]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea good point, it should be possible (maybe by unwrapping a list out of the TK_VAR). I'll take a look and send a followup PR if it works. :D

@facebook-github-bot
Copy link
Contributor

@ailzhang merged this pull request in dd61821.

facebook-github-bot pushed a commit that referenced this pull request May 8, 2020
Summary:
Followup of #37848 I realized that it's better to condition on `Value` type instead of token type. So now it also support indexing through list variables (used to be list literal only).
Also apparently our eager frontend accept indexing with float list as well, so matched this edge case behavior as well.
Pull Request resolved: #37966

Reviewed By: suo

Differential Revision: D21439642

Pulled By: ailzhang

fbshipit-source-id: cedb8431ef38747d4aa9909a6bbf8e954dbe0e25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: bc-breaking Related to a BC-breaking change oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
4 participants