Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update inference of variables and default differential from @equations macro #1175

Merged
merged 18 commits into from
Jan 17, 2025
23 changes: 23 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@
(at the time the release is made). If you need a dependency version increased,
please open an issue and we can update it and make a new Catalyst release once
testing against the newer dependency version is complete.
- New formula for inferring variables from equations (declared using the `@equations` options) in the DSL. The order of inference of species/variables/parameters is now:
(1) Every symbol explicitly declared using `@species`, `@variables`, and `@parameters` are assigned to the correct category.
(2) Every symbol used as a reaction reactant is inferred as a species.
(3) Every symbol not declared in (1) or (2) that occurs in an expression provided after `@equations` is inferred as a variable.
(4) Every symbol not declared in (1), (2), or (3) that occurs either as a reaction rate or stoichiometric coefficient is inferred to be a parameter.
E.g. in
```julia
@reaction_network begin
@equations V1 + S ~ V2^2
(p + S + V1), S --> 0
end
```
`S` is inferred as a species, `V1` and `V2` as variables, and `p` as a parameter. The previous special cases for the `@observables`, `@compounds`, and `@differentials` options still hold. Finally, the `@require_declaration` options (described in more detail below) can now be used to require everything to be explicitly declared.
- New formula for determining whether the default differentials have been used within an `@equations` option. Now, if any expression `D(...)` is encountered (where `...` can be anything), this is inferred as usage of the default differential D. E.g. in the following equations `D` is inferred as a differential with respect to the default independent variable:
```julia
@reaction_network begin
@equations D(V) + V ~ 1
end
@reaction_network begin
@equations D(D(V)) ~ 1
end
```
Comment on lines +23 to +31
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if someone uses D in a reaction as an implicit species?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from that question this is fine to merge I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using D in any other context (e.g. as a species, variable or parameter) should disable inference of D as a differential. Will clarify this

Please note that this cannot be used at the same time as `D` is used to represent a species, variable, or parameter.
- Array symbolics support is more consistent with ModelingToolkit v9. Parameter
arrays are no longer scalarized by Catalyst, while species and variables
arrays still are (as in ModelingToolkit). As such, parameter arrays should now
Expand Down
73 changes: 42 additions & 31 deletions src/dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ struct UndeclaredSymbolicError <: Exception
msg::String
end

function Base.showerror(io::IO, err::UndeclaredSymbolicError)
function Base.showerror(io::IO, err::UndeclaredSymbolicError)
print(io, "UndeclaredSymbolicError: ")
print(io, err.msg)
end
Expand Down Expand Up @@ -328,11 +328,6 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
parameters_declared = extract_syms(options, :parameters)
variables_declared = extract_syms(options, :variables)

# Reads equations.
vars_extracted, add_default_diff, equations = read_equations_options(
options, variables_declared; requiredec)
variables = vcat(variables_declared, vars_extracted)

isaacsas marked this conversation as resolved.
Show resolved Hide resolved
# Handle independent variables
if haskey(options, :ivs)
ivs = Tuple(extract_syms(options, :ivs))
Expand All @@ -352,23 +347,32 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
combinatoric_ratelaws = true
end

# Reads observables.
observed_vars, observed_eqs, obs_syms = read_observed_options(
options, [species_declared; variables], all_ivs; requiredec)

isaacsas marked this conversation as resolved.
Show resolved Hide resolved
# Collect species and parameters, including ones inferred from the reactions.
declared_syms = Set(Iterators.flatten((parameters_declared, species_declared,
variables)))
isaacsas marked this conversation as resolved.
Show resolved Hide resolved
variables_declared)))
species_extracted, parameters_extracted = extract_species_and_parameters!(
reactions, declared_syms; requiredec)

# Reads equations (and infers potential variables).
# Excludes any parameters already extracted (if they also was a variable).
declared_syms = union(declared_syms, species_extracted)
vars_extracted, add_default_diff, equations = read_equations_options(
options, declared_syms; requiredec)
variables = vcat(variables_declared, vars_extracted)
parameters_extracted = setdiff(parameters_extracted, vars_extracted)
isaacsas marked this conversation as resolved.
Show resolved Hide resolved

# Creates the finalised parameter and species lists.
species = vcat(species_declared, species_extracted)
parameters = vcat(parameters_declared, parameters_extracted)

# Create differential expression.
diffexpr = create_differential_expr(
options, add_default_diff, [species; parameters; variables], tiv)

# Reads observables.
observed_vars, observed_eqs, obs_syms = read_observed_options(
options, [species_declared; variables], all_ivs)

# Checks for input errors.
(sum(length.([reaction_lines, option_lines])) != length(ex.args)) &&
error("@reaction_network input contain $(length(ex.args) - sum(length.([reaction_lines,option_lines]))) malformed lines.")
Expand Down Expand Up @@ -701,7 +705,7 @@ end
# `vars_extracted`: A vector with extracted variables (lhs in pure differential equations only).
# `dtexpr`: If a differential equation is defined, the default derivative (D ~ Differential(t)) must be defined.
# `equations`: a vector with the equations provided.
function read_equations_options(options, variables_declared; requiredec = false)
function read_equations_options(options, syms_declared; requiredec = false)
# Prepares the equations. First, extracts equations from provided option (converting to block form if required).
# Next, uses MTK's `parse_equations!` function to split input into a vector with the equations.
eqs_input = haskey(options, :equations) ? options[:equations].args[3] : :(begin end)
Expand All @@ -713,34 +717,41 @@ function read_equations_options(options, variables_declared; requiredec = false)
# Loops through all equations, checks for lhs of the form `D(X) ~ ...`.
# When this is the case, the variable X and differential D are extracted (for automatic declaration).
# Also performs simple error checks.
vars_extracted = Vector{Symbol}()
vars_extracted = OrderedSet{Union{Symbol, Expr}}()
isaacsas marked this conversation as resolved.
Show resolved Hide resolved
add_default_diff = false
for eq in equations
if (eq.head != :call) || (eq.args[1] != :~)
error("Malformed equation: \"$eq\". Equation's left hand and right hand sides should be separated by a \"~\".")
end

# Checks if the equation have the format D(X) ~ ... (where X is a symbol). This means that the
# default differential has been used. X is added as a declared variable to the system, and
# we make a note that a differential D = Differential(iv) should be made as well.
lhs = eq.args[2]
# if lhs: is an expression. Is a function call. The function's name is D. Calls a single symbol.
if (lhs isa Expr) && (lhs.head == :call) && (lhs.args[1] == :D) &&
(lhs.args[2] isa Symbol)
diff_var = lhs.args[2]
if in(diff_var, forbidden_symbols_error)
error("A forbidden symbol ($(diff_var)) was used as an variable in this differential equation: $eq")
elseif (!in(diff_var, variables_declared)) && requiredec
throw(UndeclaredSymbolicError(
"Unrecognized symbol $(diff_var) was used as a variable in an equation: \"$eq\". Since the @require_declaration flag is set, all variables in equations must be explicitly declared via @variables, @species, or @parameters."))
else
add_default_diff = true
in(diff_var, variables_declared) || push!(vars_extracted, diff_var)
end
# If the default differential (`D`) is used, record that it should be decalred later on.

if !in(eq, excluded_syms) && find_D_call(eq)
requiredec && throw(UndeclaredSymbolicError(
"Unrecognized symbol D was used as a differential in an equation: \"$eq\". Since the @require_declaration flag is set, all differentials in equations must be explicitly declared using the @differentials option."))
add_default_diff = true
excluded_syms = push!(excluded_syms, :D)
end

# Any undecalred symbolic variables encountered should be extracted as variables.
add_syms_from_expr!(vars_extracted, eq, excluded_syms)
(!isempty(vars_extracted) && requiredec) && throw(UndeclaredSymbolicError(
"Unrecognized symbolic variables $(join(vars_extracted, ", ")) detected in equation expression: \"$(string(eq))\". Since the flag @require_declaration is declared, all symbolic variables must be explicitly declared with the @species, @variables, and @parameters options."))
isaacsas marked this conversation as resolved.
Show resolved Hide resolved
end

return vars_extracted, add_default_diff, equations
return collect(vars_extracted), add_default_diff, equations
end

# Searches an expresion `expr` and returns true if it have any subexpression `D(...)` (where `...` can be anything).
# Used to determine whether the default differential D has been used in any equation provided to `@equations`.
function find_D_call(expr)
return if Base.isexpr(expr, :call) && expr.args[1] == :D
true
elseif expr isa Expr
any(find_D_call, expr.args)
else
false
end
end

# Creates an expression declaring differentials. Here, `tiv` is the time independent variables,
Expand Down
32 changes: 16 additions & 16 deletions src/reactionsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ Base.@kwdef mutable struct NetworkProperties{I <: Integer, V <: BasicSymbolic{Re
stronglinkageclasses::Vector{Vector{Int}} = Vector{Vector{Int}}(undef, 0)
terminallinkageclasses::Vector{Vector{Int}} = Vector{Vector{Int}}(undef, 0)

checkedrobust::Bool = false
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to fix this thing. I have checked an re-checked and there is no actual difference (more than space removal). I think I put in, and then removed, a debug statement in this file (and the space removal happened). Hence changes appeared here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you must have some VSCode setting enabled.

checkedrobust::Bool = false
robustspecies::Vector{Int} = Vector{Int}(undef, 0)
deficiency::Int = -1
deficiency::Int = -1
end
#! format: on

Expand Down Expand Up @@ -215,11 +215,11 @@ end

### ReactionSystem Structure ###

"""
"""
WARNING!!!

The following variable is used to check that code that should be updated when the `ReactionSystem`
fields are updated has in fact been updated. Do not just blindly update this without first checking
The following variable is used to check that code that should be updated when the `ReactionSystem`
fields are updated has in fact been updated. Do not just blindly update this without first checking
all such code and updating it appropriately (e.g. serialization). Please use a search for
`reactionsystem_fields` throughout the package to ensure all places which should be updated, are updated.
"""
Expand Down Expand Up @@ -318,7 +318,7 @@ struct ReactionSystem{V <: NetworkProperties} <:
"""
discrete_events::Vector{MT.SymbolicDiscreteCallback}
"""
Metadata for the system, to be used by downstream packages.
Metadata for the system, to be used by downstream packages.
"""
metadata::Any
"""
Expand Down Expand Up @@ -480,10 +480,10 @@ function ReactionSystem(iv; kwargs...)
ReactionSystem(Reaction[], iv, [], []; kwargs...)
end

# Called internally (whether DSL-based or programmatic model creation is used).
# Called internally (whether DSL-based or programmatic model creation is used).
# Creates a sorted reactions + equations vector, also ensuring reaction is first in this vector.
# Extracts potential species, variables, and parameters from the input (if not provided as part of
# the model creation) and creates the corresponding vectors.
# Extracts potential species, variables, and parameters from the input (if not provided as part of
# the model creation) and creates the corresponding vectors.
# While species are ordered before variables in the unknowns vector, this ordering is not imposed here,
# but carried out at a later stage.
function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;
Expand All @@ -495,7 +495,7 @@ function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;
any(in(obs_vars), us_in) &&
error("Found an observable in the list of unknowns. This is not allowed.")

# Creates a combined iv vector (iv and sivs). This is used later in the function (so that
# Creates a combined iv vector (iv and sivs). This is used later in the function (so that
# independent variables can be excluded when encountered quantities are added to `us` and `ps`).
t = value(iv)
ivs = Set([t])
Expand Down Expand Up @@ -560,7 +560,7 @@ function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;
end
psv = collect(new_ps)

# Passes the processed input into the next `ReactionSystem` call.
# Passes the processed input into the next `ReactionSystem` call.
ReactionSystem(fulleqs, t, usv, psv; spatial_ivs, continuous_events,
discrete_events, observed, kwargs...)
end
Expand Down Expand Up @@ -1062,8 +1062,8 @@ end

### General `ReactionSystem`-specific Functions ###

# Checks if the `ReactionSystem` structure have been updated without also updating the
# `reactionsystem_fields` constant. If this is the case, returns `false`. This is used in
# Checks if the `ReactionSystem` structure have been updated without also updating the
# `reactionsystem_fields` constant. If this is the case, returns `false`. This is used in
# certain functionalities which would break if the `ReactionSystem` structure is updated without
# also updating these functionalities.
function reactionsystem_uptodate_check()
Expand Down Expand Up @@ -1241,7 +1241,7 @@ end
### `ReactionSystem` Remaking ###

"""
remake_ReactionSystem_internal(rs::ReactionSystem;
remake_ReactionSystem_internal(rs::ReactionSystem;
default_reaction_metadata::Vector{Pair{Symbol, T}} = Vector{Pair{Symbol, Any}}()) where {T}

Takes a `ReactionSystem` and remakes it, returning a modified `ReactionSystem`. Modifications depend
Expand Down Expand Up @@ -1274,7 +1274,7 @@ function set_default_metadata(rs::ReactionSystem; default_reaction_metadata = []
# Currently, `noise_scaling` is the only relevant metadata supported this way.
drm_dict = Dict(default_reaction_metadata)
if haskey(drm_dict, :noise_scaling)
# Finds parameters, species, and variables in the noise scaling term.
# Finds parameters, species, and variables in the noise scaling term.
ns_expr = drm_dict[:noise_scaling]
ns_syms = [Symbolics.unwrap(sym) for sym in get_variables(ns_expr)]
ns_ps = Iterators.filter(ModelingToolkit.isparameter, ns_syms)
Expand Down Expand Up @@ -1414,7 +1414,7 @@ function ModelingToolkit.compose(sys::ReactionSystem, systems::AbstractArray; na
MT.collect_scoped_vars!(newunknowns, newparams, ssys, iv)
end

if !isempty(newunknowns)
if !isempty(newunknowns)
@set! sys.unknowns = union(get_unknowns(sys), newunknowns)
sort!(get_unknowns(sys), by = !isspecies)
@set! sys.species = filter(isspecies, get_unknowns(sys))
Expand Down
Loading
Loading