Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse over ForwardDiff differentiation: Zygote vs. ReverseDiff #1218

Open
nmheim opened this issue May 6, 2022 · 5 comments
Open

Reverse over ForwardDiff differentiation: Zygote vs. ReverseDiff #1218

nmheim opened this issue May 6, 2022 · 5 comments
Labels
second order zygote over zygote, or otherwise

Comments

@nmheim
Copy link

nmheim commented May 6, 2022

I would like to compute the gradient of a function that contains a ForwardDiff.gradient/ForwardDiff.derivative.
Computing that higher order gradient with ForwardDiff works, but will be slow for my case be cause its a many input function.
MWE below.

using ForwardDiff
using ReverseDiff
using Zygote

f(x::Real,y::AbstractVector) = x*sum(y)

df(x,y) = ForwardDiff.derivative(x->f(x,y), x)

# this returns (nothing,)
Zygote.gradient(y->df(0.1,y), rand(5))

# this works; returning vector of ones
ForwardDiff.gradient(y->df(0.1,y), rand(5))

# this works as well
ReverseDiff.gradient(y->df(0.1,y), rand(5))

It would be really great to get this to work, because for my use case I have to deal with complex numbers and it seems like ReverseDiff does not like them that much.

Tagging @mcabbott as we have discussed this on slack. I would be happy to help in any way I can.

@nmheim
Copy link
Author

nmheim commented May 6, 2022

As this was part of the discussion on slack: I am getting the same behaviour when using ForwardDiff.gradient inside

f(x::AbstractVector,y::AbstractVector) = sum(x)*sum(y)

df(x,y) = ForwardDiff.gradient(x->f(x,y), x)[1]

x = [0.1]
y = rand(5)

# this returns (nothing,)
Zygote.gradient(y->df(x,y), y)

# this works; returning vector of ones
ForwardDiff.gradient(y->df(x,y), y)

# this works as well
ReverseDiff.gradient(y->df(x,y), y)

@mcabbott
Copy link
Member

mcabbott commented May 7, 2022

Thanks for making an issue.

The nothing comes from #968, which sends these to Zygote.forwarddiff(f, x), forward over forward, which does not, cannot, track variables closed over in its f. Only IMO the minimal solution here is to remove that, and return some of these to errors, as the silent nothing is too often surprising.

In #769 there was also a proposal to turn Zygote over ForwardDiff around automatically to ForwardDiff over Zygote. Not sure whether this can handle all cases, or be made to work really.

Alternatively, maybe we can figure out what ReverseDiff is doing. The last example is the surprising one, since ForwardDiff.gradient inside makes an array for its output and then writes into it.

Xref #1189 about this same problem (and others).

@mcabbott mcabbott changed the title Reverse over forward differentiation Reverse over ForwardDiff differentiation: Zygote vs. ReverseDiff May 9, 2022
@mcabbott
Copy link
Member

Ok I see. ReverseDiff is dropping down to scalar TrackedReal, which it can trace individually through mutation:

julia> ReverseDiff.gradient([1,2,3.0]) do x
         y = zeros(@show(eltype(x)), 3)
         y[1] = x[1] + x[2]^2
         sum(y)
       end
eltype(x) = ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}
3-element Vector{Float64}:
 1.0
 4.0
 0.0

Second derivatives:

julia> ForwardDiff.gradient([1, 2]) do x
         ForwardDiff.gradient(x) do y
           (@show(x[1]) * @show(y[2]))^2
         end |> sum
       end
x[1] = Dual{ForwardDiff.Tag{var"#85#87", Int64}}(1,1,0)
y[2] = Dual{ForwardDiff.Tag{var"#86#88"{Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#85#87", Int64}, Int64, 2}}}, ForwardDiff.Dual{ForwardDiff.Tag{var"#85#87", Int64}, Int64, 2}}}(Dual{ForwardDiff.Tag{var"#85#87", Int64}}(2,0,1),Dual{ForwardDiff.Tag{var"#85#87", Int64}}(0,0,0),Dual{ForwardDiff.Tag{var"#85#87", Int64}}(1,0,0))
2-element Vector{Int64}:
 8
 2

julia> ReverseDiff.gradient([1, 2]) do x
         ForwardDiff.gradient(x) do y
           (@show(x[1]) * @show(y[2]))^2
         end |> sum
       end
x[1] = TrackedReal<7VO>(1, 0, DVN, 1, 890)
y[2] = Dual{ForwardDiff.Tag{var"#90#92"{ReverseDiff.TrackedArray{Int64, Int64, 1, Vector{Int64}, Vector{Int64}}}, ReverseDiff.TrackedReal{Int64, Int64, ReverseDiff.TrackedArray{Int64, Int64, 1, Vector{Int64}, Vector{Int64}}}}}(TrackedReal<Le4>(2, 0, DVN, 2, 890),TrackedReal<7aD>(0, 0, ---, ---),TrackedReal<G74>(1, 0, ---, ---))
2-element Vector{Int64}:
 8
 2

julia> Zygote.gradient([1, 2]) do x
         ForwardDiff.gradient(x) do y
           (@show(x[1]) * @show(y[2]))^2
         end |> sum
       end
x[1] = 1
y[2] = Dual{ForwardDiff.Tag{var"#94#96"{Vector{Int64}}, ForwardDiff.Dual{Nothing, Int64, 2}}}(Dual{Nothing}(2,0,1),Dual{Nothing}(0,0,0),Dual{Nothing}(1,0,0))
([0.0, 2.0],)

@nmheim
Copy link
Author

nmheim commented May 10, 2022

Is there any chance that this will make it into Zygote?

@ToucheSir
Copy link
Member

Zygote uses a completely different mechanism for its AD, so I doubt ReverseDiff's approach would be directly applicable (if at all).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
second order zygote over zygote, or otherwise
Projects
None yet
Development

No branches or pull requests

3 participants