-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: use ChainRules in Forward #752
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an implementation in line with the way that the reverse-mode one is done. Once JuliaDiff/ChainRulesCore.jl#182 is complete, we could also just utilise that, and I think the implementation will probably be a bit simpler.
return :(__pushforward(dargs, f, args...)) | ||
end | ||
|
||
# g = try _lookup_grad(T) catch e e end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm really not sure what I'm doing with the edges stuff here -- this is currently copied from the reverse-mode definition. @oxinabox or @DhairyaLGandhi could you comment on what sort of thing I'm going to need to do here? In particular, I'm not sure what _lookup_grad
is doing, so I'm not sure how to reason about adapting the reverse-mode implementation.
test/forward/forward.jl
Outdated
# end | ||
# end | ||
# x | ||
# end == 0 | ||
|
||
@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1)[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't currently pass. I'm pretty sure that the Forward
implementation is doing the correct thing, and gradient
is wrong. D
(correctly) returns a Real
as the derivative w.r.t x
, while gradient
returns a Complex
. @sethaxen do you agree that Zygote is doing the wrong thing here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ChainRulesCore has an issue open about this: JuliaDiff/ChainRulesCore.jl#176. It's also related to JuliaDiff/ChainRules.jl#232. The issue is that our rules for addition between real and complex (+(x, 2im)
) will pull back a complex adjoint to x
.
So we'd expect these to do the same thing, but
julia> gradient(x -> abs(x+2im), 1)[1]
0.4472135954999579 + 0.8944271909999159im
julia> gradient(x -> abs(complex(x, 2)), 1)[1]
0.4472135954999579
The first issue above links to some Zygote discussions. Essentially, Zygote treats all reals as actually complex, so from that perspective it is doing the right thing. I just don't think that's the right way to go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh interesting. It's interesting that Zygote.Forward
doesn't naturally do the same thing as Zygote
.
I guess that the question is what the right way forward for us is with Forward
. Mike previously had some extra code for abs
to make it do the Zygote-y thing.
I wonder whether it's acceptable to utilise a different convention in Forward
than what Zygote
currently does...
@DhairyaLGandhi what are your thoughts on the way that Zygote currently handles complex numbers / have you had a chance to review the various ChainRules discussions about complex numbers?
Most recent commit fixed one bug, but re-introduced some tests that expose us to having to differentiate through |
|
||
""" | ||
is_kwfunc(::Vararg{Any}) = false | ||
is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this code exists already for reverse mode we should just import it from there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I'll need to refactor to avoid all of this duplication. I just copied + pasted to start with because I was tired haha
end | ||
|
||
""" | ||
chain_frule(f, args...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doc is wrong but the code is right.
dargs
includes dself, darg1, darg2 ...
and args
includes foo, darg1, darg2
|
||
chain_frule_f = iskw ? :chain_frule_kw : :chain_frule | ||
if hascr | ||
return :($chain_frule_f(dargs, f, args...)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where are my kwargs gone?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just not implemented yet, see check-boxes at the top :)
function has_chain_frule(T) | ||
m = meta(Tuple{typeof(frule), T.parameters...}) | ||
|
||
if m.method !== chainrules_frule_fallback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should just abstract this to be has_chainrule(kind, T)
function has_chain_frule(T) | |
m = meta(Tuple{typeof(frule), T.parameters...}) | |
if m.method !== chainrules_frule_fallback | |
function has_chain_rule(kind, T) | |
m = meta(Tuple{typeof(kind), T.parameters...}) | |
if m.method !== chainrules_rule_fallback(kind) |
foo(x) = x | ||
|
||
# This intentionally has the wrong definition so that we can detect if Zygote is using it. | ||
ChainRulesCore.frule((_, dx), ::typeof(foo), x) = x, 2 * dx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the tests in the reverse mode chainrules tests set a variable that can late be read to see if it was hit.
This should be tested the same way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nice, I'll do that.
@@ -76,6 +76,10 @@ function global_set(ref, val) | |||
ref.mod, ref.name, val) | |||
end | |||
|
|||
function ChainRules.frule(dargs, ::typeof(Zygote.global_set), ref, val::Nothing) | |||
return Zygote.global_set(ref, val), nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return Zygote.global_set(ref, val), nothing | |
return Zygote.global_set(ref, val), DoesNotExist() |
Returns a the (primal) value of `f(args...)` and tangent, by invoking | ||
`ChainRules.frule(f, args...)`. | ||
""" | ||
@noinline chain_frule(dargs, args...) = frule(dargs, args...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need a translation layer replacing Composites with NamedTuples and AbstractZeros with nothing
and visa versa
Like we have for reverse mode.
Can probably just abstract and call that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, on the TODO list :)
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
Thanks for the review @oxinabox . The main thing I'm still unsure about is what |
AFAICT |
Closing because I've definitely not got time to work on this at the minute, and don't anticipate doing so for a while. If someone else wants to have a go at this, they should feel free to do so. |
Addresses #712.
TODO:
A separate PR will be needed to move rules that are currently in
Zygote.Forward
toChainRules
. Unlike the reverse-mode rules there aren't a particularly large number of them, so it shouldn't be a massive job.