From a22bdb121d1da3ca242be9b336c28e610443b25c Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Wed, 27 Mar 2019 15:00:11 -0700 Subject: [PATCH] Add an internal helper function to do more in-place updating MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently the `∇(x̄, f, Arg{N}, args...)` method, which updates `x̄` with the result of `∇(f, Arg{N}, args...)`. This is done in-place for some functions `f` but not all. In the case of the fallback method, we can use dispatch to determine whether it's safe to do this in-place. --- src/core.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/core.jl b/src/core.jl index f84ffaf1..ac9aa85c 100644 --- a/src/core.jl +++ b/src/core.jl @@ -175,7 +175,14 @@ output and `ȳ` the reverse-mode sensitivity of `y`. ∇(y::Node, ȳ) = propagate(tape(y), reverse_tape(y, ȳ)) @inline ∇(y::Node{<:∇Scalar}) = ∇(y, one(unbox(y))) -@inline ∇(x̄, f, ::Type{Arg{N}}, args...) where N = x̄ + ∇(f, Arg{N}, args...) +# This is a fallback method where we don't necessarily know what we'll be adding and whether +# we can update the value in-place, so we'll try to be clever and dispatch. +@inline ∇(x̄, f, ::Type{Arg{N}}, args...) where {N} = update!(x̄, ∇(f, Arg{N}, args...)) +# Use broadcast for mixed array/scalar operations +@inline update!(x̄::∇Scalar, y::∇Array) = x̄ .+ y +@inline update!(x̄::∇Array, y::∇ArrayOrScalar) = (x̄ .= x̄ .+ y) +# Otherwise use regular addition +@inline update!(x̄, y) = x̄ + y """ ∇(f; get_output::Bool=false)