From bb2a1d1c843e80e92c3e80da6d2e0a36fdf08ff0 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Sun, 14 Jul 2024 20:27:21 +0200 Subject: [PATCH] inference: compute edges more precisely in post-inference Start computing edges from stmt_info later (after CodeInstance is able to have been allocated for recursion) instead of immediately. --- base/compiler/abstractinterpretation.jl | 6 +- base/compiler/stmtinfo.jl | 11 +-- base/compiler/tfuncs.jl | 2 +- base/compiler/typeinfer.jl | 92 +++++++++++++++++++++++++ base/compiler/types.jl | 1 - 5 files changed, 103 insertions(+), 9 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 4c3286c7e2737..c886eddde2d3e 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -286,7 +286,7 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes:: if matches === nothing return FailedMethodMatch("For one of the union split cases, too many methods matched") end - push!(infos, MethodMatchInfo(matches)) + push!(infos, MethodMatchInfo(matches, sig_n, mt)) for m in matches push!(applicable, m) push!(applicable_argtypes, arg_n) @@ -323,7 +323,7 @@ function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(a # (assume this will always be true, so we don't compute / update valid age in this case) return FailedMethodMatch("Too many methods matched") end - info = MethodMatchInfo(matches) + info = MethodMatchInfo(matches, atype, mt) fullmatch = any(match::MethodMatch->match.fully_covers, matches) return MethodMatches( matches.matches, info, matches.valid_worlds, mt, fullmatch) @@ -2053,7 +2053,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn end end rt = from_interprocedural!(interp, rt, sv, arginfo, sig) - info = InvokeCallInfo(match, const_result) + info = InvokeCallInfo(match, const_result, lookupsig) edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge) return CallMeta(rt, Any, effects, info) end diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index 25f5bb894eaa9..550ba0b229ff4 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -20,13 +20,15 @@ struct NoCallInfo <: CallInfo end """ info::MethodMatchInfo <: CallInfo -Captures the result of a `:jl_matching_methods` lookup for the given call (`info.results`). -This info may then be used by the optimizer to inline the matches, without having -to re-consult the method table. This info is illegal on any statement that is -not a call to a generic function. +Captures the essential arguments and result of a `:jl_matching_methods` lookup +for the given call (`info.results`). This info may then be used by the +optimizer, without having to re-consult the method table. +This info is illegal on any statement that is not a call to a generic function. """ struct MethodMatchInfo <: CallInfo results::MethodLookupResult + atype # ::Type + mt::MethodTable end nsplit_impl(info::MethodMatchInfo) = 1 getsplit_impl(info::MethodMatchInfo, idx::Int) = (@assert idx == 1; info.results) @@ -165,6 +167,7 @@ Optionally keeps `info.result::InferenceResult` that keeps constant information. struct InvokeCallInfo <: CallInfo match::MethodMatch result::Union{Nothing,ConstResult} + atype # ::Type end """ diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 28e883d83312c..61b7858e641cb 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -2905,7 +2905,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s else call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, sv, #=max_methods=#-1) end - info = verbose_stmt_info(interp) ? MethodResultPure(ReturnTypeCallInfo(call.info)) : MethodResultPure() + info = MethodResultPure(ReturnTypeCallInfo(call.info)) rt = widenslotwrapper(call.rt) if isa(rt, Const) # output was computed to be constant diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 90fd8466b0394..59771944f1257 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -527,6 +527,12 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter) append!(s_edges, edges) empty!(edges) end + #s_edges_new = compute_edges(me) + #println(Any[z isa MethodTable ? z.name : z for z in s_edges_new]) + #if length(s_edges) != length(s_edges_new) || !all(i -> isassigned(s_edges, i) ? isassigned(s_edges_new, i) && s_edges_new[i] === s_edges[i] : !isassigned(s_edges_new, i), length(s_edges)) + # println(sizehint!(s_edges, length(s_edges))) + # println(sizehint!(s_edges_new, length(s_edges_new))) + #end if me.src.edges !== nothing && me.src.edges !== Core.svec() append!(s_edges, me.src.edges::Vector) end @@ -644,6 +650,92 @@ function store_backedges(caller::MethodInstance, edges::Vector{Any}) return nothing end +add_edges!(edges::Vector{Any}, info::MethodResultPure) = add_edges!(edges, info.info) +add_edges!(edges::Vector{Any}, info::ConstCallInfo) = add_edges!(edges, info.call) +add_edges!(edges::Vector{Any}, info::OpaqueClosureCreateInfo) = nothing # TODO(jwn) +add_edges!(edges::Vector{Any}, info::ReturnTypeCallInfo) = add_edges!(edges, info.info) +function add_edges!(edges::Vector{Any}, info::ApplyCallInfo) + add_edges!(edges, info.call) + for arg in info.arginfo + arg === nothing && continue + for edge in arg.each + add_edges!(edges, edge.info) + end + end +end +add_edges!(edges::Vector{Any}, info::ModifyOpInfo) = add_edges!(edges, info.call) +add_edges!(edges::Vector{Any}, info::UnionSplitInfo) = for split in info.matches; add_edges!(edges, split); end +add_edges!(edges::Vector{Any}, info::UnionSplitApplyCallInfo) = for split in info.infos; add_edges!(edges, split); end +add_edges!(edges::Vector{Any}, info::FinalizerInfo) = nothing +add_edges!(edges::Vector{Any}, info::NoCallInfo) = nothing +function add_edges!(edges::Vector{Any}, info::MethodMatchInfo) + matches = info.results.matches + #if length(matches) == 1 && !info.results.ambig && (matches[end]::Core.MethodMatch).fully_covers + # push!(edges, specialize_method(matches[1])) + #elseif isempty(matches) || info.results.ambig || !(matches[end]::Core.MethodMatch).fully_covers + #else + # push!(edges, length(matches)) + # for m in matches + # push!(edges, specialize_method(m)) + # end + #end + if isempty(matches) || !(matches[end]::Core.MethodMatch).fully_covers + exists = false + for i in 1:length(edges) + if edges[i] === info.mt && edges[i + 1] == info.atype + exists = true + break + end + end + if !exists + push!(edges, info.mt) + push!(edges, info.atype) + end + end + for m in matches + mi = specialize_method(m) + exists = false + for i in 1:length(edges) + if edges[i] === mi && !(i > 1 && edges[i - 1] isa Type) + exists = true + break + end + end + exists || push!(edges, mi) + end +end +function add_edges!(edges::Vector{Any}, info::InvokeCallInfo) + #push!(edges, 1) + mi = specialize_method(info.match) + exists = false + for i in 2:length(edges) + if edges[i] === mi && edges[i - 1] isa Type && edges[i - 1] == info.atype + exists = true + break + end + end + if !exists + push!(edges, info.atype) + push!(edges, mi) + end + nothing +end + +function compute_edges(sv::InferenceState) + edges = [] + for i in 1:length(sv.stmt_info) + info = sv.stmt_info[i] + #rt = sv.ssavaluetypes[i] + #effects = EFFECTS_TOTAL # sv.stmt_effects[i] + #if rt === Any && effects === Effects() + # continue + #end + add_edges!(edges, info) + end + return edges +end + + function record_slot_assign!(sv::InferenceState) # look at all assignments to slots # and union the set of types stored there diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 7021601bf87cf..34310b3db0aba 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -424,7 +424,6 @@ function add_remark! end may_optimize(::AbstractInterpreter) = true may_compress(::AbstractInterpreter) = true may_discard_trees(::AbstractInterpreter) = true -verbose_stmt_info(::AbstractInterpreter) = false """ method_table(interp::AbstractInterpreter) -> MethodTableView