-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: projector
implementation (returning a closure)
#382
Changes from all commits
914bd92
06678a4
c58f974
00020e3
37f9253
4e1b79d
7dc58ee
3345ba9
31d81ed
2ea4845
465e1d7
0a06dce
d822b02
25a7cee
7801e19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
using LinearAlgebra: Diagonal, diag | ||
|
||
""" | ||
projector([T::Type], x) | ||
|
||
Returns a `project(dx)` closure which maps `dx` onto type `T`, such that it is the | ||
same size as `x`. If `T` is not provided, it is assumed to be the type of `x`. | ||
|
||
It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s | ||
onto `Array`s -- this wouldn't be possible with type information alone because the neither | ||
`AbstractZero`s nor `T` know what size of `Array` to produce. | ||
""" | ||
function projector end | ||
|
||
projector(x) = projector(typeof(x), x) | ||
|
||
# fallback (structs) | ||
function projector(::Type{T}, x::T) where T | ||
project(dx::T) = dx | ||
project(dx::AbstractZero) = zero(x) | ||
project(dx::AbstractThunk) = project(unthunk(dx)) | ||
return project | ||
end | ||
|
||
# Numbers | ||
function projector(::Type{T}, x::T) where {T<:Real} | ||
project(dx::Real) = T(dx) | ||
project(dx::Number) = T(real(dx)) # to avoid InexactError | ||
project(dx::AbstractZero) = zero(x) | ||
project(dx::AbstractThunk) = project(unthunk(dx)) | ||
return project | ||
end | ||
Comment on lines
+29
to
+32
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if there should be some There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you clarify how this would work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One attempt is here: https://gist.github.com/mcabbott/8a84086cc604d34b5e8dff2eb3839f3a |
||
function projector(::Type{T}, x::T) where {T<:Number} | ||
project(dx::Number) = T(dx) | ||
project(dx::AbstractZero) = zero(x) | ||
project(dx::AbstractThunk) = project(unthunk(dx)) | ||
return project | ||
end | ||
|
||
# Arrays | ||
function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N} | ||
sizex = size(x) | ||
projT = projector(zero(T)) | ||
project(dx::Array{T, N}) = dx # identity | ||
project(dx::AbstractArray) = project(collect(dx)) # from Diagonal | ||
Comment on lines
+41
to
+45
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I also wonder if this is the right behaviour. Maybe the ability to reproduce a similar dense array is desirable sometimes, but making the default projector materialise when it doesn't have to seems odd --- shouldn't we preserve Diagonal or Fill backwards as many steps as possible, by default? But again maybe this is trying to serve multiple purposes which perhaps can be clarified. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe there ought to be abstract types involved, something like:
where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the method which specifically wants the output to be a dense array, i.e. where I couldn't quite see how to generalise the method for an arbitrary |
||
project(dx::Array) = projT.(dx) # from different element type | ||
project(dx::AbstractZero) = zeros(T, sizex...) | ||
project(dx::AbstractThunk) = project(unthunk(dx)) | ||
return project | ||
end | ||
|
||
# Tangent | ||
function projector(::Type{<:Tangent}, x::T) where {T} | ||
project(dx) = Tangent{T}(; ((k, getproperty(dx, k)) for k in fieldnames(T))...) | ||
Comment on lines
+52
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's still not clear to me what's going to call this. Clearly we will not have |
||
return project | ||
end | ||
|
||
# Diagonal | ||
function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal) where {V} | ||
projV = projector(V, diag(x)) | ||
project(dx::AbstractMatrix) = Diagonal(projV(diag(dx))) | ||
project(dx::Tangent) = Diagonal(projV(dx.diag)) | ||
project(dx::AbstractZero) = Diagonal(projV(dx)) | ||
project(dx::AbstractThunk) = project(unthunk(dx)) | ||
return project | ||
end | ||
|
||
# Symmetric | ||
function projector(::Type{<:Symmetric{<:Any, M}}, x::Symmetric) where {M} | ||
projM = projector(M, parent(x)) | ||
uplo = Symbol(x.uplo) | ||
project(dx::AbstractMatrix) = Symmetric(projM(dx), uplo) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is right, you need to symmetrise, not merely to apply the wrapper. There's a fairly efficient one here: |
||
project(dx::Tangent) = Symmetric(projM(dx.data), uplo) | ||
project(dx::AbstractZero) = Symmetric(projM(dx), uplo) | ||
project(dx::AbstractThunk) = project(unthunk(dx)) | ||
return project | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
struct Fred | ||
a::Float64 | ||
end | ||
|
||
Base.zero(::Fred) = Fred(0.0) | ||
Base.zero(::Type{Fred}) = Fred(0.0) | ||
|
||
@testset "projection" begin | ||
@testset "fallback" begin | ||
@test Fred(1.2) == projector(Fred(3.2))(Fred(1.2)) | ||
@test Fred(0.0) == projector(Fred(3.2))(ZeroTangent()) | ||
@test Fred(3.2) == projector(Fred(-0.2))(@thunk(Fred(3.2))) | ||
end | ||
|
||
@testset "to Real" begin | ||
# Float64 | ||
@test 3.2 == projector(1.0)(3.2) | ||
@test 0.0 == projector(1.1)(ZeroTangent()) | ||
@test 3.2 == projector(1.0)(@thunk(3.2)) | ||
|
||
# down | ||
@test 3.2 == projector(1.0)(3.2 + 3im) | ||
@test 3.2f0 == projector(1.0f0)(3.2) | ||
@test 3.2f0 == projector(1.0f0)(3.2 - 3im) | ||
|
||
# up | ||
@test 2.0 == projector(2.0)(2.0f0) | ||
end | ||
|
||
@testset "to Number" begin | ||
# Complex | ||
@test 2.0 + 0.0im == projector(1.0im)(2.0 + 0.0im) | ||
|
||
# down | ||
@test 2.0 + 0.0im == projector(1.0im)(2.0) | ||
@test 0.0 + 0.0im == projector(1.0im)(ZeroTangent()) | ||
@test 0.0 + 0.0im == projector(1.0im)(@thunk(ZeroTangent())) | ||
|
||
# up | ||
@test 2.0 + 0.0im == projector(2.0 + 1.0im)(2.0) | ||
end | ||
|
||
@testset "to Array" begin | ||
# to an array of numbers | ||
@test [1.0 2.0; 3.0 4.0] == projector(zeros(2, 2))([1.0 2.0; 3.0 4.0]) | ||
@test zeros(2, 2) == projector([1.0 2; 3 4])(ZeroTangent()) | ||
@test zeros(2) == projector([1.0, 2.0])(@thunk(ZeroTangent())) | ||
@test [1.0f0 2; 3 4] == projector(zeros(Float32, 2, 2))([1.0 2; 3 4]) | ||
@test [1.0 0; 0 4] == projector(zeros(2, 2))(Diagonal([1.0, 4])) | ||
|
||
# to a array of structs | ||
@test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])([Fred(0.0), Fred(0.0)]) | ||
@test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])([ZeroTangent(), ZeroTangent()]) | ||
@test [Fred(0.0), Fred(3.2)] == projector([Fred(0.0), Fred(0.0)])([ZeroTangent(), @thunk(Fred(3.2))]) | ||
@test [Fred(0.0), Fred(0.0)] == projector([Fred(1.0), Fred(2.0)])(ZeroTangent()) | ||
@test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])(@thunk(ZeroTangent())) | ||
diagfreds = [Fred(1.0) Fred(0.0); Fred(0.0) Fred(4.0)] | ||
@test diagfreds == projector(diagfreds)(Diagonal([Fred(1.0), Fred(4.0)])) | ||
end | ||
|
||
@testset "to Tangent" begin | ||
@test Tangent{Fred}(; a = 3.2,) == projector(Tangent, Fred(3.2))(Fred(3.2)) | ||
@test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(ZeroTangent()) | ||
@test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(@thunk(ZeroTangent())) | ||
|
||
@test projector(Tangent, Diagonal(zeros(2)))(Diagonal([1.0f0, 2.0f0])) isa Tangent | ||
@test projector(Tangent, Diagonal(zeros(2)))(ZeroTangent()) isa Tangent | ||
@test projector(Tangent, Diagonal(zeros(2)))(@thunk(ZeroTangent())) isa Tangent | ||
end | ||
|
||
@testset "to Diagonal" begin | ||
d_F64 = Diagonal([0.0, 0.0]) | ||
d_F32 = Diagonal([0.0f0, 0.0f0]) | ||
d_C64 = Diagonal([0.0 + 0im, 0.0]) | ||
d_Fred = Diagonal([Fred(0.0), Fred(0.0)]) | ||
|
||
# from Matrix | ||
@test d_F64 == projector(d_F64)(zeros(2, 2)) | ||
@test d_F64 == projector(d_F64)(zeros(Float32, 2, 2)) | ||
@test d_F64 == projector(d_F64)(zeros(ComplexF64, 2, 2)) | ||
|
||
# from Diagonal of Numbers | ||
@test d_F64 == projector(d_F64)(d_F64) | ||
@test d_F64 == projector(d_F64)(d_F32) | ||
@test d_F64 == projector(d_F64)(d_C64) | ||
|
||
# from Diagonal of AbstractTangent | ||
@test d_F64 == projector(d_F64)(ZeroTangent()) | ||
@test d_C64 == projector(d_C64)(ZeroTangent()) | ||
@test d_F64 == projector(d_F64)(@thunk(ZeroTangent())) | ||
@test d_F64 == projector(d_F64)(Diagonal([ZeroTangent(), ZeroTangent()])) | ||
@test d_F64 == projector(d_F64)(Diagonal([ZeroTangent(), @thunk(ZeroTangent())])) | ||
|
||
# from Diagonal of structs | ||
@test d_Fred == projector(d_Fred)(ZeroTangent()) | ||
@test d_Fred == projector(d_Fred)(@thunk(ZeroTangent())) | ||
@test d_Fred == projector(d_Fred)(Diagonal([ZeroTangent(), ZeroTangent()])) | ||
|
||
# from Tangent | ||
@test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[0.0, 0.0])) | ||
@test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[0.0f0, 0.0f0])) | ||
@test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())])) | ||
end | ||
|
||
@testset "to Symmetric" begin | ||
data = [1.0 2; 3 4] | ||
@test Symmetric(data) == projector(Symmetric(data))(data) | ||
@test Symmetric(data, :L) == projector(Symmetric(data, :L))(data) | ||
@test Symmetric(Diagonal(data)) == projector(Symmetric(data))(Diagonal(diag(data))) | ||
|
||
@test Symmetric(zeros(2, 2)) == projector(Symmetric(data))(ZeroTangent()) | ||
@test Symmetric(zeros(2, 2)) == projector(Symmetric(data))(@thunk(ZeroTangent())) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is too tight, as
projector(2)(3.5)
is going to be an InexactError right? As isprojector(false)(1.5)
.And more generally, what if (say) I want to put dual numbers into the pullback? My impression is that that should be allowed. Which is what led me to think that only known problems should be projected out, like
dx::Complex
whenx::Real
, or anything whenx::Bool
. But it would be nice if the door were open for packages to add to the list of "things which get projected like Complex -> Real".There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that sounds like a relatively serious downside
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to have made it into the tagged version: