Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make python decomp for native_batch_norm CompositeImplicitAutograd, remove native_batch_norm from core aten opset #107791

Closed
wants to merge 1 commit into from

Conversation

andrewor14
Copy link
Contributor

Summary:
(From Brian Hirsh)

Description copied from what I put in a comment in this PR: #106329

So, the slightly-contentious idea behind this PR is that lower in the stack, I updated torch._decomps.get_decomps() to check not only the decomp table to see if a given op has a decomposition available, but to also check the dispatcher for any decomps registered to the CompositeImplicitAutograd key (link: https://github.com/pytorch/pytorch/pull/105865/files#diff-7008e894af47c01ee6b8eb94996363bd6c5a43a061a2c13a472a2f8a9242ad43R190)

There's one problem though: we don't actually make any hard guarantees that a given key in the dispatcher points does or does not point to a decomposition. We do rely pretty heavily, however, on the fact that everything registered to the CompositeImplicitAutograd key is in fact a decomposition into other ops.

QAT would like this API to faithfully return "the set of all decomps that would have run if we had traced through the dispatcher". However, native_batch_norm is an example of an op that has a pre-autograd decomp registered to it (through op.py_impl(), but the decomp is registered directly to the Autograd key instead of being registered to the CompositeImplicitAutograd key.

If we want to provide a guarantee to QAT that they can programatically access all decomps that would have run during tracing, then we need to make sure that every decomp we register to the Autograd key is also registered to the CompositeImplicitAutograd key.

This might sound kind of painful (since it requires auditing), but I think in practice this basically only applies to native_batch_norm.

Test Plan: python test/test_decomp.py

Differential Revision: D48607575

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 23, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107791

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4f44fd7 with merge base 3022a39 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48607575

…emove native_batch_norm from core aten opset (pytorch#107791)

Summary:
Pull Request resolved: pytorch#107791

(From Brian Hirsh)

Description copied from what I put in a comment in this PR: pytorch#106329

So, the slightly-contentious idea behind this PR is that lower in the stack, I updated torch._decomps.get_decomps() to check not only the decomp table to see if a given op has a decomposition available, but to also check the dispatcher for any decomps registered to the CompositeImplicitAutograd key (link: https://github.com/pytorch/pytorch/pull/105865/files#diff-7008e894af47c01ee6b8eb94996363bd6c5a43a061a2c13a472a2f8a9242ad43R190)

There's one problem though: we don't actually make any hard guarantees that a given key in the dispatcher points does or does not point to a decomposition. We do rely pretty heavily, however, on the fact that everything registered to the CompositeImplicitAutograd key is in fact a decomposition into other ops.

QAT would like this API to faithfully return "the set of all decomps that would have run if we had traced through the dispatcher". However, native_batch_norm is an example of an op that has a pre-autograd decomp registered to it (through op.py_impl(), but the decomp is registered directly to the Autograd key instead of being registered to the CompositeImplicitAutograd key.

If we want to provide a guarantee to QAT that they can programatically access all decomps that would have run during tracing, then we need to make sure that every decomp we register to the Autograd key is also registered to the CompositeImplicitAutograd key.

This might sound kind of painful (since it requires auditing), but I think in practice this basically only applies to native_batch_norm.

Test Plan: python test/test_decomp.py

Differential Revision: D48607575

fbshipit-source-id: 965696eab4119f41d46f1ce4ecb1c20cc788545e
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48607575

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 24, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@andrewor14 andrewor14 deleted the export-D48607575 branch August 24, 2023 16:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants