-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't use broadcasting on non-primitives #3
Comments
I came up with a technique to solve this for ReverseDiff where we replace the broadcasted op with a forward-mode AD'd version, then cache those intermediary derivatives for use in the backwards pass. Some of us are writing a paper on the technique now, but until that's ready, here's a prototype implementation that gets good performance on the GPU: https://github.com/jrevels/MixedModeBroadcastAD.jl (cool package, btw 🙂) |
I would love to read the paper! Fortunately, for me it's not much an issue at the moment, so I have some time before it gets critical :) In theory, I could detect broadcasting on non-primitives in advance and do during forward pass the same trick as in reverse pass - call a function on first elements of arrays, writing ops to a "minitape", and then rewriting minitape to the main tape for arrays. But it doesn't sound very robust, so mixed-mode broadcasting can be a way to go. |
gradient for mea() and sum() with keywords
Many features including broadcasting are now handled by ChainRules, so closing this issue as outdated. |
Say, we have a function
logistic(::Real)
and no wrapper that would write it to the tape (like in scalar.jl). If we broadcast it onTArray
, it will be written to the tape as:During differentiation of
Bcast
we run the function in question on the first element ofTArray
, but sincelogistic(x[1])
doesn't record to minitape functionlogistic
, but instead the list of underlying operations (e.g. 5 underlying operations), ops on minitape won't be correctly mapped back to ops on tape itself and differentiation will fail.Possible solutions are:
broadcast
through the function. For example, we can convertTArray{<:Real}
toArray{TReal}
, run broadcasting and then assembleTArray
back. This, however, sounds quite fragile.I'm going to start with (1), but leave this issue open for a while.
The text was updated successfully, but these errors were encountered: