diff --git a/Project.toml b/Project.toml index 05de6001e..9032af418 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.5.0" +version = "1.5.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index c007687b3..741afe569 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -1,24 +1,48 @@ ##### -##### `sum` +##### `sum(x)` ##### function frule((_, ẋ), ::typeof(sum), x; dims=:) return sum(x; dims=dims), sum(ẋ; dims=dims) end -function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number} +function rrule(::typeof(sum), x::AbstractArray; dims=:) + project = ProjectTo(x) y = sum(x; dims=dims) - function sum_pullback(ȳ) - # broadcasting the two works out the size no-matter `dims` - x̄ = InplaceableThunk( - x -> x .+= ȳ, - @thunk(broadcast(last∘tuple, x, ȳ)), + function sum_pullback(dy_raw) + dy = unthunk(dy_raw) + x_thunk = InplaceableThunk( + # Protect `dy` from broadcasting, for when `x` is an array of arrays: + dx -> dx .+= (dims isa Colon ? Ref(dy) : dy), + @thunk project(_unsum(x, dy, dims)) # `_unsum` handles Ref internally ) - return (NoTangent(), x̄) + return (NoTangent(), x_thunk) end return y, sum_pullback end +# This broadcasts `dy` to the shape of `x`, and should preserve e.g. CuArrays, StaticArrays. +# Ideally this would only need `typeof(x)` not `x`, but `similar` only has a suitable method +# when `eltype(x) == eltype(dy)`, which isn't guaranteed. +_unsum(x, dy, dims) = broadcast(last∘tuple, x, dy) +_unsum(x, dy, ::Colon) = broadcast(last∘tuple, x, Ref(dy)) + +# Allow for second derivatives of `sum`, by writing rules for `_unsum`: + +function frule((_, _, dydot, _), ::typeof(_unsum), x, dy, dims) + return _unsum(x, dy, dims), _unsum(x, dydot, dims) +end + +function rrule(::typeof(_unsum), x, dy, dims) + z = _unsum(x, dy, dims) + _unsum_pullback(dz) = (NoTangent(), NoTangent(), sum(unthunk(dz); dims=dims), NoTangent()) + return z, _unsum_pullback +end + +##### +##### `sum(f, x)` +##### + # Can't map over Adjoint/Transpose Vector function rrule( config::RuleConfig{>:HasReverseMode}, diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 7b8eaa3eb..065aaeb6a 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -76,6 +76,9 @@ @non_differentiable similar(::AbstractArray{Bool}, ::Any...) @non_differentiable stride(::AbstractArray{Bool}, ::Any) @non_differentiable strides(::AbstractArray{Bool}) +@non_differentiable sum(::AbstractArray{Bool}) +@non_differentiable sum(::Any, ::AbstractArray{Bool}) +@non_differentiable sum(::typeof(abs2), ::AbstractArray{Bool}) # avoids an ambiguity @non_differentiable vcat(::AbstractArray{Bool}...) @non_differentiable vec(::AbstractArray{Bool}) @non_differentiable Vector(::AbstractArray{Bool}) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 37b9c2e0b..91926a671 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -1,12 +1,36 @@ @testset "Maps and Reductions" begin - @testset "sum" begin - sizes = (3, 4, 7) - @testset "dims = $dims" for dims in (:, 1) - @testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64) - x = randn(T, sizes[1:N]...) - test_frule(sum, x; fkwargs=(;dims=dims)) - test_rrule(sum, x; fkwargs=(;dims=dims)) - end + @testset "sum(x; dims=$dims)" for dims in (:, 2, (1,3)) + # Forward + test_frule(sum, rand(5); fkwargs=(;dims=dims)) + test_frule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims)) + + # Reverse + test_rrule(sum, rand(5); fkwargs=(;dims=dims)) + test_rrule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims)) + + # Structured matrices + test_rrule(sum, rand(5)'; fkwargs=(;dims=dims)) + y, back = rrule(sum, UpperTriangular(rand(5,5)); dims=dims) + unthunk(back(y*(1+im))[2]) isa UpperTriangular{Float64} + @test_skip test_rrule(sum, UpperTriangular(rand(5,5)) ⊢ randn(5,5); fkwargs=(;dims=dims), check_inferred=false) # Problem: in add!! Evaluated: isapprox + + # Boolean -- via @non_differentiable + test_rrule(sum, randn(5) .> 0; fkwargs=(;dims=dims)) + + # Function allowing for 2nd derivatives + for x in (rand(5), rand(2,3,4)) + dy = maximum(x; dims=dims) + test_frule(ChainRules._unsum, x, dy, dims) + test_rrule(ChainRules._unsum, x, dy, dims) + end + + # Arrays of arrays + for x in ([rand(ComplexF64, 3) for _ in 1:4], [rand(3) for _ in 1:2, _ in 1:3, _ in 1:4]) + test_rrule(sum, x; fkwargs=(;dims=dims), check_inferred=false) + + dy = sum(x; dims=dims) + ddy = rrule(ChainRules._unsum, x, dy, dims)[2](x)[3] + @test size(ddy) == size(dy) end end @@ -18,20 +42,29 @@ test_frule(sum, abs2, x; fkwargs=(;dims=dims)) test_rrule(sum, abs2, x; fkwargs=(;dims=dims)) end + + # Boolean -- via @non_differentiable, test that this isn't ambiguous + test_rrule(sum, abs2, randn(5) .> 0; fkwargs=(;dims=dims)) end end # sum abs2 @testset "sum(f, xs)" begin # This calls back into AD test_rrule(sum, abs, [-4.0, 2.0, 2.0]) + test_rrule(sum, cbrt, randn(5)) test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0]) + # Complex numbers + test_rrule(sum, sqrt, rand(ComplexF64, 5)) + test_rrule(sum, abs, rand(ComplexF64, 3, 4)) # complex -> real + # inference fails for array of arrays test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false) # dims kwarg test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1)) test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2)) + test_rrule(sum, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,))) test_rrule(sum, abs, @SVector[1.0, -3.0]) @@ -49,6 +82,12 @@ # make sure we preserve type for Diagonal _, pb = rrule(ADviaRuleConfig(), sum, abs, Diagonal([1.0, -3.0])) @test pb(1.0)[3] isa Diagonal + + # Boolean -- via @non_differentiable, test that this isn't ambiguous + test_rrule(sum, sqrt, randn(5) .> 0) + test_rrule(sum, sqrt, randn(5,5) .> 0; fkwargs=(;dims=1)) + # ... and Bool produced by function + @test_skip test_rrule(sum, iszero, randn(5)) # DimensionMismatch("second dimension of A, 1, does not match length of x, 0") end @testset "prod" begin @@ -100,7 +139,6 @@ @test unthunk(rrule(prod, UpperTriangular(ones(T,2,2)))[2](1.0)[2]) == UpperTriangular([0.0 0; 1 0]) # Symmetric -- at least this doesn't have zeros, still an unlikely combination - xs = Symmetric(rand(T,4,4)) @test unthunk(rrule(prod, Symmetric(T[1 2; -333 4]))[2](1.0)[2]) == [16 8; 8 4] # TODO debug why these fail https://github.com/JuliaDiff/ChainRules.jl/issues/475