Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient Interface Design #628

Open
MikeInnes opened this issue Feb 15, 2019 · 17 comments
Open

Gradient Interface Design #628

MikeInnes opened this issue Feb 15, 2019 · 17 comments

Comments

@MikeInnes
Copy link
Member

MikeInnes commented Feb 15, 2019

We expose a pullback/vjp-based API for gradients (y, back = forward(f, x); x̄ = back(ȳ)), with x̄ = gradient(f, x) as simple syntax sugar of top of this. This interface is pretty awesome – gradient aligns nicely with the mathematical and intuitive notions of a derivative operator, it naturally expresses nested derivatives, and you can build pretty much any other AD-related functionality (checkpointing, forward mode, gradient hooks, etc) on top of pullbacks, without having to go into AD internals. So far I haven't come across anything that pullbacks can't do straightforwardly; in one case the PyTorch-style back! may be slightly more convenient, but it's overall more cumbersome and requires more knowledge of internals.

However, a challenge of the "mathematical" gradient operator is that it's cumbersome to pass in all our parameter arrays explicitly (gradient(resnet, W1, b1, W2, b2, ...)). So we need to be able to handle taking gradients of large models without it being cumbersome. There are currently two ideas about how to do this: the structural approach and the implicit approach.

Edit: Since writing this I have convinced myself that we can get the convenience of implicit params by slightly generalising the structural approach. I think this gives us a clear path forward, though unfortunately it does mean additional API churn.

Structural Gradients

The first approach (which Zygote will support whatever happens, and could be added to Flux) is to take the gradients w.r.t. some structure containing all of the parameters. The structure could be a dictionary or list, but it's usually convenient to combine the weight structure with the definition of the forward pass. This is effectively a closure, which we refer to as a "layer". Layers can contain other layers (we often call a compound layer a "model", but there's no fundamental difference). Taking a gradient looks like this:

m = Chain(Dense(10, 5, relu), Dense(5, 2), softmax)
x, y = ...= gradient(m -> loss(m(x), y), m)

This looks pretty weird at first but makes a lot of sense once it clicks. One then carries out the update step m .+= m̄.

Implicit Gradients

The implicit approach is what Flux supports natively, though it works in Zygote as well. In this case we ask for gradients of a shapeless set of parameters, which are implicitly used at some point during the forward pass. In this case we have something more like:

m = Chain(Dense(10, 5, relu), Dense(5, 2), softmax)
x, y = ...
θ = params(m)
θ̄ = gradient(() -> loss(m(x), y), θ)

θ̄ is a dictionary from param to gradient. One then loops over the parameters, doing p .+= θ̄[p].

Do we need implicit parameters?

Implicit parameters have some downsides. They feel somewhat less clean and functional than structural ones. It does not support scalars or immutable arrays well, which are needed for more restrictive backends like TPUs; supporting both means having more than one way to do things.

However, implicit parameters have a huge advantage: they make it easy to write "script-like" models. I see them as being a lot like global variables: sure they're a bit unclean, but sometimes it's just convenient to build up a model gradually in a notebook, without a lot of structure (and if I have one non-negotiable rule of API design, it's that you should never have to define a struct/class or satisfy an interface to use a library). Our VAE model is a nice example of this style which I think would be made significantly more cumbersome otherwise.

A potential solution is to make it easier to define "anonymous layers". In my ideal world these would also just be closures, but unfortunately this isn't workable (see discussion below) – closures don't explicitly store parameters when they are closed over from global scope, making them invisible to structural AD. Functions that return closures would be completely fine, but the distinction is too subtle / tied to implementation details.

Other concerns

A couple of other small subtleties.

In the implicit style parameter identity matters, which means we can reuse parameters when creating a model. For example:

d = Dense(10, 10, relu)
m = Chain(d, d, softmax)
dm = gradient(...)

In the implicit parameter style the weights of d are shared, but in the structural version we get two separate gradients for d at each point in the chain, and we'd have to construct the chain inside the forward pass to get the gradient we want. Similar issues come up in nested AD. I don't think either behaviour is more correct or better – both are well-defined and predictable – but they are different.

[This is further complicated by the fact that in-place updates mean the weight is effectively shared even in the structural case, just with weird semantics in optimiser state.]

This has some relevance to RNNs, which currently weight-ties the initial and current state fields. I don't think this needs to result in user-facing changes though. It's also possible to make a nice pure-functional RNN interface (see e.g. JAX), you just can't abstract over RNNs quite like we currently do; state needs to be a bit more explicitly managed (which isn't necessarily a deal breaker, but worth considering).

[To be clear though, while the structural approach is more functional, it does not force differentiated programs to be functional themselves, so something very like our current RNN design is still possible.]

@jekbradbury
Copy link
Contributor

With JAX/stax, there's no straightforward way to tie weights right now short of writing your own init_fn and apply_fn (and, separately, no one that I know of has tried building RNNs so we don't really know how nice the interface will feel for that). Having an intuitive way to do this is pretty important, and to me (and given Julia's pervasive reference semantics) your example in "Other concerns" should definitely have the effect of tying the weights of d.

@MikeInnes
Copy link
Member Author

Ok, I think I finally have this figured out. The trick is that you need the gradient of the module as if it were any other struct. Then global variables – fields of the module – can be treated just like any other parameter / struct field. This means that we can go full on with closures being layers, and vice vera.

This doesn't directly address James' concern, except in that the "custom layer" is now a one-line lambda and everything just works. Might be worth having some challenge cases, but I think it will look pretty good.

This was referenced Feb 22, 2019
@datnamer
Copy link

Would this allow for model/ layer inheritance?

@MikeInnes
Copy link
Member Author

In what sense? This does make it ridiculously easy to extend layer behaviour, e.g.

l1 = Dense(10, 5)
l2 = x -> relu.(l1(x))

and l2 can now be used anywhere l1 could have been. But I'm not sure if you mean something else by "inheritance".

@datnamer
Copy link

datnamer commented Feb 22, 2019

That does look really nice.

I'm thinking something more along the lines of building a library like this: https://github.com/rusty1s/pytorch_geometric

The code example shows a messaging passing layer inheriting from and Edge Convolution. This could be done by composition, but sometimes inheritance (or traits) works better. Another example would be having a user specialize an "abstract transformer" type.

In the general sense it would be differentiating (differentiable) programs written with full use of the type system.

@MikeInnes
Copy link
Member Author

Yeah, we are of course limited by what Julia can express here, so still can't do Python-style inheritance exactly. But it'd be easy to do this via composition (which can mean a library layer wrapping custom one, as well as the other way around). e.g.

mlp = Chain(Dense(...), ...)
model = MessagePassing() do x
  mlp(x) |> softmax
end

(I'm just making something up here, because I don't actually know what MessagePassing needs – you'd probably need to pass two functions in – but hopefully that gives a flavour).

@datnamer
Copy link

datnamer commented Feb 22, 2019

That make sense. What about making MessagePassing an abstract type? I'm not sure if that would make sense in this instance, but let's say generally you have a functions transform1 that calls transform2 which both take abstractlayer1. Then you have a user subtype abstractlayer1 to layer1, include some learnable parameters and maybe overload one or more of the functions in the lattice.

@MikeInnes
Copy link
Member Author

MikeInnes commented Feb 23, 2019

I think the important point from an AD perspective is that if you can write it down, it will work :) Feel free to open a new issue to discuss specifics of layer design. I'm not personally a big fan of using abstract types to simulate inheritance, but I'm happy to discuss that / how it interacts with AD anyway.

@jessebett
Copy link
Contributor

Question about this proposal in terms of functional AD. How does x̄ = gradient(f, x) know what x is unless it's defined previously. Comparing to jax now which has a very nice df = jax.grad(f) which returns a function that computes the gradient of f (by default wrt the first argument). This can be changed to other arguments like dfdy = jax.grad(f,argnum=(1)).

This is nice because I can define a function, and then give it to jax.grad to get the gradient function, and do this recursively for higher order gradients.

In the proposal how would this functional gradient look?
df = x->gradient(f,x)
It's not obvious to me that this is doing the same thing. For instance, would this compile to a function that computes the gradient of f wrt x in the same sense that the jax version does?

@MikeInnes
Copy link
Member Author

Yeah, both ways of writing this are equivalent (and identical in terms of performance etc). You can implement grad as

grad(f) = x -> gradient(f, x)[1]

And implement gradient back on top of this as

gradient(f, x) = grad(f)(x)

(modulo tweaks for argnum, multiple arguments etc.)

I don't like argnum because you can do this easily with a lambda, e.g. rather than grad(*, argnum=2)(x, y) you can do gradient(y -> x*y, y).

I also prefer not to provide curried versions of function where possible, but that's really more of a style thing. If we can come up with a good name for it I wouldn't oppose having a grad equivalent in Zygote.

This was referenced Mar 7, 2019
@Roger-luo
Copy link
Contributor

I'm playing with this interface recently with my own staff by using

gradient(()->m(input), m)

where m is my model defined with Chain, Dense etc. and gradient will give me a NamedTuple as proposed above. However, in my case, I will have some post processing for the gradients (sort of policy gradient), it seems a bit in-convenient when the gradients are stored as structure. But I do believe the explicit structural gradient is pretty natural given gradient is the adjoint of model parameters in some sense.

So I'm wondering if we could make Zygote return a Grad{T} type where T is the original type instead of returning a NamedTuple, and Grad{T} can be used to dispatch methods that T has while it can be added to original. So we can easily use multiple dispatch to traverse the nodes in the model, e.g

foo(::Grad{<:Dense}) = # blabla
foo(x::Grad{<:Chain}) = foreach(foo, grad.layers)

bors bot added a commit that referenced this issue Sep 11, 2019
669: using Zygote r=MikeInnes a=MikeInnes

Otherwise known as "break all the things". This will be a huge change so I'm beginning to prepare now, even though Zygote is still a couple of months off from being really ready. **Do not try this at home** (yet) – this branch is eventually aimed at beta testers, but isn't even ready for that yet.

The idea is to break as little code as possible, which means supporting the current `Params` API; but I also want to start prototyping the nicer things discussed in #628 and other issues.

Blocking issues:

* [x] Get the tests passing.
* [x] Check tests on GPU.
* [x] Rewrite all the docs.
* [x] Cache invalidation (JuliaLabs/Cassette.jl#6).
* [x] Moving over adjoints (FluxML/Zygote.jl#81).
* [x] General Zygote robustness.

Nice to have:

* [ ] Robust nested AD (may not be a blocker if one can still use Tracker with Flux).
* [x] Zygote support for modules / globals as discussed in #628, along with #637.
* [x] Better train/test mode as in #643.

If you're the kind of person who ignores triangular road signs, you can try this with

```julia
]add Flux#zygote Zygote#master
```

Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
Co-authored-by: Elliot Saba <staticfloat@gmail.com>
Co-authored-by: thebhatman <manjunathbhat9920@gmail.com>
@ToucheSir
Copy link
Member

Does anyone know if this hit a dead end or was ruled to be to complex to implement? I think the "closures are layers" aspect would help immensely for FluxML/FluxML-Community-Call-Minutes#10, while the "structural approach" would make loadparams! and co. a lot nicer to use.

@DhairyaLGandhi
Copy link
Member

Not a dead end, it's my preference to use this interface more and make it stable. We have https://github.com/FluxML/XLA.jl/pull/5/files#diff-2da3a01fb49af8d3ca12681d630be0e89f22536d6cb8322f6f6d239699bfd28f and FluxML/Optimisers.jl#3 which just need the rules to be ported over now. We can move ahead with this fairly swiftly.

@anhinga
Copy link

anhinga commented Apr 5, 2022

@MikeInnes Thanks for all your work!

In JAX one can compute gradients with respect to nested dictionaries, a simple example is in the README here:

https://github.com/anhinga/jax-pytree-example

I wonder how difficult would it be to do something similar in Zygote.jl

@ToucheSir
Copy link
Member

Have you tried doing this? All the pieces appear to be in place, and if you run into issues please file them at Zygote. That said, since your JAX example only uses string keys, it would be far more efficient to use namedtuples for the same purpose in Julia.

@MikeInnes
Copy link
Member Author

Yeah it works well already. eg

julia> gradient(x -> x["foo"]["bar"]^2, Dict("foo" => Dict("bar" => 5)))
(Dict("foo" => Dict("bar" => 10)),)

@anhinga
Copy link

anhinga commented Apr 5, 2022

@ToucheSir @MikeInnes Thanks!

(My mistake was trying to coerce a dictionary into Params.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

No branches or pull requests

8 participants