-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is it possible to extract pullback as a normal function (not closure)? #202
Comments
That sounds super cool.
I can't workout how exactly Yota handles sharing arbitary intermediate state between the forward and reverse passes. function chainrules_pullback(ȳ, state)
return state(ȳ)
end where state is that arbitary intermediate state that you have stored on the tape, |
Not yet. The plan is to reuse parts of Flux/ONNX.jl, but implement both - import and export of graphs. All in all, ONNX goes along with the same ideas as Yota, but surely doing proper integration will take a lot of time.
I'm totally fine with it. I think ONNX should land to Yota earlier than integration with ChainRules because, as you mentioned, ONNX is more risky, so it's worth to work it out as early as possible. Consider this issue as a pre-check of possible solution(s).
For now, let's just assume it's doable. To give you some intuition, it's pretty easy to rewrite any primitive in tape to make any stet more explicit and then reuse it later. Common subexpression elimination and other optimizations should take care of most of produced inefficiencies.
In this example, is |
yes. the pullback closure |
Yeah, it may work then. The general workflow will be:
This pullback/state may not be compatible with ONNX, but it still sounds good to have ONNX-only and ChainRules-only computational graphs for the first version. I think this solution is a good start, so I'm closing this issue. Next steps for me are to actually implement integration with ONNX and then revisit integration with ChainRules. |
One idea i have for maybe dealing with ONNX would be to do special things for |
I was thinking about integrating ChainRules into Yota. However, there's a major obstacle for this - pullback functions returned by
rrule
are closures, and Yota doesn't support closures.There are several reasons for not supporting closures (or at least not making them the default mechanism), but the main one is bad interoperability with other systems. For example, one big thing I want to have from AD system is the ability to serialize computational graphs to ONNX format. ONNX spec describes nodes as:
So every operator should be uniquely identified just by its name, and the name should be known beforehand. Closures don't seem to match this rule. It's ok to have some graphs which are not serializable to ONNX, but as far as I understand using ChainRules means that all nearly graphs should use closures.
I started to think if it's possible to "extract" pullback and convert it to an ordinary function. I even have a rough plan how to rewrite
rrule
s using IRTools, for example. But I'm 95% sure it won't end well :)But maybe I'm missing something? Maybe there are previous discussions or implementations in other AD systems which somehow get around this issue?
The text was updated successfully, but these errors were encountered: