diff --git a/Manifest.toml b/Manifest.toml index fb96d4e..73582b2 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1927,7 +1927,6 @@ weakdeps = ["SparseArrays", "StaticArraysCore"] [[deps.SciMLSensitivity]] deps = ["ADTypes", "Accessors", "Adapt", "ArrayInterface", "ChainRulesCore", "DiffEqBase", "DiffEqCallbacks", "DiffEqNoiseProcess", "Distributions", "FastBroadcast", "FiniteDiff", "ForwardDiff", "FunctionProperties", "FunctionWrappersWrappers", "Functors", "GPUArraysCore", "LinearAlgebra", "LinearSolve", "Markdown", "OrdinaryDiffEqCore", "PreallocationTools", "QuadGK", "Random", "RandomNumbers", "RecursiveArrayTools", "Reexport", "ReverseDiff", "SciMLBase", "SciMLJacobianOperators", "SciMLOperators", "SciMLStructures", "StaticArrays", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tracker", "Zygote"] git-tree-sha1 = "a8a6e5e1bfb889c392d67245911fd4e1ad0ba562" -path = "/home/keno/.julia/dev/SciMLSensitivity" repo-rev = "kf/mindep4" repo-url = "https://github.com/CedarEDA/SciMLSensitivity.jl" uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1" @@ -2040,9 +2039,9 @@ weakdeps = ["ChainRulesCore"] [[deps.StateSelection]] deps = ["DocStringExtensions", "FindFirstFunctions", "Graphs", "LinearAlgebra", "Setfield", "SparseArrays", "UnPack"] -git-tree-sha1 = "bdb37d184a0a529c307875a6bc066f561e0c25f6" +git-tree-sha1 = "e77065f876d178339a04947200593b12122724f0" repo-rev = "main" -repo-url = "https://github.com/JuliaComputing/StateSelection.jl.git" +repo-url = "https://github.com/JuliaComputing/StateSelection.jl" uuid = "64909d44-ed92-46a8-bbd9-f047dfbdc84b" version = "0.2.1" diff --git a/Project.toml b/Project.toml index c1d8752..6125446 100644 --- a/Project.toml +++ b/Project.toml @@ -50,7 +50,7 @@ Cthulhu = {rev = "master", url = "https://github.com/JuliaDebug/Cthulhu.jl.git"} DifferentiationInterface = {rev = "main", subdir = "DifferentiationInterface", url = "https://github.com/Keno/DifferentiationInterface.jl"} Diffractor = {rev = "main", url = "https://github.com/JuliaDiff/Diffractor.jl.git"} SimpleNonlinearSolve = {rev = "master", subdir = "lib/SimpleNonlinearSolve", url = "https://github.com/SciML/NonlinearSolve.jl.git"} -StateSelection = {rev = "main", url = "https://github.com/JuliaComputing/StateSelection.jl.git"} +StateSelection = {rev = "main", url = "https://github.com/JuliaComputing/StateSelection.jl"} [compat] Accessors = "0.1.36" diff --git a/src/reflection.jl b/src/reflection.jl index 7bb595f..457a9bd 100644 --- a/src/reflection.jl +++ b/src/reflection.jl @@ -26,18 +26,34 @@ end code_ad_by_type(@nospecialize(tt::Type); kwargs...) = _code_ad_by_type(tt; kwargs...).inferred.ir -function code_structure_by_type(@nospecialize(tt::Type); world::UInt = Base.tls_world_age(), result = false, kwargs...) +function code_structure_by_type(@nospecialize(tt::Type); world::UInt = Base.tls_world_age(), result = false, matched = false, mode = DAE, kwargs...) ci = _code_ad_by_type(tt; world, kwargs...) - # Perform or lookup DAECompiler specific analysis for this system. _result = structural_analysis!(ci, world) isa(_result, UncompilableIPOResult) && throw(_result.error) - return result ? _result : _result.ir -end + !matched && return result ? _result : _result.ir + result = _result + + structure = make_structure_from_ipo(result) + + tstate = TransformationState(result, structure, copy(result.total_incidence)) + err = StateSelection.check_consistency(tstate, nothing) + err !== nothing && throw(err) + ret = top_level_state_selection!(tstate) + isa(ret, UncompilableIPOResult) && throw(ret.error) + + (diff_key, init_key) = ret + key = in(mode, (DAE, DAENoInit, ODE, ODENoInit)) ? diff_key : init_key + + var_eq_matching = matching_for_key(tstate, key) + return StateSelection.MatchedSystemStructure(result, structure, var_eq_matching) +end """ @code_structure ssrm() @code_structure world = UInt(1) [other_parameters...] ssrm() + @code_structure result = true ssrm() # returns DAEIPOResult + @code_structure matched = true ssrm() # returns MatchedSystemStructure Return the IR after structural analysis of the passed function call. @@ -48,6 +64,10 @@ then goes through structural analysis, and the resulting IR is returned. Parameters: - `world::UInt = Base.get_world_counter()`: the world in which to operate. - `force_inline_all::Bool = false`: if `true`, make inlining heuristics choose to always inline where possible. +- `result::Bool = false`: if `true`, return the full [`DAEIPOResult`](@ref) instead of just the `IRCode`. +- `matched::Bool = false`: if `true`, return the `MatchedSystemStructure` after top-level state selection + for visualization purposes. +- `mode::GenerationMode = DAE`: specifies the generation mode to use for the compilation pipeline. Only used when `matched` is `true`. !!! warning This will cache analysis results. You might want to invalidate with `DAECompiler.refresh()` between calls to `@code_structure`. diff --git a/src/transform/codegen/rhs.jl b/src/transform/codegen/rhs.jl index 5ab11e0..f62d721 100644 --- a/src/transform/codegen/rhs.jl +++ b/src/transform/codegen/rhs.jl @@ -187,7 +187,7 @@ function rhs_finish!( assgn = var_assignment[varnum] if assgn == nothing - display(StateSelection.MatchedSystemStructure(structure, var_eq_matching)) + display(StateSelection.MatchedSystemStructure(result, structure, var_eq_matching)) @sshow varnum error("Variable left over in IR that doesn't have an assignment") end diff --git a/src/transform/state_selection.jl b/src/transform/state_selection.jl index 8ec1246..6d638a8 100644 --- a/src/transform/state_selection.jl +++ b/src/transform/state_selection.jl @@ -149,6 +149,22 @@ StateSelection.BipartiteGraphs.overview_label(::Type{WrongEquation}) = ('✕', " const IPOMatches = Union{Unassigned, SelectedState, StateInvariant, AlgebraicState, FullyLinear, WrongEquation, InOut} const IPOMatching = StateSelection.Matching{IPOMatches} +struct CalleeInfo + callees::Union{Nothing, Vector{StructuralSSARef}} +end +CalleeInfo(callee::StructuralSSARef) = CalleeInfo([callee]) +StateSelection.BipartiteGraphs.overview_label(::Type{CalleeInfo}) = ('%', "Used in callee (referenced by structural SSA)", 178) + +function Base.show(io::IO, (; callees)::CalleeInfo) + callees === nothing && return + first = true + for callee in callees + !first && print(io, ", ") + printstyled(io, "%", callee.id; color = 178) + first = false + end +end + function top_level_state_selection!(tstate) (; result, structure) = tstate diff --git a/src/transform/tearing/schedule.jl b/src/transform/tearing/schedule.jl index 0361a47..eff885c 100644 --- a/src/transform/tearing/schedule.jl +++ b/src/transform/tearing/schedule.jl @@ -645,6 +645,28 @@ end Compiler.optimizer_lattice(::DummyOptInterp) = Compiler.PartialsLattice(EqStructureLattice()) Compiler.get_inference_world(interp::DummyOptInterp) = interp.world +function StateSelection.SSAUses(result::DAEIPOResult) + eq_callees = Union{Nothing, StructuralSSARef}[] + var_callees_dict = Dict{Int,Vector{StructuralSSARef}}() + for value in result.eq_callee_mapping + if value === nothing + push!(eq_callees, nothing) + continue + end + callee = only(unique(first.(value))) + push!(eq_callees, callee) + for (_, var) in value + push!(get!(Vector{StructuralSSARef}, var_callees_dict, var), callee) + end + end + var_callees = [get(var_callees_dict, i, nothing) for i in 1:maximum(keys(var_callees_dict); init = 0)] + return StateSelection.SSAUses(CalleeInfo.(eq_callees), CalleeInfo.(var_callees)) +end + +function StateSelection.MatchedSystemStructure(result::DAEIPOResult, structure, var_eq_matching) + StateSelection.MatchedSystemStructure(structure, var_eq_matching, StateSelection.SSAUses(result)) +end + function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, settings::Settings) result_ci = find_matching_ci(ci->isa(ci.owner, SICMSpec) && ci.owner.key == key, ci.def, world) if result_ci !== nothing @@ -656,7 +678,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To var_eq_matching = matching_for_key(state, key) - mss = StateSelection.MatchedSystemStructure(structure, var_eq_matching) + mss = StateSelection.MatchedSystemStructure(result, structure, var_eq_matching) (eq_orders, callee_schedules) = compute_eq_schedule(key, total_incidence, result, mss) ir = index_lowering_ad!(state, key) @@ -793,7 +815,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To display(mss) cstructure = make_structure_from_ipo(callee_result) cvar_eq_matching = matching_for_key(callee_result, callee_key, cstructure) - display(StateSelection.MatchedSystemStructure(cstructure, cvar_eq_matching)) + display(StateSelection.MatchedSystemStructure(callee_result, cstructure, cvar_eq_matching)) @sshow eq_orders @sshow callee_result.total_incidence[callee_eq] @sshow total_incidence[caller_eq] diff --git a/test/reflection.jl b/test/reflection.jl index de7b424..79acd47 100644 --- a/test/reflection.jl +++ b/test/reflection.jl @@ -5,6 +5,8 @@ using DAECompiler using DAECompiler.Intrinsics using Compiler: IRCode +@noinline callee!(x, y) = always!(x + y) + function ssrm() x₁ = continuous() x₂ = continuous() @@ -20,7 +22,7 @@ function ssrm() always!((ẍ₁+ẍ₂)-(ẋ₁+ẋ₂)+x₄) always!((ẍ₁+ẍ₂)+x₃) always!(x₂+ẍ₃+ẋ₄) - always!(x₃+ẋ₄) + callee!(x₃,ẋ₄) end ir = code_structure_by_type(Tuple{typeof(ssrm)}) @@ -33,5 +35,8 @@ ir = @code_structure world = Base.get_world_counter() ssrm() @test isa(ir, IRCode) result = @code_structure result = true ssrm() @test isa(result, DAECompiler.DAEIPOResult) +result = @code_structure matched = true ssrm() +@test isa(result, DAECompiler.StateSelection.MatchedSystemStructure) +@test contains(sprint(show, result), r"%\d+") end # module