Skip to content
Draft
2 changes: 1 addition & 1 deletion docs/src/API/codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ ModelingToolkit.calculate_A_b
All code generation eventually calls `build_function_wrapper`.

```@docs
build_function_wrapper
ModelingToolkit.build_function_wrapper
```
6 changes: 4 additions & 2 deletions src/systems/diffeqs/basic_transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1037,9 +1037,11 @@ function respecialize(sys::AbstractSystem, mapping; all = false)
"""

if iscall(k)
op = operation(k)
op = operation(k)::BasicSymbolic
@assert !iscall(op)
op = SymbolicUtils.Sym{SymbolicUtils.FnType{Tuple{Any}, T}}(nameof(op))
args = arguments(k)
new_p = SymbolicUtils.term(op, args...; type = T)
new_p = op(args...)
else
new_p = SymbolicUtils.Sym{T}(getname(k))
end
Expand Down
10 changes: 8 additions & 2 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -763,8 +763,14 @@ function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = tru
oldbuf.discrete, newbuf.discrete)
@set! newbuf.constant = narrow_buffer_type_and_fallback_undefs.(
oldbuf.constant, newbuf.constant)
@set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.(
oldbuf.nonnumeric, newbuf.nonnumeric)
for (oldv, newv) in zip(oldbuf.nonnumeric, newbuf.nonnumeric)
for i in eachindex(oldv)
isassigned(newv, i) && continue
newv[i] = oldv[i]
end
end
@set! newbuf.nonnumeric = Tuple(
typeof(oldv)(newv) for (oldv, newv) in zip(oldbuf.nonnumeric, newbuf.nonnumeric))
if !ArrayInterface.ismutable(oldbuf)
@set! newbuf.tunable = similar_type(oldbuf.tunable, eltype(newbuf.tunable))(newbuf.tunable)
@set! newbuf.initials = similar_type(oldbuf.initials, eltype(newbuf.initials))(newbuf.initials)
Expand Down
73 changes: 46 additions & 27 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -701,9 +701,10 @@ function.
Note that the getter ONLY works for problem-like objects, since it generates an observed
function. It does NOT work for solutions.
"""
Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector)
Base.@nospecializeinfer function concrete_getu(indp, syms; eval_expression, eval_module)
@nospecialize
obsfn = build_explicit_observed_function(indp, syms; wrap_delays = false)
obsfn = build_explicit_observed_function(
indp, syms; wrap_delays = false, eval_expression, eval_module)
return ObservedWrapper{is_time_dependent(indp)}(obsfn)
end

Expand Down Expand Up @@ -757,7 +758,8 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
"""
function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem;
initials = false, unwrap_initials = false, p_constructor = identity)
initials = false, unwrap_initials = false, p_constructor = identity,
eval_expression = false, eval_module = @__MODULE__)
_p_constructor = p_constructor
p_constructor = PConstructorApplicator(p_constructor)
# if we call `getu` on this (and it were able to handle empty tuples) we get the
Expand All @@ -773,7 +775,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
tunable_getter = if isempty(tunable_syms)
Returns(SizedVector{0, Float64}())
else
p_constructor ∘ concrete_getu(srcsys, tunable_syms)
p_constructor ∘ concrete_getu(srcsys, tunable_syms; eval_expression, eval_module)
end
initials_getter = if initials && !isempty(syms[2])
initsyms = Vector{Any}(syms[2])
Expand All @@ -792,7 +794,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
end
end
end
p_constructor ∘ concrete_getu(srcsys, initsyms)
p_constructor ∘ concrete_getu(srcsys, initsyms; eval_expression, eval_module)
else
Returns(SizedVector{0, Float64}())
end
Expand All @@ -810,7 +812,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
# tuple of `BlockedArray`s
Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes) ∘
Base.Fix1(broadcast, p_constructor) ∘
getu(srcsys, syms[3])
concrete_getu(srcsys, syms[3]; eval_expression, eval_module)
end
const_getter = if syms[4] == ()
Returns(())
Expand All @@ -826,7 +828,8 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
end)
# nonnumerics retain the assigned buffer type without narrowing
Base.Fix1(broadcast, _p_constructor) ∘
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ getu(srcsys, syms[5])
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘
concrete_getu(srcsys, syms[5]; eval_expression, eval_module)
end
getters = (
tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter)
Expand All @@ -853,14 +856,19 @@ Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of `
with values from `srcsys`.
"""
function ReconstructInitializeprob(
srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity)
srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity,
eval_expression = false, eval_module = @__MODULE__)
@assert is_initializesystem(dstsys)
ugetter = u0_constructor ∘ getu(srcsys, unknowns(dstsys))
ugetter = u0_constructor ∘
concrete_getu(srcsys, unknowns(dstsys); eval_expression, eval_module)
if is_split(dstsys)
pgetter = get_mtkparameters_reconstructor(srcsys, dstsys; p_constructor)
pgetter = get_mtkparameters_reconstructor(
srcsys, dstsys; p_constructor, eval_expression, eval_module)
else
syms = parameters(dstsys)
pgetter = let inner = concrete_getu(srcsys, syms), p_constructor = p_constructor
pgetter = let inner = concrete_getu(srcsys, syms; eval_expression, eval_module),
p_constructor = p_constructor

function _getter2(valp, initprob)
p_constructor(inner(valp))
end
Expand Down Expand Up @@ -924,18 +932,20 @@ Given `sys` and its corresponding initialization system `initsys`, return the
`initializeprobpmap` function in `OverrideInitData` for the systems.
"""
function construct_initializeprobpmap(
sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity)
sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity, eval_expression, eval_module)
@assert is_initializesystem(initsys)
if is_split(sys)
return let getter = get_mtkparameters_reconstructor(
initsys, sys; initials = true, unwrap_initials = true, p_constructor)
initsys, sys; initials = true, unwrap_initials = true, p_constructor,
eval_expression, eval_module)
function initprobpmap_split(prob, initsol)
getter(initsol, prob)
end
end
else
return let getter = getu(initsys, parameters(sys; initial_parameters = true)),
p_constructor = p_constructor
return let getter = concrete_getu(
initsys, parameters(sys; initial_parameters = true);
eval_expression, eval_module), p_constructor = p_constructor

function initprobpmap_nosplit(prob, initsol)
return p_constructor(getter(initsol))
Expand Down Expand Up @@ -1039,14 +1049,14 @@ struct GetUpdatedU0{GG, GIU}
get_initial_unknowns::GIU
end

function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict)
function GetUpdatedU0(sys::AbstractSystem, initprob::SciMLBase.AbstractNonlinearProblem, op::AbstractDict)
dvs = unknowns(sys)
eqs = equations(sys)
guessvars = trues(length(dvs))
for (i, var) in enumerate(dvs)
guessvars[i] = !isequal(get(op, var, nothing), Initial(var))
end
get_guessvars = getu(initsys, dvs[guessvars])
get_guessvars = getu(initprob, dvs[guessvars])
get_initial_unknowns = getu(sys, Initial.(dvs))
return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns)
end
Expand Down Expand Up @@ -1108,7 +1118,7 @@ function maybe_build_initialization_problem(
guesses, missing_unknowns; implicit_dae = false,
time_dependent_init = is_time_dependent(sys), u0_constructor = identity,
p_constructor = identity, floatT = Float64, initialization_eqs = [],
use_scc = true, kwargs...)
use_scc = true, eval_expression = false, eval_module = @__MODULE__, kwargs...)
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))

if t === nothing && is_time_dependent(sys)
Expand All @@ -1117,7 +1127,7 @@ function maybe_build_initialization_problem(

initializeprob = ModelingToolkit.InitializationProblem{iip}(
sys, t, op; guesses, time_dependent_init, initialization_eqs,
use_scc, u0_constructor, p_constructor, kwargs...)
use_scc, u0_constructor, p_constructor, eval_expression, eval_module, kwargs...)
if state_values(initializeprob) !== nothing
_u0 = state_values(initializeprob)
if ArrayInterface.ismutable(_u0)
Expand Down Expand Up @@ -1145,15 +1155,16 @@ function maybe_build_initialization_problem(
initializeprob = remake(initializeprob; p = initp)

get_initial_unknowns = if time_dependent_init
GetUpdatedU0(sys, initializeprob.f.sys, op)
GetUpdatedU0(sys, initializeprob, op)
else
nothing
end
meta = InitializationMetadata(
copy(op), copy(guesses), Vector{Equation}(initialization_eqs),
use_scc, time_dependent_init,
ReconstructInitializeprob(
sys, initializeprob.f.sys; u0_constructor, p_constructor),
sys, initializeprob.f.sys; u0_constructor,
p_constructor, eval_expression, eval_module),
get_initial_unknowns, SetInitialUnknowns(sys))

if time_dependent_init
Expand All @@ -1172,10 +1183,9 @@ function maybe_build_initialization_problem(
initializeprobpmap = nothing
else
initializeprobpmap = construct_initializeprobpmap(
sys, initializeprob.f.sys; p_constructor)
sys, initializeprob.f.sys; p_constructor, eval_expression, eval_module)
end

reqd_syms = parameter_symbols(initializeprob)
# we still want the `initialization_data` because it helps with `remake`
if initializeprobmap === nothing && initializeprobpmap === nothing
update_initializeprob! = nothing
Expand All @@ -1186,7 +1196,9 @@ function maybe_build_initialization_problem(
filter!(punknowns) do p
is_parameter_solvable(p, op, defs, guesses) && get(op, p, missing) === missing
end
pvals = getu(initializeprob, punknowns)(initializeprob)
# See comment below for why `getu` is not used here.
_pgetter = build_explicit_observed_function(initializeprob.f.sys, punknowns)
pvals = _pgetter(state_values(initializeprob), parameter_values(initializeprob))
for (p, pval) in zip(punknowns, pvals)
p = unwrap(p)
op[p] = pval
Expand All @@ -1198,7 +1210,13 @@ function maybe_build_initialization_problem(
end

if time_dependent_init
uvals = getu(initializeprob, collect(missing_unknowns))(initializeprob)
# We can't use `getu` here because that goes to `SII.observed`, which goes to
# `ObservedFunctionCache` which uses `eval_expression` and `eval_module`. If
# `eval_expression == true`, this then runs into world-age issues. Building an
# RGF here is fine since it is always discarded. We can't use `eval_module` for
# the RGF since the user may not have run RGF's init.
_ugetter = build_explicit_observed_function(initializeprob.f.sys, collect(missing_unknowns))
uvals = _ugetter(state_values(initializeprob), parameter_values(initializeprob))
for (v, val) in zip(missing_unknowns, uvals)
op[v] = val
end
Expand Down Expand Up @@ -1461,7 +1479,7 @@ function process_SciMLProblem(
if is_time_dependent(sys) && t0 === nothing
t0 = zero(floatT)
end
initialization_data = SciMLBase.remake_initialization_data(
initialization_data = @invokelatest SciMLBase.remake_initialization_data(
sys, kwargs, u0, t0, p, u0, p)
kwargs = merge(kwargs, (; initialization_data))
end
Expand Down Expand Up @@ -1773,7 +1791,8 @@ Construct SciMLProblem `T` with positional arguments `args` and keywords `kwargs
"""
function maybe_codegen_scimlproblem(::Type{Val{false}}, T, args::NamedTuple; kwargs...)
# Call `remake` so it runs initialization if it is trivial
remake(T(args...; kwargs...))
# Use `@invokelatest` to avoid world-age issues with `eval_expression = true`
@invokelatest remake(T(args...; kwargs...))
end

"""
Expand Down
11 changes: 6 additions & 5 deletions test/basic_transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,12 @@ foofn(x) = 4

@testset "`respecialize`" begin
@parameters p::AbstractFoo p2(t)::AbstractFoo = p q[1:2]::AbstractFoo r
rp,
rp2 = let
only(@parameters p::Bar),
SymbolicUtils.term(operation(p2), arguments(p2)...; type = Baz)
end
rp = only(let p = nothing
@parameters p::Bar
end)
rp2 = only(let p2 = nothing
@parameters p2(t)::Baz
end)
@variables x(t) = 1.0
@named sys1 = System([D(x) ~ foofn(p) + foofn(p2) + x], t, [x], [p, p2, q, r])

Expand Down
13 changes: 12 additions & 1 deletion test/mtkparameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ ps = MTKParameters(
(BlockedArray([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [3, 3]),
BlockedArray(falses(1), [1, 0])),
(), (), ())
@test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, Vector{Bool}}
@test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, BitVector}
tsidx1 = 1
tsidx2 = 2
@test length(ps.discrete[1][Block(tsidx1)]) == 3
Expand All @@ -368,3 +368,14 @@ with_updated_parameter_timeseries_values(
sys, ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false])))
@test ps.discrete[1][Block(tsidx1)] == [10.0, 11.0, 12.0]
@test ps.discrete[2][Block(tsidx1)][] == false

@testset "Avoid specialization of nonnumeric parameters on `remake_buffer`" begin
@variables x(t)
@parameters p::Any
@named sys = System(D(x) ~ x, t, [x], [p])
sys = complete(sys)
ps = MTKParameters(sys, [p => 1.0])
@test ps.nonnumeric isa Tuple{Vector{Any}}
ps2 = remake_buffer(sys, ps, [p], [:a])
@test ps2.nonnumeric isa Tuple{Vector{Any}}
end
3 changes: 3 additions & 0 deletions test/precompile_test.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using ModelingToolkit
using OrdinaryDiffEqDefault

using Distributed

Expand Down Expand Up @@ -38,3 +39,5 @@ ODEPrecompileTest.f_eval_bad(u, p, 0.1)
@test parentmodule(typeof(ODEPrecompileTest.f_eval_good.f.f_oop)) ==
ODEPrecompileTest
@test ODEPrecompileTest.f_eval_good(u, p, 0.1) == [4, 0, -16]

@test_nowarn solve(ODEPrecompileTest.prob_eval)
20 changes: 20 additions & 0 deletions test/precompile_test/ODEPrecompileTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,24 @@ const f_eval_bad = system(; eval_expression = true, eval_module = @__MODULE__)
# Change the module the eval'd function is eval'd into to be the containing module,
# which should make it be in the package image
const f_eval_good = system(; eval_expression = true, eval_module = @__MODULE__)

function problem(; kwargs...)
# Define some variables
@independent_variables t
@parameters σ ρ β
@variables x(t) y(t) z(t)
D = Differential(t)

# Define a differential equation
eqs = [D(x) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
D(z) ~ x * y - β * z]

@named de = System(eqs, t)
de = complete(de)
return ODEProblem(de, [x => 1, y => 0, z => 0, σ => 10, ρ => 28, β => 8/3], (0.0, 5.0); kwargs...)
end

const prob_eval = problem(; eval_expression = true, eval_module = @__MODULE__)

end
Loading