-
Notifications
You must be signed in to change notification settings - Fork 23k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make python decomp for native_batch_norm CompositeImplicitAutograd, remove native_batch_norm from core aten opset #107791
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107791
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4f44fd7 with merge base 3022a39 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D48607575 |
…emove native_batch_norm from core aten opset (pytorch#107791) Summary: Pull Request resolved: pytorch#107791 (From Brian Hirsh) Description copied from what I put in a comment in this PR: pytorch#106329 So, the slightly-contentious idea behind this PR is that lower in the stack, I updated torch._decomps.get_decomps() to check not only the decomp table to see if a given op has a decomposition available, but to also check the dispatcher for any decomps registered to the CompositeImplicitAutograd key (link: https://github.com/pytorch/pytorch/pull/105865/files#diff-7008e894af47c01ee6b8eb94996363bd6c5a43a061a2c13a472a2f8a9242ad43R190) There's one problem though: we don't actually make any hard guarantees that a given key in the dispatcher points does or does not point to a decomposition. We do rely pretty heavily, however, on the fact that everything registered to the CompositeImplicitAutograd key is in fact a decomposition into other ops. QAT would like this API to faithfully return "the set of all decomps that would have run if we had traced through the dispatcher". However, native_batch_norm is an example of an op that has a pre-autograd decomp registered to it (through op.py_impl(), but the decomp is registered directly to the Autograd key instead of being registered to the CompositeImplicitAutograd key. If we want to provide a guarantee to QAT that they can programatically access all decomps that would have run during tracing, then we need to make sure that every decomp we register to the Autograd key is also registered to the CompositeImplicitAutograd key. This might sound kind of painful (since it requires auditing), but I think in practice this basically only applies to native_batch_norm. Test Plan: python test/test_decomp.py Differential Revision: D48607575 fbshipit-source-id: 965696eab4119f41d46f1ce4ecb1c20cc788545e
This pull request was exported from Phabricator. Differential Revision: D48607575 |
a688576
to
4f44fd7
Compare
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary:
(From Brian Hirsh)
Description copied from what I put in a comment in this PR: #106329
So, the slightly-contentious idea behind this PR is that lower in the stack, I updated torch._decomps.get_decomps() to check not only the decomp table to see if a given op has a decomposition available, but to also check the dispatcher for any decomps registered to the CompositeImplicitAutograd key (link: https://github.com/pytorch/pytorch/pull/105865/files#diff-7008e894af47c01ee6b8eb94996363bd6c5a43a061a2c13a472a2f8a9242ad43R190)
There's one problem though: we don't actually make any hard guarantees that a given key in the dispatcher points does or does not point to a decomposition. We do rely pretty heavily, however, on the fact that everything registered to the CompositeImplicitAutograd key is in fact a decomposition into other ops.
QAT would like this API to faithfully return "the set of all decomps that would have run if we had traced through the dispatcher". However, native_batch_norm is an example of an op that has a pre-autograd decomp registered to it (through op.py_impl(), but the decomp is registered directly to the Autograd key instead of being registered to the CompositeImplicitAutograd key.
If we want to provide a guarantee to QAT that they can programatically access all decomps that would have run during tracing, then we need to make sure that every decomp we register to the Autograd key is also registered to the CompositeImplicitAutograd key.
This might sound kind of painful (since it requires auditing), but I think in practice this basically only applies to native_batch_norm.
Test Plan: python test/test_decomp.py
Differential Revision: D48607575