Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functional AD #86

Closed
MikeInnes opened this issue Oct 17, 2017 · 7 comments
Closed

Functional AD #86

MikeInnes opened this issue Oct 17, 2017 · 7 comments

Comments

@MikeInnes
Copy link
Member

MikeInnes commented Oct 17, 2017

The eventual plan is to build a new compiler-level AD that better exploits Julia's compilation, provides a more function interface, and supports nested differentiation. A question here is how to support the grad(f, x) style interface while also still allowing abstraction and modularity in layers and their weights.

I see this looking something like:

W = randn(5,5)
b = randn(5)
loss(x, y) = mse(W*x .+ b, y)
dW, db = grad(loss, (W, b), x, y)

W and b are treated as implicit arguments to the function; this is nice in that it's essentially the ideal functional interface but without the mess of hundreds of explicit arguments.

Models will implement params, as they do now, and whatever arrays they return will be treated as trainable parameters (dparams = grad(model, params(model), args...)). We'll also have a Freeze layer to treat things as constant, e.g. m = Freeze(Dense(10, 5)); params(m) == []. Freezing parameters is a little more coarse-grained compared to now, but that's small loss compared to the gains.

@dfdx
Copy link

dfdx commented Oct 18, 2017

W and b are treated as implicit arguments to the function; this is nice in that it's essentially the ideal functional interface but without the mess of hundreds of explicit arguments.

Does this mean that a user will need to make W and b global and loss to be defined in the same context? This doesn't sound very flexible, to be honest. Recently I've been playing around (e.g. in VariationalAE.jl) with models as mutable structs. For your case it would look something like:

m = Linear(W, b)
loss(m::Linear, x, y) = mse(m.W * x .+ m.b, y)
dm = grad(loss, m, x, y)

where dm is another instance of Linear holding derivatives, i.e.:

dW = dm.W
db = dm.b

@MikeInnes
Copy link
Member Author

Does this mean that a user will need to make W and b global and loss to be defined in the same context?

Er, no? For the most part I'm not expecting any thing else to look different; so the MNIST example would stay exactly the same. It's really no different to the current TrackedArray approach in that sense, just without the hacky overloading.

Your API is something we discussed as it's closer to what Knet currently has. At a minimum it only scales up well if you allow the structure to define the forward pass (e.g. via call overloading). Even then it imposes a bigger burden on user-defined types and small models, and it's harder to figure out how it plays when you get to really complex models (as one example, higher-order models that take another model as input).

@dfdx
Copy link

dfdx commented Oct 18, 2017

Ah, I re-read the MNIST example. Do I understand correctly that loss() is actually a closure bound to an instance of object m::Chain? In this case my comment is indeed irrelevant.

@MikeInnes
Copy link
Member Author

Essentially yes, it's not actually a closure in this case because it's global, but it could be. In the docs there are some examples of closing over parameters, and I expect those to work with Cassette as well.

@baggepinnen
Copy link
Contributor

How would the user go about taking the gradient of a model output with respect to a non-parmameter like the input? This is common in creating adverserial examples, linearizing dynamical models etc.

@MikeInnes
Copy link
Member Author

I guess that would have to be grad(loss, (W, b, x), x, y). Although that really makes it clear that loss should just be a zero-arg function, like grad(() -> loss(x, y), [W, b, x]).

@MikeInnes
Copy link
Member Author

A year on we can do some much cooler things here. Closing in favour of #628.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants