Skip to content

Commit

Permalink
Merge pull request #9 from darsnack/darsnack/initial-impl
Browse files Browse the repository at this point in the history
Move all optimizers to Optimisers.jl
  • Loading branch information
DhairyaLGandhi authored Feb 8, 2021
2 parents 051270b + ce489fb commit 638ace1
Show file tree
Hide file tree
Showing 4 changed files with 501 additions and 77 deletions.
5 changes: 3 additions & 2 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ using Functors: functor, fmap, isleaf
include("interface.jl")
include("rules.jl")

export Descent, Momentum, Nesterov, RMSProp,
ADAM
export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief,
WeightDecay, OptimiserChain

end # module
12 changes: 5 additions & 7 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
function patch(x, x̄)
return x .-
end
patch(x, x̄) = x .-

function state(o, x)
if isleaf(x)
return init(o, x)
else
x, _ = functor(x)
map(x -> state(o, x), x)
return map(x -> state(o, x), x)
end
end

Expand All @@ -20,11 +18,11 @@ function update(o, x::T, x̄, state) where T
if=== nothing
return x, state
elseif isleaf(x)
_update(o, x, x̄, state)
return _update(o, x, x̄, state)
else
x̄, _ = functor(typeof(x), x̄)
x, re = functor(typeof(x), x)
x, restructure = functor(typeof(x), x)
xstate = map((x, x̄, state) -> update(o, x, x̄, state), x, x̄, state)
re(map(first, xstate)), map(x -> x[2], xstate)
return restructure(map(first, xstate)), map(x -> x[2], xstate)
end
end
Loading

0 comments on commit 638ace1

Please sign in to comment.