From a58ebed369641452ee0b29e0780d36f20ff2e9dc Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 4 Nov 2024 16:17:47 +0100 Subject: [PATCH] add ForeachConnectedSubsystem for effects modifying downstream Subsystems --- src/GraphDynamics.jl | 4 +- src/graph_solve.jl | 104 ++++++++++++++++++++++++++++++++++++++----- src/utils.jl | 2 +- 3 files changed, 97 insertions(+), 13 deletions(-) diff --git a/src/GraphDynamics.jl b/src/GraphDynamics.jl index 47dfe62..0e2122d 100644 --- a/src/GraphDynamics.jl +++ b/src/GraphDynamics.jl @@ -32,7 +32,8 @@ end isstochastic, - event_times + event_times, + ForeachConnectedSubsystem ) export @@ -231,6 +232,7 @@ add methods to this function if a subsystem or connection type has a discrete ev event_times(::Any) = () abstract type ConnectionRule end +Base.zero(::T) where {T <: ConnectionRule} = zero(T) struct NotConnected <: ConnectionRule end (::NotConnected)(l, r) = zero(promote_type(eltype(l), eltype(r))) struct ConnectionMatrix{N, CR, Tup <: NTuple{N, NTuple{N, Union{NotConnected, AbstractMatrix{CR}}}}} diff --git a/src/graph_solve.jl b/src/graph_solve.jl index 4a7cf5a..ede925b 100644 --- a/src/graph_solve.jl +++ b/src/graph_solve.jl @@ -257,11 +257,12 @@ function _continuous_affect!(integrator, sview = @view states_partitioned[i][j] pview = @view params_partitioned[i][j] sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) + F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices) if continuous_events_require_inputs(sys) input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices) - apply_continuous_event!(integrator, sview, pview, sys, input) + apply_continuous_event!(integrator, sview, pview, sys, F, input) else - apply_continuous_event!(integrator, sview, pview, sys) + apply_continuous_event!(integrator, sview, pview, sys, F) end end offset += N @@ -326,34 +327,38 @@ end t) where {Len, NConn} quote @nexprs $Len i -> begin + # First we apply events to the states if has_discrete_events(eltype(states_partitioned[i])) - for j ∈ eachindex(states_partitioned[i]) - sys_dst = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) - sview_dst = @view states_partitioned[i][j] - pview_dst = @view params_partitioned[i][j] - if discrete_event_condition(sys_dst, t) - if discrete_events_require_inputs(sys_dst) + @inbounds for j ∈ eachindex(states_partitioned[i]) + sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) + sview = @view states_partitioned[i][j] + pview = @view params_partitioned[i][j] + if discrete_event_condition(sys, t) + # println("helllllo") + F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices) + if discrete_events_require_inputs(sys) input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices) - apply_discrete_event!(integrator, sview_dst, pview_dst, sys_dst, input) + apply_discrete_event!(integrator, sview, pview, sys, F, input) else - apply_discrete_event!(integrator, sview_dst, pview_dst, sys_dst) + apply_discrete_event!(integrator, sview, pview, sys, F) end end end end + # Then we do the connection events @nexprs $NConn nc -> begin @nexprs $Len k -> begin f = _discrete_connection_affect!(Val(i), Val(k), Val(nc), t, states_partitioned, params_partitioned, connection_matrices, integrator) foreach(f, eachindex(states_partitioned[i])) - end end end end end + function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t, states_partitioned::NTuple{Len, Any}, params_partitioned::NTuple{Len, Any}, @@ -397,3 +402,80 @@ function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t, end end end + + +#----------------------------------------------------------------------- + +""" + ForeachConnectedSubsystem + +This is a callable struct which takes in a function, and then calls that function on each subsystem which has a connection leading to it +from some previously specified subsystem. + +That is, writing +```julia +F = ForeachConnectedSubsystem{k}(l, states_partitioned, params_partitioned, connection_matrices) + +F() do conn, sys_dst, states_view_dst, params_view_dst + [...] +end +``` +is like a type stable version of writing +``` +for i in eachindex(states_partitioned) + for nc in eachindex(connection_matrices) + M = connection_matrices[nc][i, k] + for j in eachindex(states_partitioned[k]) + conn = M[l, j] + if !iszero(conn) + states_view_dst = @view states_partitioned[i][j] + params_view_dst = @view params_partitioned[i][j] + sys_dst = Subsystem(states_view_dst[], params_view_dst[]) + [...] # <------- User code here + ends + end + end +end +``` +""" +struct ForeachConnectedSubsystem{k, Len, NConn, S, P, CMs} + l::Int + states_partitioned::S + params_partitioned::P + connection_matrices::CMs + function ForeachConnectedSubsystem{k}(l, + states_partitioned::NTuple{Len, Any}, + params_partitioned::NTuple{Len, Any}, + connection_matrices::ConnectionMatrices{NConn}) where {k, Len, NConn} + S = typeof(states_partitioned) + P = typeof(params_partitioned) + CMs = typeof(connection_matrices) + new{k, Len, NConn, S, P, CMs}(l, states_partitioned, params_partitioned, connection_matrices) + end +end + +@generated function ((;l, + states_partitioned, + params_partitioned, + connection_matrices)::ForeachConnectedSubsystem{k, Len, NConn})(f::F) where {k, Len, NConn, F} + quote + @nexprs $Len i -> begin + @nexprs $NConn nc -> begin + M = connection_matrices[nc][k, i] + if M isa NotConnected + nothing + else + for j ∈ eachindex(states_partitioned[i]) + @inbounds conn = M[l, j] + if !iszero(conn) + @inbounds states_view_dst = @view states_partitioned[i][j] + @inbounds params_view_dst = @view params_partitioned[i][j] + sys_dst = Subsystem(states_view_dst[], params_view_dst[]) + f(conn, sys_dst, states_view_dst, params_view_dst) + end + end + end + end + end + end +end diff --git a/src/utils.jl b/src/utils.jl index 4b79c6c..2d295fa 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,7 +3,7 @@ valueof(x) = x # this just makes it so that I can easily replace all uses of `@inbounds ex` with just `ex`. macro inbounds(ex) - # ex + #esc(ex) esc(:($Base.@inbounds $ex)) end