Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: projector implementation (returning a closure) #382

Closed
wants to merge 15 commits into from

Conversation

mzgubic
Copy link
Member

@mzgubic mzgubic commented Jun 24, 2021

Alternative to #380.

Some observations (comments welcome):

  • while there is some repetition, this was less fiddly to write and test than WIP: project implementation #380. Hopefully it is also easier to read.
  • returning the closure means we don't have to carry the primal type in the pullback anymore (apart from in the fallback case). cc @CarloLucibello
  • we can compose these, see e.g. projector(::Diagonal), where we create a projV to project the vector representing the diagonal. cc @mcabbott

I imagine we also want to add some extra requirements to ChainRulesTestUtils, to make sure that the rules always return a correct differential. Something like _is_appropriate(primal, tangent) where the unthunk(tangent) must either be the same type as the primal, a Tangent, or an AbstractZero.

@codecov-commenter
Copy link

codecov-commenter commented Jun 24, 2021

Codecov Report

Merging #382 (7801e19) into master (4b1a2f6) will increase coverage by 0.65%.
The diff coverage is 97.61%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #382      +/-   ##
==========================================
+ Coverage   89.12%   89.78%   +0.65%     
==========================================
  Files          14       15       +1     
  Lines         561      607      +46     
==========================================
+ Hits          500      545      +45     
- Misses         61       62       +1     
Impacted Files Coverage Δ
src/ChainRulesCore.jl 100.00% <ø> (ø)
src/projection.jl 97.43% <97.43%> (ø)
src/differentials/abstract_zero.jl 87.50% <100.00%> (+0.83%) ⬆️
src/differentials/thunks.jl 95.00% <100.00%> (+0.31%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4b1a2f6...7801e19. Read the comment docs.

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems simpler than I thought, maybe something like this is the way to go. I left some comments...

Comment on lines +52 to +54
# Tangent
function projector(::Type{<:Tangent}, x::T) where {T}
project(dx) = Tangent{T}(; ((k, getproperty(dx, k)) for k in fieldnames(T))...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still not clear to me what's going to call this. Clearly we will not have x::Tangent in the forward pass. So this thing is perhaps trying to serve several functions, and perhaps they can be clarified.

Comment on lines +29 to +32
project(dx::AbstractZero) = zero(x)
project(dx::AbstractThunk) = project(unthunk(dx))
return project
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there should be some struct Project which is returned, in part to avoid writing these out every time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify how this would work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +26 to +28
function projector(::Type{T}, x::T) where {T<:Real}
project(dx::Real) = T(dx)
project(dx::Number) = T(real(dx)) # to avoid InexactError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is too tight, as projector(2)(3.5) is going to be an InexactError right? As is projector(false)(1.5).

And more generally, what if (say) I want to put dual numbers into the pullback? My impression is that that should be allowed. Which is what led me to think that only known problems should be projected out, like dx::Complex when x::Real, or anything when x::Bool. But it would be nice if the door were open for packages to add to the list of "things which get projected like Complex -> Real".

Copy link
Member Author

@mzgubic mzgubic Jun 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that sounds like a relatively serious downside

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to have made it into the tagged version:

julia> ProjectTo(1)(2.5)
ERROR: InexactError: Int64(2.5)

(jl_5kFIPa) pkg> st ChainRulesCore
      Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_5kFIPa/Project.toml`
  [d360d2e6] ChainRulesCore v0.10.11

function projector(::Type{<:Symmetric{<:Any, M}}, x::Symmetric) where {M}
projM = projector(M, parent(x))
uplo = Symbol(x.uplo)
project(dx::AbstractMatrix) = Symmetric(projM(dx), uplo)
Copy link
Member

@mcabbott mcabbott Jun 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is right, you need to symmetrise, not merely to apply the wrapper.

There's a fairly efficient one here:
https://github.com/FluxML/Zygote.jl/pull/965/files#diff-9bc4a61f220da7bc58a4009fe88887b5b584b3d6139c68b0e13cbdbcd21f7289R48

Comment on lines +41 to +45
function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N}
sizex = size(x)
projT = projector(zero(T))
project(dx::Array{T, N}) = dx # identity
project(dx::AbstractArray) = project(collect(dx)) # from Diagonal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I also wonder if this is the right behaviour. Maybe the ability to reproduce a similar dense array is desirable sometimes, but making the default projector materialise when it doesn't have to seems odd --- shouldn't we preserve Diagonal or Fill backwards as many steps as possible, by default?

But again maybe this is trying to serve multiple purposes which perhaps can be clarified.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there ought to be abstract types involved, something like:

projector(x::Real) = projector(Real, x)
projector(x::Bool) = projector(Nothing, x)

projector(x::AbstractArray{<:Real}) = projector(AbstractArray{Real}, x)
projector(x::AbstractArray) = projector(AbstractArray, x)

where projector(AbstractArray, x)(dx) may reshape but won't do more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the method which specifically wants the output to be a dense array, i.e. where x is a Matrix in projector(x) call. When x is a Diagonal, a different projector method would be hit.

I couldn't quite see how to generalise the method for an arbitrary AbstractArray (see how Diagonal and Symmetric) cases are different. My plan was to just add the dispatch for any type that we need to make ChainRules rules work.

@mzgubic
Copy link
Member Author

mzgubic commented Jul 6, 2021

closed in favour of #385

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants