Skip to content

Commit

Permalink
compute edges post-inference, from info available there
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash committed Jul 16, 2024
1 parent bb2a1d1 commit d2c2e51
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 234 deletions.
56 changes: 5 additions & 51 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
rettype = excttype = Any
all_effects = Effects()
elseif isa(matches, MethodMatches) ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches))
(!matches.fullmatch || any_ambig(matches))
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
all_effects = Effects(all_effects; nothrow=false)
excttype = tmerge(𝕃ₚ, excttype, MethodError)
Expand Down Expand Up @@ -213,7 +213,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# and avoid keeping track of a more complex result type.
rettype = Any
end
add_call_backedges!(interp, rettype, all_effects, edges, matches, atype, sv)
if isa(sv, InferenceState)
# TODO (#48913) implement a proper recursion handling for irinterp:
# This works just because currently the `:terminate` condition guarantees that
Expand Down Expand Up @@ -249,8 +248,7 @@ struct UnionSplitMethodMatches
applicable_argtypes::Vector{Vector{Any}}
info::UnionSplitInfo
valid_worlds::WorldRange
mts::Vector{MethodTable}
fullmatches::Vector{Bool}
fullmatch::Bool
end
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)

Expand All @@ -274,8 +272,7 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
valid_worlds = WorldRange()
mts = MethodTable[]
fullmatches = Bool[]
fullmatch = true
for i in 1:length(split_argtypes)
arg_n = split_argtypes[i]::Vector{Any}
sig_n = argtypes_to_type(arg_n)
Expand All @@ -292,23 +289,11 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
push!(applicable_argtypes, arg_n)
end
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
thisfullmatch = any(match::MethodMatch->match.fully_covers, matches)
found = false
for (i, mt′) in enumerate(mts)
if mt′ === mt
fullmatches[i] &= thisfullmatch
found = true
break
end
end
if !found
push!(mts, mt)
push!(fullmatches, thisfullmatch)
end
fullmatch = fullmatch && any(match::MethodMatch->match.fully_covers, matches)
end
info = UnionSplitInfo(infos)
return UnionSplitMethodMatches(
applicable, applicable_argtypes, info, valid_worlds, mts, fullmatches)
applicable, applicable_argtypes, info, valid_worlds, fullmatch)
end

function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(atype), max_methods::Int)
Expand Down Expand Up @@ -492,34 +477,6 @@ function conditional_argtype(𝕃ᵢ::AbstractLattice, @nospecialize(rt), @nospe
end
end

function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype), all_effects::Effects,
edges::Vector{MethodInstance}, matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
sv::AbsIntState)
# don't bother to add backedges when both type and effects information are already
# maximized to the top since a new method couldn't refine or widen them anyway
if rettype === Any
# ignore the `:nonoverlayed` property if `interp` doesn't use overlayed method table
# since it will never be tainted anyway
if !isoverlayed(method_table(interp))
all_effects = Effects(all_effects; nonoverlayed=ALWAYS_FALSE)
end
all_effects === Effects() && return nothing
end
for edge in edges
add_backedge!(sv, edge)
end
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
return nothing
end

const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result."
const RECURSION_MSG = "Bounded recursion detected. Call was widened to force convergence."
const RECURSION_MSG_HARDLIMIT = "Bounded recursion detected under hardlimit. Call was widened to force convergence."
Expand Down Expand Up @@ -2054,7 +2011,6 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
end
rt = from_interprocedural!(interp, rt, sv, arginfo, sig)
info = InvokeCallInfo(match, const_result, lookupsig)
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
return CallMeta(rt, Any, effects, info)
end

Expand Down Expand Up @@ -2232,7 +2188,6 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
end
rt = from_interprocedural!(interp, rt, sv, arginfo, match.spec_types)
info = OpaqueClosureCallInfo(match, const_result)
edge !== nothing && add_backedge!(sv, edge)
return CallMeta(rt, Any, effects, info)
end

Expand Down Expand Up @@ -3225,7 +3180,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)

for currpc in bbstart:bbend
frame.currpc = currpc
empty_backedges!(frame, currpc)
stmt = frame.src.code[currpc]
# If we're at the end of the basic block ...
if currpc == bbend
Expand Down
49 changes: 3 additions & 46 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ mutable struct InferenceState
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
ssavaluetypes::Vector{Any}
stmt_edges::Vector{Vector{Any}}
edges::Vector{Any}
stmt_info::Vector{CallInfo}

#= intermediate states for interprocedural abstract interpretation =#
Expand Down Expand Up @@ -298,7 +298,7 @@ mutable struct InferenceState
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
nstmts = length(code)
stmt_edges = Vector{Vector{Any}}(undef, nstmts)
edges = []
stmt_info = CallInfo[ NoCallInfo() for i = 1:nstmts ]

nslots = length(src.slotflags)
Expand Down Expand Up @@ -350,7 +350,7 @@ mutable struct InferenceState

this = new(
mi, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
Expand Down Expand Up @@ -813,26 +813,9 @@ function add_cycle_backedge!(caller::InferenceState, frame::InferenceState)
update_valid_age!(caller, frame.valid_worlds)
backedge = (caller, caller.currpc)
contains_is(frame.cycle_backedges, backedge) || push!(frame.cycle_backedges, backedge)
add_backedge!(caller, frame.linfo)
return frame
end

function get_stmt_edges!(caller::InferenceState, currpc::Int=caller.currpc)
stmt_edges = caller.stmt_edges
if !isassigned(stmt_edges, currpc)
return stmt_edges[currpc] = Any[]
else
return stmt_edges[currpc]
end
end

function empty_backedges!(frame::InferenceState, currpc::Int=frame.currpc)
if isassigned(frame.stmt_edges, currpc)
empty!(frame.stmt_edges[currpc])
end
return nothing
end

function print_callstack(sv::InferenceState)
print("=================== Callstack: ==================\n")
idx = 0
Expand Down Expand Up @@ -1017,32 +1000,6 @@ function iterate(unw::AbsIntStackUnwind, (sv, cyclei)::Tuple{AbsIntState, Int})
return (parent, (parent, cyclei))
end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(caller::InferenceState, mi::MethodInstance)
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), mi)
end
function add_backedge!(irsv::IRInterpretationState, mi::MethodInstance)
return push!(irsv.edges, mi)
end

function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), mi::MethodInstance)
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), invokesig, mi)
end
function add_invoke_backedge!(irsv::IRInterpretationState, @nospecialize(invokesig::Type), mi::MethodInstance)
return push!(irsv.edges, invokesig, mi)
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(caller::InferenceState, mt::MethodTable, @nospecialize(typ))
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), mt, typ)
end
function add_mt_backedge!(irsv::IRInterpretationState, mt::MethodTable, @nospecialize(typ))
return push!(irsv.edges, mt, typ)
end

get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc]
get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx][:flag]

Expand Down
4 changes: 2 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ struct InliningState{Interp<:AbstractInterpreter}
interp::Interp
end
function InliningState(sv::InferenceState, interp::AbstractInterpreter)
edges = sv.stmt_edges[1]
return InliningState(edges, sv.world, interp)
return InliningState(sv.edges, sv.world, interp)
end
function InliningState(interp::AbstractInterpreter)
return InliningState(Any[], get_inference_world(interp), interp)
Expand Down Expand Up @@ -215,6 +214,7 @@ include("compiler/ssair/irinterp.jl")
function ir_to_codeinf!(opt::OptimizationState)
(; linfo, src) = opt
src = ir_to_codeinf!(src, opt.ir::IRCode)
src.edges = opt.inlining.edges
opt.ir = nothing
maybe_validate_code(linfo, src, "optimized")
return src
Expand Down
20 changes: 10 additions & 10 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ struct InliningEdgeTracker
new(state.edges, invokesig)
end

function add_inlining_backedge!((; edges, invokesig)::InliningEdgeTracker, mi::MethodInstance)
function add_inlining_edge!((; edges, invokesig)::InliningEdgeTracker, mi::MethodInstance)
if invokesig === nothing
push!(edges, mi)
add_one_edge!(edges, mi)
else # invoke backedge
push!(edges, invoke_signature(invokesig), mi)
add_invoke_edge!(edges, invoke_signature(invokesig), mi)
end
return nothing
end
Expand Down Expand Up @@ -794,8 +794,8 @@ function compileable_specialization(mi::MethodInstance, effects::Effects,
return nothing
end
end
add_inlining_backedge!(et, mi) # to the dispatch lookup
mi_invoke !== mi && push!(et.edges, method.sig, mi_invoke) # add_inlining_backedge to the invoke call, if that is different
add_inlining_edge!(et, mi) # to the dispatch lookup
mi_invoke !== mi && add_invoke_edge!(et.edges, method.sig, mi_invoke) # add_inlining_edge to the invoke call, if that is different
return InvokeCase(mi_invoke, effects, info)
end

Expand Down Expand Up @@ -850,7 +850,7 @@ function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,
inferred_result = get_cached_result(state, mi)
end
if inferred_result isa ConstantCase
add_inlining_backedge!(et, mi)
add_inlining_edge!(et, mi)
return inferred_result
end
if inferred_result isa InferredResult
Expand All @@ -874,7 +874,7 @@ function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,
return compileable_specialization(mi, effects, et, info;
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)

add_inlining_backedge!(et, mi)
add_inlining_edge!(et, mi)
ir = inferred_result isa CodeInstance ? retrieve_ir_for_inlining(inferred_result, src) :
retrieve_ir_for_inlining(mi, src, preserve_local_sources)
return InliningTodo(mi, ir, effects)
Expand All @@ -891,7 +891,7 @@ function resolve_todo(mi::MethodInstance, @nospecialize(info::CallInfo), flag::U

cached_result = get_cached_result(state, mi)
if cached_result isa ConstantCase
add_inlining_backedge!(et, mi)
add_inlining_edge!(et, mi)
return cached_result
end
if cached_result isa InferredResult
Expand All @@ -908,7 +908,7 @@ function resolve_todo(mi::MethodInstance, @nospecialize(info::CallInfo), flag::U
src_inlining_policy(state.interp, src, info, flag) || return nothing
ir = cached_result isa CodeInstance ? retrieve_ir_for_inlining(cached_result, src) :
retrieve_ir_for_inlining(mi, src, preserve_local_sources)
add_inlining_backedge!(et, mi)
add_inlining_edge!(et, mi)
return InliningTodo(mi, ir, effects)
end

Expand Down Expand Up @@ -1456,7 +1456,7 @@ function semiconcrete_result_item(result::SemiConcreteResult,
return compileable_specialization(mi, result.effects, et, info;
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)

add_inlining_backedge!(et, mi)
add_inlining_edge!(et, mi)
preserve_local_sources = OptimizationParams(state.interp).preserve_local_sources
ir = retrieve_ir_for_inlining(mi, result.ir, preserve_local_sources)
return InliningTodo(mi, ir, result.effects)
Expand Down
6 changes: 0 additions & 6 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,6 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
(nothrow | noub) || break
end

if last(irsv.valid_worlds) >= get_world_counter()
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
store_backedges(frame_instance(irsv), irsv.edges)
end

return Pair{Any,Tuple{Bool,Bool}}(maybe_singleton_const(ultimate_rt), (nothrow, noub))
end

Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,7 @@ function try_inline_finalizer!(ir::IRCode, argexprs::Vector{Any}, idx::Int,
if code isa CodeInstance
if use_const_api(code)
# No code in the function - Nothing to do
add_inlining_backedge!(et, mi)
add_inlining_edge!(et, mi)
return true
end
src = @atomic :monotonic code.inferred
Expand All @@ -1513,7 +1513,7 @@ function try_inline_finalizer!(ir::IRCode, argexprs::Vector{Any}, idx::Int,
length(src.cfg.blocks) == 1 || return false

# Ok, we're committed to inlining the finalizer
add_inlining_backedge!(et, mi)
add_inlining_edge!(et, mi)

# TODO: Should there be a special line number node for inlined finalizers?
inline_at = ir[SSAValue(idx)][:line]
Expand Down
25 changes: 4 additions & 21 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2944,30 +2944,14 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
else
(; valid_worlds, applicable) = matches
update_valid_age!(sv, valid_worlds)

# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
add_edges!(sv.edges, matches.info)

napplicable = length(applicable)
if napplicable == 0
rt = Const(false) # never any matches
else
rt = Const(true) # has applicable matches
for i in 1:napplicable
match = applicable[i]::MethodMatch
edge = specialize_method(match)::MethodInstance
add_backedge!(sv, edge)
end

if isa(matches, MethodMatches) ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches))
if !matches.fullmatch || any_ambig(matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
rt = Bool
end
Expand Down Expand Up @@ -3007,11 +2991,10 @@ function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv
update_valid_age!(sv, valid_worlds)
if match === nothing
rt = Const(false)
add_mt_backedge!(sv, mt, types) # this should actually be an invoke-type backedge
add_edges!(sv.edges, MethodMatchInfo(MethodLookupResult(Any[], valid_worlds, true), types, mt)) # XXX: this should actually be an invoke-type backedge
else
rt = Const(true)
edge = specialize_method(match)::MethodInstance
add_invoke_backedge!(sv, types, edge)
add_edges!(sv.edges, InvokeCallInfo(match, nothing, types))
end
return CallMeta(rt, Any, EFFECTS_TOTAL, NoCallInfo())
end
Expand Down
Loading

0 comments on commit d2c2e51

Please sign in to comment.