Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rule for sum allowing 2nd derivatives #494

Merged
merged 4 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.5.0"
version = "1.5.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
40 changes: 32 additions & 8 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,48 @@
#####
##### `sum`
##### `sum(x)`
#####

function frule((_, ẋ), ::typeof(sum), x; dims=:)
return sum(x; dims=dims), sum(ẋ; dims=dims)
end

function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
function rrule(::typeof(sum), x::AbstractArray; dims=:)
project = ProjectTo(x)
y = sum(x; dims=dims)
function sum_pullback(ȳ)
# broadcasting the two works out the size no-matter `dims`
x̄ = InplaceableThunk(
x -> x .+= ȳ,
@thunk(broadcast(last∘tuple, x, ȳ)),
function sum_pullback(dy_raw)
dy = unthunk(dy_raw)
x_thunk = InplaceableThunk(
# Protect `dy` from broadcasting, for when `x` is an array of arrays:
dx -> dx .+= (dims isa Colon ? Ref(dy) : dy),
@thunk project(_unsum(x, dy, dims)) # `_unsum` handles Ref internally
)
return (NoTangent(), )
return (NoTangent(), x_thunk)
end
return y, sum_pullback
end

# This broadcasts `dy` to the shape of `x`, and should preserve e.g. CuArrays, StaticArrays.
# Ideally this would only need `typeof(x)` not `x`, but `similar` only has a suitable method
# when `eltype(x) == eltype(dy)`, which isn't guaranteed.
_unsum(x, dy, dims) = broadcast(last∘tuple, x, dy)
_unsum(x, dy, ::Colon) = broadcast(last∘tuple, x, Ref(dy))

# Allow for second derivatives of `sum`, by writing rules for `_unsum`:

function frule((_, _, dydot, _), ::typeof(_unsum), x, dy, dims)
return _unsum(x, dy, dims), _unsum(x, dydot, dims)
end

function rrule(::typeof(_unsum), x, dy, dims)
z = _unsum(x, dy, dims)
_unsum_pullback(dz) = (NoTangent(), NoTangent(), sum(unthunk(dz); dims=dims), NoTangent())
return z, _unsum_pullback
end

#####
##### `sum(f, x)`
#####

# Can't map over Adjoint/Transpose Vector
function rrule(
config::RuleConfig{>:HasReverseMode},
Expand Down
3 changes: 3 additions & 0 deletions src/rulesets/Base/nondiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@
@non_differentiable similar(::AbstractArray{Bool}, ::Any...)
@non_differentiable stride(::AbstractArray{Bool}, ::Any)
@non_differentiable strides(::AbstractArray{Bool})
@non_differentiable sum(::AbstractArray{Bool})
@non_differentiable sum(::Any, ::AbstractArray{Bool})
@non_differentiable sum(::typeof(abs2), ::AbstractArray{Bool}) # avoids an ambiguity
@non_differentiable vcat(::AbstractArray{Bool}...)
@non_differentiable vec(::AbstractArray{Bool})
@non_differentiable Vector(::AbstractArray{Bool})
Expand Down
56 changes: 47 additions & 9 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,36 @@
@testset "Maps and Reductions" begin
@testset "sum" begin
sizes = (3, 4, 7)
@testset "dims = $dims" for dims in (:, 1)
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
x = randn(T, sizes[1:N]...)
test_frule(sum, x; fkwargs=(;dims=dims))
test_rrule(sum, x; fkwargs=(;dims=dims))
end
@testset "sum(x; dims=$dims)" for dims in (:, 2, (1,3))
# Forward
test_frule(sum, rand(5); fkwargs=(;dims=dims))
test_frule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims))

# Reverse
test_rrule(sum, rand(5); fkwargs=(;dims=dims))
test_rrule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims))

# Structured matrices
test_rrule(sum, rand(5)'; fkwargs=(;dims=dims))
y, back = rrule(sum, UpperTriangular(rand(5,5)); dims=dims)
unthunk(back(y*(1+im))[2]) isa UpperTriangular{Float64}
@test_skip test_rrule(sum, UpperTriangular(rand(5,5)) ⊢ randn(5,5); fkwargs=(;dims=dims), check_inferred=false) # Problem: in add!! Evaluated: isapprox

# Boolean -- via @non_differentiable
test_rrule(sum, randn(5) .> 0; fkwargs=(;dims=dims))

# Function allowing for 2nd derivatives
for x in (rand(5), rand(2,3,4))
dy = maximum(x; dims=dims)
test_frule(ChainRules._unsum, x, dy, dims)
test_rrule(ChainRules._unsum, x, dy, dims)
end

# Arrays of arrays
for x in ([rand(ComplexF64, 3) for _ in 1:4], [rand(3) for _ in 1:2, _ in 1:3, _ in 1:4])
test_rrule(sum, x; fkwargs=(;dims=dims), check_inferred=false)

dy = sum(x; dims=dims)
ddy = rrule(ChainRules._unsum, x, dy, dims)[2](x)[3]
@test size(ddy) == size(dy)
end
end

Expand All @@ -18,20 +42,29 @@
test_frule(sum, abs2, x; fkwargs=(;dims=dims))
test_rrule(sum, abs2, x; fkwargs=(;dims=dims))
end

# Boolean -- via @non_differentiable, test that this isn't ambiguous
test_rrule(sum, abs2, randn(5) .> 0; fkwargs=(;dims=dims))
end
end # sum abs2

@testset "sum(f, xs)" begin
# This calls back into AD
test_rrule(sum, abs, [-4.0, 2.0, 2.0])
test_rrule(sum, cbrt, randn(5))
test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0])

# Complex numbers
test_rrule(sum, sqrt, rand(ComplexF64, 5))
test_rrule(sum, abs, rand(ComplexF64, 3, 4)) # complex -> real

# inference fails for array of arrays
test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false)

# dims kwarg
test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1))
test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2))
test_rrule(sum, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,)))

test_rrule(sum, abs, @SVector[1.0, -3.0])

Expand All @@ -49,6 +82,12 @@
# make sure we preserve type for Diagonal
_, pb = rrule(ADviaRuleConfig(), sum, abs, Diagonal([1.0, -3.0]))
@test pb(1.0)[3] isa Diagonal

# Boolean -- via @non_differentiable, test that this isn't ambiguous
test_rrule(sum, sqrt, randn(5) .> 0)
test_rrule(sum, sqrt, randn(5,5) .> 0; fkwargs=(;dims=1))
# ... and Bool produced by function
@test_skip test_rrule(sum, iszero, randn(5)) # DimensionMismatch("second dimension of A, 1, does not match length of x, 0")
end

@testset "prod" begin
Expand Down Expand Up @@ -100,7 +139,6 @@
@test unthunk(rrule(prod, UpperTriangular(ones(T,2,2)))[2](1.0)[2]) == UpperTriangular([0.0 0; 1 0])

# Symmetric -- at least this doesn't have zeros, still an unlikely combination

xs = Symmetric(rand(T,4,4))
@test unthunk(rrule(prod, Symmetric(T[1 2; -333 4]))[2](1.0)[2]) == [16 8; 8 4]
# TODO debug why these fail https://github.com/JuliaDiff/ChainRules.jl/issues/475
Expand Down