Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor code to make it Optimization.jl compatible #196

Merged
merged 8 commits into from
Dec 18, 2022
13 changes: 9 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LsqFit = "2fda8390-95c7-5789-9bda-21331edee243"
PenaltyFunctions = "06bb1623-fdd5-5ca2-a01c-88eae3ea319e"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be good to keep a SciMLSensitivity because users just get a weird error telling them to add it if they use AutoZygote, and I think it won't be too uncommon to do that.

That said, a better solution may be to make SciMLSensitivity be a weak dependency of DiffEqBase that is added and used whenever Zygote, ReverseDiff, or Tracker are used. @oscardssmith is that possible?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that would work


[compat]
Calculus = "0.5"
Expand All @@ -34,15 +32,22 @@ julia = "1.6"
[extras]
BlackBoxOptim = "a134a8b2-14d6-55f6-9291-3336d3ab0209"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationBBO = "3e6eede4-6085-4f62-9a71-46d9bc1eb92b"
OptimizationNLopt = "4e6fcdb7-1186-4e1f-a706-475e75c168bb"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
ParameterizedFunctions = "65888b18-ceab-5e60-b2b9-181511a3b968"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "BlackBoxOptim", "DelayDiffEq", "LeastSquaresOptim", "NLopt", "Optim", "OrdinaryDiffEq", "ParameterizedFunctions", "Random", "StochasticDiffEq", "SteadyStateDiffEq"]
test = ["Test", "BlackBoxOptim", "DelayDiffEq", "ForwardDiff", "NLopt", "Optim", "Optimization", "OptimizationBBO", "OptimizationNLopt", "OptimizationOptimJL", "OrdinaryDiffEq", "ParameterizedFunctions", "Random", "SciMLSensitivity", "StochasticDiffEq", "SteadyStateDiffEq", "Sundials", "Zygote"]
9 changes: 3 additions & 6 deletions src/DiffEqParamEstim.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module DiffEqParamEstim
using DiffEqBase, LsqFit, PenaltyFunctions,
RecursiveArrayTools, ForwardDiff, Calculus, Distributions,
LinearAlgebra, SciMLSensitivity, Dierckx,
SciMLBase
using DiffEqBase, PenaltyFunctions,
RecursiveArrayTools, Distributions,
LinearAlgebra, Dierckx, SciMLBase

import PreallocationTools
STANDARD_PROB_GENERATOR(prob, p) = remake(prob; u0 = eltype(p).(prob.u0), p = p)
Expand All @@ -23,9 +22,7 @@ STANDARD_MS_PROB_GENERATOR = function (prob, p, k)
end

include("cost_functions.jl")
include("lm_fit.jl")
include("build_loss_objective.jl")
include("build_lsoptim_objective.jl")
include("kernels.jl")
include("two_stage_method.jl")
include("multiple_shooting_objective.jl")
Expand Down
77 changes: 5 additions & 72 deletions src/build_loss_objective.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,12 @@
export DiffEqObjective, build_loss_objective

struct DiffEqObjective{F, F2} <: Function
cost_function::F
cost_function2::F2
end

function diffeq_sen_full(f, u0, tspan, p, t, alg; kwargs...)
prob = ODEForwardSensitivityProblem(f, u0, tspan, p)
sol = solve(prob, alg; kwargs...)(t)
nvar = length(u0)
sol[1:nvar, :], [sol[(i * nvar + 1):(i * nvar + nvar), :] for i in 1:length(p)]
end

function diffeq_sen_l2!(res, df, u0, tspan, p, t, data, alg; kwargs...)
prob = ODEProblem(df, u0, tspan, p)
sol = solve(prob, alg, saveat = t; kwargs...)
function dgdu_discrete(out, u, p, t, i)
@. out = 2 * (data[:, i] - u)
end
fill!(res, false)
res .-= adjoint_sensitivities(sol, alg; t, dgdu_discrete, kwargs...)[2][1, :]
end

(f::DiffEqObjective)(x) = f.cost_function(x)
(f::DiffEqObjective)(x, y) = f.cost_function2(x, y)
export build_loss_objective

function build_loss_objective(prob::SciMLBase.AbstractSciMLProblem, alg, loss,
adtype = SciMLBase.NoAD(),
regularization = nothing, args...;
priors = nothing, mpg_autodiff = false,
verbose_opt = false, verbose_steps = 100,
priors = nothing,
prob_generator = STANDARD_PROB_GENERATOR,
autodiff_prototype = mpg_autodiff ? zero(prob.p) : nothing,
autodiff_chunk = mpg_autodiff ?
ForwardDiff.Chunk(autodiff_prototype) :
nothing, flsa_gradient = false,
adjsa_gradient = false,
kwargs...)
if verbose_opt
count = 0 # keep track of # function evaluations
end
cost_function = function (p)
cost_function = function (p, _)
tmp_prob = prob_generator(prob, p)
if typeof(loss) <: Union{L2Loss, LogLikeLoss}
sol = solve(tmp_prob, alg, args...; saveat = loss.t, save_everystep = false,
Expand All @@ -56,42 +23,8 @@ function build_loss_objective(prob::SciMLBase.AbstractSciMLProblem, alg, loss,
if regularization !== nothing
loss_val += regularization(p)
end

if verbose_opt
count::Int += 1
if mod(count, verbose_steps) == 0
println("Iteration: $count")
println("Current Cost: $loss_val")
println("Parameters: $p")
end
end
loss_val
end

if mpg_autodiff
gcfg = ForwardDiff.GradientConfig(cost_function, autodiff_prototype, autodiff_chunk)
g! = (x, out) -> ForwardDiff.gradient!(out, cost_function, x, gcfg)
elseif flsa_gradient
if typeof(loss) <: L2Loss
function g!(x, out)
sol_, sens = diffeq_sen_full(prob.f, prob.u0, prob.tspan, x, loss.t, alg)
l2lossgradient!(out, sol_, loss.data, sens, length(prob.p))
end
else
throw("LSA gradient only for L2Loss")
end
elseif adjsa_gradient
g! = (x, out) -> diffeq_sen_l2!(out, prob.f, prob.u0, prob.tspan, x, loss.t,
loss.data, alg)
else
g! = (x, out) -> Calculus.finite_difference!(cost_function, x, out, :central)
end

cost_function2 = function (p, grad)
if length(grad) > 0
g!(p, grad)
end
cost_function(p)
end
DiffEqObjective(cost_function, cost_function2)
return OptimizationFunction(cost_function, adtype)
end
17 changes: 0 additions & 17 deletions src/build_lsoptim_objective.jl

This file was deleted.

15 changes: 0 additions & 15 deletions src/lm_fit.jl

This file was deleted.

44 changes: 6 additions & 38 deletions src/multiple_shooting_objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,12 @@ struct Merged_Solution{T1, T2, T3}
end

function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss,
adtype = SciMLBase.NoAD(),
regularization = nothing; priors = nothing,
mpg_autodiff = false, discontinuity_weight = 1.0,
verbose_opt = false, verbose_steps = 100,
discontinuity_weight = 1.0,
prob_generator = STANDARD_MS_PROB_GENERATOR,
autodiff_prototype = mpg_autodiff ?
zeros(init_N_params) : nothing,
autodiff_chunk = mpg_autodiff ?
ForwardDiff.Chunk(autodiff_prototype) :
nothing,
kwargs...)
if verbose_opt
count = 0 # keep track of # function evaluations
end

cost_function = function (p)
cost_function = function (p, _)
t0, tf = prob.tspan
P, N = length(prob.p), length(prob.u0)
K = Int((length(p) - P) / N)
Expand Down Expand Up @@ -68,9 +59,10 @@ function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss,
if priors !== nothing
loss_val += prior_loss(priors, p[(end - length(priors)):end])
end
if regularization !== nothing
if !isnothing(regularization)
loss_val += regularization(p)
end

for k in 2:K
if typeof(discontinuity_weight) <: Real
loss_val += discontinuity_weight *
Expand All @@ -80,32 +72,8 @@ function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss,
(sol[k][1] - sol[k - 1][end]) .^ 2)
end
end
if verbose_opt
count::Int += 1
if mod(count, verbose_steps) == 0
println("Iteration: $count")
println("Current Cost: $loss_val")
println("Parameters: $p")
end
end
loss_val
end

if mpg_autodiff
gcfg = ForwardDiff.GradientConfig(cost_function, autodiff_prototype,
autodiff_chunk)
g! = (x, out) -> ForwardDiff.gradient!(out, cost_function, x, gcfg)
else
g! = (x, out) -> Calculus.finite_difference!(cost_function, x, out,
:central)
end

cost_function2 = function (p, grad)
if length(grad) > 0
g!(p, grad)
end
cost_function(p)
end

DiffEqObjective(cost_function, cost_function2)
return OptimizationFunction(cost_function, adtype)
end
63 changes: 14 additions & 49 deletions src/two_stage_method.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
export TwoStageCost, two_stage_method
export TwoStageCost, two_stage_objective

struct TwoStageCost{F, F2, D} <: Function
struct TwoStageCost{F, D} <: Function
cost_function::F
cost_function2::F2
estimated_solution::D
estimated_derivative::D
end

(f::TwoStageCost)(p) = f.cost_function(p)
(f::TwoStageCost)(p, g) = f.cost_function2(p, g)
(f::TwoStageCost)(p, _p = nothing) = f.cost_function(p, _p)

decide_kernel(kernel::CollocationKernel) = kernel
function decide_kernel(kernel::Symbol)
Expand Down Expand Up @@ -71,8 +69,9 @@ function construct_estimated_solution_and_derivative!(data, kernel, tpoints)
estimated_solution = reduce(hcat, transpose.(last.(x)))
estimated_derivative, estimated_solution
end

function construct_iip_cost_function(f, du, preview_est_sol, preview_est_deriv, tpoints)
function (p)
function (p, _)
_du = PreallocationTools.get_tmp(du, p)
vecdu = vec(_du)
cost = zero(first(p))
Expand All @@ -87,7 +86,7 @@ function construct_iip_cost_function(f, du, preview_est_sol, preview_est_deriv,
end

function construct_oop_cost_function(f, du, preview_est_sol, preview_est_deriv, tpoints)
function (p)
function (p, _)
cost = zero(first(p))
for i in 1:length(preview_est_sol)
est_sol = preview_est_sol[i]
Expand All @@ -98,61 +97,27 @@ function construct_oop_cost_function(f, du, preview_est_sol, preview_est_deriv,
end
end

get_chunksize(cs) = cs
get_chunksize(cs::Type{Val{CS}}) where {CS} = CS

function two_stage_method(prob::DiffEqBase.DEProblem, tpoints, data;
kernel = EpanechnikovKernel(),
loss_func = L2Loss, mpg_autodiff = false,
verbose = false, verbose_steps = 100,
autodiff_chunk = length(prob.p))
function two_stage_objective(prob::DiffEqBase.DEProblem, tpoints, data,
adtype = SciMLBase.NoAD();
kernel = EpanechnikovKernel())
f = prob.f
kernel_function = decide_kernel(kernel)
estimated_derivative, estimated_solution = construct_estimated_solution_and_derivative!(data,
kernel_function,
tpoints)

# Step - 2

du = PreallocationTools.dualcache(similar(prob.u0), autodiff_chunk)
preview_est_sol = [@view estimated_solution[:, i]
for i in 1:size(estimated_solution, 2)]
preview_est_deriv = [@view estimated_derivative[:, i]
for i in 1:size(estimated_solution, 2)]
if DiffEqBase.isinplace(prob)
cost_function = construct_iip_cost_function(f, du, preview_est_sol,
preview_est_deriv, tpoints)
else
cost_function = construct_oop_cost_function(f, du, preview_est_sol,
preview_est_deriv, tpoints)
end

if mpg_autodiff
gcfg = ForwardDiff.GradientConfig(cost_function, prob.p,
ForwardDiff.Chunk{get_chunksize(autodiff_chunk)}())
g! = (x, out) -> ForwardDiff.gradient!(out, cost_function, x, gcfg)
cost_function = if isinplace(prob)
construct_oop_cost_function(f, prob.u0, preview_est_sol, preview_est_deriv, tpoints)
else
g! = (x, out) -> Calculus.finite_difference!(cost_function, x, out, :central)
end
if verbose
count = 0 # keep track of # function evaluations
end
cost_function2 = function (p, grad)
if length(grad) > 0
g!(p, grad)
end
loss_val = cost_function(p)
if verbose
count::Int += 1
if mod(count, verbose_steps) == 0
println("Iteration: $count")
println("Current Cost: $loss_val")
println("Parameters: $p")
end
end
loss_val
construct_iip_cost_function(f, prob.u0, preview_est_sol, preview_est_deriv, tpoints)
end

return TwoStageCost(cost_function, cost_function2, estimated_solution,
estimated_derivative)
return OptimizationFunction(TwoStageCost(cost_function, estimated_solution,
estimated_derivative), adtype)
end
Loading