-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Broadcast gradient error #1132
Comments
Could be related to #1001, which I think adds the ForwardDiff code path for broadcasted operations. |
I tried adding a rule like this: import DiffRules, IrrationalConstants, SpecialFunctions
∂logerfcx(x) = 2 * (x - inv(SpecialFunctions.erfcx(x)) / IrrationalConstants.sqrtπ)
DiffRules.@define_diffrule SpecialFunctions.logerfcx(x) = :(∂logerfcx($x)) But I get the exact same error afterwards. Related: https://discourse.julialang.org/t/issue-with-forwarddiff-custom-ad-rule/72886 |
Worth noting that broadcasting isn't picking up adjoints defined by hand due to recent changes to special case numeric arrays. I think that should be considered a bug. julia> function fn(x)
x ^ 2
end
fn (generic function with 1 method)
julia> gradient(x -> sum(fn.(x)), ones(3,3))
([2.0 2.0 2.0; 2.0 2.0 2.0; 2.0 2.0 2.0],)
julia> gradient(x -> sum(fn.(x)), 3.)
(6.0,)
julia> gradient(x -> sum(fn(x)), 3.)
(6.0,)
julia> Zygote.@adjoint function fn(x)
fn(x), Δ -> (1.,)
end
julia> gradient(x -> sum(fn.(x)), ones(3,3))
([2.0 2.0 2.0; 2.0 2.0 2.0; 2.0 2.0 2.0],)
julia> gradient(x -> sum(fn(x)), 3.)
(1.0,) Note that the rule is indeed picked up for the non-broadcasted version of the function. This is also due to using forwarddiff in the internal call to broadcasting in Zygote. If this were replaced by Zygote.Forward, we could make mixed-mode work better I think. |
Assuming it works, Zygote's built-in forward mode would still require an |
Oh, can you point to an example where |
I think forward vs reverse is a distraction here, as for simple R^2 -> R functions they will hardly differ. The problem is that Zygote tends to break type inference, and doing that in a tight loop (like broadcasting) murders performance. I have no idea whether Zygote's abandoned forward mode would do any better; it relies on the same IRTools tricks but perhaps inference is easier. @cossio this was never documented, never really got beyond proof of principle. The present code calls |
It isn't explained anywhere that I know of. You could try to piece together the internals from #503, but based on the comments there I imagine you've read through it already. As Michael mentioned, development kind of fell off a cliff after the first PR and for whatever reason there hasn't been enough appetite to pick it back up. Edit: #752 may be of interest. |
We need JuliaDiff/DiffRules.jl#74 otherwise we hit FluxML/Zygote.jl#1132 This got merged in DiffRules 1.8. So even if we don't import DiffRules directly, we add as a dep in Project.toml, just to add the compat bound.
The text was updated successfully, but these errors were encountered: