Skip to content

Commit

Permalink
Some Array rules (#491)
Browse files Browse the repository at this point in the history
* Bump minor version

* Add undef Array constructor

* construct Array from existing AbstractArray

* vect implementation

* Bump precision

* Additional tests

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>

* Fix undef tests

* Constraint vect implementation

* Add float-only test

* Link to non_differentiable tests issue

* Type-stable `vect` implementation

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>

* type stable vect pullback

* Update src/rulesets/Base/array.jl

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>

* Test Union{Number, AbstractArray}

* Don't test inference below 1.6

* Update src/rulesets/Base/array.jl

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>

* Update src/rulesets/Base/array.jl

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>

* Style fix

* Style fix

* Update src/rulesets/Base/array.jl

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>

* Style fix

* Add an extra test

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
  • Loading branch information
willtebbutt and mcabbott authored Aug 2, 2021
1 parent 6df028a commit 7593339
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
39 changes: 39 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
#####
##### constructors
#####

ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...)

function rrule(::Type{T}, x::AbstractArray) where {T<:Array}
project_x = ProjectTo(x)
Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))
return T(x), Array_pullback
end

#####
##### `vect`
#####

@non_differentiable Base.vect()

# Case of uniform type `T`: the data passes straight through,
# so no projection should be required.
function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...)
return Base.vect(X...), vect_pullback
end

# Numbers and arrays are often promoted, to make a uniform vector.
# ProjectTo here reverses this
function rrule(
::typeof(Base.vect),
X::Vararg{Union{Number,AbstractArray{<:Number}}, N},
) where {N}
projects = map(ProjectTo, X)
function vect_pullback(ȳ)
= ntuple(n -> projects[n](ȳ[n]), N)
return (NoTangent(), X̄...)
end
return Base.vect(X...), vect_pullback
end

#####
##### `reshape`
#####
Expand Down
37 changes: 37 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,40 @@
@testset "constructors" begin

# We can't use test_rrule here (as it's currently implemented) because the elements of
# the array have arbitrary values. The only thing we can do is ensure that we're getting
# `ZeroTangent`s back, and that the forwards pass produces the correct thing still.
# Issue: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/202
@testset "undef" begin
val, pullback = rrule(Array{Float64}, undef, 5)
@test size(val) == (5, )
@test val isa Array{Float64, 1}
@test pullback(randn(5)) == (NoTangent(), NoTangent(), NoTangent())
end
@testset "from existing array" begin
test_rrule(Array, randn(2, 5))
test_rrule(Array, Diagonal(randn(5)))
test_rrule(Matrix, Diagonal(randn(5)))
test_rrule(Matrix, transpose(randn(4)))
test_rrule(Array{ComplexF64}, randn(3))
end
end

@testset "vect" begin
test_rrule(Base.vect)
@testset "homogeneous type" begin
test_rrule(Base.vect, (5.0, ), (4.0, ))
test_rrule(Base.vect, 5.0, 4.0, 3.0)
test_rrule(Base.vect, randn(2, 2), randn(3, 3))
end
@testset "inhomogeneous type" begin
test_rrule(
Base.vect, 5.0, 3f0;
atol=1e-6, rtol=1e-6, check_inferred=VERSION>=v"1.6",
) # tolerance due to Float32.
test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false)
end
end

@testset "reshape" begin
test_rrule(reshape, rand(4, 5), (2, 10))
test_rrule(reshape, rand(4, 5), 2, 10)
Expand Down

2 comments on commit 7593339

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/42027

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.5.0 -m "<description of version>" 7593339839832ffb7ac83401918390ffe9d4eb42
git push origin v1.5.0

Please sign in to comment.