From 9c5a84fa2c230447a2c0020b0d67614b7a58cd40 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Wed, 27 Mar 2019 15:00:11 -0700 Subject: [PATCH] Add an internal helper function to do more in-place updating MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently the `∇(x̄, f, Arg{N}, args...)` method updates `x̄` with the result of `∇(f, Arg{N}, args...)`. This is done in-place for some functions `f` but not all. In the case of the fallback method, we can use dispatch to determine whether it's safe to do this in-place, thereby hopefully saving some allocations. --- src/core.jl | 14 +++++++++++++- test/core.jl | 22 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/core.jl b/src/core.jl index 29b70c12..59452f2b 100644 --- a/src/core.jl +++ b/src/core.jl @@ -179,7 +179,19 @@ output and `ȳ` the reverse-mode sensitivity of `y`. ∇(y::Node, ȳ) = propagate(tape(y), reverse_tape(y, ȳ)) @inline ∇(y::Node{<:∇Scalar}) = ∇(y, one(unbox(y))) -@inline ∇(x̄, f, ::Type{Arg{N}}, args...) where N = x̄ + ∇(f, Arg{N}, args...) +# This is a fallback method where we don't necessarily know what we'll be adding and whether +# we can update the value in-place, so we'll try to be clever and dispatch. +@inline ∇(x̄, f, ::Type{Arg{N}}, args...) where {N} = update!(x̄, ∇(f, Arg{N}, args...)) + +# Update regular arrays in-place. Structured array types should not be updated in-place, +# even though it technically "works" (https://github.com/JuliaLang/julia/issues/31674), +# so we'll only permit mutating addition for `Array`s, e.g. `Vector` and `Matrix`. +# Mixed array and scalar adds should not occur, as sensitivities should always have the +# same shape, so we won't bother allowing e.g. updating an array with a scalar on the RHS. +update!(x̄::Array{T,N}, y::AbstractArray{S,N}) where {T,S,N} = x̄ .+= y + +# Fall back to using regular addition +update!(x̄, y) = x̄ + y """ ∇(f; get_output::Bool=false) diff --git a/test/core.jl b/test/core.jl index 3de763c3..8f8eaed2 100644 --- a/test/core.jl +++ b/test/core.jl @@ -190,4 +190,26 @@ let @test oned_container(Dict("a"=>5.0, "b"=>randn(3))) == Dict("a"=>1.0, "b"=>ones(3)) end +# To ensure we end up using the fallback machinery for ∇(x̄, f, ...) we'll define a new +# function and setup for it to use in the testset below +quad(A::Matrix, B::Matrix) = B'A*B +@explicit_intercepts quad Tuple{Matrix, Matrix} +Nabla.∇(::typeof(quad), ::Type{Arg{1}}, p, Y, Ȳ, A::Matrix, B::Matrix) = B*Ȳ*B' +Nabla.∇(::typeof(quad), ::Type{Arg{2}}, p, Y, Ȳ, A::Matrix, B::Matrix) = A*B*Ȳ' + A'B*Ȳ + +@testset "Mutating values in the tape" begin + rng = MersenneTwister(123456) + n = 5 + A = Leaf(Tape(), randn(rng, n, n)) + B = randn(rng, n, n) + Q = quad(A, B) + QQ = quad(Q, B) + rt = ∇(QQ, Matrix(1.0I, n, n)) + oldvals = map(deepcopy∘unbox, getfield(rt, :tape)) + Nabla.propagate(Q, rt) # This triggers a mutating addition + newvals = map(unbox, getfield(rt, :tape)) + @test !(oldvals[1] ≈ newvals[1]) + @test oldvals[2:end] ≈ newvals[2:end] +end + end # testset "core"