Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New optimisers interface #379

Merged
merged 25 commits into from
Oct 31, 2018
Merged

New optimisers interface #379

merged 25 commits into from
Oct 31, 2018

Conversation

DhairyaLGandhi
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi commented Aug 28, 2018

Rebase of #283

# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)

end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Watch out for spurious whitespace changes here. Might be an editor settings issue.

@MikeInnes MikeInnes changed the title Rebase branch newoptimizer to master https://github.com/FluxML/Flux.jl/pull/283 New optimisers interface Aug 28, 2018
src/Flux.jl Outdated
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
export Descent, ADAM, Momentum, Nesterov,
RMSProp, update!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer not to export this; it's much too generic a name.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

@@ -385,6 +385,10 @@ end

using Requires

Base.Broadcast.cat_nested(t::Base.Broadcast.Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what this is for? It seems oddly general, rather than doing anything Flux or TrackedArray-specific, to be here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were to help flatten a nested broadcast, taken from broadcast API. Will remove, them as they aren't necessary.

test/optimise.jl Show resolved Hide resolved
DescentWeightDecay(η = 1, γ = 0) = DescentWeightDecay(η, γ)
function update!(o::DescentWeightDecay, x, Δ)
η, γ = o.eta, o.gamma
@. x = x - η * (Δ + γ * x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like something went wrong here; we shouldn't be modifying x in the optimiser since the delta will get applied later on. Also, can't this be Compose(WeightDecay(), Descent())?

@. p.x = p.x - η * (p.Δ + γ * p.x)
@. p.Δ = 0
end
Descent(η = 0.1) = Descent(η)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this method causes stack overflow.

@MikeInnes
Copy link
Member

This really is nearly there now, thanks for the hard work @dhairyagandhi96. The one last thing we need is updates for the train! function so that it works with the new optimiser setup.

Also, would be great to get docstrings for the new decay optimisers, and tweak anything in the main docs that is out of date. We should document the optimiser setup in more detail, but that can wait for now.

opt = opt
else
if opt isa ADAMW
opt = Optimiser(opt, DescentWeightDecay(1, decay))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Equivalent to WeightDecay(decay)?

# Train function
function train!(loss::Function, data, opt; cb = () -> ())
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(model, data, loss, opt; cb) instead", :train)
train!(opt.ps, loss, data, opt.opt; cb = cb)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this would be compatible with any opt, not just the updaterule defined above (and then the tests would not need to avoid ()->()).

return Δ
end

# TODO: decay
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is done now? :)

function update!(o::InvDecay, x, Δ)
γ, n = o.gamma, o.n
Δ .*= 1 / (1 + γ * n)
o.n += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember that the same optimiser object will be used in an update! with many different parameters. Right now n is counting the overall number of parameters in the model, not the number of times an individual parameter was seen. Any parameter-specific state should go in a dictionary.


WeightDecay(η = 1) = WeightDecay(η, 0)
function update!(o::WeightDecay, x, Δ)
η, wd = o.eta, o.wd
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eta is not used here, I assume it can be removed


function update!(o::ExpDecay, x, Δ)
γ = o.gamma
@. Δ += γ * x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now this looks equivalent to weight decay

function update!(opt, xs)
for x in xs
x, Δ = data(x), grad(x)
Δ = update!(opt, x, Δ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should pass x without unwrapping it here, since that will be used as the key in any dictionaries.

@@ -35,7 +45,7 @@ function stop()
end

"""
train!(loss, data, opt)
train!(model, loss, data, opt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have this be train!(loss, params, opt, data) and not try to call params for the user.

cb = runall(cb)
opt = runall(opt)
@progress for d in data
try
l = loss(d...)
@interrupts back!(l)
opt()
foreach(x -> x.data .-= update!(opt, x.data, x.grad), ps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this be the time to use the update! you defined above?

for t = 1: 10^5
l = loss(rand(10))
back!(l)
delta = Optimise.update!(opt, w′.data, w′.grad)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also use the shorter update! method?

@CarloLucibello
Copy link
Member

Can we also choose names more consistently?

Now we have,

export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, InvDecay, ExpDecay, WeightDecay, DescentWeightDecay

while I would go for:

export Descent, Adam, Momentum, Nesterov, RmsProp,
AdaGrad, AdaMax, AdaDelta, AmsGrad,  Nadam,
Adamw, InvDecay, ExpDecay, WeightDecay, DescentWeightDecay

@MikeInnes
Copy link
Member

It's a good idea, perhaps something we should do in a separate PR to avoid making this one more complex (we could also save the Descent change for that PR, but it doesn't really matter).

@MikeInnes MikeInnes merged commit 43c5f90 into FluxML:master Oct 31, 2018
@MikeInnes
Copy link
Member

This turned into a bit of a marathon but it's a great PR at this point. Thanks for all the great work @dhairyagandhi96!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants