From bdeaf2458411bb2af4df64f8cbba0abe99af4510 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 17 Jul 2020 13:23:32 +0100 Subject: [PATCH 1/9] add getindex rrule --- src/rulesets/Base/array.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 4803438f7..ba5d73f22 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -102,3 +102,25 @@ function rrule(::typeof(fill), value::Any, dims::Int...) end return fill(value, dims), fill_pullback end + +##### +##### getindex +##### + +function rrule(::typeof(getindex), x::Array{<:Number}, inds::Union{Int, Vararg{Int}}) + y = getindex(x, inds...) + function getindex_pullback(ȳ) + function getindex_add!(Δ) + Δ[inds...] .+= ȳ; + return Δ + end + + x̄ = InplaceableThunk( + @thunk(getindex_add!(zeros(x))), + getindex_add! + ) + return (NO_FIELDS, x̄, (DoesNotExist() for _ in inds)...) + end + + return y, getindex_pullback +end \ No newline at end of file From 47bf51722ebad37923a9a919821906511043cf84 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 17 Jul 2020 14:59:05 +0100 Subject: [PATCH 2/9] correct and test getindex rrule --- src/rulesets/Base/array.jl | 6 +++--- test/rulesets/Base/array.jl | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index ba5d73f22..2c720e4dc 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -107,16 +107,16 @@ end ##### getindex ##### -function rrule(::typeof(getindex), x::Array{<:Number}, inds::Union{Int, Vararg{Int}}) +function rrule(::typeof(getindex), x::Array{<:Number}, inds::Vararg{Int}) y = getindex(x, inds...) function getindex_pullback(ȳ) function getindex_add!(Δ) - Δ[inds...] .+= ȳ; + Δ[inds...] = Δ[inds...] .+ ȳ return Δ end x̄ = InplaceableThunk( - @thunk(getindex_add!(zeros(x))), + @thunk(getindex_add!(zero(x))), getindex_add! ) return (NO_FIELDS, x̄, (DoesNotExist() for _ in inds)...) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index b34c71b98..8ffcdee38 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -66,7 +66,7 @@ end (ds, dv, dd) = pullback(ones(4)) @test ds === NO_FIELDS @test dd isa DoesNotExist - @test extern(dv) == 4 + @test extern(dv) == 4 y, pullback = rrule(fill, 2.0, (3, 3, 3)) @test y == fill(2.0, (3, 3, 3)) @@ -75,3 +75,17 @@ end @test dd isa DoesNotExist @test dv ≈ 27.0 end + +@testset "getindex" begin + x = [1.0 2.0 3.0; 10.0 20.0 30.0] + ind = [2,3] + ȳ = 7.2 + x̄_fd, = j′vp(ChainRulesTestUtils._fdm, a->getindex(a, ind...), ȳ, x) + y, pullback = rrule(getindex, x, ind...) + _, x̄_ad, = pullback(ȳ) + + @test unthunk(x̄_ad) ≈ x̄_fd + + x_like = x .+ 1.0 + @test x̄_ad.add!(copy(x_like)) ≈ x_like + x̄_fd +end \ No newline at end of file From e85da6f2f97c04091ed94300bf9366d7d6dee454 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 15 Oct 2020 19:06:50 +0100 Subject: [PATCH 3/9] Use rrule_test --- src/rulesets/Base/array.jl | 2 +- test/rulesets/Base/array.jl | 17 +++++------------ 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 2c720e4dc..da16f6f94 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -123,4 +123,4 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds::Vararg{Int}) end return y, getindex_pullback -end \ No newline at end of file +end diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 8ffcdee38..95a6ca9f5 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -66,7 +66,7 @@ end (ds, dv, dd) = pullback(ones(4)) @test ds === NO_FIELDS @test dd isa DoesNotExist - @test extern(dv) == 4 + @test extern(dv) == 4 y, pullback = rrule(fill, 2.0, (3, 3, 3)) @test y == fill(2.0, (3, 3, 3)) @@ -78,14 +78,7 @@ end @testset "getindex" begin x = [1.0 2.0 3.0; 10.0 20.0 30.0] - ind = [2,3] - ȳ = 7.2 - x̄_fd, = j′vp(ChainRulesTestUtils._fdm, a->getindex(a, ind...), ȳ, x) - y, pullback = rrule(getindex, x, ind...) - _, x̄_ad, = pullback(ȳ) - - @test unthunk(x̄_ad) ≈ x̄_fd - - x_like = x .+ 1.0 - @test x̄_ad.add!(copy(x_like)) ≈ x_like + x̄_fd -end \ No newline at end of file + x̄ = [1.4 2.5 3.7; 10.5 20.1 30.2] + rrule_test(getindex, 2.3, (x, x̄), (2, nothing)) + rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing)) +end From ba9a148368396ac1adb9901719a3abaa7ae0cd2c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 16 Oct 2020 16:57:20 +0100 Subject: [PATCH 4/9] Add tests of overlapping indexes --- test/rulesets/Base/array.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 95a6ca9f5..0328783a6 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -81,4 +81,8 @@ end x̄ = [1.4 2.5 3.7; 10.5 20.1 30.2] rrule_test(getindex, 2.3, (x, x̄), (2, nothing)) rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing)) + + # overlapping indexes + rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (2, nothing)) + rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing), (2, nothing)) end From 1665f83f782333b2365ea834c0342e4a5f60f312 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 16 Oct 2020 17:30:17 +0100 Subject: [PATCH 5/9] Remove incorrect comment about overlapping indexs --- test/rulesets/Base/array.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 0328783a6..5379f3693 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -82,7 +82,6 @@ end rrule_test(getindex, 2.3, (x, x̄), (2, nothing)) rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing)) - # overlapping indexes rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (2, nothing)) rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing), (2, nothing)) end From 10a8f0d347e448fd4436ecb7d3093372599894e9 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 16 Oct 2020 17:58:46 +0100 Subject: [PATCH 6/9] Move indexing to own file --- src/ChainRules.jl | 1 + src/rulesets/Base/array.jl | 22 ---------------------- src/rulesets/Base/indexing.jl | 21 +++++++++++++++++++++ test/rulesets/Base/array.jl | 10 ---------- test/rulesets/Base/indexing.jl | 7 +++++++ test/runtests.jl | 1 + 6 files changed, 30 insertions(+), 32 deletions(-) create mode 100644 src/rulesets/Base/indexing.jl create mode 100644 test/rulesets/Base/indexing.jl diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 09dce121c..1e0989b40 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -32,6 +32,7 @@ include("rulesets/Base/fastmath_able.jl") include("rulesets/Base/evalpoly.jl") include("rulesets/Base/array.jl") include("rulesets/Base/arraymath.jl") +include("rulesets/Base/indexing.jl") include("rulesets/Base/mapreduce.jl") include("rulesets/Statistics/statistics.jl") diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index da16f6f94..4803438f7 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -102,25 +102,3 @@ function rrule(::typeof(fill), value::Any, dims::Int...) end return fill(value, dims), fill_pullback end - -##### -##### getindex -##### - -function rrule(::typeof(getindex), x::Array{<:Number}, inds::Vararg{Int}) - y = getindex(x, inds...) - function getindex_pullback(ȳ) - function getindex_add!(Δ) - Δ[inds...] = Δ[inds...] .+ ȳ - return Δ - end - - x̄ = InplaceableThunk( - @thunk(getindex_add!(zero(x))), - getindex_add! - ) - return (NO_FIELDS, x̄, (DoesNotExist() for _ in inds)...) - end - - return y, getindex_pullback -end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl new file mode 100644 index 000000000..a888c501a --- /dev/null +++ b/src/rulesets/Base/indexing.jl @@ -0,0 +1,21 @@ +##### +##### getindex +##### + +function rrule(::typeof(getindex), x::Array{<:Number}, inds::Vararg{Int}) + y = getindex(x, inds...) + function getindex_pullback(ȳ) + function getindex_add!(Δ) + Δ[inds...] = Δ[inds...] .+ ȳ + return Δ + end + + x̄ = InplaceableThunk( + @thunk(getindex_add!(zero(x))), + getindex_add! + ) + return (NO_FIELDS, x̄, (DoesNotExist() for _ in inds)...) + end + + return y, getindex_pullback +end diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 5379f3693..b34c71b98 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -75,13 +75,3 @@ end @test dd isa DoesNotExist @test dv ≈ 27.0 end - -@testset "getindex" begin - x = [1.0 2.0 3.0; 10.0 20.0 30.0] - x̄ = [1.4 2.5 3.7; 10.5 20.1 30.2] - rrule_test(getindex, 2.3, (x, x̄), (2, nothing)) - rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing)) - - rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (2, nothing)) - rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing), (2, nothing)) -end diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl new file mode 100644 index 000000000..94ddad823 --- /dev/null +++ b/test/rulesets/Base/indexing.jl @@ -0,0 +1,7 @@ +@testset "getindex" begin + x = [1.0 2.0 3.0; 10.0 20.0 30.0] + x̄ = [1.4 2.5 3.7; 10.5 20.1 30.2] + rrule_test(getindex, 2.3, (x, x̄), (2, nothing)) + rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing)) + rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (2, nothing)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 3c815ccc1..21aaee18d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,6 +24,7 @@ println("Testing ChainRules.jl") include(joinpath("rulesets", "Base", "evalpoly.jl")) include(joinpath("rulesets", "Base", "array.jl")) include(joinpath("rulesets", "Base", "arraymath.jl")) + include(joinpath("rulesets", "Base", "indexing.jl")) include(joinpath("rulesets", "Base", "mapreduce.jl")) end From 631cdf79a8c8b4eb382486ed42f26a6bbeb20aa8 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 16 Oct 2020 19:59:24 +0100 Subject: [PATCH 7/9] handle all indexing on Arrays more tests --- src/rulesets/Base/indexing.jl | 12 +++++-- test/rulesets/Base/indexing.jl | 60 +++++++++++++++++++++++++++++++--- 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index a888c501a..3d21e5e2e 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -2,11 +2,17 @@ ##### getindex ##### -function rrule(::typeof(getindex), x::Array{<:Number}, inds::Vararg{Int}) - y = getindex(x, inds...) +function rrule(::typeof(getindex), x::Array, inds...) + # removes any logical indexing, CartesianIndex etc + # leaving us just with a tuple of Int, Arrays of Int and Ranges of Int + plain_inds = Base.to_indices(x, inds) + y = getindex(x, plain_inds...) function getindex_pullback(ȳ) function getindex_add!(Δ) - Δ[inds...] = Δ[inds...] .+ ȳ + # this a optimizes away for simple cases + for (ȳ_ii, ii) in zip(ȳ, Iterators.product(plain_inds...)) + Δ[ii...] += ȳ_ii + end return Δ end diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 94ddad823..ab2ded33e 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -1,7 +1,57 @@ @testset "getindex" begin - x = [1.0 2.0 3.0; 10.0 20.0 30.0] - x̄ = [1.4 2.5 3.7; 10.5 20.1 30.2] - rrule_test(getindex, 2.3, (x, x̄), (2, nothing)) - rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing)) - rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (2, nothing)) + @testset "getindex(::Matrix{<:Number},...)" begin + x = [1.0 2.0 3.0; 10.0 20.0 30.0] + x̄ = [1.4 2.5 3.7; 10.5 20.1 30.2] + full_ȳ = [7.4 5.5 2.7; 8.5 11.1 4.2] + + @testset "single element" begin + rrule_test(getindex, 2.3, (x, x̄), (2, nothing)) + rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing)) + rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (2, nothing)) + + rrule_test(getindex, 2.3, (x, x̄), (CartesianIndex(2, 3), nothing)) + end + + @testset "slice/index postions" begin + rrule_test(getindex, [2.3, 3.1], (x, x̄), (2:3, nothing)) + rrule_test(getindex, [2.3, 3.1], (x, x̄), (3:-1:2, nothing)) + rrule_test(getindex, [2.3, 3.1], (x, x̄), ([3,2], nothing)) + rrule_test(getindex, [2.3, 3.1], (x, x̄), ([2,3], nothing)) + + rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (1:2, nothing), (2:3, nothing)) + rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (:, nothing), (2:3, nothing)) + + rrule_test(getindex, [2.3, 3.1], (x, x̄), (2:3, nothing), (1, nothing)) + rrule_test(getindex, [2.3, 3.1], (x, x̄), (1, nothing), (2:3, nothing)) + + rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (1:2, nothing), (2:3, nothing)) + rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (:, nothing), (2:3, nothing)) + + + rrule_test(getindex, full_ȳ, (x, x̄), (:, nothing), (:, nothing)) + rrule_test(getindex, full_ȳ[:], (x, x̄), (:, nothing)) + end + + @testset "masking" begin + rrule_test(getindex, full_ȳ, (x, x̄), (trues(size(x)), nothing)) + rrule_test(getindex, full_ȳ[:], (x, x̄), (trues(length(x)), nothing)) + + mask = falses(size(x)) + mask[2,3] = true + mask[1,2] = true + rrule_test(getindex, [2.3, 3.1], (x, x̄), (mask, nothing)) + + rrule_test( + getindex, full_ȳ[1,:], (x, x̄), ([true, false], nothing), (:, nothing) + ) + end + + @testset "By position with repeated elements" begin + rrule_test(getindex, [2.3, 3.1], (x, x̄), ([2, 2], nothing)) + rrule_test(getindex, [2.3, 3.1, 4.1], (x, x̄), ([2, 2, 2], nothing)) + rrule_test( + getindex, [2.3 3.1; 4.1 5.1], (x, x̄), ([2,2], nothing), ([3,3], nothing) + ) + end + end end From 680bd02bcbac6dda3a5dfb9f0574da07c03df3d9 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 16 Oct 2020 20:15:07 +0100 Subject: [PATCH 8/9] fix typo in test --- test/rulesets/Base/indexing.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index ab2ded33e..d8ba38ba0 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -21,8 +21,8 @@ rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (1:2, nothing), (2:3, nothing)) rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (:, nothing), (2:3, nothing)) - rrule_test(getindex, [2.3, 3.1], (x, x̄), (2:3, nothing), (1, nothing)) - rrule_test(getindex, [2.3, 3.1], (x, x̄), (1, nothing), (2:3, nothing)) + rrule_test(getindex, [2.3, 3.1], (x, x̄), (1:2, nothing), (1, nothing)) + rrule_test(getindex, [2.3, 3.1], (x, x̄), (1, nothing), (1:2, nothing)) rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (1:2, nothing), (2:3, nothing)) rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (:, nothing), (2:3, nothing)) From 47b840149d87bca5bfa4abf817ce3db16bc9de4f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 16 Oct 2020 21:47:48 +0100 Subject: [PATCH 9/9] Restrict to arrays of numbers --- src/rulesets/Base/indexing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 3d21e5e2e..8cf0067aa 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -2,7 +2,7 @@ ##### getindex ##### -function rrule(::typeof(getindex), x::Array, inds...) +function rrule(::typeof(getindex), x::Array{<:Number}, inds...) # removes any logical indexing, CartesianIndex etc # leaving us just with a tuple of Int, Arrays of Int and Ranges of Int plain_inds = Base.to_indices(x, inds)