Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement OpFromGraph in PyTorch backend #956
Implement OpFromGraph in PyTorch backend #956
Changes from 5 commits
c4b20ec
9c64320
fdd5d5c
d98e68e
10a841f
b29be45
cefec02
0f18d8d
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to compile the inner function? Is that a thing in PyTorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was following what numba does where it jits the inner function - we could remove the inner torch.compile and just return op.fgraph if that seems more reasonable. That will still lead to some c-linker issues fwiw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the inner function, you only need to do indexing if the number of return values is more than 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Numba can only have inner compiled functions, I don't know if that's a requirement in pytorch, and whether it has any advantages. We don't do it for JAX
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not see / know of any requirement to have an inner compiled function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this because of the two inner functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has something bizarre to do with the combination of
fgraph_fn
being a bunch of nested functions, and this inner function being nested. The bigger part of that torch compiler isn't super great at handling conditionals user closure variables, at least in pytensor. It would probably need a much deeper dive. It looks like it might be something that can happen with other functions.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's worrisome. What error did you get without this disabling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have the error above in a comment, but it's essentially going to say the generated code from pytensor can't find some functions (all the inner functions returned in torch dispatch)