diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index dcff371ecf..d8f026021b 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -726,6 +726,32 @@ end end end +""" +$(TYPEDSIGNATURES) + +Remake the system `sys` with every field replaced by the value in `kwargs`. + +```julia +@variables x(t) y(t) +@named sysx = ODESystem([x ~ 0], t) +sysy = remake(sysx, eqs = [y ~ 0]) +``` + +WARNING: intended for internal use; does not perform any sanity checks. +""" +# TODO: optionally re-call constructor to sanity check? +# TODO: use SciMLBase's generic remake()? doesn't work out of the box, though +# TODO: should register new tag? move Threads.atomic_add!(SYSTEM_COUNT, ...) to a separate function? +# TODO: recreate the struct once with all new fields, instead of once for every field +function remake(sys::AbstractSystem; kwargs...) + for (field, value) in kwargs + # like `Setfield.@set! sys.field = value`, but with `field` replaced by an arbitrarily named symbol + # (e.g. https://discourse.julialang.org/t/accessing-struct-via-symbol/58809/4) + sys = Setfield.set(sys, Setfield.PropertyLens{field}(), value) + end + return sys +end + rename(x, name) = @set x.name = name function Base.propertynames(sys::AbstractSystem; private = false) @@ -2595,30 +2621,19 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol = nam end end - # collect fields common to all system types - eqs = union(get_eqs(basesys), get_eqs(sys)) - sts = union(get_unknowns(basesys), get_unknowns(sys)) - ps = union(get_ps(basesys), get_ps(sys)) - dep_ps = union_nothing(parameter_dependencies(basesys), parameter_dependencies(sys)) - obs = union(get_observed(basesys), get_observed(sys)) - cevs = union(get_continuous_events(basesys), get_continuous_events(sys)) - devs = union(get_discrete_events(basesys), get_discrete_events(sys)) - defs = merge(get_defaults(basesys), get_defaults(sys)) # prefer `sys` - meta = union_nothing(get_metadata(basesys), get_metadata(sys)) - syss = union(get_systems(basesys), get_systems(sys)) - args = length(ivs) == 0 ? (eqs, sts, ps) : (eqs, ivs[1], sts, ps) - kwargs = (parameter_dependencies = dep_ps, observed = obs, continuous_events = cevs, - discrete_events = devs, defaults = defs, systems = syss, metadata = meta, - name = name, gui_metadata = gui_metadata) + # gracefully fields, being nice if one or both are nothing + ext(x, y) = y # prefer sys (y) over basesys (x) + ext(x, y::Nothing) = x + ext(x::Nothing, y) = ext(y, x) + ext(x::Nothing, y::Nothing) = nothing + ext(x::AbstractDict, y::AbstractDict) = merge(x, y) # prefer sys (y) over basesys (x) + ext(x::AbstractVector, y::AbstractVector) = union(x, y) + ext(field::Symbol) = ext(getfield(sys, field), getfield(basesys, field)) # TODO: use get_...? - # collect fields specific to some system types - if basesys isa ODESystem - ieqs = union(get_initialization_eqs(basesys), get_initialization_eqs(sys)) - guesses = merge(get_guesses(basesys), get_guesses(sys)) # prefer `sys` - kwargs = merge(kwargs, (initialization_eqs = ieqs, guesses = guesses)) - end - - return T(args...; kwargs...) + # both systems were individually sanity-checked upon construction, + # so it should be fine to merge their fields without further checking + kwargs = Dict(field => ext(field) for field in fieldnames(T)) + return remake(sys; kwargs..., name, gui_metadata) end function Base.:(&)(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol = nameof(sys)) @@ -2634,10 +2649,9 @@ system's name. See also [`extend`](@ref). """ function compose(sys::AbstractSystem, systems::AbstractArray; name = nameof(sys)) - nsys = length(systems) - nsys == 0 && return sys - @set! sys.name = name - @set! sys.systems = [get_systems(sys); systems] + if !isempty(systems) + sys = remake(sys; name = name, systems = [get_systems(sys); systems]) + end return sys end function compose(syss...; name = nameof(first(syss))) diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index fb4fedc920..ac767f8b97 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -135,8 +135,7 @@ function alias_elimination!(state::TearingState; kwargs...) state.structure.eq_to_diff = new_eq_to_diff state.structure.var_to_diff = new_var_to_diff - sys = state.sys - @set! sys.eqs = eqs + sys = remake(state.sys; eqs) state.sys = sys return invalidate_cache!(sys), mm end diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 227b4624bf..0f0b300f99 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -485,9 +485,8 @@ function expand_connections(sys::AbstractSystem, find = nothing, replace = nothi ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets) _sys = expand_instream(instream_csets, sys; debug = debug, tol = tol) sys = flatten(sys, true) - @set! sys.eqs = [equations(_sys); ceqs] d_defs = domain_defaults(sys, domain_csets) - @set! sys.defaults = merge(get_defaults(sys), d_defs) + remake(sys; eqs = [equations(_sys); ceqs], defaults = merge(get_defaults(sys), d_defs)) end function unnamespace(root, namespace) diff --git a/src/systems/diffeqs/first_order_transform.jl b/src/systems/diffeqs/first_order_transform.jl index b1a51f3346..14ff5c5239 100644 --- a/src/systems/diffeqs/first_order_transform.jl +++ b/src/systems/diffeqs/first_order_transform.jl @@ -7,9 +7,7 @@ form by defining new variables which represent the N-1 derivatives. function ode_order_lowering(sys::ODESystem) iv = get_iv(sys) eqs_lowered, new_vars = ode_order_lowering(equations(sys), iv, unknowns(sys)) - @set! sys.eqs = eqs_lowered - @set! sys.unknowns = new_vars - return sys + return remake(sys; eqs = eqs_lowered, unknowns = new_vars) end function dae_order_lowering(sys::ODESystem) diff --git a/src/utils.jl b/src/utils.jl index 9aa893e321..52a09bb449 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,14 +1,3 @@ -""" - union_nothing(x::Union{T1, Nothing}, y::Union{T2, Nothing}) where {T1, T2} - -Unite x and y gracefully when they could be nothing. If neither is nothing, x and y are united normally. If one is nothing, the other is returned unmodified. If both are nothing, nothing is returned. -""" -function union_nothing(x::Union{T1, Nothing}, y::Union{T2, Nothing}) where {T1, T2} - isnothing(x) && return y # y can be nothing or something - isnothing(y) && return x # x can be nothing or something - return union(x, y) # both x and y are something and can be united normally -end - get_iv(D::Differential) = D.x function make_operation(@nospecialize(op), args)