Skip to content

An additional approach to implementing rules #338

Open
@willtebbutt

Description

@willtebbutt

There are basically two reasons to implement rules:

  1. to define AD. For example, you do have to tell an AD system somewhere how to differentiate addition and multiplication of floats,
  2. to make AD faster, without changing the semantics.

For 1 we obviously can't get around defining rules, however, for 2 we tend to implement rules in the same way as for 1 -- by completely over-riding any particular AD and just telling it how to differentiate a thing. However, one thing that we've not explored to a particularly great extent is re-writing code to make it more AD friendly, and then just saying "run AD on this".

Leaving aside concerns about the best way to achieve a code re-write for a minute, suppose that you wished to implement an rrule for *(::AbstractMatrix, ::Diagonal). LinearAlgebra implements this as follows:

(*)(A::AbstractMatrix, D::Diagonal) =
    rmul!(copyto!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A), D)

The problem from the perspective of a reverse-mode AD tool (that doesn't know how to handle mutation) is that the underlying implementation of this non-mutating operation is mutating. However, it is really quite clear how a non-mutating version of this operation could be implemented by looking at the definition of rmul!. Specifically, something like

A .* permutedims(D.diag)

This is the kind of code that we could plausible hope to run one of our current (or near-future) reverse-mode AD tools on, and have it do something sensible, whereas there was really no hope with LinearAlgebras definition.

Moreover, this kind of approach seems simpler for the rule-writer: rather than having to know how to differentiate a function, the rule-implementer just needs to know how to re-write the primal pass in a way that is more friendly towards AD.

This kind of approach is only valuable if there's functionality that could be easily re-written in an AD-friendly manner. My hypothesis is that there is lots of functionality in Base / the standard libraries that satisfies this because it was implemented by

  1. implementing a mutating version of a function (e.g. gemm!)
  2. implementing the non-mutating version of a function in terms of the mutating version. (e.g. gemm)

This could provide a very simple partial solution to #232 by alleviating the need for generic rules in favour of code re-writes which are much more straightforward to achieve and lets the AD system auto-generate appropriate cotangents.

Thoughts on the principle? We can discuss implementation details once we've established whether or not we basically like the idea.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions