Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Add an internal helper function to do more in-place updating #145

Merged
merged 1 commit into from
Apr 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,19 @@ output and `ȳ` the reverse-mode sensitivity of `y`.
∇(y::Node, ȳ) = propagate(tape(y), reverse_tape(y, ȳ))
@inline ∇(y::Node{<:∇Scalar}) = ∇(y, one(unbox(y)))

@inline ∇(x̄, f, ::Type{Arg{N}}, args...) where N = x̄ + ∇(f, Arg{N}, args...)
# This is a fallback method where we don't necessarily know what we'll be adding and whether
# we can update the value in-place, so we'll try to be clever and dispatch.
@inline ∇(x̄, f, ::Type{Arg{N}}, args...) where {N} = update!(x̄, ∇(f, Arg{N}, args...))

# Update regular arrays in-place. Structured array types should not be updated in-place,
# even though it technically "works" (https://github.com/JuliaLang/julia/issues/31674),
# so we'll only permit mutating addition for `Array`s, e.g. `Vector` and `Matrix`.
# Mixed array and scalar adds should not occur, as sensitivities should always have the
# same shape, so we won't bother allowing e.g. updating an array with a scalar on the RHS.
update!(x̄::Array{T,N}, y::AbstractArray{S,N}) where {T,S,N} = x̄ .+= y

# Fall back to using regular addition
update!(x̄, y) = x̄ + y

"""
∇(f; get_output::Bool=false)
Expand Down
22 changes: 22 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,26 @@ let
@test oned_container(Dict("a"=>5.0, "b"=>randn(3))) == Dict("a"=>1.0, "b"=>ones(3))
end

# To ensure we end up using the fallback machinery for ∇(x̄, f, ...) we'll define a new
# function and setup for it to use in the testset below
quad(A::Matrix, B::Matrix) = B'A*B
@explicit_intercepts quad Tuple{Matrix, Matrix}
Nabla.∇(::typeof(quad), ::Type{Arg{1}}, p, Y, Ȳ, A::Matrix, B::Matrix) = B*Ȳ*B'
Nabla.∇(::typeof(quad), ::Type{Arg{2}}, p, Y, Ȳ, A::Matrix, B::Matrix) = A*B*Ȳ' + A'B*Ȳ

@testset "Mutating values in the tape" begin
rng = MersenneTwister(123456)
n = 5
A = Leaf(Tape(), randn(rng, n, n))
B = randn(rng, n, n)
Q = quad(A, B)
QQ = quad(Q, B)
rt = ∇(QQ, Matrix(1.0I, n, n))
oldvals = map(deepcopy∘unbox, getfield(rt, :tape))
Nabla.propagate(Q, rt) # This triggers a mutating addition
newvals = map(unbox, getfield(rt, :tape))
@test !(oldvals[1] ≈ newvals[1])
@test oldvals[2:end] ≈ newvals[2:end]
end

end # testset "core"