From f86dabc0c97e9dcdb8e225b356b730f4a079b98a Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 8 Jun 2020 15:02:03 +0530 Subject: [PATCH 1/7] add ADAM --- src/rules.jl | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 51153c98..39a3aca6 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -1,24 +1,39 @@ +abstract type AbstractOptimiser end + +(opt::AbstractOptimiser)(x, x̂, state) = update(opt, x, x̂, state) +(opt::AbstractOptimiser)(m, m̂) = update(opt, m, m̂, state(opt, m))[1] + """ Descent(η) Classic gradient descent optimiser with learning rate `η`. For each parameter `p` and its gradient `p̄`, this runs `p -= η*p̄`. """ -mutable struct Descent +mutable struct Descent <: AbstractOptimiser eta::Float64 end init(o::Descent, x) = nothing -function apply(o::Descent, x, x̄, state) +function apply(o::Descent, x, x̄, st) η = convert(eltype(x̄), o.eta) - x̄ .* η, state + x̄ .* η, st end -function (o::Descent)(m, m̄) - update(o, m, m̄, state(o, m))[1] +mutable struct ADAM{T,K} <: AbstractOptimiser + eta::T + beta::Tuple{K,K} end -function (o::Descent)(m, m̄, st) - update(o, m, m̄, st) +const ϵ = 1e-8 +init(o::ADAM, x) = IdDict() + +function apply(o::ADAM, x, Δ, st) + η, β = o.eta, o.beta + mt, vt, βp = get!(st, x, (zero(x), zero(x), β)) + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ^2 + @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η + st[x] = (mt, vt, βp .* β) + return Δ, st end From 2104909656d4014f60c22ba9a2f947a04d8cfae6 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 17 Jun 2020 15:04:02 +0530 Subject: [PATCH 2/7] rm abstract stuff --- src/rules.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 39a3aca6..f809d5f8 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -1,15 +1,10 @@ -abstract type AbstractOptimiser end - -(opt::AbstractOptimiser)(x, x̂, state) = update(opt, x, x̂, state) -(opt::AbstractOptimiser)(m, m̂) = update(opt, m, m̂, state(opt, m))[1] - """ Descent(η) Classic gradient descent optimiser with learning rate `η`. For each parameter `p` and its gradient `p̄`, this runs `p -= η*p̄`. """ -mutable struct Descent <: AbstractOptimiser +mutable struct Descent eta::Float64 end @@ -20,7 +15,15 @@ function apply(o::Descent, x, x̄, st) x̄ .* η, st end -mutable struct ADAM{T,K} <: AbstractOptimiser +function (o::Descent)(m, m̄) + update(o, m, m̄, state(o, m))[1] +end + +function (o::Descent)(m, m̄, state) + update(o, m, m̄, state) +end + +mutable struct ADAM{T,K} eta::T beta::Tuple{K,K} end From e813cfd24c5ed7d5de8ffb47b7aa1e3b4adcea29 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 17 Jun 2020 16:30:39 +0530 Subject: [PATCH 3/7] rm IdDict --- src/rules.jl | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index f809d5f8..71316cfd 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -10,17 +10,17 @@ end init(o::Descent, x) = nothing -function apply(o::Descent, x, x̄, st) +function apply(o::Descent, x, x̄, state) η = convert(eltype(x̄), o.eta) - x̄ .* η, st + x̄ .* η, state end function (o::Descent)(m, m̄) update(o, m, m̄, state(o, m))[1] end -function (o::Descent)(m, m̄, state) - update(o, m, m̄, state) +function (o::Descent)(m, m̄, st) + update(o, m, m̄, st) end mutable struct ADAM{T,K} @@ -29,14 +29,23 @@ mutable struct ADAM{T,K} end const ϵ = 1e-8 -init(o::ADAM, x) = IdDict() + +function (o::ADAM)(m, m̄) + op = update(o, m, m̄, state(o, m))[1] +end + +function (o::ADAM)(m, m̄, state) + update(o, m, m̄, state) +end + +init(o::ADAM, x::AbstractArray) = (zero(x), zero(x), o.beta) +init(o::ADAM, x) = nothing function apply(o::ADAM, x, Δ, st) η, β = o.eta, o.beta - mt, vt, βp = get!(st, x, (zero(x), zero(x), β)) + mt, vt, βp = st @. mt = β[1] * mt + (1 - β[1]) * Δ @. vt = β[2] * vt + (1 - β[2]) * Δ^2 @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η - st[x] = (mt, vt, βp .* β) - return Δ, st + return Δ, (mt, vt, βp .* β) end From bcebfaca78b1411637c41e9e2381b7aba1c976a9 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 17 Jun 2020 17:18:27 +0530 Subject: [PATCH 4/7] immutable adam structure --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 71316cfd..57e7ba33 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -23,7 +23,7 @@ function (o::Descent)(m, m̄, st) update(o, m, m̄, st) end -mutable struct ADAM{T,K} +struct ADAM{T,K} eta::T beta::Tuple{K,K} end From 3231216466fa896d78592d870867f41ceb357b6b Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sat, 27 Jun 2020 10:43:32 +0530 Subject: [PATCH 5/7] rm in place updates --- src/rules.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 57e7ba33..7a114baa 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -30,10 +30,6 @@ end const ϵ = 1e-8 -function (o::ADAM)(m, m̄) - op = update(o, m, m̄, state(o, m))[1] -end - function (o::ADAM)(m, m̄, state) update(o, m, m̄, state) end @@ -44,8 +40,8 @@ init(o::ADAM, x) = nothing function apply(o::ADAM, x, Δ, st) η, β = o.eta, o.beta mt, vt, βp = st - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η + mt = β[1] .* mt .+ (1 .- β[1]) .* Δ + vt = β[2] .* vt .+ (1 .- β[2]) .* Δ^2 + Δ = mt ./ (1 .- βp[1]) ./ (√(vt ./ (1 .- βp[2])) .+ ϵ) .* η return Δ, (mt, vt, βp .* β) end From 918aed7f3f27472eee11efaa25d4a2d30968836d Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 5 Nov 2020 12:38:43 +0530 Subject: [PATCH 6/7] correct broadcasting --- src/rules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 7a114baa..d685d173 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -41,7 +41,7 @@ function apply(o::ADAM, x, Δ, st) η, β = o.eta, o.beta mt, vt, βp = st mt = β[1] .* mt .+ (1 .- β[1]) .* Δ - vt = β[2] .* vt .+ (1 .- β[2]) .* Δ^2 - Δ = mt ./ (1 .- βp[1]) ./ (√(vt ./ (1 .- βp[2])) .+ ϵ) .* η + vt = β[2] .* vt .+ (1 .- β[2]) .* Δ .^ 2 + Δ = mt ./ (1 .- βp[1]) ./ (.√(vt ./ (1 .- βp[2])) .+ ϵ) .* η return Δ, (mt, vt, βp .* β) end From f0780c1b8155a34b14ba816721cbd42507edcce1 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 5 Nov 2020 23:26:22 +0530 Subject: [PATCH 7/7] use Float32 by default --- src/rules.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index d685d173..ac7b1b20 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -28,7 +28,7 @@ struct ADAM{T,K} beta::Tuple{K,K} end -const ϵ = 1e-8 +const ϵ = 1f-8 function (o::ADAM)(m, m̄, state) update(o, m, m̄, state) @@ -40,8 +40,8 @@ init(o::ADAM, x) = nothing function apply(o::ADAM, x, Δ, st) η, β = o.eta, o.beta mt, vt, βp = st - mt = β[1] .* mt .+ (1 .- β[1]) .* Δ - vt = β[2] .* vt .+ (1 .- β[2]) .* Δ .^ 2 - Δ = mt ./ (1 .- βp[1]) ./ (.√(vt ./ (1 .- βp[2])) .+ ϵ) .* η + mt = β[1] .* mt .+ (1f0 .- β[1]) .* Δ + vt = β[2] .* vt .+ (1f0 .- β[2]) .* Δ .^ 2 + Δ = mt ./ (1 .- βp[1]) ./ (.√(vt ./ (1f0 .- βp[2])) .+ ϵ) .* η return Δ, (mt, vt, βp .* β) end