Skip to content

Commit

Permalink
Merge pull request #162 from SciML/prealloc
Browse files Browse the repository at this point in the history
Directly use PreallocationTools.jl
  • Loading branch information
ChrisRackauckas authored Nov 4, 2021
2 parents d812632 + 602cd0d commit fb1bd24
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LsqFit = "2fda8390-95c7-5789-9bda-21331edee243"
PenaltyFunctions = "06bb1623-fdd5-5ca2-a01c-88eae3ea319e"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

Expand All @@ -25,6 +26,7 @@ Distributions = "0.21, 0.22, 0.23, 0.24"
ForwardDiff = "0.10"
LsqFit = "0.8, 0.9, 0.10, 0.11, 0.12"
PenaltyFunctions = "0.1, 0.2"
PreallocationTools = "0.2"
RecursiveArrayTools = "1.0, 2.0"
SciMLBase = "1"
julia = "1"
Expand Down
1 change: 1 addition & 0 deletions src/DiffEqParamEstim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using DiffEqBase, LsqFit, PenaltyFunctions,
LinearAlgebra, DiffEqSensitivity, Dierckx,
SciMLBase

import PreallocationTools
STANDARD_PROB_GENERATOR(prob,p) = remake(prob;u0=eltype(p).(prob.u0),p=p)
STANDARD_PROB_GENERATOR(prob::EnsembleProblem,p) = EnsembleProblem(
remake(prob.prob;u0=eltype(p).(prob.prob.u0),p=p),
Expand Down
7 changes: 4 additions & 3 deletions src/two_stage_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function construct_estimated_solution_and_derivative!(data,kernel,tpoints)
end
function construct_iip_cost_function(f,du,preview_est_sol,preview_est_deriv,tpoints)
function (p)
_du = DiffEqBase.get_tmp(du,p)
_du = PreallocationTools.get_tmp(du,p)
vecdu = vec(_du)
cost = zero(first(p))
for i in 1:length(preview_est_sol)
Expand All @@ -100,19 +100,20 @@ function construct_oop_cost_function(f,du,preview_est_sol,preview_est_deriv,tpoi
end
end

get_chunksize(cs) = cs
get_chunksize(cs::Type{Val{CS}}) where CS = CS

function two_stage_method(prob::DiffEqBase.DEProblem,tpoints,data;kernel= EpanechnikovKernel(),
loss_func = L2Loss,mpg_autodiff = false,
verbose = false,verbose_steps = 100,
autodiff_chunk = Val{ForwardDiff.pickchunksize(length(prob.p))})
autodiff_chunk = length(prob.p))
f = prob.f
kernel_function = decide_kernel(kernel)
estimated_derivative,estimated_solution = construct_estimated_solution_and_derivative!(data,kernel_function,tpoints)

# Step - 2

du = DiffEqBase.dualcache(similar(prob.u0), autodiff_chunk)
du = PreallocationTools.dualcache(similar(prob.u0), autodiff_chunk)
preview_est_sol = [@view estimated_solution[:,i] for i in 1:size(estimated_solution,2)]
preview_est_deriv = [@view estimated_derivative[:,i] for i in 1:size(estimated_solution,2)]
if DiffEqBase.isinplace(prob)
Expand Down

0 comments on commit fb1bd24

Please sign in to comment.