Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor code to make it Optimization.jl compatible #196

Merged
merged 8 commits into from
Dec 18, 2022
29 changes: 12 additions & 17 deletions src/build_loss_objective.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export DiffEqObjective, build_loss_objective

struct DiffEqObjective{F, F2} <: Function
struct DiffEqObjective{F <: Function, F2 <: Union{Function, Nothing}}
cost_function::F
cost_function2::F2
end
Expand All @@ -23,17 +23,13 @@ function diffeq_sen_l2!(res, df, u0, tspan, p, t, data, alg; kwargs...)
end

(f::DiffEqObjective)(x) = f.cost_function(x)
(f::DiffEqObjective)(x, y) = f.cost_function2(x, y)

function build_loss_objective(prob::SciMLBase.AbstractSciMLProblem, alg, loss,
regularization = nothing, args...;
priors = nothing, mpg_autodiff = false,
priors = nothing, autodiff = false,
verbose_opt = false, verbose_steps = 100,
prob_generator = STANDARD_PROB_GENERATOR,
autodiff_prototype = mpg_autodiff ? zero(prob.p) : nothing,
autodiff_chunk = mpg_autodiff ?
ForwardDiff.Chunk(autodiff_prototype) :
nothing, flsa_gradient = false,
flsa_gradient = false,
adjsa_gradient = false,
kwargs...)
if verbose_opt
Expand Down Expand Up @@ -68,10 +64,7 @@ function build_loss_objective(prob::SciMLBase.AbstractSciMLProblem, alg, loss,
loss_val
end

if mpg_autodiff
gcfg = ForwardDiff.GradientConfig(cost_function, autodiff_prototype, autodiff_chunk)
g! = (x, out) -> ForwardDiff.gradient!(out, cost_function, x, gcfg)
elseif flsa_gradient
if flsa_gradient
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is also an autodiff option

if typeof(loss) <: L2Loss
function g!(x, out)
sol_, sens = diffeq_sen_full(prob.f, prob.u0, prob.tspan, x, loss.t, alg)
Expand All @@ -83,15 +76,17 @@ function build_loss_objective(prob::SciMLBase.AbstractSciMLProblem, alg, loss,
elseif adjsa_gradient
g! = (x, out) -> diffeq_sen_l2!(out, prob.f, prob.u0, prob.tspan, x, loss.t,
loss.data, alg)
else
g! = (x, out) -> Calculus.finite_difference!(cost_function, x, out, :central)
end

cost_function2 = function (p, grad)
if length(grad) > 0
g!(p, grad)
if flsa_gradient || adjsa_gradient
cost_function2 = function (p, grad)
if length(grad) > 0
g!(p, grad)
end
cost_function(p)
end
cost_function(p)
else
cost_function2 = nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cost_function2 can just in general be removed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah even I thought that but wasn't sure if it would be too breaking

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Go for broke: This should get a major release with this change anyways. Just make these functions that return objective functions into functions that return an OptimizationProblem or something of the sort.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay cool

end
DiffEqObjective(cost_function, cost_function2)
end
25 changes: 2 additions & 23 deletions src/multiple_shooting_objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,9 @@ end

function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss,
regularization = nothing; priors = nothing,
mpg_autodiff = false, discontinuity_weight = 1.0,
discontinuity_weight = 1.0,
verbose_opt = false, verbose_steps = 100,
prob_generator = STANDARD_MS_PROB_GENERATOR,
autodiff_prototype = mpg_autodiff ?
zeros(init_N_params) : nothing,
autodiff_chunk = mpg_autodiff ?
ForwardDiff.Chunk(autodiff_prototype) :
nothing,
kwargs...)
if verbose_opt
count = 0 # keep track of # function evaluations
Expand Down Expand Up @@ -91,21 +86,5 @@ function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss,
loss_val
end

if mpg_autodiff
gcfg = ForwardDiff.GradientConfig(cost_function, autodiff_prototype,
autodiff_chunk)
g! = (x, out) -> ForwardDiff.gradient!(out, cost_function, x, gcfg)
else
g! = (x, out) -> Calculus.finite_difference!(cost_function, x, out,
:central)
end

cost_function2 = function (p, grad)
if length(grad) > 0
g!(p, grad)
end
cost_function(p)
end

DiffEqObjective(cost_function, cost_function2)
DiffEqObjective(cost_function, nothing)
end
32 changes: 2 additions & 30 deletions src/two_stage_method.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
export TwoStageCost, two_stage_method

struct TwoStageCost{F, F2, D} <: Function
struct TwoStageCost{F, D} <: Function
cost_function::F
cost_function2::F2
estimated_solution::D
estimated_derivative::D
end

(f::TwoStageCost)(p) = f.cost_function(p)
(f::TwoStageCost)(p, g) = f.cost_function2(p, g)

decide_kernel(kernel::CollocationKernel) = kernel
function decide_kernel(kernel::Symbol)
Expand Down Expand Up @@ -127,32 +125,6 @@ function two_stage_method(prob::DiffEqBase.DEProblem, tpoints, data;
preview_est_deriv, tpoints)
end

if mpg_autodiff
gcfg = ForwardDiff.GradientConfig(cost_function, prob.p,
ForwardDiff.Chunk{get_chunksize(autodiff_chunk)}())
g! = (x, out) -> ForwardDiff.gradient!(out, cost_function, x, gcfg)
else
g! = (x, out) -> Calculus.finite_difference!(cost_function, x, out, :central)
end
if verbose
count = 0 # keep track of # function evaluations
end
cost_function2 = function (p, grad)
if length(grad) > 0
g!(p, grad)
end
loss_val = cost_function(p)
if verbose
count::Int += 1
if mod(count, verbose_steps) == 0
println("Iteration: $count")
println("Current Cost: $loss_val")
println("Parameters: $p")
end
end
loss_val
end

return TwoStageCost(cost_function, cost_function2, estimated_solution,
return TwoStageCost(cost_function, estimated_solution,
estimated_derivative)
end