diff --git a/Project.toml b/Project.toml index a3b2c3d26..bd8155f02 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.8.0" +version = "1.9.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/projection.jl b/src/projection.jl index c4bbf4575..e5542d4d8 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -272,18 +272,55 @@ end ##### # Ref +# Note that Ref is mutable. This causes Zygote to represent its structral tangent not as a NamedTuple, +# but as `Ref{Any}((x=val,))`. Here we use a Tangent, there is at present no mutable version, but see +# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/105 function ProjectTo(x::Ref) sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)? - if sub isa ProjectTo{<:AbstractZero} + return ProjectTo{Tangent{typeof(x)}}(; x=sub) +end +(project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx)))) +function (project::ProjectTo{<:Tangent{<:Ref}})(dx::Ref) + dy = project.x(dx[]) + return project_type(project)(; x=dy) +end +# Since this works like a zero-array in broadcasting, it should also accept a number: +(project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx)) + +# Tuple +function ProjectTo(x::Tuple) + elements = map(ProjectTo, x) + if elements isa NTuple{<:Any,ProjectTo{<:AbstractZero}} return ProjectTo{NoTangent}() else - return ProjectTo{Ref}(; type=typeof(x), x=sub) + return ProjectTo{Tangent{typeof(x)}}(; elements=elements) end end -(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x)) -(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[])) -# Since this works like a zero-array in broadcasting, it should also accept a number: -(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx)) +# This method means that projection is re-applied to the contents of a Tangent. +# We're not entirely sure whether this is every necessary; but it should be safe, +# and should often compile away: +(project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tangent) = project(backing(dx)) +function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple) + len = length(project.elements) + if length(dx) != len + str = "tuple with length(x) == $len cannot have a gradient with length(dx) == $(length(dx))" + throw(DimensionMismatch(str)) + end + # Here map will fail if the lengths don't match, but gives a much less helpful error: + dy = map((f, x) -> f(x), project.elements, dx) + return project_type(project)(dy...) +end +function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray) + for d in 1:ndims(dx) + if size(dx, d) != get(length(project.elements), d, 1) + throw(_projection_mismatch(axes(project.elements), size(dx))) + end + end + dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray + dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) + return project_type(project)(dz...) +end + ##### ##### `LinearAlgebra` diff --git a/test/projection.jl b/test/projection.jl index d0c6c7eca..592b0a84d 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -137,12 +137,29 @@ struct NoSuperType end prefvec = ProjectTo(Ref([1, 2, 3 + 4im])) # recurses into contents @test prefvec(Ref(1:3)).x isa Vector{ComplexF64} @test prefvec(Tangent{Base.RefValue}(; x=1:3)).x isa Vector{ComplexF64} - @test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(; x=1:5)) + @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(; x=1:5)) @test ProjectTo(Ref(true)) isa ProjectTo{NoTangent} @test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent} end + @testset "Base: Tuple" begin + pt1 = ProjectTo((1.0,)) + @test pt1((1 + im,)) == Tangent{Tuple{Float64}}(1.0,) + @test pt1(pt1((1,))) == pt1(pt1((1,))) # accepts correct Tangent + @test pt1(Tangent{Any}(1)) == pt1((1,)) # accepts Tangent{Any} + @test pt1([1,]) == Tangent{Tuple{Float64}}(1.0,) # accepts Vector + @test pt1(NoTangent()) === NoTangent() + @test pt1(ZeroTangent()) === ZeroTangent() + + @test_throws Exception pt1([1, 2]) # DimensionMismatch, wrong length + @test_throws Exception pt1([]) + + pt3 = ProjectTo(([1, 2, 3], false, :gamma)) # partly non-differentiable + @test pt3((1:3, 4, 5)) == Tangent{Tuple{Vector{Int}, Bool, Symbol}}([1.0, 2.0, 3.0], NoTangent(), NoTangent()) + @test ProjectTo((true, [false])) isa ProjectTo{NoTangent} + end + @testset "Base: non-diff" begin @test ProjectTo(:a)(1) == NoTangent() @test ProjectTo('b')(2) == NoTangent()