Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: projector implementation (returning a closure) #382

Closed
wants to merge 15 commits into from
3 changes: 2 additions & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @scalar_rule, @thunk, @not_implemented
export canonicalize, extern, unthunk # differential operations
export canonicalize, extern, unthunk, projector # differential operations
export add!! # gradient accumulation operations
# differentials
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
Expand All @@ -26,6 +26,7 @@ include("differentials/notimplemented.jl")

include("differential_arithmetic.jl")
include("accumulation.jl")
include("projection.jl")

include("config.jl")
include("rules.jl")
Expand Down
1 change: 1 addition & 0 deletions src/differentials/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Base.:/(z::AbstractZero, ::Any) = z
Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T)

Base.getindex(z::AbstractZero, k) = z
Base.getproperty(z::AbstractZero, f::Symbol) = z

"""
ZeroTangent() <: AbstractZero
Expand Down
4 changes: 4 additions & 0 deletions src/differentials/thunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ end

@inline unthunk(x::Thunk) = x.f()

Base.getproperty(a::Thunk, f::Symbol) = f === :f ? getfield(a, f) : getproperty(unthunk(a), f)

Base.show(io::IO, x::Thunk) = print(io, "Thunk($(repr(x.f)))")

"""
Expand All @@ -209,6 +211,8 @@ end

unthunk(x::InplaceableThunk) = unthunk(x.val)

Base.getproperty(a::InplaceableThunk, f::Symbol) = f in (:val, :add!) ? getfield(a, f) : getproperty(unthunk(a), f)

function Base.show(io::IO, x::InplaceableThunk)
return print(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")
end
77 changes: 77 additions & 0 deletions src/projection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using LinearAlgebra: Diagonal, diag

"""
projector([T::Type], x)

Returns a `project(dx)` closure which maps `dx` onto type `T`, such that it is the
same size as `x`. If `T` is not provided, it is assumed to be the type of `x`.

It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s
onto `Array`s -- this wouldn't be possible with type information alone because the neither
`AbstractZero`s nor `T` know what size of `Array` to produce.
"""
function projector end

projector(x) = projector(typeof(x), x)

# fallback (structs)
function projector(::Type{T}, x::T) where T
project(dx::T) = dx
project(dx::AbstractZero) = zero(x)
project(dx::AbstractThunk) = project(unthunk(dx))
return project
end

# Numbers
function projector(::Type{T}, x::T) where {T<:Real}
project(dx::Real) = T(dx)
project(dx::Number) = T(real(dx)) # to avoid InexactError
Comment on lines +26 to +28
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is too tight, as projector(2)(3.5) is going to be an InexactError right? As is projector(false)(1.5).

And more generally, what if (say) I want to put dual numbers into the pullback? My impression is that that should be allowed. Which is what led me to think that only known problems should be projected out, like dx::Complex when x::Real, or anything when x::Bool. But it would be nice if the door were open for packages to add to the list of "things which get projected like Complex -> Real".

Copy link
Member Author

@mzgubic mzgubic Jun 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that sounds like a relatively serious downside

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to have made it into the tagged version:

julia> ProjectTo(1)(2.5)
ERROR: InexactError: Int64(2.5)

(jl_5kFIPa) pkg> st ChainRulesCore
      Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_5kFIPa/Project.toml`
  [d360d2e6] ChainRulesCore v0.10.11

project(dx::AbstractZero) = zero(x)
project(dx::AbstractThunk) = project(unthunk(dx))
return project
end
Comment on lines +29 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there should be some struct Project which is returned, in part to avoid writing these out every time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify how this would work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function projector(::Type{T}, x::T) where {T<:Number}
project(dx::Number) = T(dx)
project(dx::AbstractZero) = zero(x)
project(dx::AbstractThunk) = project(unthunk(dx))
return project
end

# Arrays
function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N}
sizex = size(x)
projT = projector(zero(T))
project(dx::Array{T, N}) = dx # identity
project(dx::AbstractArray) = project(collect(dx)) # from Diagonal
Comment on lines +41 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I also wonder if this is the right behaviour. Maybe the ability to reproduce a similar dense array is desirable sometimes, but making the default projector materialise when it doesn't have to seems odd --- shouldn't we preserve Diagonal or Fill backwards as many steps as possible, by default?

But again maybe this is trying to serve multiple purposes which perhaps can be clarified.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there ought to be abstract types involved, something like:

projector(x::Real) = projector(Real, x)
projector(x::Bool) = projector(Nothing, x)

projector(x::AbstractArray{<:Real}) = projector(AbstractArray{Real}, x)
projector(x::AbstractArray) = projector(AbstractArray, x)

where projector(AbstractArray, x)(dx) may reshape but won't do more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the method which specifically wants the output to be a dense array, i.e. where x is a Matrix in projector(x) call. When x is a Diagonal, a different projector method would be hit.

I couldn't quite see how to generalise the method for an arbitrary AbstractArray (see how Diagonal and Symmetric) cases are different. My plan was to just add the dispatch for any type that we need to make ChainRules rules work.

project(dx::Array) = projT.(dx) # from different element type
project(dx::AbstractZero) = zeros(T, sizex...)
project(dx::AbstractThunk) = project(unthunk(dx))
return project
end

# Tangent
function projector(::Type{<:Tangent}, x::T) where {T}
project(dx) = Tangent{T}(; ((k, getproperty(dx, k)) for k in fieldnames(T))...)
Comment on lines +52 to +54
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still not clear to me what's going to call this. Clearly we will not have x::Tangent in the forward pass. So this thing is perhaps trying to serve several functions, and perhaps they can be clarified.

return project
end

# Diagonal
function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal) where {V}
projV = projector(V, diag(x))
project(dx::AbstractMatrix) = Diagonal(projV(diag(dx)))
project(dx::Tangent) = Diagonal(projV(dx.diag))
project(dx::AbstractZero) = Diagonal(projV(dx))
project(dx::AbstractThunk) = project(unthunk(dx))
return project
end

# Symmetric
function projector(::Type{<:Symmetric{<:Any, M}}, x::Symmetric) where {M}
projM = projector(M, parent(x))
uplo = Symbol(x.uplo)
project(dx::AbstractMatrix) = Symmetric(projM(dx), uplo)
Copy link
Member

@mcabbott mcabbott Jun 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is right, you need to symmetrise, not merely to apply the wrapper.

There's a fairly efficient one here:
https://github.com/FluxML/Zygote.jl/pull/965/files#diff-9bc4a61f220da7bc58a4009fe88887b5b584b3d6139c68b0e13cbdbcd21f7289R48

project(dx::Tangent) = Symmetric(projM(dx.data), uplo)
project(dx::AbstractZero) = Symmetric(projM(dx), uplo)
project(dx::AbstractThunk) = project(unthunk(dx))
return project
end
114 changes: 114 additions & 0 deletions test/projection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
struct Fred
a::Float64
end

Base.zero(::Fred) = Fred(0.0)
Base.zero(::Type{Fred}) = Fred(0.0)

@testset "projection" begin
@testset "fallback" begin
@test Fred(1.2) == projector(Fred(3.2))(Fred(1.2))
@test Fred(0.0) == projector(Fred(3.2))(ZeroTangent())
@test Fred(3.2) == projector(Fred(-0.2))(@thunk(Fred(3.2)))
end

@testset "to Real" begin
# Float64
@test 3.2 == projector(1.0)(3.2)
@test 0.0 == projector(1.1)(ZeroTangent())
@test 3.2 == projector(1.0)(@thunk(3.2))

# down
@test 3.2 == projector(1.0)(3.2 + 3im)
@test 3.2f0 == projector(1.0f0)(3.2)
@test 3.2f0 == projector(1.0f0)(3.2 - 3im)

# up
@test 2.0 == projector(2.0)(2.0f0)
end

@testset "to Number" begin
# Complex
@test 2.0 + 0.0im == projector(1.0im)(2.0 + 0.0im)

# down
@test 2.0 + 0.0im == projector(1.0im)(2.0)
@test 0.0 + 0.0im == projector(1.0im)(ZeroTangent())
@test 0.0 + 0.0im == projector(1.0im)(@thunk(ZeroTangent()))

# up
@test 2.0 + 0.0im == projector(2.0 + 1.0im)(2.0)
end

@testset "to Array" begin
# to an array of numbers
@test [1.0 2.0; 3.0 4.0] == projector(zeros(2, 2))([1.0 2.0; 3.0 4.0])
@test zeros(2, 2) == projector([1.0 2; 3 4])(ZeroTangent())
@test zeros(2) == projector([1.0, 2.0])(@thunk(ZeroTangent()))
@test [1.0f0 2; 3 4] == projector(zeros(Float32, 2, 2))([1.0 2; 3 4])
@test [1.0 0; 0 4] == projector(zeros(2, 2))(Diagonal([1.0, 4]))

# to a array of structs
@test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])([Fred(0.0), Fred(0.0)])
@test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])([ZeroTangent(), ZeroTangent()])
@test [Fred(0.0), Fred(3.2)] == projector([Fred(0.0), Fred(0.0)])([ZeroTangent(), @thunk(Fred(3.2))])
@test [Fred(0.0), Fred(0.0)] == projector([Fred(1.0), Fred(2.0)])(ZeroTangent())
@test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])(@thunk(ZeroTangent()))
diagfreds = [Fred(1.0) Fred(0.0); Fred(0.0) Fred(4.0)]
@test diagfreds == projector(diagfreds)(Diagonal([Fred(1.0), Fred(4.0)]))
end

@testset "to Tangent" begin
@test Tangent{Fred}(; a = 3.2,) == projector(Tangent, Fred(3.2))(Fred(3.2))
@test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(ZeroTangent())
@test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(@thunk(ZeroTangent()))

@test projector(Tangent, Diagonal(zeros(2)))(Diagonal([1.0f0, 2.0f0])) isa Tangent
@test projector(Tangent, Diagonal(zeros(2)))(ZeroTangent()) isa Tangent
@test projector(Tangent, Diagonal(zeros(2)))(@thunk(ZeroTangent())) isa Tangent
end

@testset "to Diagonal" begin
d_F64 = Diagonal([0.0, 0.0])
d_F32 = Diagonal([0.0f0, 0.0f0])
d_C64 = Diagonal([0.0 + 0im, 0.0])
d_Fred = Diagonal([Fred(0.0), Fred(0.0)])

# from Matrix
@test d_F64 == projector(d_F64)(zeros(2, 2))
@test d_F64 == projector(d_F64)(zeros(Float32, 2, 2))
@test d_F64 == projector(d_F64)(zeros(ComplexF64, 2, 2))

# from Diagonal of Numbers
@test d_F64 == projector(d_F64)(d_F64)
@test d_F64 == projector(d_F64)(d_F32)
@test d_F64 == projector(d_F64)(d_C64)

# from Diagonal of AbstractTangent
@test d_F64 == projector(d_F64)(ZeroTangent())
@test d_C64 == projector(d_C64)(ZeroTangent())
@test d_F64 == projector(d_F64)(@thunk(ZeroTangent()))
@test d_F64 == projector(d_F64)(Diagonal([ZeroTangent(), ZeroTangent()]))
@test d_F64 == projector(d_F64)(Diagonal([ZeroTangent(), @thunk(ZeroTangent())]))

# from Diagonal of structs
@test d_Fred == projector(d_Fred)(ZeroTangent())
@test d_Fred == projector(d_Fred)(@thunk(ZeroTangent()))
@test d_Fred == projector(d_Fred)(Diagonal([ZeroTangent(), ZeroTangent()]))

# from Tangent
@test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[0.0, 0.0]))
@test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[0.0f0, 0.0f0]))
@test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())]))
end

@testset "to Symmetric" begin
data = [1.0 2; 3 4]
@test Symmetric(data) == projector(Symmetric(data))(data)
@test Symmetric(data, :L) == projector(Symmetric(data, :L))(data)
@test Symmetric(Diagonal(data)) == projector(Symmetric(data))(Diagonal(diag(data)))

@test Symmetric(zeros(2, 2)) == projector(Symmetric(data))(ZeroTangent())
@test Symmetric(zeros(2, 2)) == projector(Symmetric(data))(@thunk(ZeroTangent()))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using Test
end

include("accumulation.jl")
include("projection.jl")

include("rules.jl")
include("rule_definition_tools.jl")
Expand Down