-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a update!(params, opt, losses) suitable for writing own training loops #607
Conversation
@@ -6,6 +6,38 @@ function update!(opt, x, x̄) | |||
update!(x, apply!(opt, x, copy(data(x̄)))) | |||
end | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This docstring should likely be included in the manual somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes? There is a documentor command to insert a docstring directly in the manual right?
My documentor skills are not up to this, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you see the current doc source there are examples that are easy to copy
I'm open to something like this but it should use gradients directly rather than loss values; we are increasingly moving away from I think the easiest way is to mimic |
I'm not sold (yet). What should the usage example look like: # Inputs: model, xs, ys, opt
ps = params(model)
for (x,y_true) in zip(xs, ys)
y_pred = model(x)
losses = (y_pred - y_pred).^2
update!(ps, opt, losses)
end are you saying it should be: # Inputs: model, xs, ys, opt
ps = params(model)
for (x,y_true) in zip(xs, ys)
y_pred = model(x)
losses = (y_pred - y_pred).^2
grads = gradients(losses, params) # does this function exist?
update!(ps, opt, losses)
end From a usage perspective I can easily calculate losses. I know there are |
It would look like update!(opt, ps, gradient(() -> loss(...), ps)) We can debate the merits of passing loss values around separately, but basically the issue is that they depend on an implementation detail (the fact that you have tracked values). That doesn't work in e.g. Zygote. |
This makes more sense when you understand the Zygote API. Zygote has a primitive called
This will calculate loss, back = Zygote.forward((model) -> Flux.logitcrossentropy(model(x), y), model)
gradients = back(1.0f) At this point, the |
So the example would be:
Does that actually work? Or does the call to the |
I guess I could grow to like: code line the following example:
That in turn suggests the following API
And the example would become:
|
Right, Zygote also supports the implicit-style API where you just feed a parameter set, which is a mid-way point that's compatible with Flux. In that style it would be # Inputs: model, xs, ys, opt
θ = params(model)
for (x,y) in zip(xs, ys)
dθ = gradient(θ) do
ŷ = model(x)
sum((y - ŷ).^2)
end
update!(opt, θ, dθ)
end I don't think this is really so much worse than your original example, and it has much clearer semantics in terms of what is being differentiated and when. To understand why that's an issue you need to consider cases like LSTMs. You need a boundary that tells Flux not to differentiate through loop iterations, only through what's in |
Can you point me to an example of this? I don't quite grok what you're saying. |
It's the same reason that we currently need To make it concrete, something like: θ = params(model)
h = rand(...)
for (x,y) in zip(xs, ys)
h, ŷ = model(x, h)
loss = sum((y - ŷ).^2)
update!(opt, θ, loss)
end This doesn't do what you'd want -- each loop iteration will backprop through |
So how do we feel about |
I'd be on board but it should have a different name to |
How about Is it still possible to insert it inthe middle of the current |
Sure, in principle. For now I think we should add the basic version of |
the core idea of the PR is to expose the inside of |
Yes, but all we need for that is |
How would you do overlapping-window truncated BPTT with Zygote? As in, if I have a sequence |
Yeah good question -- that's a nice example of the flexibility that it gives you. I think you can get the same effect by storing a pullback for each step and chaining them together in the backwards pass (as well as dropping the |
I note that |
After thinking about this more I still think It may save only one function call. And calling |
I don't necessarily disagree but do think the more urgent thing is to nail the low-level foundations, and then build on that. Are you up for PRing the |
re the discussion @staticfloat and I were having on slack:
Shortedned:
I feel like pushing advanced users to write their own loops is a good thing.
and giving them a simple
update!(params, opt, losses)
,to call from their loop, is giving them the tools for the job,