Skip to content

Commit

Permalink
inference: compute edges more precisely in post-inference
Browse files Browse the repository at this point in the history
Start computing edges from stmt_info later (after CodeInstance is able
to have been allocated for recursion) instead of immediately.
  • Loading branch information
vtjnash committed Jul 16, 2024
1 parent 813b5fc commit bb2a1d1
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 9 deletions.
6 changes: 3 additions & 3 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
if matches === nothing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
push!(infos, MethodMatchInfo(matches))
push!(infos, MethodMatchInfo(matches, sig_n, mt))
for m in matches
push!(applicable, m)
push!(applicable_argtypes, arg_n)
Expand Down Expand Up @@ -323,7 +323,7 @@ function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(a
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
end
info = MethodMatchInfo(matches)
info = MethodMatchInfo(matches, atype, mt)
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
return MethodMatches(
matches.matches, info, matches.valid_worlds, mt, fullmatch)
Expand Down Expand Up @@ -2053,7 +2053,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
end
end
rt = from_interprocedural!(interp, rt, sv, arginfo, sig)
info = InvokeCallInfo(match, const_result)
info = InvokeCallInfo(match, const_result, lookupsig)
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
return CallMeta(rt, Any, effects, info)
end
Expand Down
11 changes: 7 additions & 4 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ struct NoCallInfo <: CallInfo end
"""
info::MethodMatchInfo <: CallInfo
Captures the result of a `:jl_matching_methods` lookup for the given call (`info.results`).
This info may then be used by the optimizer to inline the matches, without having
to re-consult the method table. This info is illegal on any statement that is
not a call to a generic function.
Captures the essential arguments and result of a `:jl_matching_methods` lookup
for the given call (`info.results`). This info may then be used by the
optimizer, without having to re-consult the method table.
This info is illegal on any statement that is not a call to a generic function.
"""
struct MethodMatchInfo <: CallInfo
results::MethodLookupResult
atype # ::Type
mt::MethodTable
end
nsplit_impl(info::MethodMatchInfo) = 1
getsplit_impl(info::MethodMatchInfo, idx::Int) = (@assert idx == 1; info.results)
Expand Down Expand Up @@ -165,6 +167,7 @@ Optionally keeps `info.result::InferenceResult` that keeps constant information.
struct InvokeCallInfo <: CallInfo
match::MethodMatch
result::Union{Nothing,ConstResult}
atype # ::Type
end

"""
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2905,7 +2905,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
else
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, sv, #=max_methods=#-1)
end
info = verbose_stmt_info(interp) ? MethodResultPure(ReturnTypeCallInfo(call.info)) : MethodResultPure()
info = MethodResultPure(ReturnTypeCallInfo(call.info))
rt = widenslotwrapper(call.rt)
if isa(rt, Const)
# output was computed to be constant
Expand Down
92 changes: 92 additions & 0 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,12 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
append!(s_edges, edges)
empty!(edges)
end
#s_edges_new = compute_edges(me)
#println(Any[z isa MethodTable ? z.name : z for z in s_edges_new])
#if length(s_edges) != length(s_edges_new) || !all(i -> isassigned(s_edges, i) ? isassigned(s_edges_new, i) && s_edges_new[i] === s_edges[i] : !isassigned(s_edges_new, i), length(s_edges))
# println(sizehint!(s_edges, length(s_edges)))
# println(sizehint!(s_edges_new, length(s_edges_new)))
#end
if me.src.edges !== nothing && me.src.edges !== Core.svec()
append!(s_edges, me.src.edges::Vector)
end
Expand Down Expand Up @@ -644,6 +650,92 @@ function store_backedges(caller::MethodInstance, edges::Vector{Any})
return nothing
end

add_edges!(edges::Vector{Any}, info::MethodResultPure) = add_edges!(edges, info.info)
add_edges!(edges::Vector{Any}, info::ConstCallInfo) = add_edges!(edges, info.call)
add_edges!(edges::Vector{Any}, info::OpaqueClosureCreateInfo) = nothing # TODO(jwn)
add_edges!(edges::Vector{Any}, info::ReturnTypeCallInfo) = add_edges!(edges, info.info)
function add_edges!(edges::Vector{Any}, info::ApplyCallInfo)
add_edges!(edges, info.call)
for arg in info.arginfo
arg === nothing && continue
for edge in arg.each
add_edges!(edges, edge.info)
end
end
end
add_edges!(edges::Vector{Any}, info::ModifyOpInfo) = add_edges!(edges, info.call)
add_edges!(edges::Vector{Any}, info::UnionSplitInfo) = for split in info.matches; add_edges!(edges, split); end
add_edges!(edges::Vector{Any}, info::UnionSplitApplyCallInfo) = for split in info.infos; add_edges!(edges, split); end
add_edges!(edges::Vector{Any}, info::FinalizerInfo) = nothing
add_edges!(edges::Vector{Any}, info::NoCallInfo) = nothing
function add_edges!(edges::Vector{Any}, info::MethodMatchInfo)
matches = info.results.matches
#if length(matches) == 1 && !info.results.ambig && (matches[end]::Core.MethodMatch).fully_covers
# push!(edges, specialize_method(matches[1]))
#elseif isempty(matches) || info.results.ambig || !(matches[end]::Core.MethodMatch).fully_covers
#else
# push!(edges, length(matches))
# for m in matches
# push!(edges, specialize_method(m))
# end
#end
if isempty(matches) || !(matches[end]::Core.MethodMatch).fully_covers
exists = false
for i in 1:length(edges)
if edges[i] === info.mt && edges[i + 1] == info.atype
exists = true
break
end
end
if !exists
push!(edges, info.mt)
push!(edges, info.atype)
end
end
for m in matches
mi = specialize_method(m)
exists = false
for i in 1:length(edges)
if edges[i] === mi && !(i > 1 && edges[i - 1] isa Type)
exists = true
break
end
end
exists || push!(edges, mi)
end
end
function add_edges!(edges::Vector{Any}, info::InvokeCallInfo)
#push!(edges, 1)
mi = specialize_method(info.match)
exists = false
for i in 2:length(edges)
if edges[i] === mi && edges[i - 1] isa Type && edges[i - 1] == info.atype
exists = true
break
end
end
if !exists
push!(edges, info.atype)
push!(edges, mi)
end
nothing
end

function compute_edges(sv::InferenceState)
edges = []
for i in 1:length(sv.stmt_info)
info = sv.stmt_info[i]
#rt = sv.ssavaluetypes[i]
#effects = EFFECTS_TOTAL # sv.stmt_effects[i]
#if rt === Any && effects === Effects()
# continue
#end
add_edges!(edges, info)
end
return edges
end


function record_slot_assign!(sv::InferenceState)
# look at all assignments to slots
# and union the set of types stored there
Expand Down
1 change: 0 additions & 1 deletion base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ function add_remark! end
may_optimize(::AbstractInterpreter) = true
may_compress(::AbstractInterpreter) = true
may_discard_trees(::AbstractInterpreter) = true
verbose_stmt_info(::AbstractInterpreter) = false

"""
method_table(interp::AbstractInterpreter) -> MethodTableView
Expand Down

0 comments on commit bb2a1d1

Please sign in to comment.