Skip to content

Commit

Permalink
Add generated, performant snake_case for NodeID. (#1982)
Browse files Browse the repository at this point in the history
Fixes #1981

This precalculates all snakecases for NodeType and compiles it. @visr
could you benchmark this?

---------

Co-authored-by: Martijn Visser <mgvisser@gmail.com>
  • Loading branch information
evetion and visr authored Dec 21, 2024
1 parent 0ecc373 commit de9b527
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 12 deletions.
9 changes: 5 additions & 4 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@ function update_concentrations!(u, t, integrator)::Nothing
# of the basins after processing inflows only
cumulative_in .= 0.0

mass .+= concentration[1, :, :] .* vertical_flux.drainage * dt
@views mass .+= concentration[1, :, :] .* vertical_flux.drainage * dt
basin.concentration_data.cumulative_in .= vertical_flux.drainage * dt

# Precipitation depends on fixed area
for node_id in basin.node_id
fixed_area = basin_areas(basin, node_id.idx)[end]
added_precipitation = fixed_area * vertical_flux.precipitation[node_id.idx] * dt

mass[node_id.idx, :] .+= concentration[2, node_id.idx, :] .* added_precipitation
@views mass[node_id.idx, :] .+=
concentration[2, node_id.idx, :] .* added_precipitation
cumulative_in[node_id.idx] += added_precipitation
end

Expand All @@ -212,7 +212,8 @@ function update_concentrations!(u, t, integrator)::Nothing
if active
outflow_id = edge[1].edge[2]
volume = integral(flow_rate, tprev, t)
mass[outflow_id.idx, :] .+= flow_boundary.concentration[id.idx, :] .* volume
@views mass[outflow_id.idx, :] .+=
flow_boundary.concentration[id.idx, :] .* volume
cumulative_in[outflow_id.idx] += volume
end
end
Expand Down
10 changes: 5 additions & 5 deletions core/src/concentration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ function mass_updates_user_demand!(integrator::DEIntegrator)::Nothing
(; basin, user_demand) = integrator.p
(; concentration_state, mass) = basin.concentration_data

for (inflow_edge, outflow_edge) in
zip(user_demand.inflow_edge, user_demand.outflow_edge)
@views for (inflow_edge, outflow_edge) in
zip(user_demand.inflow_edge, user_demand.outflow_edge)
from_node = inflow_edge.edge[1]
to_node = outflow_edge.edge[2]
userdemand_idx = outflow_edge.edge[1].idx
Expand Down Expand Up @@ -41,7 +41,7 @@ function mass_inflows_basin!(integrator::DEIntegrator)::Nothing
for (inflow_edge, outflow_edge) in zip(state_inflow_edge, state_outflow_edge)
from_node = inflow_edge.edge[1]
to_node = outflow_edge.edge[2]
if from_node.type == NodeType.Basin
@views if from_node.type == NodeType.Basin
flow = flow_update_on_edge(integrator, inflow_edge.edge)
if flow < 0
cumulative_in[from_node.idx] -= flow
Expand All @@ -67,7 +67,7 @@ function mass_inflows_basin!(integrator::DEIntegrator)::Nothing
flow = flow_update_on_edge(integrator, outflow_edge.edge)
if flow > 0
cumulative_in[to_node.idx] += flow
if from_node.type == NodeType.Basin
@views if from_node.type == NodeType.Basin
mass[to_node.idx, :] .+= concentration_state[from_node.idx, :] .* flow
elseif from_node.type == NodeType.LevelBoundary
mass[to_node.idx, :] .+=
Expand Down Expand Up @@ -95,7 +95,7 @@ function mass_outflows_basin!(integrator::DEIntegrator)::Nothing
(; state_inflow_edge, state_outflow_edge, basin) = integrator.p
(; mass, concentration_state) = basin.concentration_data

for (inflow_edge, outflow_edge) in zip(state_inflow_edge, state_outflow_edge)
@views for (inflow_edge, outflow_edge) in zip(state_inflow_edge, state_outflow_edge)
from_node = inflow_edge.edge[1]
to_node = outflow_edge.edge[2]
if from_node.type == NodeType.Basin
Expand Down
12 changes: 12 additions & 0 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ const SolverStats = @NamedTuple{
5 Drainage = 6 Precipitation = 7
Base.to_index(id::Substance.T) = Int(id) # used to index into concentration matrices

@generated function config.snake_case(nt::NodeType.T)
ex = quote end
for (sym, _) in EnumX.symbol_map(NodeType.T)
sc = QuoteNode(config.snake_case(sym))
t = NodeType.T(sym)
push!(ex.args, :(nt === $t && return $sc))
end
push!(ex.args, :(return :nothing)) # type stability
ex
end

# Support creating a NodeType enum instance from a symbol or string
function NodeType.T(s::Symbol)::NodeType.T
symbol_map = EnumX.symbol_map(NodeType.T)
Expand Down Expand Up @@ -86,6 +97,7 @@ Base.convert(::Type{Int32}, id::NodeID) = id.value
Base.broadcastable(id::NodeID) = Ref(id)
Base.:(==)(id_1::NodeID, id_2::NodeID) = id_1.type == id_2.type && id_1.value == id_2.value
Base.show(io::IO, id::NodeID) = print(io, id.type, " #", id.value)
config.snake_case(id::NodeID) = config.snake_case(id.type)

function Base.isless(id_1::NodeID, id_2::NodeID)::Bool
if id_1.type != id_2.type
Expand Down
6 changes: 3 additions & 3 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ function get_variable_ref(
PreallocationRef(cache(1), flow_idx; from_du = true)
end
else
node = getfield(p, snake_case(Symbol(node_id.type)))
node = getfield(p, snake_case(node_id))
PreallocationRef(node.flow_rate, node_id.idx)
end
else
Expand Down Expand Up @@ -814,7 +814,7 @@ function collect_control_mappings!(p)::Nothing

for node_type in instances(NodeType.T)
node_type == NodeType.Terminal && continue
node = getfield(p, Symbol(snake_case(string(node_type))))
node = getfield(p, snake_case(node_type))
if hasfield(typeof(node), :control_mapping)
control_mappings[node_type] = node.control_mapping
end
Expand Down Expand Up @@ -1096,7 +1096,7 @@ function get_state_index(
component_name = if id.type == NodeType.UserDemand
inflow ? :user_demand_inflow : :user_demand_outflow
else
snake_case(Symbol(id.type))
snake_case(id)
end
for (comp, range) in pairs(NT)
if comp == component_name
Expand Down

0 comments on commit de9b527

Please sign in to comment.