Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move all optimizers to Optimisers.jl #9

Merged
merged 15 commits into from
Feb 8, 2021
6 changes: 4 additions & 2 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ using Functors: functor, fmap, isleaf
include("interface.jl")
include("rules.jl")

export Descent, Momentum, Nesterov, RMSProp,
ADAM
export init, update!
darsnack marked this conversation as resolved.
Show resolved Hide resolved
export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief,
WeightDecay, Sequence

end # module
25 changes: 13 additions & 12 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
function patch(x, x̄)
return x .- x̄
function patch!(x, x̄)
@. x .-= x̄
darsnack marked this conversation as resolved.
Show resolved Hide resolved
return x
end

function state(o, x)
function init(o, x)
darsnack marked this conversation as resolved.
Show resolved Hide resolved
if isleaf(x)
return init(o, x)
else
x, _ = functor(x)
map(x -> state(o, x), x)
return map(x -> init(o, x), x)
end
end

function _update(o, x, x̄, st)
x̄, st = apply(o, x, x̄, st)
return patch(x, x̄), st
function _update!(o, x, x̄, st)
x̄, st = apply!(o, x, x̄, st)
return patch!(x, x̄), st
end

function update(o, x::T, x̄, state) where T
function update!(o, x::T, x̄, state) where T
if x̄ === nothing
return x, state
elseif isleaf(x)
_update(o, x, x̄, state)
_update!(o, x, x̄, state)
else
x̄, _ = functor(typeof(x), x̄)
x, re = functor(typeof(x), x)
xstate = map((x, x̄, state) -> update(o, x, x̄, state), x, x̄, state)
re(map(first, xstate)), map(x -> x[2], xstate)
x, restructure = functor(typeof(x), x)
xstate = map((x, x̄, state) -> update!(o, x, x̄, state), x, x̄, state)
restructure(map(first, xstate)), map(x -> x[2], xstate)
end
end
Loading