From 914bd92700b609b40d13eff98f51df423c8073cd Mon Sep 17 00:00:00 2001 From: WT Date: Wed, 24 Feb 2021 22:57:10 +0000 Subject: [PATCH 01/15] Sketch project implementation --- src/ChainRulesCore.jl | 1 + src/projection.jl | 55 +++++++++++++++++++++++++++++++++++++++++++ test/projection.jl | 3 +++ test/runtests.jl | 1 + 4 files changed, 60 insertions(+) create mode 100644 src/projection.jl create mode 100644 test/projection.jl diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index e3bd53deb..5156a83fd 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -26,6 +26,7 @@ include("differentials/notimplemented.jl") include("differential_arithmetic.jl") include("accumulation.jl") +include("projection.jl") include("config.jl") include("rules.jl") diff --git a/src/projection.jl b/src/projection.jl new file mode 100644 index 000000000..51cce00c3 --- /dev/null +++ b/src/projection.jl @@ -0,0 +1,55 @@ +using LinearAlgebra: Diagonal, diag + +""" + project(T::Type, x, dx) + +"project" `dx` onto type `T` such that it is the same size as `x`. + +It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s +onto `Array`s -- this wouldn't be possible with type information alone because the neither +`AbstractZero`s nor `T` know what size of `Array` to produce. +""" +function project end + +# Number-types +project(::Type{T}, x::T, dx::T) where {T<:Real} = dx + +project(::Type{T}, x::T, dx::AbstractZero) where {T<:Real} = zero(x) + +project(::Type{T}, x::T, dx::AbstractThunk) where {T<:Real} = project(x, unthunk(dx)) + + + +# Arrays +project(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = dx + +project(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = project.(Ref(T), x, dx) + +function project(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array} + return project(T, x, collect(dx)) +end + +function project(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T} + return project.(Ref(T), x, Ref(dx)) +end + + + +# Diagonal +function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V} + return Diagonal(project(V, diag(x), diag(dx))) +end + +function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Composite) where {V} + return Diagonal(project(V, diag(x), dx.diag)) +end + +function project(::Type{<:Composite}, x::Diagonal, dx::Diagonal) + return Composite{typeof(x)}(diag=diag(dx)) +end + + + +# One use for this functionality is to make it easy to define addition between two different +# representations of the same tangent. This also makes it clear that the +Base.:(+)(x::Composite{<:Diagonal}, y::Diagonal) = x + project(typeof(x), x, y) diff --git a/test/projection.jl b/test/projection.jl new file mode 100644 index 000000000..25b38fc1d --- /dev/null +++ b/test/projection.jl @@ -0,0 +1,3 @@ +@testset "projection" begin + +end diff --git a/test/runtests.jl b/test/runtests.jl index 090a85828..f98e971b2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,7 @@ using Test end include("accumulation.jl") + include("projection.jl") include("rules.jl") include("rule_definition_tools.jl") From 06678a4f64e222b5ca352e2600bae6fde439b253 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 11:54:38 +0100 Subject: [PATCH 02/15] change Composite to Tangent --- src/projection.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 51cce00c3..ab77978ef 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -40,16 +40,16 @@ function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) return Diagonal(project(V, diag(x), diag(dx))) end -function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Composite) where {V} +function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Tangent) where {V} return Diagonal(project(V, diag(x), dx.diag)) end -function project(::Type{<:Composite}, x::Diagonal, dx::Diagonal) - return Composite{typeof(x)}(diag=diag(dx)) +function project(::Type{<:Tangent}, x::Diagonal, dx::Diagonal) + return Tangent{typeof(x)}(diag=diag(dx)) end # One use for this functionality is to make it easy to define addition between two different # representations of the same tangent. This also makes it clear that the -Base.:(+)(x::Composite{<:Diagonal}, y::Diagonal) = x + project(typeof(x), x, y) +Base.:(+)(x::Tangent{<:Diagonal}, y::Diagonal) = x + project(typeof(x), x, y) From c58f974909386311fa16d2dbc8b3eb84717504e9 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 11:54:57 +0100 Subject: [PATCH 03/15] export project --- src/ChainRulesCore.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 5156a83fd..7ba0570ab 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @scalar_rule, @thunk, @not_implemented -export canonicalize, extern, unthunk # differential operations +export canonicalize, extern, unthunk, project # differential operations export add!! # gradient accumulation operations # differentials export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk From 00020e3d00afd1b9654552164d7e023893e6f25b Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 11:55:24 +0100 Subject: [PATCH 04/15] make T optional --- src/projection.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index ab77978ef..c9fa6b14f 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,9 +1,10 @@ using LinearAlgebra: Diagonal, diag """ - project(T::Type, x, dx) + project([T::Type], x, dx) -"project" `dx` onto type `T` such that it is the same size as `x`. +"project" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, +it is assumed to be the type of `x`. It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s onto `Array`s -- this wouldn't be possible with type information alone because the neither @@ -11,6 +12,8 @@ onto `Array`s -- this wouldn't be possible with type information alone because t """ function project end +project(x, dx) = project(typeof(x), x, dx) + # Number-types project(::Type{T}, x::T, dx::T) where {T<:Real} = dx From 37f9253a767909e441750415e4c818f806dc8d07 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 12:05:01 +0100 Subject: [PATCH 05/15] add tests and Complex --- src/projection.jl | 2 ++ test/projection.jl | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/src/projection.jl b/src/projection.jl index c9fa6b14f..e8088b8c4 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -17,6 +17,8 @@ project(x, dx) = project(typeof(x), x, dx) # Number-types project(::Type{T}, x::T, dx::T) where {T<:Real} = dx +project(::Type{T}, x::T, dx::Complex) where {T<:Real} = real(dx) + project(::Type{T}, x::T, dx::AbstractZero) where {T<:Real} = zero(x) project(::Type{T}, x::T, dx::AbstractThunk) where {T<:Real} = project(x, unthunk(dx)) diff --git a/test/projection.jl b/test/projection.jl index 25b38fc1d..1fbfeeccc 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -1,3 +1,10 @@ @testset "projection" begin + @testset "Number types" begin + @test 3.2 == project(1.0, 3.2) + @test 3.2 == project(1.0, 3.2 + 3im) + @test 3.2f0 == project(Float32, 1.0f0, 3.2 - 3im) + @test 0.0 == project(1.1, ZeroTangent()) + @test 3.2 == project(1.0, @thunk(3.2)) + end end From 4e1b79d19d8eae25ab3245bc745139bfc8c5e26f Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 15:50:24 +0100 Subject: [PATCH 06/15] workout the edge cases --- src/projection.jl | 30 +++++++++-------- test/projection.jl | 81 +++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 92 insertions(+), 19 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index e8088b8c4..ff5a8d101 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -14,47 +14,49 @@ function project end project(x, dx) = project(typeof(x), x, dx) -# Number-types -project(::Type{T}, x::T, dx::T) where {T<:Real} = dx +# identity +project(::Type{T}, x::T, dx::T) where T = dx -project(::Type{T}, x::T, dx::Complex) where {T<:Real} = real(dx) +### AbstractZero +project(::Type{T}, x::T, dx::AbstractZero) where T = zero(x) -project(::Type{T}, x::T, dx::AbstractZero) where {T<:Real} = zero(x) +### AbstractThunk +project(::Type{T}, x::T, dx::AbstractThunk) where T = project(x, unthunk(dx)) -project(::Type{T}, x::T, dx::AbstractThunk) where {T<:Real} = project(x, unthunk(dx)) +### Number-types +project(::Type{T}, x::T, dx::T2) where {T<:Number, T2<:Number} = T(dx) +project(::Type{T}, x::T, dx::Complex) where {T<:Real} = T(real(dx)) # Arrays project(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = dx +# for project([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) project(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = project.(Ref(T), x, dx) +# for project(rand(2, 2), Diagonal(rand(2))) function project(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array} return project(T, x, collect(dx)) end +# for project([Foo(0.0), Foo(0.0)], ZeroTangent()) function project(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T} return project.(Ref(T), x, Ref(dx)) end - -# Diagonal +## Diagonal function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V} return Diagonal(project(V, diag(x), diag(dx))) end - function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Tangent) where {V} return Diagonal(project(V, diag(x), dx.diag)) end +function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractZero) where {V} + return Diagonal(project(V, diag(x), dx)) +end function project(::Type{<:Tangent}, x::Diagonal, dx::Diagonal) return Tangent{typeof(x)}(diag=diag(dx)) end - - - -# One use for this functionality is to make it easy to define addition between two different -# representations of the same tangent. This also makes it clear that the -Base.:(+)(x::Tangent{<:Diagonal}, y::Diagonal) = x + project(typeof(x), x, y) diff --git a/test/projection.jl b/test/projection.jl index 1fbfeeccc..68308c396 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -1,10 +1,81 @@ +struct Foo + a::Float64 +end + +Base.zero(::Foo) = Foo(0.0) +Base.zero(::Type{Foo}) = "F0" + @testset "projection" begin - @testset "Number types" begin - @test 3.2 == project(1.0, 3.2) - @test 3.2 == project(1.0, 3.2 + 3im) - @test 3.2f0 == project(Float32, 1.0f0, 3.2 - 3im) - @test 0.0 == project(1.1, ZeroTangent()) + #identity + @test Foo(1.2) == project(Foo(-0.2), Foo(1.2)) + @test 3.2 == project(1.0, 3.2) + @test 2.0 + 0.0im == project(1.0im, 2.0) + + @testset "From AbstractZero" begin + @testset "to numbers" begin + @test 0.0 == project(1.1, ZeroTangent()) + @test 0.0f0 == project(1.1f0, ZeroTangent()) + end + + @testset "to arrays (dense and structured)" begin + @test zeros(2, 2) == project([1.0 2; 3 4], ZeroTangent()) + @test Diagonal(zeros(2)) == project(Diagonal([1.0, 4]), ZeroTangent()) + @test Diagonal(zeros(ComplexF64, 2)) == project(Diagonal([1.0 + 0im, 4]), ZeroTangent()) + end + + @testset "to structs" begin + @test Foo(0.0) == project(Foo(3.2), ZeroTangent()) + end + + @testset "to arrays of structs" begin + @test [Foo(0.0), Foo(0.0)] == project([Foo(0.0), Foo(0.0)], ZeroTangent()) + @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), ZeroTangent()) + end + end + + @testset "From AbstractThunk" begin @test 3.2 == project(1.0, @thunk(3.2)) + @test Foo(3.2) == project(Foo(-0.2), @thunk(Foo(3.2))) + @test zeros(2) == project([1.0, 2.0], @thunk(ZeroTangent())) + @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), @thunk(ZeroTangent())) + end + + @testset "To number types" begin + @testset "to subset" begin + @test 3.2 == project(1.0, 3.2 + 3im) + @test 3.2f0 == project(1.0f0, 3.2) + @test 3.2f0 == project(1.0f0, 3.2 - 3im) + end + + @testset "to superset" begin + @test 2.0 + 0.0im == project(2.0 + 1.0im, 2.0) + @test 2.0 == project(2.0, 2.0f0) + end + end + + @testset "To Arrays" begin + # change eltype + @test [1.0 2.0; 3.0 4.0] == project(zeros(2, 2), [1.0 2.0; 3.0 4.0]) + @test [1.0f0 2; 3 4] == project(zeros(Float32, 2, 2), [1.0 2; 3 4]) + + # from a structured array + @test [1.0 0; 0 4] == project(zeros(2, 2), Diagonal([1.0, 4])) + + # from an array of specials + @test [Foo(0.0), Foo(0.0)] == project([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) end + + @testset "Diagonal" begin + d = Diagonal([1.0, 4.0]) + t = Tangent{Diagonal}(;diag=[1.0, 4.0]) + @test d == project(d, [1.0 2; 3 4]) + @test d == project(d, t) + @test project(Tangent, d, d) isa Tangent + + @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), Diagonal([ZeroTangent(), ZeroTangent()])) + @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), @thunk(ZeroTangent())) + end + + # how to project to Upper/Lower Symmetric end From 7dc58ee3457830552c6adbadcc47deb269e5615e Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 18:10:27 +0100 Subject: [PATCH 07/15] rename dummy struct --- test/projection.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/projection.jl b/test/projection.jl index 68308c396..030808195 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -1,14 +1,14 @@ -struct Foo +struct Fred a::Float64 end -Base.zero(::Foo) = Foo(0.0) -Base.zero(::Type{Foo}) = "F0" +Base.zero(::Fred) = Fred(0.0) +Base.zero(::Type{Fred}) = "F0" @testset "projection" begin #identity - @test Foo(1.2) == project(Foo(-0.2), Foo(1.2)) + @test Fred(1.2) == project(Fred(-0.2), Fred(1.2)) @test 3.2 == project(1.0, 3.2) @test 2.0 + 0.0im == project(1.0im, 2.0) @@ -25,20 +25,20 @@ Base.zero(::Type{Foo}) = "F0" end @testset "to structs" begin - @test Foo(0.0) == project(Foo(3.2), ZeroTangent()) + @test Fred(0.0) == project(Fred(3.2), ZeroTangent()) end @testset "to arrays of structs" begin - @test [Foo(0.0), Foo(0.0)] == project([Foo(0.0), Foo(0.0)], ZeroTangent()) - @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), ZeroTangent()) + @test [Fred(0.0), Fred(0.0)] == project([Fred(0.0), Fred(0.0)], ZeroTangent()) + @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), ZeroTangent()) end end @testset "From AbstractThunk" begin @test 3.2 == project(1.0, @thunk(3.2)) - @test Foo(3.2) == project(Foo(-0.2), @thunk(Foo(3.2))) + @test Fred(3.2) == project(Fred(-0.2), @thunk(Fred(3.2))) @test zeros(2) == project([1.0, 2.0], @thunk(ZeroTangent())) - @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), @thunk(ZeroTangent())) + @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), @thunk(ZeroTangent())) end @testset "To number types" begin @@ -63,7 +63,7 @@ Base.zero(::Type{Foo}) = "F0" @test [1.0 0; 0 4] == project(zeros(2, 2), Diagonal([1.0, 4])) # from an array of specials - @test [Foo(0.0), Foo(0.0)] == project([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) + @test [Fred(0.0), Fred(0.0)] == project([Fred(0.0), Fred(0.0)], [ZeroTangent(), ZeroTangent()]) end @testset "Diagonal" begin @@ -73,8 +73,8 @@ Base.zero(::Type{Foo}) = "F0" @test d == project(d, t) @test project(Tangent, d, d) isa Tangent - @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), Diagonal([ZeroTangent(), ZeroTangent()])) - @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), @thunk(ZeroTangent())) + @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), Diagonal([ZeroTangent(), ZeroTangent()])) + @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), @thunk(ZeroTangent())) end # how to project to Upper/Lower Symmetric From 3345ba91778286d9f560f6fbb7ec472bcb7767f4 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 23 Jun 2021 18:23:50 +0100 Subject: [PATCH 08/15] rename project to projector --- src/projection.jl | 60 +++++++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index ff5a8d101..340f5ca46 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,62 +1,66 @@ using LinearAlgebra: Diagonal, diag """ - project([T::Type], x, dx) + projector([T::Type], x, dx) -"project" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, +"projector" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, it is assumed to be the type of `x`. -It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s +It's necessary to have `x` to ensure that it's possible to projector e.g. `AbstractZero`s onto `Array`s -- this wouldn't be possible with type information alone because the neither `AbstractZero`s nor `T` know what size of `Array` to produce. -""" -function project end +""" # TODO change docstring to reflect projecor returns a closure +function projector end -project(x, dx) = project(typeof(x), x, dx) +projector(x, dx) = projector(typeof(x), x, dx) # identity -project(::Type{T}, x::T, dx::T) where T = dx +projector(::Type{T}, x::T, dx::T) where T = identity ### AbstractZero -project(::Type{T}, x::T, dx::AbstractZero) where T = zero(x) +projector(::Type{T}, x::T, dx::AbstractZero) where T = _ -> zero(x) ### AbstractThunk -project(::Type{T}, x::T, dx::AbstractThunk) where T = project(x, unthunk(dx)) +projector(::Type{T}, x::T, dx::AbstractThunk) where T = projector(x, unthunk(dx)) ### Number-types -project(::Type{T}, x::T, dx::T2) where {T<:Number, T2<:Number} = T(dx) -project(::Type{T}, x::T, dx::Complex) where {T<:Real} = T(real(dx)) +projector(::Type{T}, x::T, dx::T2) where {T<:Number, T2<:Number} = dx -> T(dx) +projector(::Type{T}, x::T, dx::Complex) where {T<:Real} = dx -> T(real(dx)) # Arrays -project(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = dx +projector(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = identity -# for project([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) -project(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = project.(Ref(T), x, dx) +# for projector([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) +projector(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = projector.(Ref(T), x, dx) # TODO -# for project(rand(2, 2), Diagonal(rand(2))) -function project(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array} - return project(T, x, collect(dx)) +# for projector(rand(2, 2), Diagonal(rand(2))) +function projector(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array} + return projector(T, x, collect(dx)) end -# for project([Foo(0.0), Foo(0.0)], ZeroTangent()) -function project(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T} - return project.(Ref(T), x, Ref(dx)) +# for projector([Foo(0.0), Foo(0.0)], ZeroTangent()) +function projector(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T} + return projector.(Ref(T), x, Ref(dx)) # TODO end ## Diagonal -function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V} - return Diagonal(project(V, diag(x), diag(dx))) +function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V} + d = diag(x) + return dx -> Diagonal(projector(V, d, diag(dx))) end -function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Tangent) where {V} - return Diagonal(project(V, diag(x), dx.diag)) +function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Tangent) where {V} + d = diag(x) + return dx -> Diagonal(projector(V, d, dx.diag)) end -function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractZero) where {V} - return Diagonal(project(V, diag(x), dx)) +function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractZero) where {V} + d = diag(x) + return dx -> Diagonal(projector(V, d, dx)) end -function project(::Type{<:Tangent}, x::Diagonal, dx::Diagonal) - return Tangent{typeof(x)}(diag=diag(dx)) +function projector(::Type{<:Tangent}, x::Diagonal, dx::Diagonal) + T = typeof(x) + return dx -> Tangent{T}(diag=diag(dx)) end From 31d81edc819ce0c1967eafcd124b5a0389487395 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 10:16:09 +0100 Subject: [PATCH 09/15] move to projector --- src/ChainRulesCore.jl | 2 +- src/differentials/abstract_zero.jl | 1 + src/differentials/thunks.jl | 2 + src/projection.jl | 93 +++++++++--------- test/projection.jl | 145 +++++++++++++++++------------ 5 files changed, 139 insertions(+), 104 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 7ba0570ab..9cf6b7c31 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @scalar_rule, @thunk, @not_implemented -export canonicalize, extern, unthunk, project # differential operations +export canonicalize, extern, unthunk, projector # differential operations export add!! # gradient accumulation operations # differentials export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk diff --git a/src/differentials/abstract_zero.jl b/src/differentials/abstract_zero.jl index 01dbfc8f3..fb5342dd2 100644 --- a/src/differentials/abstract_zero.jl +++ b/src/differentials/abstract_zero.jl @@ -27,6 +27,7 @@ Base.:/(z::AbstractZero, ::Any) = z Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T) Base.getindex(z::AbstractZero, k) = z +Base.getproperty(z::AbstractZero, f::Symbol) = z """ ZeroTangent() <: AbstractZero diff --git a/src/differentials/thunks.jl b/src/differentials/thunks.jl index 781fff60c..ed6a3d35c 100644 --- a/src/differentials/thunks.jl +++ b/src/differentials/thunks.jl @@ -37,6 +37,8 @@ Base.imag(a::AbstractThunk) = imag(unthunk(a)) Base.Complex(a::AbstractThunk) = Complex(unthunk(a)) Base.Complex(a::AbstractThunk, b::AbstractThunk) = Complex(unthunk(a), unthunk(b)) +Base.getproperty(a::AbstractThunk, f::Symbol) = f === :f ? getfield(a, f) : getproperty(unthunk(a), f) + Base.mapreduce(f, op, a::AbstractThunk; kws...) = mapreduce(f, op, unthunk(a); kws...) function Base.mapreduce(f, op, itr, a::AbstractThunk; kws...) return mapreduce(f, op, itr, unthunk(a); kws...) diff --git a/src/projection.jl b/src/projection.jl index 340f5ca46..daedb0175 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,9 +1,9 @@ using LinearAlgebra: Diagonal, diag """ - projector([T::Type], x, dx) + projector([T::Type], x) -"projector" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, +"project" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, it is assumed to be the type of `x`. It's necessary to have `x` to ensure that it's possible to projector e.g. `AbstractZero`s @@ -12,55 +12,62 @@ onto `Array`s -- this wouldn't be possible with type information alone because t """ # TODO change docstring to reflect projecor returns a closure function projector end -projector(x, dx) = projector(typeof(x), x, dx) +projector(x) = projector(typeof(x), x) -# identity -projector(::Type{T}, x::T, dx::T) where T = identity - -### AbstractZero -projector(::Type{T}, x::T, dx::AbstractZero) where T = _ -> zero(x) - -### AbstractThunk -projector(::Type{T}, x::T, dx::AbstractThunk) where T = projector(x, unthunk(dx)) - - -### Number-types -projector(::Type{T}, x::T, dx::T2) where {T<:Number, T2<:Number} = dx -> T(dx) -projector(::Type{T}, x::T, dx::Complex) where {T<:Real} = dx -> T(real(dx)) +# fallback +function projector(::Type{T}, x::T) where T + println("to Any") + project(dx::T) = dx + project(dx::AbstractZero) = zero(x) + project(dx::AbstractThunk) = project(unthunk(dx)) + return project +end +# Numbers +function projector(::Type{T}, x::T) where {T<:Real} + println("to Real") + project(dx::Real) = T(dx) + project(dx::Number) = T(real(dx)) # to avoid InexactError + project(dx::AbstractZero) = zero(x) + project(dx::AbstractThunk) = project(unthunk(dx)) + return project +end +function projector(::Type{T}, x::T) where {T<:Number} + println("to Number") + project(dx::Number) = T(dx) + project(dx::AbstractZero) = zero(x) + project(dx::AbstractThunk) = project(unthunk(dx)) + return project +end # Arrays -projector(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = identity - -# for projector([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) -projector(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = projector.(Ref(T), x, dx) # TODO - -# for projector(rand(2, 2), Diagonal(rand(2))) -function projector(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array} - return projector(T, x, collect(dx)) +function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N} + println("to Array") + element = zero(eltype(x)) + project(dx::Array{T, N}) = dx # identity + project(dx::AbstractArray) = project(collect(dx)) # from Diagonal + project(dx::Array) = projector(element).(dx) # from different element type + project(dx::AbstractZero) = zero(x) + project(dx::AbstractThunk) = project(unthunk(dx)) + return project end -# for projector([Foo(0.0), Foo(0.0)], ZeroTangent()) -function projector(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T} - return projector.(Ref(T), x, Ref(dx)) # TODO +# Tangent +function projector(::Type{<:Tangent}, x) + println("to Tangent") + keys = fieldnames(typeof(x)) + project(dx) = Tangent{typeof(x)}(; ((k, getproperty(dx, k)) for k in keys)...) + return project end - -## Diagonal -function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V} - d = diag(x) - return dx -> Diagonal(projector(V, d, diag(dx))) -end -function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Tangent) where {V} +# Diagonal +function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal) where {V} + println("to Diagonal") d = diag(x) - return dx -> Diagonal(projector(V, d, dx.diag)) -end -function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractZero) where {V} - d = diag(x) - return dx -> Diagonal(projector(V, d, dx)) + project(dx::AbstractMatrix) = Diagonal(projector(V, d)(diag(dx))) + project(dx::Tangent) = Diagonal(projector(V, d)(dx.diag)) + project(dx::AbstractZero) = Diagonal(projector(V, d)(dx)) + project(dx::AbstractThunk) = project(unthunk(dx)) + return project end -function projector(::Type{<:Tangent}, x::Diagonal, dx::Diagonal) - T = typeof(x) - return dx -> Tangent{T}(diag=diag(dx)) -end diff --git a/test/projection.jl b/test/projection.jl index 030808195..d581e7100 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -3,78 +3,103 @@ struct Fred end Base.zero(::Fred) = Fred(0.0) -Base.zero(::Type{Fred}) = "F0" +Base.zero(::Type{Fred}) = Fred(0.0) @testset "projection" begin - - #identity - @test Fred(1.2) == project(Fred(-0.2), Fred(1.2)) - @test 3.2 == project(1.0, 3.2) - @test 2.0 + 0.0im == project(1.0im, 2.0) - - @testset "From AbstractZero" begin - @testset "to numbers" begin - @test 0.0 == project(1.1, ZeroTangent()) - @test 0.0f0 == project(1.1f0, ZeroTangent()) - end - - @testset "to arrays (dense and structured)" begin - @test zeros(2, 2) == project([1.0 2; 3 4], ZeroTangent()) - @test Diagonal(zeros(2)) == project(Diagonal([1.0, 4]), ZeroTangent()) - @test Diagonal(zeros(ComplexF64, 2)) == project(Diagonal([1.0 + 0im, 4]), ZeroTangent()) - end - - @testset "to structs" begin - @test Fred(0.0) == project(Fred(3.2), ZeroTangent()) - end - - @testset "to arrays of structs" begin - @test [Fred(0.0), Fred(0.0)] == project([Fred(0.0), Fred(0.0)], ZeroTangent()) - @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), ZeroTangent()) - end + @testset "fallback" begin + @test Fred(1.2) == projector(Fred(3.2))(Fred(1.2)) + @test Fred(0.0) == projector(Fred(3.2))(ZeroTangent()) + @test Fred(3.2) == projector(Fred(-0.2))(@thunk(Fred(3.2))) end - @testset "From AbstractThunk" begin - @test 3.2 == project(1.0, @thunk(3.2)) - @test Fred(3.2) == project(Fred(-0.2), @thunk(Fred(3.2))) - @test zeros(2) == project([1.0, 2.0], @thunk(ZeroTangent())) - @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), @thunk(ZeroTangent())) - end + @testset "to Real" begin + # Float64 + @test 3.2 == projector(1.0)(3.2) + @test 0.0 == projector(1.1)(ZeroTangent()) + @test 3.2 == projector(1.0)(@thunk(3.2)) - @testset "To number types" begin - @testset "to subset" begin - @test 3.2 == project(1.0, 3.2 + 3im) - @test 3.2f0 == project(1.0f0, 3.2) - @test 3.2f0 == project(1.0f0, 3.2 - 3im) - end - - @testset "to superset" begin - @test 2.0 + 0.0im == project(2.0 + 1.0im, 2.0) - @test 2.0 == project(2.0, 2.0f0) - end + # down + @test 3.2 == projector(1.0)(3.2 + 3im) + @test 3.2f0 == projector(1.0f0)(3.2) + @test 3.2f0 == projector(1.0f0)(3.2 - 3im) + + # up + @test 2.0 == projector(2.0)(2.0f0) end - @testset "To Arrays" begin - # change eltype - @test [1.0 2.0; 3.0 4.0] == project(zeros(2, 2), [1.0 2.0; 3.0 4.0]) - @test [1.0f0 2; 3 4] == project(zeros(Float32, 2, 2), [1.0 2; 3 4]) + @testset "to Number" begin + # Complex + @test 2.0 + 0.0im == projector(1.0im)(2.0 + 0.0im) + + # down + @test 2.0 + 0.0im == projector(1.0im)(2.0) + @test 0.0 + 0.0im == projector(1.0im)(ZeroTangent()) + @test 0.0 + 0.0im == projector(1.0im)(@thunk(ZeroTangent())) - # from a structured array - @test [1.0 0; 0 4] == project(zeros(2, 2), Diagonal([1.0, 4])) + # up + @test 2.0 + 0.0im == projector(2.0 + 1.0im)(2.0) + end + + @testset "to Array" begin + # to an array of numbers + @test [1.0 2.0; 3.0 4.0] == projector(zeros(2, 2))([1.0 2.0; 3.0 4.0]) + @test zeros(2, 2) == projector([1.0 2; 3 4])(ZeroTangent()) + @test zeros(2) == projector([1.0, 2.0])(@thunk(ZeroTangent())) + @test [1.0f0 2; 3 4] == projector(zeros(Float32, 2, 2))([1.0 2; 3 4]) + @test [1.0 0; 0 4] == projector(zeros(2, 2))(Diagonal([1.0, 4])) + + # to a array of structs + @test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])([Fred(0.0), Fred(0.0)]) + @test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])([ZeroTangent(), ZeroTangent()]) + @test [Fred(0.0), Fred(3.2)] == projector([Fred(0.0), Fred(0.0)])([ZeroTangent(), @thunk(Fred(3.2))]) + @test [Fred(0.0), Fred(0.0)] == projector([Fred(1.0), Fred(2.0)])(ZeroTangent()) + @test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])(@thunk(ZeroTangent())) + diagfreds = [Fred(1.0) Fred(0.0); Fred(0.0) Fred(4.0)] + @test diagfreds == projector(diagfreds)(Diagonal([Fred(1.0), Fred(4.0)])) + end - # from an array of specials - @test [Fred(0.0), Fred(0.0)] == project([Fred(0.0), Fred(0.0)], [ZeroTangent(), ZeroTangent()]) + @testset "to Diagonal" begin + d_F64 = Diagonal([0.0, 0.0]) + d_F32 = Diagonal([0.0f0, 0.0f0]) + d_C64 = Diagonal([0.0 + 0im, 0.0]) + d_Fred = Diagonal([Fred(0.0), Fred(0.0)]) + + # from Matrix + @test d_F64 == projector(d_F64)(zeros(2, 2)) + @test d_F64 == projector(d_F64)(zeros(Float32, 2, 2)) + @test d_F64 == projector(d_F64)(zeros(ComplexF64, 2, 2)) + + # from Diagonal of Numbers + @test d_F64 == projector(d_F64)(d_F64) + @test d_F64 == projector(d_F64)(d_F32) + @test d_F64 == projector(d_F64)(d_C64) + + # from Diagonal of AbstractTangent + @test d_F64 == projector(d_F64)(ZeroTangent()) + @test d_C64 == projector(d_C64)(ZeroTangent()) + @test d_F64 == projector(d_F64)(@thunk(ZeroTangent())) + @test d_F64 == projector(d_F64)(Diagonal([ZeroTangent(), ZeroTangent()])) + @test d_F64 == projector(d_F64)(Diagonal([ZeroTangent(), @thunk(ZeroTangent())])) + + # from Diagonal of structs + @test d_Fred == projector(d_Fred)(ZeroTangent()) + @test d_Fred == projector(d_Fred)(@thunk(ZeroTangent())) + @test d_Fred == projector(d_Fred)(Diagonal([ZeroTangent(), ZeroTangent()])) + + # from Tangent + @test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[0.0, 0.0])) + @test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[0.0f0, 0.0f0])) + @test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())])) end - @testset "Diagonal" begin - d = Diagonal([1.0, 4.0]) - t = Tangent{Diagonal}(;diag=[1.0, 4.0]) - @test d == project(d, [1.0 2; 3 4]) - @test d == project(d, t) - @test project(Tangent, d, d) isa Tangent + @testset "to Tangent" begin + @test Tangent{Fred}(; a = 3.2,) == projector(Tangent, Fred(3.2))(Fred(3.2)) + @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(ZeroTangent()) + @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(@thunk(ZeroTangent())) - @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), Diagonal([ZeroTangent(), ZeroTangent()])) - @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), @thunk(ZeroTangent())) + @test projector(Tangent, Diagonal(zeros(2)))(Diagonal([1.0f0, 2.0f0])) isa Tangent + @test projector(Tangent, Diagonal(zeros(2)))(ZeroTangent()) isa Tangent + @test projector(Tangent, Diagonal(zeros(2)))(@thunk(ZeroTangent())) isa Tangent end # how to project to Upper/Lower Symmetric From 2ea4845df1702c89b9d454e05eaeb4f6bd3d5b9f Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 15:32:27 +0100 Subject: [PATCH 10/15] do not close over x (other than in the general case) --- src/projection.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index daedb0175..407d80cc0 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -44,29 +44,29 @@ end function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N} println("to Array") element = zero(eltype(x)) + sizex = size(x) project(dx::Array{T, N}) = dx # identity project(dx::AbstractArray) = project(collect(dx)) # from Diagonal project(dx::Array) = projector(element).(dx) # from different element type - project(dx::AbstractZero) = zero(x) + project(dx::AbstractZero) = zeros(T, sizex...) project(dx::AbstractThunk) = project(unthunk(dx)) return project end # Tangent -function projector(::Type{<:Tangent}, x) +function projector(::Type{<:Tangent}, x::T) where {T} println("to Tangent") - keys = fieldnames(typeof(x)) - project(dx) = Tangent{typeof(x)}(; ((k, getproperty(dx, k)) for k in keys)...) + project(dx) = Tangent{T}(; ((k, getproperty(dx, k)) for k in fieldnames(T))...) return project end # Diagonal function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal) where {V} println("to Diagonal") - d = diag(x) - project(dx::AbstractMatrix) = Diagonal(projector(V, d)(diag(dx))) - project(dx::Tangent) = Diagonal(projector(V, d)(dx.diag)) - project(dx::AbstractZero) = Diagonal(projector(V, d)(dx)) + projV = projector(V, diag(x)) + project(dx::AbstractMatrix) = Diagonal(projV(diag(dx))) + project(dx::Tangent) = Diagonal(projV(dx.diag)) + project(dx::AbstractZero) = Diagonal(projV(dx)) project(dx::AbstractThunk) = project(unthunk(dx)) return project end From 465e1d7bb7b2ba7e2e0e2b954e7ccd30a6f26c49 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 15:34:16 +0100 Subject: [PATCH 11/15] update docstring --- src/projection.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 407d80cc0..308f38276 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -3,13 +3,13 @@ using LinearAlgebra: Diagonal, diag """ projector([T::Type], x) -"project" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, -it is assumed to be the type of `x`. +Returns a `project(dx)` closure which maps `dx` onto type `T`, such that it is the +same size as `x`. If `T` is not provided, it is assumed to be the type of `x`. -It's necessary to have `x` to ensure that it's possible to projector e.g. `AbstractZero`s +It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s onto `Array`s -- this wouldn't be possible with type information alone because the neither `AbstractZero`s nor `T` know what size of `Array` to produce. -""" # TODO change docstring to reflect projecor returns a closure +""" function projector end projector(x) = projector(typeof(x), x) From 0a06dce733176c7e06c3b9ec4cdfb190092b1b5f Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 17:26:11 +0100 Subject: [PATCH 12/15] fix getproperty --- src/differentials/thunks.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/differentials/thunks.jl b/src/differentials/thunks.jl index ed6a3d35c..00f8d18c7 100644 --- a/src/differentials/thunks.jl +++ b/src/differentials/thunks.jl @@ -37,8 +37,6 @@ Base.imag(a::AbstractThunk) = imag(unthunk(a)) Base.Complex(a::AbstractThunk) = Complex(unthunk(a)) Base.Complex(a::AbstractThunk, b::AbstractThunk) = Complex(unthunk(a), unthunk(b)) -Base.getproperty(a::AbstractThunk, f::Symbol) = f === :f ? getfield(a, f) : getproperty(unthunk(a), f) - Base.mapreduce(f, op, a::AbstractThunk; kws...) = mapreduce(f, op, unthunk(a); kws...) function Base.mapreduce(f, op, itr, a::AbstractThunk; kws...) return mapreduce(f, op, itr, unthunk(a); kws...) @@ -190,6 +188,8 @@ end @inline unthunk(x::Thunk) = x.f() +Base.getproperty(a::Thunk, f::Symbol) = f === :f ? getfield(a, f) : getproperty(unthunk(a), f) + Base.show(io::IO, x::Thunk) = print(io, "Thunk($(repr(x.f)))") """ @@ -211,6 +211,8 @@ end unthunk(x::InplaceableThunk) = unthunk(x.val) +Base.getproperty(a::InplaceableThunk, f::Symbol) = f in (:val, :add!) ? getfield(a, f) : getproperty(unthunk(a), f) + function Base.show(io::IO, x::InplaceableThunk) return print(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))") end From d822b020a7c31ce57d03df6b603dc2c18538f410 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 17:26:43 +0100 Subject: [PATCH 13/15] add to Tangent and to Symmetric --- src/projection.jl | 15 +++++++++++++-- test/projection.jl | 26 +++++++++++++++++--------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 308f38276..86dd7ec81 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -14,9 +14,9 @@ function projector end projector(x) = projector(typeof(x), x) -# fallback +# fallback (structs) function projector(::Type{T}, x::T) where T - println("to Any") + println("to Any, T=$T") project(dx::T) = dx project(dx::AbstractZero) = zero(x) project(dx::AbstractThunk) = project(unthunk(dx)) @@ -71,3 +71,14 @@ function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal) where {V} return project end +# Symmetric +function projector(::Type{<:Symmetric{<:Any, M}}, x::Symmetric) where {M} + println("to Symetric") + projM = projector(M, parent(x)) + uplo = Symbol(x.uplo) + project(dx::AbstractMatrix) = Symmetric(projM(dx), uplo) + project(dx::Tangent) = Symmetric(projM(dx.data), uplo) + project(dx::AbstractZero) = Symmetric(projM(dx), uplo) + project(dx::AbstractThunk) = project(unthunk(dx)) + return project +end diff --git a/test/projection.jl b/test/projection.jl index d581e7100..4a5e57217 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -58,6 +58,16 @@ Base.zero(::Type{Fred}) = Fred(0.0) @test diagfreds == projector(diagfreds)(Diagonal([Fred(1.0), Fred(4.0)])) end + @testset "to Tangent" begin + @test Tangent{Fred}(; a = 3.2,) == projector(Tangent, Fred(3.2))(Fred(3.2)) + @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(ZeroTangent()) + @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(@thunk(ZeroTangent())) + + @test projector(Tangent, Diagonal(zeros(2)))(Diagonal([1.0f0, 2.0f0])) isa Tangent + @test projector(Tangent, Diagonal(zeros(2)))(ZeroTangent()) isa Tangent + @test projector(Tangent, Diagonal(zeros(2)))(@thunk(ZeroTangent())) isa Tangent + end + @testset "to Diagonal" begin d_F64 = Diagonal([0.0, 0.0]) d_F32 = Diagonal([0.0f0, 0.0f0]) @@ -92,15 +102,13 @@ Base.zero(::Type{Fred}) = Fred(0.0) @test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())])) end - @testset "to Tangent" begin - @test Tangent{Fred}(; a = 3.2,) == projector(Tangent, Fred(3.2))(Fred(3.2)) - @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(ZeroTangent()) - @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(@thunk(ZeroTangent())) + @testset "to Symmetric" begin + data = [1.0 2; 3 4] + @test Symmetric(data) == projector(Symmetric(data))(data) + @test Symmetric(data, :L) == projector(Symmetric(data, :L))(data) + @test Symmetric(Diagonal(data)) == projector(Symmetric(data))(Diagonal(diag(data))) - @test projector(Tangent, Diagonal(zeros(2)))(Diagonal([1.0f0, 2.0f0])) isa Tangent - @test projector(Tangent, Diagonal(zeros(2)))(ZeroTangent()) isa Tangent - @test projector(Tangent, Diagonal(zeros(2)))(@thunk(ZeroTangent())) isa Tangent + @test Symmetric(zeros(2, 2)) == projector(Symmetric(data))(ZeroTangent()) + @test Symmetric(zeros(2, 2)) == projector(Symmetric(data))(@thunk(ZeroTangent())) end - - # how to project to Upper/Lower Symmetric end From 25a7ceeea230438656a41c3ea55080f6949ac4aa Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 17:27:20 +0100 Subject: [PATCH 14/15] remove debug strings --- src/projection.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 86dd7ec81..203342387 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -16,7 +16,6 @@ projector(x) = projector(typeof(x), x) # fallback (structs) function projector(::Type{T}, x::T) where T - println("to Any, T=$T") project(dx::T) = dx project(dx::AbstractZero) = zero(x) project(dx::AbstractThunk) = project(unthunk(dx)) @@ -25,7 +24,6 @@ end # Numbers function projector(::Type{T}, x::T) where {T<:Real} - println("to Real") project(dx::Real) = T(dx) project(dx::Number) = T(real(dx)) # to avoid InexactError project(dx::AbstractZero) = zero(x) @@ -33,7 +31,6 @@ function projector(::Type{T}, x::T) where {T<:Real} return project end function projector(::Type{T}, x::T) where {T<:Number} - println("to Number") project(dx::Number) = T(dx) project(dx::AbstractZero) = zero(x) project(dx::AbstractThunk) = project(unthunk(dx)) @@ -42,7 +39,6 @@ end # Arrays function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N} - println("to Array") element = zero(eltype(x)) sizex = size(x) project(dx::Array{T, N}) = dx # identity @@ -55,14 +51,12 @@ end # Tangent function projector(::Type{<:Tangent}, x::T) where {T} - println("to Tangent") project(dx) = Tangent{T}(; ((k, getproperty(dx, k)) for k in fieldnames(T))...) return project end # Diagonal function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal) where {V} - println("to Diagonal") projV = projector(V, diag(x)) project(dx::AbstractMatrix) = Diagonal(projV(diag(dx))) project(dx::Tangent) = Diagonal(projV(dx.diag)) @@ -73,7 +67,6 @@ end # Symmetric function projector(::Type{<:Symmetric{<:Any, M}}, x::Symmetric) where {M} - println("to Symetric") projM = projector(M, parent(x)) uplo = Symbol(x.uplo) project(dx::AbstractMatrix) = Symmetric(projM(dx), uplo) From 7801e19ca8590b58da9a3700e4081000947ee0f9 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 17:41:10 +0100 Subject: [PATCH 15/15] separate out the projector --- src/projection.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 203342387..171cca1d9 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -39,11 +39,11 @@ end # Arrays function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N} - element = zero(eltype(x)) sizex = size(x) + projT = projector(zero(T)) project(dx::Array{T, N}) = dx # identity project(dx::AbstractArray) = project(collect(dx)) # from Diagonal - project(dx::Array) = projector(element).(dx) # from different element type + project(dx::Array) = projT.(dx) # from different element type project(dx::AbstractZero) = zeros(T, sizex...) project(dx::AbstractThunk) = project(unthunk(dx)) return project