Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to extract pullback as a normal function (not closure)? #202

Closed
dfdx opened this issue Aug 5, 2020 · 5 comments
Closed

Comments

@dfdx
Copy link

dfdx commented Aug 5, 2020

I was thinking about integrating ChainRules into Yota. However, there's a major obstacle for this - pullback functions returned by rrule are closures, and Yota doesn't support closures.

There are several reasons for not supporting closures (or at least not making them the default mechanism), but the main one is bad interoperability with other systems. For example, one big thing I want to have from AD system is the ability to serialize computational graphs to ONNX format. ONNX spec describes nodes as:

Computation nodes are comprised of a name, the name of an operator that it invokes, a list of named inputs, a list of named outputs, and a list of attributes.

So every operator should be uniquely identified just by its name, and the name should be known beforehand. Closures don't seem to match this rule. It's ok to have some graphs which are not serializable to ONNX, but as far as I understand using ChainRules means that all nearly graphs should use closures.

I started to think if it's possible to "extract" pullback and convert it to an ordinary function. I even have a rough plan how to rewrite rrules using IRTools, for example. But I'm 95% sure it won't end well :)

But maybe I'm missing something? Maybe there are previous discussions or implementations in other AD systems which somehow get around this issue?

@oxinabox oxinabox transferred this issue from JuliaDiff/ChainRules.jl Aug 6, 2020
@oxinabox
Copy link
Member

oxinabox commented Aug 6, 2020

cf dfdx/Yota.jl#65

For example, one big thing I want to have from AD system is the ability to serialize computational graphs to ONNX format.

That sounds super cool.
Have you got any prototype or full version of this working already for Yota?
I feel like without that, it is hard to have a proper discussion.
It seems like a super hard problem, and that ChainRules using closures would be just one of many challenges to overcome.
If it is something that is actively being working on I am happy to talk and work it out.
But otherwise it might be something that we can put off solving until ChainRules v2.0.

So every operator should be uniquely identified just by its name, and the name should be known beforehand.
...I started to think if it's possible to "extract" pullback and convert it to an ordinary function

I can't workout how exactly Yota handles sharing arbitary intermediate state between the forward and reverse passes.
For example the need for F in both in the the logabsdet example in the docs.
but assuming you have a way to do that then could one define the pullback function for all things that have chainrules as being:

function chainrules_pullback(ȳ, state)
    return state(ȳ)
end

where state is that arbitary intermediate state that you have stored on the tape,
and in this instance contains the pullback closure that is returned from rrule ?

@dfdx
Copy link
Author

dfdx commented Aug 7, 2020

Have you got any prototype or full version of this working already for Yota?

Not yet. The plan is to reuse parts of Flux/ONNX.jl, but implement both - import and export of graphs. All in all, ONNX goes along with the same ideas as Yota, but surely doing proper integration will take a lot of time.

But otherwise it might be something that we can put off solving until ChainRules v2.0.

I'm totally fine with it. I think ONNX should land to Yota earlier than integration with ChainRules because, as you mentioned, ONNX is more risky, so it's worth to work it out as early as possible. Consider this issue as a pre-check of possible solution(s).

I can't workout how exactly Yota handles sharing arbitary intermediate state between the forward and reverse passes [...]

For now, let's just assume it's doable. To give you some intuition, it's pretty easy to rewrite any primitive in tape to make any stet more explicit and then reuse it later. Common subexpression elimination and other optimizations should take care of most of produced inefficiencies.

but assuming you have a way to do that then could one define the pullback function for all things that have chainrules as being [...]

In this example, is state just the pullback function?

@oxinabox
Copy link
Member

oxinabox commented Aug 11, 2020

In this example, is state just the pullback function?

yes. the pullback closure

@dfdx
Copy link
Author

dfdx commented Aug 11, 2020

Yeah, it may work then. The general workflow will be:

  1. During forward pass, record to a tape all primitives from Yota's own @diffrules + functions for which rrule() is overloaded.
  2. Rewrite all rrule-based primitives to actual calls to rrule, capture result and pullback.
  3. Define a @diffrule which during reverse pass records chainrules_pullback() with corresponding state.

This pullback/state may not be compatible with ONNX, but it still sounds good to have ONNX-only and ChainRules-only computational graphs for the first version.

I think this solution is a good start, so I'm closing this issue. Next steps for me are to actually implement integration with ONNX and then revisit integration with ChainRules.

@dfdx dfdx closed this as completed Aug 11, 2020
@oxinabox
Copy link
Member

One idea i have for maybe dealing with ONNX would be to do special things for @scalar_rule.
We have a lot more control over what kinds of program structure is used when the @scalar_rule macro is used to declare the rule.
So we could probably change the code for that break some of those out to have the pullbacks available as seperate functions that are referenced from the rrule.
but we will see once we have ONNX stuff to start with.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants