Skip to content

Commit

Permalink
add nlfunc to ODEFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Sep 25, 2024
1 parent 06864fd commit 45f4520
Showing 1 changed file with 47 additions and 16 deletions.
63 changes: 47 additions & 16 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ the usage of `f`. These include:
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
internally computed on demand when required. The cost of this operation is highly dependent
on the sparsity pattern.
- `nlfunc`: a `NonlinearFunction`
- `nl_state_compres`: maps u->nlfunc_u

Check warning on line 293 in src/scimlfunctions.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"compres" should be "compress" or "compares".
- `nl_state_decompres`: maps nlfunc_u->u

Check warning on line 294 in src/scimlfunctions.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"decompres" should be "decompress".
## iip: In-Place vs Out-Of-Place
Expand Down Expand Up @@ -401,8 +404,8 @@ automatically symbolically generating the Jacobian and more from the
numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV,
SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
O, TCV, SYS, IProb, IProbMap,
NLF<:Union{Nothing, NonlinearFunction}, NLSC, NLISC} <: AbstractODEFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
Expand All @@ -421,6 +424,9 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
sys::SYS
initializeprob::IProb
initializeprobmap::IProbMap
nlfunc::NLF
nl_state_compres::NLSC

Check warning on line 428 in src/scimlfunctions.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"compres" should be "compress" or "compares".
nl_state_decompres::NLISC

Check warning on line 429 in src/scimlfunctions.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"decompres" should be "decompress".
end

@doc doc"""
Expand Down Expand Up @@ -517,8 +523,8 @@ information on generating the SplitFunction from this symbolic engine.
"""
struct SplitFunction{
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
TPJ, O,
TCV, SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
TPJ, O, TCV, SYS, IProb, IProbMap,
NLF<:Union{Nothing, NonlinearFunction}, NLSC, NLISC} <: AbstractODEFunction{iip}
f1::F1
f2::F2
mass_matrix::TMM
Expand All @@ -538,6 +544,9 @@ struct SplitFunction{
sys::SYS
initializeprob::IProb
initializeprobmap::IProbMap
nlfunc::NLF
nl_state_compres::NLSC

Check warning on line 548 in src/scimlfunctions.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"compres" should be "compress" or "compares".
nl_state_decompres::NLISC

Check warning on line 549 in src/scimlfunctions.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"decompres" should be "decompress".
end

@doc doc"""
Expand Down Expand Up @@ -2416,6 +2425,9 @@ function ODEFunction{iip, specialize}(f;
sys = __has_sys(f) ? f.sys : nothing,
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing
nlfunc = __has_nlfunc(f) ? f.nlfunc : nothing
nl_state_compres = __has_nl_state_compres(f) ? f.nl_state_compres : identity

Check warning on line 2429 in src/scimlfunctions.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"compres" should be "compress" or "compares".

Check warning on line 2429 in src/scimlfunctions.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"compres" should be "compress" or "compares".

Check warning on line 2429 in src/scimlfunctions.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"compres" should be "compress" or "compares".
nl_state_decompres = __has_nl_state_decompres(f) ? f.nl_state_decompres : identity

Check warning on line 2430 in src/scimlfunctions.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"decompres" should be "decompress".
) where {iip,
specialize
}
Expand Down Expand Up @@ -2471,12 +2483,13 @@ function ODEFunction{iip, specialize}(f;
Any, Any, Any, Any,
Any, Any, Any, typeof(jac_prototype),
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
Any,typeof(_colorvec),
typeof(sys), Any, Any,
Union{Nothing, NonlinearFunction}, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, initializeprobmap)
observed, _colorvec, sys, initializeprob, initializeprobmap,
nlfunc, nl_state_compres, nl_state_decompres)
elseif specialize === false
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2486,10 +2499,14 @@ function ODEFunction{iip, specialize}(f;
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initializeprob),
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(initializeprobmap),
typeof(nlfunc),
typeof(nl_state_compres),
typeof(nl_state_decompres)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, initializeprobmap)
observed, _colorvec, sys, initializeprob, initializeprobmap,
nlfunc, nl_state_compres, nl_state_decompres)
else
ODEFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2499,10 +2516,14 @@ function ODEFunction{iip, specialize}(f;
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initializeprob),
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(initializeprobmap),
typeof(nlfunc),
typeof(nl_state_compres),
typeof(nl_state_decompres)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, initializeprobmap)
observed, _colorvec, sys, initializeprob, initializeprobmap,
nlfunc, nl_state_compres, nl_state_decompres)
end
end

Expand All @@ -2519,10 +2540,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
Any, Any, Any, Any, typeof(f.jac_prototype),
typeof(f.sparsity), Any, Any, Any,
Any, typeof(f.colorvec),
typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
typeof(f.sys), Any, Any
Union{Nothing, NonlinearFunction}, Any, Any
}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap)
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap,
f.nlfunc, f.nl_state_compres, f.nl_state_decompres)
else
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
typeof(f.analytic), typeof(f.tgrad),
Expand All @@ -2531,11 +2555,15 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
typeof(f.paramjac),
typeof(f.observed), typeof(f.colorvec),
typeof(f.sys), typeof(f.initializeprob),
typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
typeof(f.initializeprobmap),
typeof(nlfunc),
typeof(nl_state_compres),
typeof(nl_state_decompres)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob,
f.initializeprobmap)
f.initializeprobmap,
f.nlfunc, f.nl_state_compres, f.nl_state_decompres)
end
end

Expand Down Expand Up @@ -4336,6 +4364,9 @@ __has_analytic_full(f) = isdefined(f, :analytic_full)
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
__has_initializeprob(f) = isdefined(f, :initializeprob)
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
__has_nl_state_compres(f) = isdefined(f, :nl_state_compres)
__has_nl_state_decompres(f) = isdefined(f, :nl_state_decompres)


# compatibility
has_invW(f::AbstractSciMLFunction) = false
Expand Down

0 comments on commit 45f4520

Please sign in to comment.