Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some Array rules #491

Merged
merged 24 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.4.0"
version = "1.5.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
39 changes: 39 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
#####
##### constructors
#####

ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...)

function rrule(::Type{T}, x::AbstractArray) where {T<:Array}
project_x = ProjectTo(x)
Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))
return T(x), Array_pullback
end

#####
##### `vect`
#####

@non_differentiable Base.vect()

# Case of uniform type `T`: the data passes straight through,
# so no projection should be required.
function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...)
return Base.vect(X...), vect_pullback
end

# Numbers and arrays are often promoted, to make a uniform vector.
# ProjectTo here reverses this
function rrule(
::typeof(Base.vect),
X::Vararg{Union{Number,AbstractArray{<:Number}}, N},
) where {N}
projects = map(ProjectTo, X)
function vect_pullback(ȳ)
X̄ = ntuple(n -> projects[n](ȳ[n]), N)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering whether projects[n] gets handled well. @code_warntype seems happy... and my attempts to make things easier to unroll all make it slower:

julia> @btime rrule(Base.vect, 1,2,3)[2]($(rand(3)))
  28.684 ns (1 allocation: 112 bytes)
(NoTangent(), 0.7437540971290453, 0.6835525631785602, 0.29678387383966687)

julia> @btime rrule(Base.vect, 1+im,2+im,3+im)[2]($(rand(3)))
  235.140 ns (6 allocations: 416 bytes)
(NoTangent(), 0.9212083670665245 + 0.0im, 0.989459216141123 + 0.0im, 0.8454719840778347 + 0.0im)

julia> @btime rrule(Base.vect, 1, 2+3im, 4.0)[2]($(rand(3)))
  609.760 ns (6 allocations: 320 bytes)
(NoTangent(), 0.2914057312235363, 0.23309219863512798 + 0.0im, 0.08023319383991401)

julia> struct StaticGetter{i} end; @inline (::StaticGetter{i})(v) where {i} = v[i]; # from Zygote

julia> function rrule(
            ::typeof(Base.vect),
            X::Vararg{Union{Number,AbstractArray{<:Number}}, N},
        ) where {N}
            valN = Val(N)
            projects = map(ProjectTo, X)
            function vect_pullback(ȳ)
                X̄ = ntuple(n -> StaticGetter{n}()(projects)(ȳ[n]), valN)
                return (NoTangent(), X̄...)
            end
            return Base.vect(X...), vect_pullback
        end
rrule (generic function with 723 methods)

julia> @btime rrule(Base.vect, 1, 2+3im, 4.0)[2]($(rand(3)))
  1.442 μs (11 allocations: 448 bytes)
(NoTangent(), 0.7692288886268137, 0.39993377443044065 + 0.0im, 0.6341039234276757)

return (NoTangent(), X̄...)
end
return Base.vect(X...), vect_pullback
end

#####
##### `reshape`
#####
Expand Down
36 changes: 36 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,39 @@
@testset "constructors" begin

# We can't use test_rrule here (as it's currently implemented) because the elements of
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
# the array have arbitrary values. The only thing we can do is ensure that we're getting
# `ZeroTangent`s back, and that the forwards pass produces the correct thing still.
# Issue: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/202
@testset "undef" begin
val, pullback = rrule(Array{Float64}, undef, 5)
@test size(val) == (5, )
@test val isa Array{Float64, 1}
@test pullback(randn(5)) == (NoTangent(), NoTangent(), NoTangent())
end
@testset "from existing array" begin
test_rrule(Array, randn(2, 5))
test_rrule(Array, Diagonal(randn(5)))
test_rrule(Matrix, Diagonal(randn(5)))
test_rrule(Matrix, transpose(randn(4)))
test_rrule(Array{ComplexF64}, randn(3))
end
end

@testset "vect" begin
test_rrule(Base.vect)
@testset "homogeneous type" begin
test_rrule(Base.vect, 5.0, 4.0, 3.0)
test_rrule(Base.vect, randn(2, 2), randn(3, 3))
end
@testset "inhomogeneous type" begin
test_rrule(
Base.vect, 5.0, 3f0;
atol=1e-6, rtol=1e-6, check_inferred=VERSION>=v"1.6",
) # tolerance due to Float32.
test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false)
end
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
end

@testset "reshape" begin
test_rrule(reshape, rand(4, 5), (2, 10))
test_rrule(reshape, rand(4, 5), 2, 10)
Expand Down