Skip to content

Commit

Permalink
compose tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Oct 5, 2018
1 parent 9bc9771 commit 0f2019e
Showing 1 changed file with 6 additions and 20 deletions.
26 changes: 6 additions & 20 deletions src/optimise/optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ function update!(o::NADAM, x, Δ)
end

"""
ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
ADAMW((η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
"""
Expand All @@ -231,36 +231,22 @@ ADAMW(η = 0.001, β = (0.9, 0.999), η_decay = 1, γ_decay = 0) = Compose(ADAM(
# Compose optimizers

"""
`Compose(Compose(...), ...)`
Compose(a, b, c...)
Compose optimizers to support inbuilt or custom gradient updates while fitting the loss.
Example:\n\n
`Compose(ADAM(), Compose(RMSProp(0.001), ExpDecay(0.02)))`
Combine several optimisers into one; each optimiser produces a modified gradient
that will be fed into the next, and this is finally applied to the parameter as
usual.
"""
mutable struct Compose
os::Vector{Any}
Compose(o...) = Compose(Any[o...])
end

Compose(o...) = Compose(flattenCompose(o...))

@forward Compose.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
@forward Compose.os Base.iterate

Base.getindex(c::Compose, i::AbstractArray) = Compose(c.os[i]...)

function flattenCompose(o...)
res = []
for opt in o
if opt isa Compose
push!(res, flattenCompose(opt.os...)...)
else
push!(res, opt)
end
end
return res
end

function update!(o::Compose, x, Δ)
for opt in o.os
Δ = update!(opt, x, Δ)
Expand Down

0 comments on commit 0f2019e

Please sign in to comment.