diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 09dce121c..1e0989b40 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -32,6 +32,7 @@ include("rulesets/Base/fastmath_able.jl") include("rulesets/Base/evalpoly.jl") include("rulesets/Base/array.jl") include("rulesets/Base/arraymath.jl") +include("rulesets/Base/indexing.jl") include("rulesets/Base/mapreduce.jl") include("rulesets/Statistics/statistics.jl") diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl new file mode 100644 index 000000000..8cf0067aa --- /dev/null +++ b/src/rulesets/Base/indexing.jl @@ -0,0 +1,27 @@ +##### +##### getindex +##### + +function rrule(::typeof(getindex), x::Array{<:Number}, inds...) + # removes any logical indexing, CartesianIndex etc + # leaving us just with a tuple of Int, Arrays of Int and Ranges of Int + plain_inds = Base.to_indices(x, inds) + y = getindex(x, plain_inds...) + function getindex_pullback(ȳ) + function getindex_add!(Δ) + # this a optimizes away for simple cases + for (ȳ_ii, ii) in zip(ȳ, Iterators.product(plain_inds...)) + Δ[ii...] += ȳ_ii + end + return Δ + end + + x̄ = InplaceableThunk( + @thunk(getindex_add!(zero(x))), + getindex_add! + ) + return (NO_FIELDS, x̄, (DoesNotExist() for _ in inds)...) + end + + return y, getindex_pullback +end diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl new file mode 100644 index 000000000..d8ba38ba0 --- /dev/null +++ b/test/rulesets/Base/indexing.jl @@ -0,0 +1,57 @@ +@testset "getindex" begin + @testset "getindex(::Matrix{<:Number},...)" begin + x = [1.0 2.0 3.0; 10.0 20.0 30.0] + x̄ = [1.4 2.5 3.7; 10.5 20.1 30.2] + full_ȳ = [7.4 5.5 2.7; 8.5 11.1 4.2] + + @testset "single element" begin + rrule_test(getindex, 2.3, (x, x̄), (2, nothing)) + rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing)) + rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (2, nothing)) + + rrule_test(getindex, 2.3, (x, x̄), (CartesianIndex(2, 3), nothing)) + end + + @testset "slice/index postions" begin + rrule_test(getindex, [2.3, 3.1], (x, x̄), (2:3, nothing)) + rrule_test(getindex, [2.3, 3.1], (x, x̄), (3:-1:2, nothing)) + rrule_test(getindex, [2.3, 3.1], (x, x̄), ([3,2], nothing)) + rrule_test(getindex, [2.3, 3.1], (x, x̄), ([2,3], nothing)) + + rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (1:2, nothing), (2:3, nothing)) + rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (:, nothing), (2:3, nothing)) + + rrule_test(getindex, [2.3, 3.1], (x, x̄), (1:2, nothing), (1, nothing)) + rrule_test(getindex, [2.3, 3.1], (x, x̄), (1, nothing), (1:2, nothing)) + + rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (1:2, nothing), (2:3, nothing)) + rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (:, nothing), (2:3, nothing)) + + + rrule_test(getindex, full_ȳ, (x, x̄), (:, nothing), (:, nothing)) + rrule_test(getindex, full_ȳ[:], (x, x̄), (:, nothing)) + end + + @testset "masking" begin + rrule_test(getindex, full_ȳ, (x, x̄), (trues(size(x)), nothing)) + rrule_test(getindex, full_ȳ[:], (x, x̄), (trues(length(x)), nothing)) + + mask = falses(size(x)) + mask[2,3] = true + mask[1,2] = true + rrule_test(getindex, [2.3, 3.1], (x, x̄), (mask, nothing)) + + rrule_test( + getindex, full_ȳ[1,:], (x, x̄), ([true, false], nothing), (:, nothing) + ) + end + + @testset "By position with repeated elements" begin + rrule_test(getindex, [2.3, 3.1], (x, x̄), ([2, 2], nothing)) + rrule_test(getindex, [2.3, 3.1, 4.1], (x, x̄), ([2, 2, 2], nothing)) + rrule_test( + getindex, [2.3 3.1; 4.1 5.1], (x, x̄), ([2,2], nothing), ([3,3], nothing) + ) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 3c815ccc1..21aaee18d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,6 +24,7 @@ println("Testing ChainRules.jl") include(joinpath("rulesets", "Base", "evalpoly.jl")) include(joinpath("rulesets", "Base", "array.jl")) include(joinpath("rulesets", "Base", "arraymath.jl")) + include(joinpath("rulesets", "Base", "indexing.jl")) include(joinpath("rulesets", "Base", "mapreduce.jl")) end