Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support chunked frule #92

Open
oxinabox opened this issue Jan 12, 2020 · 20 comments
Open

Support chunked frule #92

oxinabox opened this issue Jan 12, 2020 · 20 comments
Labels
design Requires some desgin before changes are made forward-mode Related to use of ChainRules for ForwardMode AD

Comments

@oxinabox
Copy link
Member

oxinabox commented Jan 12, 2020

from #90 it seems that @YingboMa wants frule to be able to be called on a Vector of sensitivies for the same primal value,
and get a of sensitivities vector back,
but without broadcasting ? (presumably because that would also recompute the forwad primal)

I don't understand properly.
So this thread is get @YingboMa @shashi or @ChrisRackauckas to explain that.

This might need a redesign of frrule again similar to solving #74.
Maybe we went too far there, since broadcasting the pushforward would presumably solve that case.

@ChrisRackauckas
Copy link
Member

It's just like reverse mode. The partials can be any thing, like an array. A nice way to handle chunking is to have the partials be a matrix, and then a lot of operations will naturally push forward Jv simultaneously where the v's are each column of the seed matrix.

I don't get what's broadcasting? I didn't know it was, but if it was automatically broadcasting then it shouldn't: exploiting the linear algebra is crucial to doing this fast, just like reverse mode.

@oxinabox
Copy link
Member Author

See broadcast:
#93

@oxinabox
Copy link
Member Author

Idea is that people want to concurrently propagate multiple sensitivity values for the same primal value.
Makes sense, since one can replace the directional derivative vector v in vjp / jvp with a matrix of column vectors each of which is a directional dervative.

And if you do this with a basis matrix, you get the jacobian in that basis.
Which for the standard basis I is just the explict jacobian, rather than an implict thing
It gets interesting if you do some other basis though.
Since a sparse jacobian is from a subspace of a the same space as the dense jacobian,
thus it is expressable with a smaller basis.
Thus the whole sparse diff thing is about detecting that sparsity and working out the basis one needs.


Now what does this mean in practical terms:
For some function y=f(x)
with sensitivity dx and dy

If those are just 1 derivative (which is the current case we think of)
then their shapish information of dy be matching against y.

If they are collections of derivatives, ie. dx = [dx1 dx2 dx3 ...]
then dy has outer shape matching the collection dx,
and inner shape matching to dy still.

Right now this or more or less works in practice,
if y and x are scalar or vector,
and this their differential types are also scalar or vector.

Once they are not. it gets harder.
And maybe we need to think about how to represent collections of differentials.

This is great

@willtebbutt
Copy link
Member

willtebbutt commented Jan 13, 2020

It seems to be that it should be fairly straightforward to generalise our existing framework to account for this.

I completely understand what @ChrisRackauckas is getting at. It would be helpful if you could specify the extent of your interface as it stands.

  • how are chunks defined for differentials of primal::Reals?
  • how are chunks defined for differentials of primal::Vector{<:Real}? (I'm assuming Matrix{<:Real}?)
  • how are chunks defined for primal::Matrix{<:Real} and more general primal::Arrays?
  • are there any other types of differentials that you support?

We'll need to come up with some interface that plays nicely with all of the above, and plays nicely with our Zero and Composite as well. Zero should be quite straightforward to extend, and the way in which we choose to implement the way in which we choose to extend Composite will presumably be constrained by the need to represent e.g. a chunk of differentials of primal::Vector{<:Real}s with a Matrix{<:Real}.

One other question @ChrisRackauckas @YingboMa @shashi -- assuming that you do indeed represent a chunk of differentials w.r.t. a primal::Vector{<:Real} as a Matrix{<:Real}, how do you discern between a Matrix{<:Real} that represents this, and one which is a single differential w.r.t. a Matrix{<:Real} primal? If you've not explicitly addressed this problem thus far, could you explain why it's not been problematic for you so far?

As with #91 , I'm keen to address this quickly so that we can all press forwards.

@oxinabox
Copy link
Member Author

Main reason to want this is to compute jacobians, or nontrivial parts of jacobians.
So when propagating directional deriviatives i.e. columns forwards.

One thing this does is rather than a loop of MatrixVector operations,
one hits the much faster Matrix
Matrix operations, and have good cache locality etc.

Further more though, becuase forward mode has fused pushforward. #74
one can't just computer the pushforward, and call it in a loop with different seeds,
without also redoing work to compute the primal.

where as its less important with seperate pullback in reverse mode, because then you don't redo the primal computation.
Still would be good though.

@YingboMa
Copy link
Member

When chunking, we have ndims(partials) == ndims(primal) + 1.

We don't force Vector or Matrix. We only use <:AbstractArray. I am thinking if we could have

Base.ndims(::Zero) = 0
# inferable `argmax(map(ndims, partials))`
@generated _argmax_ndims(partials) = :(partials[$(argmax(map(ndims, partials.parameters)))])
ChainRules.extern(::Zero, partials) = zero(_argmax_ndims(partials))

which gives

julia> extern(Zero(), (1, Zero(), [1 2; 3 4], [1, 2]))
2×2 Array{Int64,2}:
 0  0
 0  0

so Zero can still be strong.

@oxinabox
Copy link
Member Author

Can we be chunking some sensitivities and not chunking others?

@YingboMa
Copy link
Member

YingboMa commented Jan 15, 2020

FD2 doesn't do forward mode AD on multiple arguments. It is possible to do, but aware that when two chunks meet, the larger one wins.

@YingboMa
Copy link
Member

Actually, I don't think the above extern solves the actual problem with Zero. What do we do if we get multiple partials from frule? We need to know the actual computation that happened to Zero to know the correct size for partials.

@willtebbutt
Copy link
Member

When chunking, we have ndims(partials) == ndims(primal) + 1.

Makes sense.

We don't force Vector or Matrix. We only use <:AbstractArray. I am thinking if we could have

Also makes sense. Could you provide an example of where this behaviour is currently being exploited to accelerate inference? Do we have any frules that are already supporting this pattern?

Given our intention to drop extern as a thing, I was thinking something more along the lines of

struct ChunkedZero <: AbstractDifferential
   chunk_size::Int
end

Coupled with the primal, I'm pretty sure this provides enough information to know what to do. i.e. you've retained the chunk-size information, which is the only extra bit of information you need AFAICT.

As regards chunking for structured stuff, I don't know if there's really anything that we need to do -- presumably everything just work out recursively...

@YingboMa
Copy link
Member

Could you provide an example of where this behaviour is currently being exploited to accelerate inference?

accelerate inference?

Do we have any frules that are already supporting this pattern?

julia> using LinearAlgebra

julia> x = rand(3);

julia> dx = rand(3, 3);

julia> frule(BLAS.asum, x, Zero(), dx)
(0.7486352425663505, 5.584182515140222)

julia> sum(i->frule(BLAS.asum, x, Zero(), dx[:, i])[2], 1:3)
5.584182515140222

I was thinking something more along the lines of ...

Yes, I agree, ChunkedZero is a great idea.

@oxinabox
Copy link
Member Author

Would we end up with x*Zero() = ChunkedZero(ndims(x))

@willtebbutt
Copy link
Member

willtebbutt commented Jan 16, 2020

accelerate inference?

Sorry, had my probabilistic programming hat on by accident. I meant examples where chunked computations are done at the minute, and are faster than naively iterating over each e.g. column vector of differentials. The asum example you've given is sufficient -- I guess we just lucked out that that implementation happens to do the right thing already because broadcasting.

Would we end up with x*Zero() = ChunkedZero(ndims(x))

Sorry @oxinabox could you expand on this? I'm not quite sure what is meant.

@YingboMa
Copy link
Member

ndims(x) isn't the chunk size though. Do you mean axes(x, dims)? (so that it doesn't even assume 1-based indexing.)

@oxinabox
Copy link
Member Author

A question is how do we support chucked frule in the case of the author only writing a nonchunked frule.
One option is to just demand they always do it.

If we don't do that, another option is: we have:

  • non-fused frule that returns a pushforward
  • fused frule that does not, but falls back to calling the nonfused version if not defined.
  • chunked frule which if not defined first tries to get a pullback from non-fused frule and call that in a loop, but if that is not defined it falls back to looping ove the fused frule.

This seems intense though.

This is a motivation for having both fused and nonfused frule,
though I admit not the only one potentially might want for #102

@oxinabox
Copy link
Member Author

I think we just don't want to support nonchunked ffule.
Or rather its not out problem.
If the user type constrains it to not accept chunks it can fallback to the AD,
or the users code can be corrected.
We should document this once we have the story straight

@oxinabox
Copy link
Member Author

oxinabox commented Jan 19, 2020

One challenge is supporting chunking on functions that don’t use their inputs, e.g. zero arg constructors.
I think they can just encode their chunkyness by passing a arrays of dself)

This is needed to support forward mode mutation
its different from the Zero() case (though similar), in that its not a case of "You don't need any"
They need as much space as they need.
Though potentially we can just push this problem up to the AD.

@ettersi
Copy link
Contributor

ettersi commented Jul 14, 2020

Has there been a conclusion on this?

@oxinabox
Copy link
Member Author

Not really.
We still only have scalar primal, + vector sensitivity supported for chunked forward mode.

@ettersi
Copy link
Contributor

ettersi commented Jul 14, 2020

Ok. Here are my two cents.

From the discussion in JuliaDiff/ChainRules.jl#232, it is becoming increasingly clear to me that it would probably be good if every Julia type X had a (documented) associated differential type DX. Some examples:

  • X <: Any -> DX == Composite{X}
  • X <: Number -> DX == float(X)
  • X <: AbstractArray -> DX == AbstractArray
  • X == Diagonal -> DX <: Diagonal

This scheme can then straightforwardly be extended to include chunking: every type should also specify a chunked differential type CDX, and guarantee that frule and rrule handle these types correctly. Continuing the above examples:

  • X <: Any -> CDX == VectorOrTuple{Composite{X}}
  • X <: Number -> CDX == VectorOrTuple{float(X)}
  • X <: AbstractArray{T,N} -> CDX == AbstractArray{T,N+1}
  • X == Diagonal -> CDX == [Dunno, should be worked out]

I further agree with @oxinabox that frule and rrule should implement only the chunked interface to avoid implementing each rule twice.

@mcabbott mcabbott added the design Requires some desgin before changes are made label Jul 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Requires some desgin before changes are made forward-mode Related to use of ChainRules for ForwardMode AD
Projects
None yet
Development

No branches or pull requests

7 participants