From 85aaae99b42f5bd6e4eaedee444b4aa4781f4fe6 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 12 Feb 2021 09:24:09 -0600 Subject: [PATCH 1/5] Add ParameterSchedulers.jl --- Manifest.toml | 7 +++++++ Project.toml | 1 + 2 files changed, 8 insertions(+) diff --git a/Manifest.toml b/Manifest.toml index 7210fa33bf..8cd785e521 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -256,6 +256,13 @@ git-tree-sha1 = "d45739abcfc03b51f6a42712894a593f74c80a23" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.3.3" +[[ParameterSchedulers]] +git-tree-sha1 = "ab80539e1061e586a49300813039f5c11d3ba8e8" +repo-rev = "darsnack/rm-optim" +repo-url = "https://github.com/darsnack/ParameterSchedulers.jl.git" +uuid = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" +version = "0.2.0" + [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/Project.toml b/Project.toml index b51f143252..39c3c3afa9 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" From fbaeb06db0a8b6374997d53f7eff79cb937c3275 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 12 Feb 2021 09:39:41 -0600 Subject: [PATCH 2/5] Add initial implementation --- src/optimise/Optimise.jl | 9 ++++++++ src/optimise/schedulers.jl | 42 ++++++++++++++++++++++++++++++++++++++ test/optimise.jl | 13 ++++++++++++ 3 files changed, 64 insertions(+) create mode 100644 src/optimise/schedulers.jl diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index e2485a05d0..854175132d 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -9,6 +9,15 @@ export train!, update!, ClipValue, ClipNorm include("optimisers.jl") + +module Schedule + using ..Optimise + using ParameterSchedulers + import ParameterSchedulers: AbstractSchedule + + include("schedulers.jl") +end + include("train.jl") end diff --git a/src/optimise/schedulers.jl b/src/optimise/schedulers.jl new file mode 100644 index 0000000000..6fdd5e794d --- /dev/null +++ b/src/optimise/schedulers.jl @@ -0,0 +1,42 @@ +""" + Scheduler{T<:ParameterSchedulers.AbstractSchedule, O} + Scheduler(schedule::ParameterSchedulers.AbstractSchedule, opt) + +Wrap a `schedule` and `opt` together with a `Scheduler`. +Each time `Optimise.update!` is called, the scheduler sets the +learning rate (`opt.eta`) to the current schedule value, +then the schedule is advanced by one iteration. + +A `Scheduler` can be used anywhere a Flux optimizer is used. + +!!! warning + `opt` is limited to optimizer with a learning rate + (i.e. the field `opt.eta` exists) + +# Examples +```jldoctest +julia> opt = Momentum(); + +julia> schedule = Schedule.Exp(λ = 0.01, γ = 0.5); + +julia> sopt = Schedule.Scheduler(schedule, opt) + +""" +mutable struct Scheduler{T<:AbstractSchedule, O} + schedule::T + optimiser::O + iter::IdDict{Any, Int} + + Scheduler{T, O}(schedule::T, optimiser::O) where {T, O} = + new{T, O}(schedule, optimiser, IdDict()) +end + +function Optimise.apply!(opt::Scheduler, x, dx) + # set param + t = get!(opt.iter, x, 1) + opt.optimiser.eta = opt.schedule[t] + opt.iter[x] += 1 + + # do normal apply + return Optimise.apply!(opt.optimiser, x, dx) +end \ No newline at end of file diff --git a/test/optimise.jl b/test/optimise.jl index 04cbf6f6c0..27b7897a8f 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -44,6 +44,19 @@ end end end +@testset "Scheduler" begin + schedule = Schedule.Exp(λ = 0.1, γ = 0.5) + opt = Descent() + scheduler = Schedule.Scheduler(schedule, opt) + m = Chain(Dense(10, 5), Dense(5, 2, tanh)) + ps = params(m) + for t in 1:10 + gs = gradient(() -> sum(m(rand(10))), ps) + Optimise.update!(scheduler, ps, gs) + @test opt.eta ≈ schedule[t] + end +end + @testset "Training Loop" begin i = 0 l = 1 From d3881becf825f5839c9ab3b836db690d7502d2c3 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 12 Feb 2021 09:51:49 -0600 Subject: [PATCH 3/5] Fix export errors and docstring --- src/Flux.jl | 2 ++ src/optimise/schedulers.jl | 16 +++++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 5e6776d601..179a68d80c 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -24,6 +24,8 @@ include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs using .Optimise: skip +using .Optimise: Schedule +export Schedule export Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, OADAM, ADAMW, RADAM, AdaBelief, InvDecay, ExpDecay, diff --git a/src/optimise/schedulers.jl b/src/optimise/schedulers.jl index 6fdd5e794d..64a5223b59 100644 --- a/src/optimise/schedulers.jl +++ b/src/optimise/schedulers.jl @@ -17,19 +17,21 @@ A `Scheduler` can be used anywhere a Flux optimizer is used. ```jldoctest julia> opt = Momentum(); -julia> schedule = Schedule.Exp(λ = 0.01, γ = 0.5); - -julia> sopt = Schedule.Scheduler(schedule, opt) +julia> schedule = Schedule.Exp(λ = 0.01, γ = 0.5) +ParameterSchedulers.Exp{Float64}(0.01, 0.5) +julia> scheduler = Schedule.Scheduler(schedule, opt) +Scheduler(ParameterSchedulers.Exp{Float64}(0.01, 0.5), Momentum(0.01, 0.9, IdDict{Any,Any}())) """ -mutable struct Scheduler{T<:AbstractSchedule, O} +struct Scheduler{T<:AbstractSchedule, O} schedule::T optimiser::O iter::IdDict{Any, Int} - - Scheduler{T, O}(schedule::T, optimiser::O) where {T, O} = - new{T, O}(schedule, optimiser, IdDict()) end +Scheduler(schedule, optimiser) = Scheduler(schedule, optimiser, IdDict{Any, Int}()) + +Base.show(io::IO, o::Scheduler) = + print(io, "Scheduler(", o.schedule, ", ", o.optimiser, ")") function Optimise.apply!(opt::Scheduler, x, dx) # set param From c3071776e3d01fbbc49c364bf8d43590ca59a240 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 15 Feb 2021 10:44:17 -0600 Subject: [PATCH 4/5] Remove Scheduler --- src/optimise/Optimise.jl | 7 +----- src/optimise/schedulers.jl | 44 -------------------------------------- 2 files changed, 1 insertion(+), 50 deletions(-) delete mode 100644 src/optimise/schedulers.jl diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 854175132d..dd18f2ba1f 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -9,15 +9,10 @@ export train!, update!, ClipValue, ClipNorm include("optimisers.jl") +include("train.jl") module Schedule - using ..Optimise using ParameterSchedulers - import ParameterSchedulers: AbstractSchedule - - include("schedulers.jl") end -include("train.jl") - end diff --git a/src/optimise/schedulers.jl b/src/optimise/schedulers.jl deleted file mode 100644 index 64a5223b59..0000000000 --- a/src/optimise/schedulers.jl +++ /dev/null @@ -1,44 +0,0 @@ -""" - Scheduler{T<:ParameterSchedulers.AbstractSchedule, O} - Scheduler(schedule::ParameterSchedulers.AbstractSchedule, opt) - -Wrap a `schedule` and `opt` together with a `Scheduler`. -Each time `Optimise.update!` is called, the scheduler sets the -learning rate (`opt.eta`) to the current schedule value, -then the schedule is advanced by one iteration. - -A `Scheduler` can be used anywhere a Flux optimizer is used. - -!!! warning - `opt` is limited to optimizer with a learning rate - (i.e. the field `opt.eta` exists) - -# Examples -```jldoctest -julia> opt = Momentum(); - -julia> schedule = Schedule.Exp(λ = 0.01, γ = 0.5) -ParameterSchedulers.Exp{Float64}(0.01, 0.5) - -julia> scheduler = Schedule.Scheduler(schedule, opt) -Scheduler(ParameterSchedulers.Exp{Float64}(0.01, 0.5), Momentum(0.01, 0.9, IdDict{Any,Any}())) -""" -struct Scheduler{T<:AbstractSchedule, O} - schedule::T - optimiser::O - iter::IdDict{Any, Int} -end -Scheduler(schedule, optimiser) = Scheduler(schedule, optimiser, IdDict{Any, Int}()) - -Base.show(io::IO, o::Scheduler) = - print(io, "Scheduler(", o.schedule, ", ", o.optimiser, ")") - -function Optimise.apply!(opt::Scheduler, x, dx) - # set param - t = get!(opt.iter, x, 1) - opt.optimiser.eta = opt.schedule[t] - opt.iter[x] += 1 - - # do normal apply - return Optimise.apply!(opt.optimiser, x, dx) -end \ No newline at end of file From 82786a8d76c3aaca46a26f0a7568ef0619b50a9a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 15 Feb 2021 11:48:13 -0600 Subject: [PATCH 5/5] Bring next! into Optimise scope Co-authored-by: Carlo Lucibello --- src/optimise/Optimise.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index dd18f2ba1f..d552a5b388 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -12,7 +12,8 @@ include("optimisers.jl") include("train.jl") module Schedule - using ParameterSchedulers +using ParameterSchedulers +using ParameterSchedulers : next! end end