-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
An additional approach to implementing rules #338
Comments
This would be useful. Zygote uses this approach in a few places (with For another example, With the suggested approach, we would just reimplement But is this any easier than enabling a rule to call back into an AD, discussed in JuliaDiff/ChainRulesCore.jl#68? |
It feels like it has a slightly different set of requirements that might be a bit simpler to handle, in a similar vein to JuliaDiff/ChainRulesCore.jl#270 . edit: I'm glad to hear that you found it easy to arrive at a use-case! |
Doesn't |
No it doesn't. At some level you need to hit primitives that have defined adjoints/rrules, but you don't need to do it for the entire function. A really nice example of this was the now-removed workaround for
I have so many use-cases for this. 🙂 |
Oh I see, that's pretty neat! Thanks for the example |
With the now-merged calling back into AD mechanism I found this pattern to be quite useful! function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(complicated_f), args...)
rrule_via_ad(config, simple_f, args...)
end which could probably be done by a convenience macro like I'm currently trying it for differentiating through config constructors of physics solvers, which in our case are large structs with heterogeneous field types, few of which are differentiable. Other fields involve a lot of pre-computation including calls into non-Julia code, mutation, fft size heuristics and many sanity-check assertions, which is nice to explicitly bypass when constructing pullbacks. |
There are basically two reasons to implement rules:
For 1 we obviously can't get around defining rules, however, for 2 we tend to implement rules in the same way as for 1 -- by completely over-riding any particular AD and just telling it how to differentiate a thing. However, one thing that we've not explored to a particularly great extent is re-writing code to make it more AD friendly, and then just saying "run AD on this".
Leaving aside concerns about the best way to achieve a code re-write for a minute, suppose that you wished to implement an
rrule
for*(::AbstractMatrix, ::Diagonal)
.LinearAlgebra
implements this as follows:The problem from the perspective of a reverse-mode AD tool (that doesn't know how to handle mutation) is that the underlying implementation of this non-mutating operation is mutating. However, it is really quite clear how a non-mutating version of this operation could be implemented by looking at the definition of
rmul!
. Specifically, something likeThis is the kind of code that we could plausible hope to run one of our current (or near-future) reverse-mode AD tools on, and have it do something sensible, whereas there was really no hope with
LinearAlgebra
s definition.Moreover, this kind of approach seems simpler for the rule-writer: rather than having to know how to differentiate a function, the rule-implementer just needs to know how to re-write the primal pass in a way that is more friendly towards AD.
This kind of approach is only valuable if there's functionality that could be easily re-written in an AD-friendly manner. My hypothesis is that there is lots of functionality in Base / the standard libraries that satisfies this because it was implemented by
gemm!
)gemm
)This could provide a very simple partial solution to #232 by alleviating the need for generic rules in favour of code re-writes which are much more straightforward to achieve and lets the AD system auto-generate appropriate cotangents.
Thoughts on the principle? We can discuss implementation details once we've established whether or not we basically like the idea.
The text was updated successfully, but these errors were encountered: