Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: use ChainRules in Forward #752

Closed
wants to merge 4 commits into from
Closed

WIP: use ChainRules in Forward #752

wants to merge 4 commits into from

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Aug 1, 2020

Addresses #712.

TODO:

  • Get nested AD working
  • Ensure that kwargs work correctly
  • "Unwrap" ChainRules types into Zygote type for compatibility
  • Remove all excess rules
  • Tidy up

A separate PR will be needed to move rules that are currently in Zygote.Forward to ChainRules. Unlike the reverse-mode rules there aren't a particularly large number of them, so it shouldn't be a massive job.

Copy link
Member Author

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an implementation in line with the way that the reverse-mode one is done. Once JuliaDiff/ChainRulesCore.jl#182 is complete, we could also just utilise that, and I think the implementation will probably be a bit simpler.

return :(__pushforward(dargs, f, args...))
end

# g = try _lookup_grad(T) catch e e end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really not sure what I'm doing with the edges stuff here -- this is currently copied from the reverse-mode definition. @oxinabox or @DhairyaLGandhi could you comment on what sort of thing I'm going to need to do here? In particular, I'm not sure what _lookup_grad is doing, so I'm not sure how to reason about adapting the reverse-mode implementation.

# end
# end
# x
# end == 0

@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1)[1]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't currently pass. I'm pretty sure that the Forward implementation is doing the correct thing, and gradient is wrong. D (correctly) returns a Real as the derivative w.r.t x, while gradient returns a Complex. @sethaxen do you agree that Zygote is doing the wrong thing here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChainRulesCore has an issue open about this: JuliaDiff/ChainRulesCore.jl#176. It's also related to JuliaDiff/ChainRules.jl#232. The issue is that our rules for addition between real and complex (+(x, 2im)) will pull back a complex adjoint to x.
So we'd expect these to do the same thing, but

julia> gradient(x -> abs(x+2im), 1)[1]
0.4472135954999579 + 0.8944271909999159im

julia> gradient(x -> abs(complex(x, 2)), 1)[1]
0.4472135954999579

The first issue above links to some Zygote discussions. Essentially, Zygote treats all reals as actually complex, so from that perspective it is doing the right thing. I just don't think that's the right way to go.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh interesting. It's interesting that Zygote.Forward doesn't naturally do the same thing as Zygote.

I guess that the question is what the right way forward for us is with Forward. Mike previously had some extra code for abs to make it do the Zygote-y thing.

I wonder whether it's acceptable to utilise a different convention in Forward than what Zygote currently does...

@DhairyaLGandhi what are your thoughts on the way that Zygote currently handles complex numbers / have you had a chance to review the various ChainRules discussions about complex numbers?

test/forward/forward.jl Outdated Show resolved Hide resolved
@willtebbutt
Copy link
Member Author

Most recent commit fixed one bug, but re-introduced some tests that expose us to having to differentiate through broadcast, which isn't something that Zygote.Forward currently handles. I'll need to give that some thought...


"""
is_kwfunc(::Vararg{Any}) = false
is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code exists already for reverse mode we should just import it from there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I'll need to refactor to avoid all of this duplication. I just copied + pasted to start with because I was tired haha

end

"""
chain_frule(f, args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doc is wrong but the code is right.
dargs includes dself, darg1, darg2 ... and args includes foo, darg1, darg2


chain_frule_f = iskw ? :chain_frule_kw : :chain_frule
if hascr
return :($chain_frule_f(dargs, f, args...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are my kwargs gone?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just not implemented yet, see check-boxes at the top :)

src/forward/compiler.jl Outdated Show resolved Hide resolved
Comment on lines +84 to +87
function has_chain_frule(T)
m = meta(Tuple{typeof(frule), T.parameters...})

if m.method !== chainrules_frule_fallback
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should just abstract this to be has_chainrule(kind, T)

Suggested change
function has_chain_frule(T)
m = meta(Tuple{typeof(frule), T.parameters...})
if m.method !== chainrules_frule_fallback
function has_chain_rule(kind, T)
m = meta(Tuple{typeof(kind), T.parameters...})
if m.method !== chainrules_rule_fallback(kind)

foo(x) = x

# This intentionally has the wrong definition so that we can detect if Zygote is using it.
ChainRulesCore.frule((_, dx), ::typeof(foo), x) = x, 2 * dx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tests in the reverse mode chainrules tests set a variable that can late be read to see if it was hit.
This should be tested the same way

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice, I'll do that.

@@ -76,6 +76,10 @@ function global_set(ref, val)
ref.mod, ref.name, val)
end

function ChainRules.frule(dargs, ::typeof(Zygote.global_set), ref, val::Nothing)
return Zygote.global_set(ref, val), nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return Zygote.global_set(ref, val), nothing
return Zygote.global_set(ref, val), DoesNotExist()

Returns a the (primal) value of `f(args...)` and tangent, by invoking
`ChainRules.frule(f, args...)`.
"""
@noinline chain_frule(dargs, args...) = frule(dargs, args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need a translation layer replacing Composites with NamedTuples and AbstractZeros with nothing
and visa versa

Like we have for reverse mode.
Can probably just abstract and call that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, on the TODO list :)

Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
@willtebbutt
Copy link
Member Author

Thanks for the review @oxinabox . The main thing I'm still unsure about is what _lookup_grad was doing in the reverse mode implementation. (See my first comment)

@oxinabox
Copy link
Member

oxinabox commented Aug 5, 2020

AFAICT _lookup_grad in reverse, dispite the name, does not look up the gradient,
it performs the core of the source code tranform to do AD. it is defined in emit.jl
Roughly speaking the matching function in forward is dual which is also terribly named

@willtebbutt
Copy link
Member Author

Closing because I've definitely not got time to work on this at the minute, and don't anticipate doing so for a while. If someone else wants to have a go at this, they should feel free to do so.

@willtebbutt willtebbutt deleted the wct/forward-cr branch January 16, 2021 12:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants