-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient Interface Design #628
Comments
With JAX/stax, there's no straightforward way to tie weights right now short of writing your own |
Ok, I think I finally have this figured out. The trick is that you need the gradient of the module as if it were any other struct. Then global variables – fields of the module – can be treated just like any other parameter / struct field. This means that we can go full on with closures being layers, and vice vera. This doesn't directly address James' concern, except in that the "custom layer" is now a one-line lambda and everything just works. Might be worth having some challenge cases, but I think it will look pretty good. |
Would this allow for model/ layer inheritance? |
In what sense? This does make it ridiculously easy to extend layer behaviour, e.g. l1 = Dense(10, 5)
l2 = x -> relu.(l1(x)) and |
That does look really nice. I'm thinking something more along the lines of building a library like this: https://github.com/rusty1s/pytorch_geometric The code example shows a messaging passing layer inheriting from and Edge Convolution. This could be done by composition, but sometimes inheritance (or traits) works better. Another example would be having a user specialize an "abstract transformer" type. In the general sense it would be differentiating (differentiable) programs written with full use of the type system. |
Yeah, we are of course limited by what Julia can express here, so still can't do Python-style inheritance exactly. But it'd be easy to do this via composition (which can mean a library layer wrapping custom one, as well as the other way around). e.g. mlp = Chain(Dense(...), ...)
model = MessagePassing() do x
mlp(x) |> softmax
end (I'm just making something up here, because I don't actually know what MessagePassing needs – you'd probably need to pass two functions in – but hopefully that gives a flavour). |
That make sense. What about making MessagePassing an abstract type? I'm not sure if that would make sense in this instance, but let's say generally you have a functions transform1 that calls transform2 which both take abstractlayer1. Then you have a user subtype abstractlayer1 to layer1, include some learnable parameters and maybe overload one or more of the functions in the lattice. |
I think the important point from an AD perspective is that if you can write it down, it will work :) Feel free to open a new issue to discuss specifics of layer design. I'm not personally a big fan of using abstract types to simulate inheritance, but I'm happy to discuss that / how it interacts with AD anyway. |
Question about this proposal in terms of functional AD. How does This is nice because I can define a function, and then give it to In the proposal how would this functional gradient look? |
Yeah, both ways of writing this are equivalent (and identical in terms of performance etc). You can implement grad(f) = x -> gradient(f, x)[1] And implement gradient(f, x) = grad(f)(x) (modulo tweaks for I don't like I also prefer not to provide curried versions of function where possible, but that's really more of a style thing. If we can come up with a good name for it I wouldn't oppose having a |
I'm playing with this interface recently with my own staff by using gradient(()->m(input), m) where So I'm wondering if we could make Zygote return a foo(::Grad{<:Dense}) = # blabla
foo(x::Grad{<:Chain}) = foreach(foo, grad.layers) |
669: using Zygote r=MikeInnes a=MikeInnes Otherwise known as "break all the things". This will be a huge change so I'm beginning to prepare now, even though Zygote is still a couple of months off from being really ready. **Do not try this at home** (yet) – this branch is eventually aimed at beta testers, but isn't even ready for that yet. The idea is to break as little code as possible, which means supporting the current `Params` API; but I also want to start prototyping the nicer things discussed in #628 and other issues. Blocking issues: * [x] Get the tests passing. * [x] Check tests on GPU. * [x] Rewrite all the docs. * [x] Cache invalidation (JuliaLabs/Cassette.jl#6). * [x] Moving over adjoints (FluxML/Zygote.jl#81). * [x] General Zygote robustness. Nice to have: * [ ] Robust nested AD (may not be a blocker if one can still use Tracker with Flux). * [x] Zygote support for modules / globals as discussed in #628, along with #637. * [x] Better train/test mode as in #643. If you're the kind of person who ignores triangular road signs, you can try this with ```julia ]add Flux#zygote Zygote#master ``` Co-authored-by: Mike J Innes <mike.j.innes@gmail.com> Co-authored-by: Elliot Saba <staticfloat@gmail.com> Co-authored-by: thebhatman <manjunathbhat9920@gmail.com>
Does anyone know if this hit a dead end or was ruled to be to complex to implement? I think the "closures are layers" aspect would help immensely for FluxML/FluxML-Community-Call-Minutes#10, while the "structural approach" would make |
Not a dead end, it's my preference to use this interface more and make it stable. We have https://github.com/FluxML/XLA.jl/pull/5/files#diff-2da3a01fb49af8d3ca12681d630be0e89f22536d6cb8322f6f6d239699bfd28f and FluxML/Optimisers.jl#3 which just need the rules to be ported over now. We can move ahead with this fairly swiftly. |
@MikeInnes Thanks for all your work! In JAX one can compute gradients with respect to nested dictionaries, a simple example is in the README here: https://github.com/anhinga/jax-pytree-example I wonder how difficult would it be to do something similar in Zygote.jl |
Have you tried doing this? All the pieces appear to be in place, and if you run into issues please file them at Zygote. That said, since your JAX example only uses string keys, it would be far more efficient to use namedtuples for the same purpose in Julia. |
Yeah it works well already. eg julia> gradient(x -> x["foo"]["bar"]^2, Dict("foo" => Dict("bar" => 5)))
(Dict("foo" => Dict("bar" => 10)),) |
@ToucheSir @MikeInnes Thanks! (My mistake was trying to coerce a dictionary into Params.) |
We expose a pullback/vjp-based API for gradients (
y, back = forward(f, x); x̄ = back(ȳ)
), withx̄ = gradient(f, x)
as simple syntax sugar of top of this. This interface is pretty awesome –gradient
aligns nicely with the mathematical and intuitive notions of a derivative operator, it naturally expresses nested derivatives, and you can build pretty much any other AD-related functionality (checkpointing, forward mode, gradient hooks, etc) on top of pullbacks, without having to go into AD internals. So far I haven't come across anything that pullbacks can't do straightforwardly; in one case the PyTorch-styleback!
may be slightly more convenient, but it's overall more cumbersome and requires more knowledge of internals.However, a challenge of the "mathematical" gradient operator is that it's cumbersome to pass in all our parameter arrays explicitly (
gradient(resnet, W1, b1, W2, b2, ...)
). So we need to be able to handle taking gradients of large models without it being cumbersome.There are currently two ideas about how to do this: the structural approach and the implicit approach.Edit: Since writing this I have convinced myself that we can get the convenience of implicit params by slightly generalising the structural approach. I think this gives us a clear path forward, though unfortunately it does mean additional API churn.
Structural Gradients
The first approach (which Zygote will support whatever happens, and could be added to Flux) is to take the gradients w.r.t. some structure containing all of the parameters. The structure could be a dictionary or list, but it's usually convenient to combine the weight structure with the definition of the forward pass. This is effectively a closure, which we refer to as a "layer". Layers can contain other layers (we often call a compound layer a "model", but there's no fundamental difference). Taking a gradient looks like this:
This looks pretty weird at first but makes a lot of sense once it clicks. One then carries out the update step
m .+= m̄
.Implicit Gradients
The implicit approach is what Flux supports natively, though it works in Zygote as well. In this case we ask for gradients of a shapeless set of parameters, which are implicitly used at some point during the forward pass. In this case we have something more like:
θ̄
is a dictionary from param to gradient. One then loops over the parameters, doingp .+= θ̄[p]
.Do we need implicit parameters?
Implicit parameters have some downsides. They feel somewhat less clean and functional than structural ones. It does not support scalars or immutable arrays well, which are needed for more restrictive backends like TPUs; supporting both means having more than one way to do things.
However, implicit parameters have a huge advantage: they make it easy to write "script-like" models. I see them as being a lot like global variables: sure they're a bit unclean, but sometimes it's just convenient to build up a model gradually in a notebook, without a lot of structure (and if I have one non-negotiable rule of API design, it's that you should never have to define a struct/class or satisfy an interface to use a library). Our VAE model is a nice example of this style which I think would be made significantly more cumbersome otherwise.
A potential solution is to make it easier to define "anonymous layers". In my ideal world these would also just be closures, but
unfortunately this isn't workable(see discussion below) – closures don't explicitly store parameters when they are closed over from global scope, making them invisible to structural AD. Functions that return closures would be completely fine, but the distinction is too subtle / tied to implementation details.Other concerns
A couple of other small subtleties.
In the implicit style parameter identity matters, which means we can reuse parameters when creating a model. For example:
In the implicit parameter style the weights of
d
are shared, but in the structural version we get two separate gradients ford
at each point in the chain, and we'd have to construct the chain inside the forward pass to get the gradient we want. Similar issues come up in nested AD. I don't think either behaviour is more correct or better – both are well-defined and predictable – but they are different.[This is further complicated by the fact that in-place updates mean the weight is effectively shared even in the structural case, just with weird semantics in optimiser state.]
This has some relevance to RNNs, which currently weight-ties the initial and current state fields. I don't think this needs to result in user-facing changes though. It's also possible to make a nice pure-functional RNN interface (see e.g. JAX), you just can't abstract over RNNs quite like we currently do; state needs to be a bit more explicitly managed (which isn't necessarily a deal breaker, but worth considering).
[To be clear though, while the structural approach is more functional, it does not force differentiated programs to be functional themselves, so something very like our current RNN design is still possible.]
The text was updated successfully, but these errors were encountered: