From 7f8b54a3c71ef56e32d8b65e3ffc398ba4597235 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 2 Aug 2021 22:51:47 -0400 Subject: [PATCH 1/4] sum which allows 2nd derivatives --- src/rulesets/Base/mapreduce.jl | 40 ++++++++++++++++++++++++++------- test/rulesets/Base/mapreduce.jl | 36 ++++++++++++++++++++++------- 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index c007687b3..741afe569 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -1,24 +1,48 @@ ##### -##### `sum` +##### `sum(x)` ##### function frule((_, ẋ), ::typeof(sum), x; dims=:) return sum(x; dims=dims), sum(ẋ; dims=dims) end -function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number} +function rrule(::typeof(sum), x::AbstractArray; dims=:) + project = ProjectTo(x) y = sum(x; dims=dims) - function sum_pullback(ȳ) - # broadcasting the two works out the size no-matter `dims` - x̄ = InplaceableThunk( - x -> x .+= ȳ, - @thunk(broadcast(last∘tuple, x, ȳ)), + function sum_pullback(dy_raw) + dy = unthunk(dy_raw) + x_thunk = InplaceableThunk( + # Protect `dy` from broadcasting, for when `x` is an array of arrays: + dx -> dx .+= (dims isa Colon ? Ref(dy) : dy), + @thunk project(_unsum(x, dy, dims)) # `_unsum` handles Ref internally ) - return (NoTangent(), x̄) + return (NoTangent(), x_thunk) end return y, sum_pullback end +# This broadcasts `dy` to the shape of `x`, and should preserve e.g. CuArrays, StaticArrays. +# Ideally this would only need `typeof(x)` not `x`, but `similar` only has a suitable method +# when `eltype(x) == eltype(dy)`, which isn't guaranteed. +_unsum(x, dy, dims) = broadcast(last∘tuple, x, dy) +_unsum(x, dy, ::Colon) = broadcast(last∘tuple, x, Ref(dy)) + +# Allow for second derivatives of `sum`, by writing rules for `_unsum`: + +function frule((_, _, dydot, _), ::typeof(_unsum), x, dy, dims) + return _unsum(x, dy, dims), _unsum(x, dydot, dims) +end + +function rrule(::typeof(_unsum), x, dy, dims) + z = _unsum(x, dy, dims) + _unsum_pullback(dz) = (NoTangent(), NoTangent(), sum(unthunk(dz); dims=dims), NoTangent()) + return z, _unsum_pullback +end + +##### +##### `sum(f, x)` +##### + # Can't map over Adjoint/Transpose Vector function rrule( config::RuleConfig{>:HasReverseMode}, diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 37b9c2e0b..a1047bbab 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -1,12 +1,32 @@ @testset "Maps and Reductions" begin - @testset "sum" begin - sizes = (3, 4, 7) - @testset "dims = $dims" for dims in (:, 1) - @testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64) - x = randn(T, sizes[1:N]...) - test_frule(sum, x; fkwargs=(;dims=dims)) - test_rrule(sum, x; fkwargs=(;dims=dims)) - end + @testset "sum(x; dims=$dims)" for dims in (:, 2, (1,3)) + # Forward + test_frule(sum, rand(5); fkwargs=(;dims=dims)) + test_frule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims)) + + # Reverse + test_rrule(sum, rand(5); fkwargs=(;dims=dims)) + test_rrule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims)) + + # Structured matrices + test_rrule(sum, rand(5)'; fkwargs=(;dims=dims)) + y, back = rrule(sum, UpperTriangular(rand(5,5)); dims=dims) + unthunk(back(y*(1+im))[2]) isa UpperTriangular{Float64} + + # Function allowing for 2nd derivatives + for x in (rand(5), rand(2,3,4)) + dy = maximum(x; dims=dims) + test_frule(ChainRules._unsum, x, dy, dims) + test_rrule(ChainRules._unsum, x, dy, dims) + end + + # Arrays of arrays + for x in ([rand(ComplexF64, 3) for _ in 1:4], [rand(3) for _ in 1:2, _ in 1:3, _ in 1:4]) + test_rrule(sum, x; fkwargs=(;dims=dims), check_inferred=false) + + dy = sum(x; dims=dims) + ddy = rrule(ChainRules._unsum, x, dy, dims)[2](x)[3] + @test size(ddy) == size(dy) end end From 56e6a31e703132822c3d8963a514dc9a459b343d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 3 Aug 2021 17:03:32 -0400 Subject: [PATCH 2/4] add some sum(::Array{Bool}) rules --- src/rulesets/Base/nondiff.jl | 3 +++ test/rulesets/Base/mapreduce.jl | 11 ++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 7b8eaa3eb..065aaeb6a 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -76,6 +76,9 @@ @non_differentiable similar(::AbstractArray{Bool}, ::Any...) @non_differentiable stride(::AbstractArray{Bool}, ::Any) @non_differentiable strides(::AbstractArray{Bool}) +@non_differentiable sum(::AbstractArray{Bool}) +@non_differentiable sum(::Any, ::AbstractArray{Bool}) +@non_differentiable sum(::typeof(abs2), ::AbstractArray{Bool}) # avoids an ambiguity @non_differentiable vcat(::AbstractArray{Bool}...) @non_differentiable vec(::AbstractArray{Bool}) @non_differentiable Vector(::AbstractArray{Bool}) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index a1047bbab..4b5290b0f 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -12,7 +12,11 @@ test_rrule(sum, rand(5)'; fkwargs=(;dims=dims)) y, back = rrule(sum, UpperTriangular(rand(5,5)); dims=dims) unthunk(back(y*(1+im))[2]) isa UpperTriangular{Float64} + @test_skip test_rrule(sum, UpperTriangular(rand(5,5)) ⊢ randn(5,5); fkwargs=(;dims=dims), check_inferred=false) # Problem: in add!! Evaluated: isapprox + # Boolean -- via @non_differentiable + test_rrule(sum, randn(5) .> 0; fkwargs=(;dims=dims)) + # Function allowing for 2nd derivatives for x in (rand(5), rand(2,3,4)) dy = maximum(x; dims=dims) @@ -39,6 +43,9 @@ test_rrule(sum, abs2, x; fkwargs=(;dims=dims)) end end + + # Boolean -- via @non_differentiable, test that this isn't ambiguous + test_rrule(sum, abs2, randn(5) .> 0; fkwargs=(;dims=dims)) end # sum abs2 @testset "sum(f, xs)" begin @@ -69,6 +76,9 @@ # make sure we preserve type for Diagonal _, pb = rrule(ADviaRuleConfig(), sum, abs, Diagonal([1.0, -3.0])) @test pb(1.0)[3] isa Diagonal + + # Boolean -- via @non_differentiable, test that this isn't ambiguous + @test_skip test_rrule(sum, sqrt, randn(5) .> 0; fkwargs=(;dims=dims)) # MethodError: no method matching real(::NoTangent) end @testset "prod" begin @@ -120,7 +130,6 @@ @test unthunk(rrule(prod, UpperTriangular(ones(T,2,2)))[2](1.0)[2]) == UpperTriangular([0.0 0; 1 0]) # Symmetric -- at least this doesn't have zeros, still an unlikely combination - xs = Symmetric(rand(T,4,4)) @test unthunk(rrule(prod, Symmetric(T[1 2; -333 4]))[2](1.0)[2]) == [16 8; 8 4] # TODO debug why these fail https://github.com/JuliaDiff/ChainRules.jl/issues/475 From 2ff02b6eccfab7f41fc71fedc80d07791312391e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 3 Aug 2021 18:29:06 -0400 Subject: [PATCH 3/4] fix tests, add a few --- test/rulesets/Base/mapreduce.jl | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 4b5290b0f..91926a671 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -42,23 +42,29 @@ test_frule(sum, abs2, x; fkwargs=(;dims=dims)) test_rrule(sum, abs2, x; fkwargs=(;dims=dims)) end - end - # Boolean -- via @non_differentiable, test that this isn't ambiguous - test_rrule(sum, abs2, randn(5) .> 0; fkwargs=(;dims=dims)) + # Boolean -- via @non_differentiable, test that this isn't ambiguous + test_rrule(sum, abs2, randn(5) .> 0; fkwargs=(;dims=dims)) + end end # sum abs2 @testset "sum(f, xs)" begin # This calls back into AD test_rrule(sum, abs, [-4.0, 2.0, 2.0]) + test_rrule(sum, cbrt, randn(5)) test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0]) + # Complex numbers + test_rrule(sum, sqrt, rand(ComplexF64, 5)) + test_rrule(sum, abs, rand(ComplexF64, 3, 4)) # complex -> real + # inference fails for array of arrays test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false) # dims kwarg test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1)) test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2)) + test_rrule(sum, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,))) test_rrule(sum, abs, @SVector[1.0, -3.0]) @@ -78,7 +84,10 @@ @test pb(1.0)[3] isa Diagonal # Boolean -- via @non_differentiable, test that this isn't ambiguous - @test_skip test_rrule(sum, sqrt, randn(5) .> 0; fkwargs=(;dims=dims)) # MethodError: no method matching real(::NoTangent) + test_rrule(sum, sqrt, randn(5) .> 0) + test_rrule(sum, sqrt, randn(5,5) .> 0; fkwargs=(;dims=1)) + # ... and Bool produced by function + @test_skip test_rrule(sum, iszero, randn(5)) # DimensionMismatch("second dimension of A, 1, does not match length of x, 0") end @testset "prod" begin From 66612b32cecfe6ccf2350160c7000066a9c87008 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 3 Aug 2021 21:33:41 -0400 Subject: [PATCH 4/4] v1.5.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 05de6001e..9032af418 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.5.0" +version = "1.5.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"