Description
There are basically two reasons to implement rules:
- to define AD. For example, you do have to tell an AD system somewhere how to differentiate addition and multiplication of floats,
- to make AD faster, without changing the semantics.
For 1 we obviously can't get around defining rules, however, for 2 we tend to implement rules in the same way as for 1 -- by completely over-riding any particular AD and just telling it how to differentiate a thing. However, one thing that we've not explored to a particularly great extent is re-writing code to make it more AD friendly, and then just saying "run AD on this".
Leaving aside concerns about the best way to achieve a code re-write for a minute, suppose that you wished to implement an rrule
for *(::AbstractMatrix, ::Diagonal)
. LinearAlgebra
implements this as follows:
(*)(A::AbstractMatrix, D::Diagonal) =
rmul!(copyto!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A), D)
The problem from the perspective of a reverse-mode AD tool (that doesn't know how to handle mutation) is that the underlying implementation of this non-mutating operation is mutating. However, it is really quite clear how a non-mutating version of this operation could be implemented by looking at the definition of rmul!
. Specifically, something like
A .* permutedims(D.diag)
This is the kind of code that we could plausible hope to run one of our current (or near-future) reverse-mode AD tools on, and have it do something sensible, whereas there was really no hope with LinearAlgebra
s definition.
Moreover, this kind of approach seems simpler for the rule-writer: rather than having to know how to differentiate a function, the rule-implementer just needs to know how to re-write the primal pass in a way that is more friendly towards AD.
This kind of approach is only valuable if there's functionality that could be easily re-written in an AD-friendly manner. My hypothesis is that there is lots of functionality in Base / the standard libraries that satisfies this because it was implemented by
- implementing a mutating version of a function (e.g.
gemm!
) - implementing the non-mutating version of a function in terms of the mutating version. (e.g.
gemm
)
This could provide a very simple partial solution to #232 by alleviating the need for generic rules in favour of code re-writes which are much more straightforward to achieve and lets the AD system auto-generate appropriate cotangents.
Thoughts on the principle? We can discuss implementation details once we've established whether or not we basically like the idea.