Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply ChainRulesCore.jl's projection operators #153

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.12.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
julia = "1"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ end
include("fillalgebra.jl")
include("fillbroadcast.jl")
include("trues.jl")
include("chainrules.jl")

##
# print
Expand Down
30 changes: 30 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import ChainRulesCore: ProjectTo, NoTangent

"""
ProjectTo(::Fill) -> ProjectTo{Fill}
ProjectTo(::Ones) -> ProjectTo{NoTangent}

Most FillArrays arrays store one number, and so their gradients under automatic
differentiation represent the variation of this one number.

The exception is those like `Ones` and `Zeros` whose type fixes their value,
which have no graidient.
"""
ProjectTo(x::Fill{<:Number}) = ProjectTo{Fill}(; element = ProjectTo(getindex_value(x)), axes = axes(x))

ProjectTo(x::AbstractFill{Bool}) = ProjectTo{NoTangent}() # Bool is always regarded as categorical

ProjectTo(x::Zeros) = ProjectTo{NoTangent}()
ProjectTo(x::Ones) = ProjectTo{NoTangent}()

function (project::ProjectTo{Fill})(dx::AbstractArray)
for d in 1:max(ndims(dx), length(project.axes))
size(dx, d) == length(get(project.axes, d, 1)) || throw(_projection_mismatch(axes_x, size(dx)))
end
Fill(mean(dx), project.axes) # Note that mean(dx::Fill) is optimised

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need another rule for the constructor to multiply the mean dx by the length of the vector? Think of x -> sum(Fill(x, 3)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the equivalent of https://github.com/FluxML/Zygote.jl/blob/e6a86745d66b5974eaafa8a8f28bcd4b100374df/src/lib/array.jl#L17

If the constructor is close to where the Fill is used, then perhaps it's a little wasteful to first project like this, and then un-create. But not so serious.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

end

function _projection_mismatch(axes_x::Tuple, size_dx::Tuple)
size_x = map(length, axes_x)
DimensionMismatch("variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx")
end
14 changes: 13 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using FillArrays, LinearAlgebra, SparseArrays, StaticArrays, Random, Base64, Test, Statistics

using FillArrays, StaticArrays, ChainRulesCore, Base64
using LinearAlgebra, SparseArrays, Random, Statistics, Test # standard libraries

import FillArrays: AbstractFill, RectDiagonal, SquareEye

@testset "fill array constructors and convert" begin
Expand Down Expand Up @@ -1323,3 +1326,12 @@ end
@test cor(Fill(3, 4, 5)) ≈ cor(fill(3, 4, 5)) nans=true
@test cor(Fill(3, 4, 5), dims=2) ≈ cor(fill(3, 4, 5), dims=2) nans=true
end

@testset "ChainRules integration" begin
@test ProjectTo(Fill(1,2,3))(ones(2,3)) === Fill(1.0, 2, 3)
@test ProjectTo(Fill(1,2,3))(ones(2,3,1) .+ im) === Fill(1.0, 2, 3)
@test ProjectTo(Fill(1,2,3))(Fill(1+im, 2,3)) === Fill(1.0, 2, 3)

@test ProjectTo(Eye(3))(rand(3,3)) === NoTangent()
@test ProjectTo(Zeros(3))(rand(3)) === NoTangent()
end