Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't use broadcasting on non-primitives #3

Closed
dfdx opened this issue Jul 26, 2018 · 3 comments
Closed

Can't use broadcasting on non-primitives #3

dfdx opened this issue Jul 26, 2018 · 3 comments

Comments

@dfdx
Copy link
Owner

dfdx commented Jul 26, 2018

Say, we have a function logistic(::Real) and no wrapper that would write it to the tape (like in scalar.jl). If we broadcast it on TArray, it will be written to the tape as:

record!(tape, Bcast, logistic, (x,))

During differentiation of Bcast we run the function in question on the first element of TArray, but since logistic(x[1]) doesn't record to minitape function logistic, but instead the list of underlying operations (e.g. 5 underlying operations), ops on minitape won't be correctly mapped back to ops on tape itself and differentiation will fail.

Possible solutions are:

  1. Forbid broadcasting on non-primitives. This actually may be fine for closed system that Yota targets to be right now, but most likely will cause annoyance for broader audience.
  2. Push broadcast through the function. For example, we can convert TArray{<:Real} to Array{TReal}, run broadcasting and then assemble TArray back. This, however, sounds quite fragile.
  3. Write operations to a minitape and then rewrite calls to corresponding broadcasting on the main tape. The disadvantage is that we will execute these ops twice - for first element and the whole tensor - which is undesirable for dynamic graphs.

I'm going to start with (1), but leave this issue open for a while.

@jrevels
Copy link

jrevels commented Aug 13, 2018

I came up with a technique to solve this for ReverseDiff where we replace the broadcasted op with a forward-mode AD'd version, then cache those intermediary derivatives for use in the backwards pass. Some of us are writing a paper on the technique now, but until that's ready, here's a prototype implementation that gets good performance on the GPU: https://github.com/jrevels/MixedModeBroadcastAD.jl

(cool package, btw 🙂)

@dfdx
Copy link
Owner Author

dfdx commented Aug 13, 2018

I would love to read the paper! Fortunately, for me it's not much an issue at the moment, so I have some time before it gets critical :)

In theory, I could detect broadcasting on non-primitives in advance and do during forward pass the same trick as in reverse pass - call a function on first elements of arrays, writing ops to a "minitape", and then rewriting minitape to the main tape for arrays. But it doesn't sound very robust, so mixed-mode broadcasting can be a way to go.

dfdx added a commit that referenced this issue Jan 20, 2019
gradient for mea() and sum() with keywords
@dfdx
Copy link
Owner Author

dfdx commented Jul 3, 2021

Many features including broadcasting are now handled by ChainRules, so closing this issue as outdated.

@dfdx dfdx closed this as completed Jul 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants