Skip to content

Commit

Permalink
Wrap cost functions in OptimizationFunction and remove fitting codes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Dec 13, 2022
1 parent 53c4979 commit 66594b3
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 163 deletions.
9 changes: 3 additions & 6 deletions src/DiffEqParamEstim.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module DiffEqParamEstim
using DiffEqBase, LsqFit, PenaltyFunctions,
RecursiveArrayTools, ForwardDiff, Calculus, Distributions,
LinearAlgebra, SciMLSensitivity, Dierckx,
SciMLBase
using DiffEqBase, PenaltyFunctions,
RecursiveArrayTools, Distributions,
LinearAlgebra, Dierckx, SciMLBase

import PreallocationTools
STANDARD_PROB_GENERATOR(prob, p) = remake(prob; u0 = eltype(p).(prob.u0), p = p)
Expand All @@ -23,9 +22,7 @@ STANDARD_MS_PROB_GENERATOR = function (prob, p, k)
end

include("cost_functions.jl")
include("lm_fit.jl")
include("build_loss_objective.jl")
include("build_lsoptim_objective.jl")
include("kernels.jl")
include("two_stage_method.jl")
include("multiple_shooting_objective.jl")
Expand Down
71 changes: 5 additions & 66 deletions src/build_loss_objective.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,12 @@
export DiffEqObjective, build_loss_objective

struct DiffEqObjective{F <: Function, F2 <: Union{Function, Nothing}}
cost_function::F
cost_function2::F2
end

function diffeq_sen_full(f, u0, tspan, p, t, alg; kwargs...)
prob = ODEForwardSensitivityProblem(f, u0, tspan, p)
sol = solve(prob, alg; kwargs...)(t)
nvar = length(u0)
sol[1:nvar, :], [sol[(i * nvar + 1):(i * nvar + nvar), :] for i in 1:length(p)]
end

function diffeq_sen_l2!(res, df, u0, tspan, p, t, data, alg; kwargs...)
prob = ODEProblem(df, u0, tspan, p)
sol = solve(prob, alg, saveat = t; kwargs...)
function dgdu_discrete(out, u, p, t, i)
@. out = 2 * (data[:, i] - u)
end
fill!(res, false)
res .-= adjoint_sensitivities(sol, alg; t, dgdu_discrete, kwargs...)[2][1, :]
end

(f::DiffEqObjective)(x) = f.cost_function(x)
export build_loss_objective

function build_loss_objective(prob::SciMLBase.AbstractSciMLProblem, alg, loss,
regularization = nothing, args...;
priors = nothing, autodiff = false,
verbose_opt = false, verbose_steps = 100,
priors = nothing,
prob_generator = STANDARD_PROB_GENERATOR,
flsa_gradient = false,
adjsa_gradient = false,
kwargs...)
if verbose_opt
count = 0 # keep track of # function evaluations
end
cost_function = function (p)

cost_function = function (p, nothing)
tmp_prob = prob_generator(prob, p)
if typeof(loss) <: Union{L2Loss, LogLikeLoss}
sol = solve(tmp_prob, alg, args...; saveat = loss.t, save_everystep = false,
Expand All @@ -53,40 +24,8 @@ function build_loss_objective(prob::SciMLBase.AbstractSciMLProblem, alg, loss,
loss_val += regularization(p)
end

if verbose_opt
count::Int += 1
if mod(count, verbose_steps) == 0
println("Iteration: $count")
println("Current Cost: $loss_val")
println("Parameters: $p")
end
end
loss_val
end

if flsa_gradient
if typeof(loss) <: L2Loss
function g!(x, out)
sol_, sens = diffeq_sen_full(prob.f, prob.u0, prob.tspan, x, loss.t, alg)
l2lossgradient!(out, sol_, loss.data, sens, length(prob.p))
end
else
throw("LSA gradient only for L2Loss")
end
elseif adjsa_gradient
g! = (x, out) -> diffeq_sen_l2!(out, prob.f, prob.u0, prob.tspan, x, loss.t,
loss.data, alg)
end

if flsa_gradient || adjsa_gradient
cost_function2 = function (p, grad)
if length(grad) > 0
g!(p, grad)
end
cost_function(p)
end
else
cost_function2 = nothing
end
DiffEqObjective(cost_function, cost_function2)
OptimizationFunction(cost_function)
end
17 changes: 0 additions & 17 deletions src/build_lsoptim_objective.jl

This file was deleted.

15 changes: 0 additions & 15 deletions src/lm_fit.jl

This file was deleted.

16 changes: 2 additions & 14 deletions src/multiple_shooting_objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,10 @@ end
function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss,
regularization = nothing; priors = nothing,
discontinuity_weight = 1.0,
verbose_opt = false, verbose_steps = 100,
prob_generator = STANDARD_MS_PROB_GENERATOR,
kwargs...)
if verbose_opt
count = 0 # keep track of # function evaluations
end

cost_function = function (p)
cost_function = function (p, nothing)
t0, tf = prob.tspan
P, N = length(prob.p), length(prob.u0)
K = Int((length(p) - P) / N)
Expand Down Expand Up @@ -75,16 +71,8 @@ function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss,
(sol[k][1] - sol[k - 1][end]) .^ 2)
end
end
if verbose_opt
count::Int += 1
if mod(count, verbose_steps) == 0
println("Iteration: $count")
println("Current Cost: $loss_val")
println("Parameters: $p")
end
end
loss_val
end

DiffEqObjective(cost_function, nothing)
return OptimizationFunction(cost_function)
end
59 changes: 14 additions & 45 deletions src/two_stage_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ struct TwoStageCost{F, D} <: Function
estimated_derivative::D
end

(f::TwoStageCost)(p) = f.cost_function(p)
(f::TwoStageCost)(p, _ = nothing) = f.cost_function(p, _)

decide_kernel(kernel::CollocationKernel) = kernel
function decide_kernel(kernel::Symbol)
Expand Down Expand Up @@ -69,62 +69,31 @@ function construct_estimated_solution_and_derivative!(data, kernel, tpoints)
estimated_solution = reduce(hcat, transpose.(last.(x)))
estimated_derivative, estimated_solution
end
function construct_iip_cost_function(f, du, preview_est_sol, preview_est_deriv, tpoints)
function (p)
_du = PreallocationTools.get_tmp(du, p)
vecdu = vec(_du)
cost = zero(first(p))
for i in 1:length(preview_est_sol)
est_sol = preview_est_sol[i]
f(_du, est_sol, p, tpoints[i])
vecdu .= vec(preview_est_deriv[i]) .- vec(_du)
cost += sum(abs2, vecdu)
end
cost
end
end

function construct_oop_cost_function(f, du, preview_est_sol, preview_est_deriv, tpoints)
function (p)
cost = zero(first(p))
for i in 1:length(preview_est_sol)
est_sol = preview_est_sol[i]
_du = f(est_sol, p, tpoints[i])
cost += sum(abs2, vec(preview_est_deriv[i]) .- vec(_du))
end
cost
end
end

get_chunksize(cs) = cs
get_chunksize(cs::Type{Val{CS}}) where {CS} = CS

function two_stage_method(prob::DiffEqBase.DEProblem, tpoints, data;
kernel = EpanechnikovKernel(),
loss_func = L2Loss, mpg_autodiff = false,
verbose = false, verbose_steps = 100,
autodiff_chunk = length(prob.p))
function two_stage_objective(prob::DiffEqBase.DEProblem, tpoints, data;
kernel = EpanechnikovKernel())
f = prob.f
kernel_function = decide_kernel(kernel)
estimated_derivative, estimated_solution = construct_estimated_solution_and_derivative!(data,
kernel_function,
tpoints)

# Step - 2

du = PreallocationTools.dualcache(similar(prob.u0), autodiff_chunk)
preview_est_sol = [@view estimated_solution[:, i]
for i in 1:size(estimated_solution, 2)]
preview_est_deriv = [@view estimated_derivative[:, i]
for i in 1:size(estimated_solution, 2)]
if DiffEqBase.isinplace(prob)
cost_function = construct_iip_cost_function(f, du, preview_est_sol,
preview_est_deriv, tpoints)
else
cost_function = construct_oop_cost_function(f, du, preview_est_sol,
preview_est_deriv, tpoints)

function cost_function(p, nothing)
cost = zero(first(p))
for i in 1:length(preview_est_sol)
est_sol = preview_est_sol[i]
_du = f(est_sol, p, tpoints[i])
cost += sum(abs2, vec(preview_est_deriv[i]) .- vec(_du))
end
cost
end

return TwoStageCost(cost_function, estimated_solution,
estimated_derivative)
return OptimizationFunction(TwoStageCost(cost_function, estimated_solution,
estimated_derivative))
end

0 comments on commit 66594b3

Please sign in to comment.