Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explain the Abstract Primals Problem #343

Open
oxinabox opened this issue May 4, 2021 · 11 comments
Open

Explain the Abstract Primals Problem #343

oxinabox opened this issue May 4, 2021 · 11 comments
Assignees
Labels
documentation Improvements or additions to documentation

Comments

@oxinabox
Copy link
Member

oxinabox commented May 4, 2021

There is a problem that comes up to do with abstract primals.
Most commonly in the case of AbstractArrays.
We don't have a good explination of what the problem is anywhere, it is scattered across various issues and PRs over various issues.
@willtebbutt has spent a bunch of time thinking about it.

I propose that we should open a docs PR that clearly explains it, with examples etc.
As part of the design docs section.
Once we have that PR open, we can talk more about solving it.

@mcabbott I were discussing this on slack. (they might post there note on this later)
The below is roughly extracted from that It kind of discusses a lot of the problem, thought it isn't super clear.
Since the whole discussion exists because we don't have a clear eplination of what the problem is.


Rough ugly notes:

The problem is that we want to define rules not just for Arrays but also for StaticArrays.
Which nearest common super type is AbstractArray.
But if you do that then people say this is bad because it will mean Diagonal will take O(N^2) rather than O(N).
One could say that the user op'ed into this, since they used a AbstractMatrix and so it is on the method author (in this case the rule author) to provide an optimized method if appropriate.

But there is a greater problem.
IIRC some operations on FillArray give the wrong answer, not just the wrong time complexity, if you treat it as a generic Array.
Mike and Will T had a big argument about it.

However:
If you only define rules on fundermental array types like Array and maybe StaticArray and GPUArrays, you get the correct time complexity, and you avoid defining things so generally that they break weirder array types.
And the AD will decompose the wrapper arrays correctly and get what is inside.
FillArrays work correctly if you treat them like structs. (As do all wrapper arrays).
@willtebbutt has spent ages thinking about this.
o understand the problem though I think we still don't have a great solution.
I am hoping Will will write something that we can put in the docs. (We have a few nice writeups like that on there)
I think we might actually need to formally introduce the idea of a fundermental array type as a trait maybe into ChainRules.
Or maybe just add a StaticArrays dep (it's basically a stdlib at this point) and then we can have a union for them

Now defining things only on concrete types seems unidiomatic.
Julia code normally is define on the general case then multiple dispatch is used to provide specific optizations for the specific case
However, AD has 2 ways to achieve functionality. Rule and Generation. And we want to endure overall most specific gets hit.
Julia should work with the most specific function takes precedence over the less specific function right?
The most specific function is the one that is most specific and customised for that type.
With AD the most specific functionality is always available: it is to let the AD run, which generates code for the exact input.
So when a rule is applied it is actually getting in the way of the most specific type.
E.g. when a rule is defined for AbstractMatrix that prevents the AD from generating the more specific functionality for Symmeteic{Diagonal{...}} .

When I say wrapper arrays I mean anything that has a parent array. (According to the parent function)
Though it really is more general: it is anything which has the method underconsideration defined without resorting to ccall.
Since as long as it doesn't resort to ccall the AD will be able to generate a pullback for the method.
That generated pullback will call something that we will have an optimised rule for. Might not even be the same function but it will be something, and so we are solid.
Remember AD systems do generally generate code that is optimal if they don't error, except if there is specific domain expertise the rule author is applying.
So we only need the rule to catch the errorring case.

Julia's specialisation rules do apply to AD rules and to code generated by AD. But the AD doesn't get to generate it's code if it hits a rule.
And the code will be more specialised than something hitting abstract matrix. This is the thing where the AD on Diagonal would get to break things down according to the primal method definition and would end up hitting a rule for Vector, rather than for matrix.
Where as the compiler's specialisation of a rrule for AbstractMatrix's is not so specialised and ends up looking vat a bunch of zero elements.
The AD would have done the better thing if the rule hadn't have been defined.
Because AD systems are good at findings derivatives.

Optimal as in identical to the code someone would write for this concrete input by hand.
AD doesn't always do it, because sometimes there is domain knowledge to apply. But in simple cases like decomposing function on wrapper arrays, it does.

As I said: if you have domain knowledge then you can do better.
But still generally that domain knowledge will be able to be applied to a "fundermental" array type (,like Array, and maybe StaticArray, GPUArray) and it woll still end up benifitting the wrapper arrays type.
and the generated code the comes for the pullback between the wrapper type and the type that has the domain knowledge rule for will be optimal (in the sense of being basically identical to what a human would do to do this).

@oxinabox oxinabox added the documentation Improvements or additions to documentation label May 4, 2021
@mcabbott
Copy link
Member

mcabbott commented May 4, 2021

I wrote some tidier comments here: JuliaDiff/ChainRules.jl#337 (comment) . This takes norm as the particular example, because I think it's helpful to be concrete, and times lots of things, because factors of 10^5 are hard to ignore in the real world. I think the lessons are fairly general, but not totally general, and that trying to elucidate what the boundaries are might be useful.

The "big argument" about FillArray is here: FluxML/Zygote.jl#863 and actually quite civil! Tl;dr is that I don't find time-complexity arguments all that compelling, but do think preserving (some) structural constraints is a good idea.

@oxinabox
Copy link
Member Author

This is also discussed in some detail here.
JuliaDiff/ChainRules.jl#232 (comment)

@mzgubic
Copy link
Member

mzgubic commented Jun 15, 2021

plan

Having read the discussions in JuliaDiff/ChainRules.jl#337, #347, and JuliaDiff/ChainRules.jl#232, it seems that:

  1. We want to write abstractly typed rules
  2. We want a mechanism (projection, rebasis, name tbc) to make sure the output of those abstractly typed rules is the same that the AD would produce. In the *(::Diagonal, ::Matrix) example this would be Tangent{Diagonal}(...) rather than a Matrix.
  3. We want a mechanism to manually opt out of these abstractly typed rules. There are two reasons for that: The first reason is that it is easier to opt out (one liner) than to opt in (writing a rule). The other reason is that performance penalty from using a generic rule is (generally) smaller than that of AD doing something it is really bad at (like repeated getindex).

Is that a fair conclusion?

It appears that while the automatic decisions on opt-in/out via hasmethod is attractive, it does have some pretty significant issues and we have decided against it.

aside

One idea that has been mentioned but not discussed extensively is the second point in JuliaDiff/ChainRules.jl#232 (comment), i.e. how to prevent abstractly typed rules from overriding what would have been an efficient AD pullback from transforming a specialised forward pass.

Could we require the signature of the rrule to match exactly the signature of the forward pass? Note that this can be done either in individual AD systems, or in ChainRules.

Doing this would draw a parallel to the normal dispatch. I suppose we win in cases where AD is more efficient at transforming the specialised forward pass than the fallback rrule. And similarly we lose in cases where the it is slower (because it is written with loops) or does not work at all.

Overall I still think the plan above is better. But thought I'd bring it up in case people have opinions.

@willtebbutt
Copy link
Member

We want a mechanism to manually opt out of these abstractly typed rules. There are two reasons for that: The first reason is that it is easier to opt out (one liner) than to opt in (writing a rule). The other reason is that performance penalty from using a generic rule is (generally) smaller than that of AD doing something it is really bad at (like repeated getindex).

Overall I still think the plan above is better. But thought I'd bring it up in case people have opinions.

I'm still very uneasy about the idea of having to opt-out of rules -- my experience has generally been that, if I've written a specialised method for some type, I want AD to have a crack at it.

We are all in agreement that if you define a type, don't implement a specialised method of a particular function for it, then you want to hit the generic projected fallback. That feels like a good tradeoff.

I'm finding it hard to know how to make progress on the in-between ground though. Maybe we need sketches of code in both cases or something?

@mzgubic
Copy link
Member

mzgubic commented Jun 15, 2021

If I understand correctly, in the case where we define a specialised method for some type, we want to do one of the two:

  1. Prefer AD to work out the pullback. If that is bad, we can always define a rule which does the more optimal thing.

  2. Prefer the abstractly typed rule (status quo). If the abstractly typed rule is bad, we can either define a better rule (already possible), or opt out and run AD on it (needs work).

I don't have enough experience to guess which option is better generally. On the other hand, opt-out seems to require less effort than opt-in, so I have a slight preference for that.

What are some examples in which 1) is better? I thought it would be better for *(::Diagonal, ::Matrix), but it actually turns out that if the fallback ::typeof(*), A::AbstractVecOrMat{<:CommutativeMulNumber}, B::AbstractVecOrMat{<:CommutativeMulNumber} rrule is commented out, Zygote actually errors:

julia> n = 5;

julia> d = Diagonal(rand(n));

julia> m = rand(n,n);

julia> gradient(d, m -> sum(d*m), d, m)
ERROR: MethodError: objects of type Diagonal{Float64, Vector{Float64}} are not callable
Use square brackets [] for indexing an Array.
Stacktrace:
 [1] macro expansion
   @ ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0 [inlined]
 [2] _pullback(::Zygote.Context, ::Diagonal{Float64, Vector{Float64}}, ::var"#5#6", ::Diagonal{Float64, Vector{Float64}}, ::Matrix{Float64})
   @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:9
 [3] _pullback(::Diagonal{Float64, Vector{Float64}}, ::Function, ::Diagonal{Float64, Vector{Float64}}, ::Matrix{Float64})
   @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface.jl:34
 [4] pullback(::Diagonal{Float64, Vector{Float64}}, ::Function, ::Diagonal{Float64, Vector{Float64}}, ::Vararg{Any, N} where N)
   @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface.jl:40
 [5] gradient(::Diagonal{Float64, Vector{Float64}}, ::Function, ::Vararg{Any, N} where N)
   @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface.jl:58
 [6] top-level scope
   @ REPL[8]:1

EDIT:

I think you miss some brackets:

Oh that's embarrassing...

Anyway, it seems that while the rrule scales terribly, the AD pullback is only faster above n~1000

Using the fallback rule

julia> @btime Zygote.pullback((d,m) -> d*m, d, m) setup=(n=5; d=Diagonal(rand(n)); m=rand(n,n))
  69.041 ns (1 allocation: 288 bytes)

julia> @btime Zygote.pullback((d,m) -> d*m, d, m) setup=(n=100; d=Diagonal(rand(n)); m=rand(n,n))
  10.750 μs (2 allocations: 78.20 KiB)

julia> @btime Zygote.pullback((d,m) -> d*m, d, m) setup=(n=1000; d=Diagonal(rand(n)); m=rand(n,n))
  1.895 ms (2 allocations: 7.63 MiB)

while AD (commenting out the rule) gets

julia> @btime Zygote.pullback((d,m) -> d*m, d, m) setup=(n=5; d=Diagonal(rand(n)); m=rand(n,n))
  52.701 μs (487 allocations: 23.45 KiB)

julia> @btime Zygote.pullback((d,m) -> d*m, d, m) setup=(n=100; d=Diagonal(rand(n)); m=rand(n,n))
  69.860 μs (489 allocations: 179.30 KiB)

julia> @btime Zygote.pullback((d,m) -> d*m, d, m) setup=(n=1000; d=Diagonal(rand(n)); m=rand(n,n))
  2.914 ms (489 allocations: 15.28 MiB)

@mcabbott
Copy link
Member

mcabbott commented Jun 15, 2021

I think you miss some brackets:

julia> Zygote.gradient((d,m) -> sum(d*m), d, m)[1]
5×5 Matrix{Float64}:
 3.24163  1.55452  2.24316  2.41607  2.82723
 3.24163  1.55452  2.24316  2.41607  2.82723
 3.24163  1.55452  2.24316  2.41607  2.82723
 3.24163  1.55452  2.24316  2.41607  2.82723
 3.24163  1.55452  2.24316  2.41607  2.82723

Arguably this is mathematically wrong (as it is nonzero off-diagonal) not just computationally inefficient (n^2 not n).

But this is an easy rule to add, partly because ChainRules depends on LinearAlgebra. The specialised method for the forward pass is this, not AD friendly:

(*)(D::Diagonal, A::AbstractMatrix) =
    lmul!(D, copyto!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A))

opt-out seems to require less effort than opt-in

I think the crucial asymmetry is that, if you are defining a specialised method for f on type T and thinking about whether it is AD friendly, then by definition you have both f and T loaded, and it doesn't seem crazy to ask you to load ChainRuleCore.jl to add a one-line opt-out.

Whereas if I want the generic f rule to apply to my T, I may not have f loaded at all in the package which defines T. Or conversely, if my f(x::AbstractArray) treats every such as a regular array of numbers, then I may want it to work with many Ts which are not defined anywhere connected to my package.

@mzgubic
Copy link
Member

mzgubic commented Jun 15, 2021

Whereas if I want the generic f rule to apply to my T, I may not have f loaded at all in the package which defines T. Or conversely, if my f(x::AbstractArray) treats every such as a regular array of numbers, then I may want it to work with many Ts which are not defined anywhere connected to my package.

I think we all agree that in this case (where no f(::T) method is defined) we want to hit the fallback rule, rather than let AD work.

As far as I understand, the question is whether in the case where the primal f(::T) is defined, and a fallback rule for f is defined, whether we prefer the fallback rule or the AD-generated pullback?

@mcabbott
Copy link
Member

mcabbott commented Jun 15, 2021

Ok. So I agree you can imagine some automated rule that says "skip the abstract rule if there is a more specific primal". But I think there are two problems with that idea. One is that there are a great many more specific primals, methods(*) lists possibly hundreds of special cases, many of which are (like the one I pasted in) not AD friendly. But most are fairly well-served by the abstract rule. The other is that the dispatch to more- and less-friendly routines often happens on a later function than the one which has the rule defined -- in the norm discussion, it's not the first ::AbstractArray method which gets specialised, but all the things it calls.

@mzgubic
Copy link
Member

mzgubic commented Jun 15, 2021

These are good points.

I can imagine that for some f and some set of arguments it can be called with, the primal methods are split into

  1. those in which ADing through the primal method errors
  2. those in which ADing through the primal method is slower than using the fallback rule
  3. those in which using a fallback rule is slower than ADing through the primal
  4. those in which using a fallback rule returns a wrong answer

Preferring fallback rules over AD

I suppose an argument for this option is that we want to prefer things working out of the box (even if less efficiently) over AD running into problems and throwing up a stacktrace. Making sure that things are efficient (or even tractable) should be a secondary concern that would require some action: opting out of rules, or writing a custom rule.

In this case:

  • We prefer the fallback rule over AD in case a specialised primal method is defined. (fixes 1 and 2)
  • We need a mechanism for opting out of a fallback rule (return nothing?), to run AD instead (fixes 3)
  • Fix 4 by doing the projection/rebasis inside fallback rules

Opt outs here are one liners.

Preferring AD over fallback rules

If we decide to prefer ADing over fallback rules, then:

  • We prefer AD over fallback rule, meaning 4 and 3 are solved
  • We need a mechanism to opt-in to using a fallback rule to solve 2 and 1.

How would this opt in mechanism look like? One is to define a more specific rule. This could be quite repetitive and might result in a large number of rules. Some refactoring would be needed for each function, to define some _fallback_rrule_f() which is then called by the specialisations that opt-in.

This sounds like it would be more code than opting-out. @willtebbutt did I get this right?


The information that we are missing is what the fraction of 1-4 types of methods are, and how important they are. Getting accurate number on this is hard, guesstimating them could be possible (but not by me).

Did I miss any arguments in the above?

@mcabbott
Copy link
Member

I don't know that there's a useful summary count, but I do think discussing particular examples is good, else it's easy to talk past each other. (Although I'm not sure I want to inflict the norm mega-thread on anyone.)

Fix 4 by doing the projection/rebasis

Going complex -> real, or full -> diagonal, is a projection, not a change of basis. I believe this is correct for many examples, such as the Diagonal * Matrix above. I hope it will be true in general, i.e. that an abstract rule for f(::AbstractMatrix) which is mathematically correct for Matrix will, once projected, always be correct for special matrices. Are there evil counter-examples?

@willtebbutt
Copy link
Member

willtebbutt commented Jun 15, 2021

This sounds like it would be more code than opting-out. @willtebbutt did I get this right?

I think that's a fair summary.

I'm becoming more convinced by the idea that making the opt-out mechanism work, provided that we deal with the projections / representational changes for the sake of correctness properly, is the way forward. I don't particularly like it, but I think I agree it's the lesser of two evils.

Are there evil counter-examples?

@wesselb and I have been discussing this a bit, and I suspect that you can side-step this issue by being careful about how you define the tangent space of any given variable. Taking the Diagonal{Float64, Vector{Float64}} example, defining the tangent space to be the Diagonal{Float64, Vector{Float64}} matrices (of the same size) deals with this issue. It think that it follows from this and linearity of cotangents that you can just drop the off-diagonal elements.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

4 participants