Skip to content

Commit

Permalink
Define prepare_alg dispatch for OrdinaryDiffEq AD algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Aug 11, 2021
1 parent f47db7b commit 69327ee
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,19 @@ get_chunksize(alg::DAEAlgorithm{CS,AD}) where {CS,AD} = Val(CS)
get_chunksize(alg::ExponentialAlgorithm) = Val(alg.chunksize)
# get_chunksize(alg::CompositeAlgorithm) = get_chunksize(alg.algs[alg.current_alg])

function DiffEqBase.prepare_alg(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorithm{0,AD,FDT},
OrdinaryDiffEqImplicitAlgorithm{0,AD,FDT},
DAEAlgorithm{0,AD,FDT}},u0,p,prob) where {AD,FDT}
# If chunksize is zero, pick chunksize right at the start of solve and
# then do function barrier to infer the full solve
x = if prob.f.colorvec === nothing
length(u0)
else
maximum(prob.f.color)
end
remake(alg,chunk_size=ForwardDiff.pickchunksize(x))
end

alg_autodiff(alg::OrdinaryDiffEqAlgorithm) = error("This algorithm does not have an autodifferentiation option defined.")
alg_autodiff(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD}) where {CS,AD} = AD
alg_autodiff(alg::DAEAlgorithm{CS,AD}) where {CS,AD} = AD
Expand Down

0 comments on commit 69327ee

Please sign in to comment.