Skip to content

Annotate use by callees in MatchedSystemStructure #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Manifest.toml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Cthulhu = {rev = "master", url = "https://github.com/JuliaDebug/Cthulhu.jl.git"}
DifferentiationInterface = {rev = "main", subdir = "DifferentiationInterface", url = "https://github.com/Keno/DifferentiationInterface.jl"}
Diffractor = {rev = "main", url = "https://github.com/JuliaDiff/Diffractor.jl.git"}
SimpleNonlinearSolve = {rev = "master", subdir = "lib/SimpleNonlinearSolve", url = "https://github.com/SciML/NonlinearSolve.jl.git"}
StateSelection = {rev = "main", url = "https://github.com/JuliaComputing/StateSelection.jl.git"}
StateSelection = {rev = "main", url = "https://github.com/JuliaComputing/StateSelection.jl"}

[compat]
Accessors = "0.1.36"
Expand Down
28 changes: 24 additions & 4 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,34 @@ end
code_ad_by_type(@nospecialize(tt::Type); kwargs...) =
_code_ad_by_type(tt; kwargs...).inferred.ir

function code_structure_by_type(@nospecialize(tt::Type); world::UInt = Base.tls_world_age(), result = false, kwargs...)
function code_structure_by_type(@nospecialize(tt::Type); world::UInt = Base.tls_world_age(), result = false, matched = false, mode = DAE, kwargs...)
ci = _code_ad_by_type(tt; world, kwargs...)
# Perform or lookup DAECompiler specific analysis for this system.
_result = structural_analysis!(ci, world)
isa(_result, UncompilableIPOResult) && throw(_result.error)
return result ? _result : _result.ir
end
!matched && return result ? _result : _result.ir
result = _result

structure = make_structure_from_ipo(result)

tstate = TransformationState(result, structure, copy(result.total_incidence))
err = StateSelection.check_consistency(tstate, nothing)
err !== nothing && throw(err)

ret = top_level_state_selection!(tstate)
isa(ret, UncompilableIPOResult) && throw(ret.error)

(diff_key, init_key) = ret
key = in(mode, (DAE, DAENoInit, ODE, ODENoInit)) ? diff_key : init_key

var_eq_matching = matching_for_key(tstate, key)
return StateSelection.MatchedSystemStructure(result, structure, var_eq_matching)
end

"""
@code_structure ssrm()
@code_structure world = UInt(1) [other_parameters...] ssrm()
@code_structure result = true ssrm() # returns DAEIPOResult
@code_structure matched = true ssrm() # returns MatchedSystemStructure

Return the IR after structural analysis of the passed function call.

Expand All @@ -48,6 +64,10 @@ then goes through structural analysis, and the resulting IR is returned.
Parameters:
- `world::UInt = Base.get_world_counter()`: the world in which to operate.
- `force_inline_all::Bool = false`: if `true`, make inlining heuristics choose to always inline where possible.
- `result::Bool = false`: if `true`, return the full [`DAEIPOResult`](@ref) instead of just the `IRCode`.
- `matched::Bool = false`: if `true`, return the `MatchedSystemStructure` after top-level state selection
for visualization purposes.
- `mode::GenerationMode = DAE`: specifies the generation mode to use for the compilation pipeline. Only used when `matched` is `true`.

!!! warning
This will cache analysis results. You might want to invalidate with `DAECompiler.refresh()` between calls to `@code_structure`.
Expand Down
2 changes: 1 addition & 1 deletion src/transform/codegen/rhs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ function rhs_finish!(

assgn = var_assignment[varnum]
if assgn == nothing
display(StateSelection.MatchedSystemStructure(structure, var_eq_matching))
display(StateSelection.MatchedSystemStructure(result, structure, var_eq_matching))
@sshow varnum
error("Variable left over in IR that doesn't have an assignment")
end
Expand Down
16 changes: 16 additions & 0 deletions src/transform/state_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ StateSelection.BipartiteGraphs.overview_label(::Type{WrongEquation}) = ('✕', "
const IPOMatches = Union{Unassigned, SelectedState, StateInvariant, AlgebraicState, FullyLinear, WrongEquation, InOut}
const IPOMatching = StateSelection.Matching{IPOMatches}

struct CalleeInfo
callees::Union{Nothing, Vector{StructuralSSARef}}
end
CalleeInfo(callee::StructuralSSARef) = CalleeInfo([callee])
StateSelection.BipartiteGraphs.overview_label(::Type{CalleeInfo}) = ('%', "Used in callee (referenced by structural SSA)", 178)

function Base.show(io::IO, (; callees)::CalleeInfo)
callees === nothing && return
first = true
for callee in callees
!first && print(io, ", ")
printstyled(io, "%", callee.id; color = 178)
first = false
end
end

function top_level_state_selection!(tstate)
(; result, structure) = tstate

Expand Down
26 changes: 24 additions & 2 deletions src/transform/tearing/schedule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,28 @@ end
Compiler.optimizer_lattice(::DummyOptInterp) = Compiler.PartialsLattice(EqStructureLattice())
Compiler.get_inference_world(interp::DummyOptInterp) = interp.world

function StateSelection.SSAUses(result::DAEIPOResult)
eq_callees = Union{Nothing, StructuralSSARef}[]
var_callees_dict = Dict{Int,Vector{StructuralSSARef}}()
for value in result.eq_callee_mapping
if value === nothing
push!(eq_callees, nothing)
continue
end
callee = only(unique(first.(value)))
push!(eq_callees, callee)
for (_, var) in value
push!(get!(Vector{StructuralSSARef}, var_callees_dict, var), callee)
end
end
var_callees = [get(var_callees_dict, i, nothing) for i in 1:maximum(keys(var_callees_dict); init = 0)]
return StateSelection.SSAUses(CalleeInfo.(eq_callees), CalleeInfo.(var_callees))
end

function StateSelection.MatchedSystemStructure(result::DAEIPOResult, structure, var_eq_matching)
StateSelection.MatchedSystemStructure(structure, var_eq_matching, StateSelection.SSAUses(result))
end

function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, settings::Settings)
result_ci = find_matching_ci(ci->isa(ci.owner, SICMSpec) && ci.owner.key == key, ci.def, world)
if result_ci !== nothing
Expand All @@ -656,7 +678,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To

var_eq_matching = matching_for_key(state, key)

mss = StateSelection.MatchedSystemStructure(structure, var_eq_matching)
mss = StateSelection.MatchedSystemStructure(result, structure, var_eq_matching)
(eq_orders, callee_schedules) = compute_eq_schedule(key, total_incidence, result, mss)

ir = index_lowering_ad!(state, key)
Expand Down Expand Up @@ -793,7 +815,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To
display(mss)
cstructure = make_structure_from_ipo(callee_result)
cvar_eq_matching = matching_for_key(callee_result, callee_key, cstructure)
display(StateSelection.MatchedSystemStructure(cstructure, cvar_eq_matching))
display(StateSelection.MatchedSystemStructure(callee_result, cstructure, cvar_eq_matching))
@sshow eq_orders
@sshow callee_result.total_incidence[callee_eq]
@sshow total_incidence[caller_eq]
Expand Down
7 changes: 6 additions & 1 deletion test/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using DAECompiler
using DAECompiler.Intrinsics
using Compiler: IRCode

@noinline callee!(x, y) = always!(x + y)

function ssrm()
x₁ = continuous()
x₂ = continuous()
Expand All @@ -20,7 +22,7 @@ function ssrm()
always!((ẍ₁+ẍ₂)-(ẋ₁+ẋ₂)+x₄)
always!((ẍ₁+ẍ₂)+x₃)
always!(x₂+ẍ₃+ẋ₄)
always!(x₃+ẋ₄)
callee!(x₃,ẋ₄)
end

ir = code_structure_by_type(Tuple{typeof(ssrm)})
Expand All @@ -33,5 +35,8 @@ ir = @code_structure world = Base.get_world_counter() ssrm()
@test isa(ir, IRCode)
result = @code_structure result = true ssrm()
@test isa(result, DAECompiler.DAEIPOResult)
result = @code_structure matched = true ssrm()
@test isa(result, DAECompiler.StateSelection.MatchedSystemStructure)
@test contains(sprint(show, result), r"%\d+")

end # module