Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcast gradient error #1132

Closed
cossio opened this issue Dec 10, 2021 · 7 comments · Fixed by JuliaDiff/DiffRules.jl#74
Closed

Broadcast gradient error #1132

cossio opened this issue Dec 10, 2021 · 7 comments · Fixed by JuliaDiff/DiffRules.jl#74

Comments

@cossio
Copy link
Contributor

cossio commented Dec 10, 2021

image

@cossio
Copy link
Contributor Author

cossio commented Dec 10, 2021

Could be related to #1001, which I think adds the ForwardDiff code path for broadcasted operations.

@cossio
Copy link
Contributor Author

cossio commented Dec 10, 2021

I tried adding a rule like this:

import DiffRules, IrrationalConstants, SpecialFunctions

∂logerfcx(x) = 2 * (x - inv(SpecialFunctions.erfcx(x)) / IrrationalConstants.sqrtπ)
DiffRules.@define_diffrule SpecialFunctions.logerfcx(x) = :(∂logerfcx($x))

But I get the exact same error afterwards.

Related: https://discourse.julialang.org/t/issue-with-forwarddiff-custom-ad-rule/72886

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Dec 10, 2021

Worth noting that broadcasting isn't picking up adjoints defined by hand due to recent changes to special case numeric arrays. I think that should be considered a bug.

julia> function fn(x)
         x ^ 2
       end
fn (generic function with 1 method)

julia> gradient(x -> sum(fn.(x)), ones(3,3))
([2.0 2.0 2.0; 2.0 2.0 2.0; 2.0 2.0 2.0],)

julia> gradient(x -> sum(fn.(x)), 3.)
(6.0,)

julia> gradient(x -> sum(fn(x)), 3.)
(6.0,)

julia> Zygote.@adjoint function fn(x)
         fn(x), Δ -> (1.,)
       end

julia> gradient(x -> sum(fn.(x)), ones(3,3))
([2.0 2.0 2.0; 2.0 2.0 2.0; 2.0 2.0 2.0],)

julia> gradient(x -> sum(fn(x)), 3.)
(1.0,)

Note that the rule is indeed picked up for the non-broadcasted version of the function. This is also due to using forwarddiff in the internal call to broadcasting in Zygote. If this were replaced by Zygote.Forward, we could make mixed-mode work better I think.

@ToucheSir
Copy link
Member

Assuming it works, Zygote's built-in forward mode would still require an @tangent definition instead of an @adjoint. So in that sense you have to pick your poison: @tangent, @define_diffrule or maybe even frule.

@cossio
Copy link
Contributor Author

cossio commented Dec 11, 2021

Assuming it works, Zygote's built-in forward mode would still require an @tangent definition instead of an @adjoint. So in that sense you have to pick your poison: @tangent, @define_diffrule or maybe even frule.

Oh, can you point to an example where @tangent is explained?
I can't find anything in the docs https://fluxml.ai/Zygote.jl/stable/search/?q=tangent.

@mcabbott
Copy link
Member

I think forward vs reverse is a distraction here, as for simple R^2 -> R functions they will hardly differ. The problem is that Zygote tends to break type inference, and doing that in a tight loop (like broadcasting) murders performance.

I have no idea whether Zygote's abandoned forward mode would do any better; it relies on the same IRTools tricks but perhaps inference is easier. @cossio this was never documented, never really got beyond proof of principle.

The present code calls Broadcast.combine_eltypes which calls type inference for the primal. A safer way to guard the ForwardDiff fast path would be to similarly check that the dual function infers to return a Dual number. It will infer to Union{} if there are missing methods, as in this case, and then you could take the all-Zygote fallback path.

@ToucheSir
Copy link
Member

ToucheSir commented Dec 11, 2021

https://github.com/FluxML/Zygote.jl/blob/master/src/forward/lib.jl

It isn't explained anywhere that I know of. You could try to piece together the internals from #503, but based on the comments there I imagine you've read through it already. As Michael mentioned, development kind of fell off a cliff after the first PR and for whatever reason there hasn't been enough appetite to pick it back up.

Edit: #752 may be of interest.

cossio pushed a commit to cossio/RestrictedBoltzmannMachines.jl that referenced this issue Dec 11, 2021
We need
JuliaDiff/DiffRules.jl#74
otherwise we hit
FluxML/Zygote.jl#1132

This got merged in DiffRules 1.8.
So even if we don't import DiffRules directly, we add as a dep
in Project.toml, just to add the compat bound.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants