Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a update!(params, opt, losses) suitable for writing own training loops #607

Closed
wants to merge 3 commits into from

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Feb 8, 2019

re the discussion @staticfloat and I were having on slack:
Shortedned:

Lyndon White

What is the pattern in Flux to use when
I already have a TrackedVector of losses,
and I just want to use those to update my model.
Do I use train! with dummy arguments? (edited)

I think I am happiest handling my early stopping and batching and loss calculation all myself in a loop.
Feeding train! an interator and a callback, and loss-function feels a bit too constraining, like i have to shoe horn my problem into that form. (edited)

Elliot Saba [Today at 10:57 AM]

I would just not use train!() at all in that case. Just call Flux.back!() with your losses directly, and then do the optimizer step.

I feel like pushing advanced users to write their own loops is a good thing.
and giving them a simple update!(params, opt, losses),
to call from their loop, is giving them the tools for the job,

@@ -6,6 +6,38 @@ function update!(opt, x, x̄)
update!(x, apply!(opt, x, copy(data(x̄))))
end


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring should likely be included in the manual somewhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes? There is a documentor command to insert a docstring directly in the manual right?
My documentor skills are not up to this, I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you see the current doc source there are examples that are easy to copy

@MikeInnes
Copy link
Member

I'm open to something like this but it should use gradients directly rather than loss values; we are increasingly moving away from back! as an API in favour of the more principled gradient.

I think the easiest way is to mimic update!(opt, x, delta) by providing update!(opt, ::Params, ::Grads).

@oxinabox
Copy link
Member Author

oxinabox commented Feb 8, 2019

I'm not sold (yet).
Tell me more about what you mean?

What should the usage example look like:
right now have:

# Inputs: model, xs, ys, opt 
ps = params(model)
for (x,y_true) in zip(xs, ys)
    y_pred = model(x)
    losses = (y_pred - y_pred).^2
    update!(ps, opt, losses)
end

are you saying it should be:

# Inputs: model, xs, ys, opt 
ps = params(model)
for (x,y_true) in zip(xs, ys)
    y_pred = model(x)
    losses = (y_pred - y_pred).^2
    grads = gradients(losses, params)   # does this function exist?
    update!(ps, opt, losses)
end

From a usage perspective I can easily calculate losses.
Defining losses is easy, even losses that are too annoying to define functions for,
it is easy enough to express them.

I know there are gradient comamnds that take functions as inputs,
but are there ones that take just TrackedArrays ?

@MikeInnes
Copy link
Member

MikeInnes commented Feb 8, 2019

It would look like

update!(opt, ps, gradient(() -> loss(...), ps))

We can debate the merits of passing loss values around separately, but basically the issue is that they depend on an implementation detail (the fact that you have tracked values). That doesn't work in e.g. Zygote. gradient lets you do all the same things in a way that simply generalises the mathematical idea, expresses nested derivatives naturally, doesn't have weird non-local dataflow etc.

@staticfloat
Copy link
Contributor

This makes more sense when you understand the Zygote API. Zygote has a primitive called forward() that you call as:

y, back = Zygote.forward(function_to_analyze, input_data)

This will calculate y = function_to_analyze(input_data), but it will also return a function back() such that back(sensitivity) will return the gradient imposed upon input_data. How then do we apply this to Flux? We set input_data to model, and function_to_analyze to (model) -> model(x), where x is a value captured through the closure. Is it a little bizarre? Yes. But it's also much more explicit and useful than the back!(loss) abstraction:

loss, back = Zygote.forward((model) -> Flux.logitcrossentropy(model(x), y), model)
gradients = back(1.0f)

At this point, the gradients are an object with a similar structure to model (e.g. if model[1].weight exists, then gradients[1].weight will exist and I can update my model via model[1].weight -= eta * gradients[1].weight).

@oxinabox
Copy link
Member Author

oxinabox commented Feb 8, 2019

So the example would be:

# Inputs: model, xs, ys, opt 
ps = params(model)
for (x,y_true) in zip(xs, ys)
    y_pred = model(x)
    losses = (y_pred - y_pred).^2
    update!(opt, ps, gradient(() -> losses, ps))
end

Does that actually work? Or does the call to the model need to be inside the closure that the gradient is taken over too?

@oxinabox
Copy link
Member Author

oxinabox commented Feb 8, 2019

I guess I could grow to like: code line the following example:

# Inputs: model, xs, ys, opt 
ps = params(model)
for (x,y_true) in zip(xs, ys)
    grad =  gradient(ps) do
          y_pred = model(x) 
          (y_pred - y_pred).^2
     end
    update!(opt, ps, grad))
end

That in turn suggests the following API

"""
...
getloss is a closure that takes no inputs, and returns a loss to be used for updating the model.
"""
function update!(getloss,ps, opt)
    grad = gradient(getloss, ps)
    update!(opt, ps, grad)
end

And the example would become:

# Inputs: model, xs, ys, opt 
ps = params(model)
for (x,y_true) in zip(xs, ys)
   update!(opt, ps,) do
          # Calculate and return loss here
          y_pred = model(x) 
          (y_pred - y_pred).^2
     end
end

@MikeInnes
Copy link
Member

MikeInnes commented Feb 8, 2019

Right, Zygote also supports the implicit-style API where you just feed a parameter set, which is a mid-way point that's compatible with Flux. In that style it would be

# Inputs: model, xs, ys, opt 
θ = params(model)
for (x,y) in zip(xs, ys)
  dθ = gradient(θ) do= model(x)
    sum((y - ŷ).^2)
  end
  update!(opt, θ, dθ)
end

I don't think this is really so much worse than your original example, and it has much clearer semantics in terms of what is being differentiated and when.

To understand why that's an issue you need to consider cases like LSTMs. You need a boundary that tells Flux not to differentiate through loop iterations, only through what's in gradient (which could also be a loop, but only over a single batch rather than an entire epoch).

@staticfloat
Copy link
Contributor

You need a boundary that tells Flux not to differentiate through loop iterations, only through what's in gradient

Can you point me to an example of this? I don't quite grok what you're saying.

@MikeInnes
Copy link
Member

MikeInnes commented Feb 8, 2019

It's the same reason that we currently need truncate!. If you train an RNN then by default Flux tracks over the entire epoch. Zygote doesn't have this problem because gradients are clearly scoped, so it'll do what you meant.

To make it concrete, something like:

θ = params(model)
h = rand(...)
for (x,y) in zip(xs, ys)
  h, ŷ = model(x, h)
  loss = sum((y - ŷ).^2)
  update!(opt, θ, loss)
end

This doesn't do what you'd want -- each loop iteration will backprop through n previous iterations. It's possible understand why and fix it if you have some intuition for how the AD actually works -- i.e. the way it tracks operations and builds a graph -- but it's not ideal to expect users to have that.

@oxinabox
Copy link
Member Author

oxinabox commented Feb 8, 2019

So how do we feel about
update!(getloss,ps, opt)
Where getloss is a zero argument closure?

@MikeInnes
Copy link
Member

I'd be on board but it should have a different name to update!, since it's quite a different function.

@oxinabox
Copy link
Member Author

oxinabox commented Feb 8, 2019

How about train_step?

Is it still possible to insert it inthe middle of the current train! loop?

@MikeInnes
Copy link
Member

Sure, in principle. For now I think we should add the basic version of update! and build stuff on top separately.

@oxinabox
Copy link
Member Author

oxinabox commented Feb 8, 2019

the core idea of the PR is to expose the inside of train! to the user so they can write there own.
If this method isn't good enough to be used in the place of the core of train!,
than is it good enough to be used to write your own training loops?

@MikeInnes
Copy link
Member

Yes, but all we need for that is update!(opt, xs, dxs). train_step! is just a (very minor) convenience over that, and I think it needs discussing whether it's worth the extra API to save a single function call.

@jekbradbury
Copy link
Contributor

How would you do overlapping-window truncated BPTT with Zygote? As in, if I have a sequence A B C D E and I would like to propagate gradients that come from C back through A and D back through B but not through A? (This is done in Tracker-style frameworks by truncating the left edge of the graph after every step.)

@MikeInnes
Copy link
Member

MikeInnes commented Feb 8, 2019

Yeah good question -- that's a nice example of the flexibility that it gives you. I think you can get the same effect by storing a pullback for each step and chaining them together in the backwards pass (as well as dropping the nth each time). It's probably a little more manual but easy enough to abstract over.

@oxinabox
Copy link
Member Author

oxinabox commented Feb 9, 2019

I note that
train_step!(getloss, ps, opt) (wiith getloss being a 0arg closure)
would have pleasing symetry in name and arguments,
to train!(loss, ps, data, opt) (with loss being a closure taking args as provided by iterating data)

@oxinabox
Copy link
Member Author

oxinabox commented Feb 26, 2019

After thinking about this more I still think train_step! is a good idea.

It may save only one function call.
But that function call is a gradients call.
And thinking about gradients explicitly, is not required.
It is nicer to think just about loss and know the AD will sort it out.
Mentioning gradients to me switches over to the headspace where I am wondering if I should mess with them.

And calling gradient then update feels more low level like you are digging deep in the guts.

@MikeInnes
Copy link
Member

I don't necessarily disagree but do think the more urgent thing is to nail the low-level foundations, and then build on that. Are you up for PRing the update! change?

@MikeInnes
Copy link
Member

I think this is moot after #651 (but thanks a lot for pushing the design along @oxinabox). We should bikeshed train_step in another issue and work something nice out.

@MikeInnes MikeInnes closed this Mar 6, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants