-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New optimisers interface #379
Conversation
src/optimise/Optimise.jl
Outdated
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad) | ||
|
||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Watch out for spurious whitespace changes here. Might be an editor settings issue.
89a24ac
to
0b440f1
Compare
src/Flux.jl
Outdated
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, | ||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM | ||
export Descent, ADAM, Momentum, Nesterov, | ||
RMSProp, update! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer not to export this; it's much too generic a name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it.
src/tracker/array.jl
Outdated
@@ -385,6 +385,10 @@ end | |||
|
|||
using Requires | |||
|
|||
Base.Broadcast.cat_nested(t::Base.Broadcast.Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what this is for? It seems oddly general, rather than doing anything Flux or TrackedArray-specific, to be here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These were to help flatten a nested broadcast, taken from broadcast API. Will remove, them as they aren't necessary.
src/optimise/optimisers.jl
Outdated
DescentWeightDecay(η = 1, γ = 0) = DescentWeightDecay(η, γ) | ||
function update!(o::DescentWeightDecay, x, Δ) | ||
η, γ = o.eta, o.gamma | ||
@. x = x - η * (Δ + γ * x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like something went wrong here; we shouldn't be modifying x
in the optimiser since the delta will get applied later on. Also, can't this be Compose(WeightDecay(), Descent())
?
src/optimise/optimisers.jl
Outdated
@. p.x = p.x - η * (p.Δ + γ * p.x) | ||
@. p.Δ = 0 | ||
end | ||
Descent(η = 0.1) = Descent(η) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this method causes stack overflow.
This really is nearly there now, thanks for the hard work @dhairyagandhi96. The one last thing we need is updates for the Also, would be great to get docstrings for the new decay optimisers, and tweak anything in the main docs that is out of date. We should document the optimiser setup in more detail, but that can wait for now. |
src/optimise/deprecations.jl
Outdated
opt = opt | ||
else | ||
if opt isa ADAMW | ||
opt = Optimiser(opt, DescentWeightDecay(1, decay)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Equivalent to WeightDecay(decay)
?
src/optimise/deprecations.jl
Outdated
# Train function | ||
function train!(loss::Function, data, opt; cb = () -> ()) | ||
depwarn("train!(loss, data, opt; cb) is deprecated; use train!(model, data, loss, opt; cb) instead", :train) | ||
train!(opt.ps, loss, data, opt.opt; cb = cb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally this would be compatible with any opt
, not just the updaterule
defined above (and then the tests would not need to avoid ()->()
).
src/optimise/optimisers.jl
Outdated
return Δ | ||
end | ||
|
||
# TODO: decay |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is done now? :)
src/optimise/optimisers.jl
Outdated
function update!(o::InvDecay, x, Δ) | ||
γ, n = o.gamma, o.n | ||
Δ .*= 1 / (1 + γ * n) | ||
o.n += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember that the same optimiser object will be used in an update!
with many different parameters. Right now n
is counting the overall number of parameters in the model, not the number of times an individual parameter was seen. Any parameter-specific state should go in a dictionary.
src/optimise/optimisers.jl
Outdated
|
||
WeightDecay(η = 1) = WeightDecay(η, 0) | ||
function update!(o::WeightDecay, x, Δ) | ||
η, wd = o.eta, o.wd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eta is not used here, I assume it can be removed
src/optimise/optimisers.jl
Outdated
|
||
function update!(o::ExpDecay, x, Δ) | ||
γ = o.gamma | ||
@. Δ += γ * x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now this looks equivalent to weight decay
src/optimise/train.jl
Outdated
function update!(opt, xs) | ||
for x in xs | ||
x, Δ = data(x), grad(x) | ||
Δ = update!(opt, x, Δ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should pass x
without unwrapping it here, since that will be used as the key in any dictionaries.
@@ -35,7 +45,7 @@ function stop() | |||
end | |||
|
|||
""" | |||
train!(loss, data, opt) | |||
train!(model, loss, data, opt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should have this be train!(loss, params, opt, data)
and not try to call params
for the user.
src/optimise/train.jl
Outdated
cb = runall(cb) | ||
opt = runall(opt) | ||
@progress for d in data | ||
try | ||
l = loss(d...) | ||
@interrupts back!(l) | ||
opt() | ||
foreach(x -> x.data .-= update!(opt, x.data, x.grad), ps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this be the time to use the update!
you defined above?
for t = 1: 10^5 | ||
l = loss(rand(10)) | ||
back!(l) | ||
delta = Optimise.update!(opt, w′.data, w′.grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also use the shorter update!
method?
Can we also choose names more consistently? Now we have,
while I would go for:
|
It's a good idea, perhaps something we should do in a separate PR to avoid making this one more complex (we could also save the |
This turned into a bit of a marathon but it's a great PR at this point. Thanks for all the great work @dhairyagandhi96! |
Rebase of #283