Skip to content

Commit

Permalink
Remove AD parts from objective function creations
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Dec 12, 2022
1 parent cb2083b commit 53c4979
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 70 deletions.
29 changes: 12 additions & 17 deletions src/build_loss_objective.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export DiffEqObjective, build_loss_objective

struct DiffEqObjective{F, F2} <: Function
struct DiffEqObjective{F <: Function, F2 <: Union{Function, Nothing}}
cost_function::F
cost_function2::F2
end
Expand All @@ -23,17 +23,13 @@ function diffeq_sen_l2!(res, df, u0, tspan, p, t, data, alg; kwargs...)
end

(f::DiffEqObjective)(x) = f.cost_function(x)
(f::DiffEqObjective)(x, y) = f.cost_function2(x, y)

function build_loss_objective(prob::SciMLBase.AbstractSciMLProblem, alg, loss,
regularization = nothing, args...;
priors = nothing, mpg_autodiff = false,
priors = nothing, autodiff = false,
verbose_opt = false, verbose_steps = 100,
prob_generator = STANDARD_PROB_GENERATOR,
autodiff_prototype = mpg_autodiff ? zero(prob.p) : nothing,
autodiff_chunk = mpg_autodiff ?
ForwardDiff.Chunk(autodiff_prototype) :
nothing, flsa_gradient = false,
flsa_gradient = false,
adjsa_gradient = false,
kwargs...)
if verbose_opt
Expand Down Expand Up @@ -68,10 +64,7 @@ function build_loss_objective(prob::SciMLBase.AbstractSciMLProblem, alg, loss,
loss_val
end

if mpg_autodiff
gcfg = ForwardDiff.GradientConfig(cost_function, autodiff_prototype, autodiff_chunk)
g! = (x, out) -> ForwardDiff.gradient!(out, cost_function, x, gcfg)
elseif flsa_gradient
if flsa_gradient
if typeof(loss) <: L2Loss
function g!(x, out)
sol_, sens = diffeq_sen_full(prob.f, prob.u0, prob.tspan, x, loss.t, alg)
Expand All @@ -83,15 +76,17 @@ function build_loss_objective(prob::SciMLBase.AbstractSciMLProblem, alg, loss,
elseif adjsa_gradient
g! = (x, out) -> diffeq_sen_l2!(out, prob.f, prob.u0, prob.tspan, x, loss.t,
loss.data, alg)
else
g! = (x, out) -> Calculus.finite_difference!(cost_function, x, out, :central)
end

cost_function2 = function (p, grad)
if length(grad) > 0
g!(p, grad)
if flsa_gradient || adjsa_gradient
cost_function2 = function (p, grad)
if length(grad) > 0
g!(p, grad)
end
cost_function(p)
end
cost_function(p)
else
cost_function2 = nothing
end
DiffEqObjective(cost_function, cost_function2)
end
25 changes: 2 additions & 23 deletions src/multiple_shooting_objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,9 @@ end

function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss,
regularization = nothing; priors = nothing,
mpg_autodiff = false, discontinuity_weight = 1.0,
discontinuity_weight = 1.0,
verbose_opt = false, verbose_steps = 100,
prob_generator = STANDARD_MS_PROB_GENERATOR,
autodiff_prototype = mpg_autodiff ?
zeros(init_N_params) : nothing,
autodiff_chunk = mpg_autodiff ?
ForwardDiff.Chunk(autodiff_prototype) :
nothing,
kwargs...)
if verbose_opt
count = 0 # keep track of # function evaluations
Expand Down Expand Up @@ -91,21 +86,5 @@ function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss,
loss_val
end

if mpg_autodiff
gcfg = ForwardDiff.GradientConfig(cost_function, autodiff_prototype,
autodiff_chunk)
g! = (x, out) -> ForwardDiff.gradient!(out, cost_function, x, gcfg)
else
g! = (x, out) -> Calculus.finite_difference!(cost_function, x, out,
:central)
end

cost_function2 = function (p, grad)
if length(grad) > 0
g!(p, grad)
end
cost_function(p)
end

DiffEqObjective(cost_function, cost_function2)
DiffEqObjective(cost_function, nothing)
end
32 changes: 2 additions & 30 deletions src/two_stage_method.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
export TwoStageCost, two_stage_method

struct TwoStageCost{F, F2, D} <: Function
struct TwoStageCost{F, D} <: Function
cost_function::F
cost_function2::F2
estimated_solution::D
estimated_derivative::D
end

(f::TwoStageCost)(p) = f.cost_function(p)
(f::TwoStageCost)(p, g) = f.cost_function2(p, g)

decide_kernel(kernel::CollocationKernel) = kernel
function decide_kernel(kernel::Symbol)
Expand Down Expand Up @@ -127,32 +125,6 @@ function two_stage_method(prob::DiffEqBase.DEProblem, tpoints, data;
preview_est_deriv, tpoints)
end

if mpg_autodiff
gcfg = ForwardDiff.GradientConfig(cost_function, prob.p,
ForwardDiff.Chunk{get_chunksize(autodiff_chunk)}())
g! = (x, out) -> ForwardDiff.gradient!(out, cost_function, x, gcfg)
else
g! = (x, out) -> Calculus.finite_difference!(cost_function, x, out, :central)
end
if verbose
count = 0 # keep track of # function evaluations
end
cost_function2 = function (p, grad)
if length(grad) > 0
g!(p, grad)
end
loss_val = cost_function(p)
if verbose
count::Int += 1
if mod(count, verbose_steps) == 0
println("Iteration: $count")
println("Current Cost: $loss_val")
println("Parameters: $p")
end
end
loss_val
end

return TwoStageCost(cost_function, cost_function2, estimated_solution,
return TwoStageCost(cost_function, estimated_solution,
estimated_derivative)
end

0 comments on commit 53c4979

Please sign in to comment.