-
Notifications
You must be signed in to change notification settings - Fork 23k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fx] make fx.wrap idempotent #104838
[fx] make fx.wrap idempotent #104838
Conversation
Previously, if you called `torch.fx.wrap()` on the same thing twice, it would add two entries to `_wrapped_fns_to_patch`. Then, when tracing, the patcher would process them both. On the second entry, the patcher would double-wrap the fn (e.g. `wrap(wrap(orig_fn))`) This makes it so that wrapping is observable after the trace. While normally, a Patcher instance will "revert" the wrapping after tracing, the double wrapped function goes from `wrap(wrap(orig_fn)) -> wrap(orig_fn)`. This happens to work in normal fx stuff (after all, the wrapper function will behave exactly like the original function). But it upsets torch.package, which is not expecting to see a weird wrapper function in the graph.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104838
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9ffe96b: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering when will this be synced to internal codebase?
cc @osalpekar can you comment on the time to sync this to internal? I'm not sure the current diff train SLA. |
Since this was merged during the weekend, it will be imported in tonight's (07/10) diff-train import and landed tomorrow (07/11). |
Previously, if you called
torch.fx.wrap()
on the same thing twice, it would add two entries to_wrapped_fns_to_patch
. Then, when tracing, the patcher would process them both. On the second entry, the patcher would double-wrap the fn (e.g.wrap(wrap(orig_fn))
)This makes it so that wrapping is observable after the trace. While normally, a Patcher instance will "revert" the wrapping after tracing, the double wrapped function goes from
wrap(wrap(orig_fn)) -> wrap(orig_fn)
.This happens to work in normal fx stuff (after all, the wrapper function will behave exactly like the original function). But it upsets torch.package, which is not expecting to see a weird wrapper function in the graph.
This PR adds a dictionary to deduplicate
wrap()
calls, ensuring that the patcher only operates each once per frame-fn pair.