Skip to content

Commit

Permalink
avoid source deserialization by using locally available inferred source
Browse files Browse the repository at this point in the history
Currently the inlining algorithm is allowed to use inferred source of
const-prop'ed call that is always locally available (since const-prop'
result isn't cached globally). For non const-prop'ed and globally cached
calls, however, it undergoes a more expensive process, making a
round-trip through serialized inferred source. We can actually bypass
this expensive deserialization when inferred source for globally-cached
result is available locally, i.e. when it has been inferred in the same
inference shot.

Note that it would be more efficient to propagate `IRCode` object
directly and skip inflation from `CodeInfo` to `IRCode` as experimented
in #47137, but currently the round-trip through `CodeInfo`-representation
is necessary because it often leads to better CFG simplification and
`cfg_simplify!` seems to be still expensive.
  • Loading branch information
aviatesk committed Oct 31, 2023
1 parent 6084a62 commit 35630bd
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 21 deletions.
29 changes: 21 additions & 8 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
splitsigs = switchtupleunion(sig)
for sig_n in splitsigs
result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, si, sv)
(; rt, edge, effects) = result
(; rt, edge, effects, inferred_src) = result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp,
Expand Down Expand Up @@ -90,7 +90,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
this_rt = widenwrappedconditional(this_rt)
else
result = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, si, sv)
(; rt, edge, effects) = result
(; rt, edge, effects, inferred_src) = result
this_conditional = ignorelimited(rt)
this_rt = widenwrappedconditional(rt)
# try constant propagation with argtypes for this match
Expand Down Expand Up @@ -119,6 +119,11 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
const_results = fill!(Vector{Union{Nothing,ConstResult}}(undef, napplicable), nothing)
end
const_results[i] = const_result
elseif inferred_src !== nothing
if const_results === nothing
const_results = fill!(Vector{Union{Nothing,ConstResult}}(undef, napplicable), nothing)
end
const_results[i] = InferredResult(inferred_src)
end
edge === nothing || push!(edges, edge)
end
Expand Down Expand Up @@ -621,7 +626,7 @@ function abstract_call_method(interp::AbstractInterpreter,
sparams = recomputed[2]::SimpleVector
end

(; rt, edge, effects) = typeinf_edge(interp, method, sig, sparams, sv)
(; rt, edge, effects, inferred_src) = typeinf_edge(interp, method, sig, sparams, sv)

if edge === nothing
edgecycle = edgelimited = true
Expand All @@ -645,7 +650,7 @@ function abstract_call_method(interp::AbstractInterpreter,
end
end

return MethodCallResult(rt, edgecycle, edgelimited, edge, effects)
return MethodCallResult(rt, edgecycle, edgelimited, edge, effects, inferred_src)
end

function edge_matches_sv(interp::AbstractInterpreter, frame::AbsIntState,
Expand Down Expand Up @@ -748,12 +753,14 @@ struct MethodCallResult
edgelimited::Bool
edge::Union{Nothing,MethodInstance}
effects::Effects
inferred_src::Union{Nothing,CodeInfo}
function MethodCallResult(@nospecialize(rt),
edgecycle::Bool,
edgelimited::Bool,
edge::Union{Nothing,MethodInstance},
effects::Effects)
return new(rt, edgecycle, edgelimited, edge, effects)
effects::Effects,
inferred_src::Union{Nothing,CodeInfo}=nothing)
return new(rt, edgecycle, edgelimited, edge, effects, inferred_src)
end
end

Expand Down Expand Up @@ -1945,7 +1952,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
tienv = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector
ti = tienv[1]; env = tienv[2]::SimpleVector
result = abstract_call_method(interp, method, ti, env, false, si, sv)
(; rt, edge, effects) = result
(; rt, edge, effects, inferred_src) = result
match = MethodMatch(ti, env, method, argtype <: method.sig)
res = nothing
sig = match.spec_types
Expand All @@ -1968,6 +1975,9 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
(; rt, effects, const_result, edge) = const_call_result
end
end
if const_result === nothing && inferred_src !== nothing
const_result = InferredResult(inferred_src)
end
rt = from_interprocedural!(interp, rt, sv, arginfo, sig)
info = InvokeCallInfo(match, const_result)
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
Expand Down Expand Up @@ -2091,7 +2101,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
closure::PartialOpaque, arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, check::Bool=true)
sig = argtypes_to_type(arginfo.argtypes)
result = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, si, sv)
(; rt, edge, effects) = result
(; rt, edge, effects, inferred_src) = result
tt = closure.typ
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
Expand All @@ -2115,6 +2125,9 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
effects = Effects(effects; nothrow=false)
end
end
if const_result === nothing && inferred_src !== nothing
const_result = InferredResult(inferred_src)
end
rt = from_interprocedural!(interp, rt, sv, arginfo, match.spec_types)
info = OpaqueClosureCallInfo(match, const_result)
edge !== nothing && add_backedge!(sv, edge)
Expand Down
30 changes: 20 additions & 10 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -864,8 +864,9 @@ end

# the general resolver for usual and const-prop'ed calls
function resolve_todo(mi::MethodInstance, result::Union{MethodMatch,InferenceResult},
argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32,
state::InliningState; invokesig::Union{Nothing,Vector{Any}}=nothing)
argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32, state::InliningState;
invokesig::Union{Nothing,Vector{Any}}=nothing,
inferred_result::Union{Nothing,InferredResult}=nothing)
et = InliningEdgeTracker(state, invokesig)

if isa(result, InferenceResult)
Expand Down Expand Up @@ -900,6 +901,9 @@ function resolve_todo(mi::MethodInstance, result::Union{MethodMatch,InferenceRes
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)

add_inlining_backedge!(et, mi)
if src isa String && inferred_result !== nothing
src = inferred_result.inferred_src
end
return InliningTodo(mi, retrieve_ir_for_inlining(mi, src), effects)
end

Expand Down Expand Up @@ -944,7 +948,8 @@ end

function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
@nospecialize(info::CallInfo), flag::UInt32, state::InliningState;
allow_typevars::Bool, invokesig::Union{Nothing,Vector{Any}}=nothing)
allow_typevars::Bool, invokesig::Union{Nothing,Vector{Any}}=nothing,
inferred_result::Union{Nothing,InferredResult}=nothing)
method = match.method
spec_types = match.spec_types

Expand Down Expand Up @@ -973,7 +978,7 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
# Get the specialization for this method signature
# (later we will decide what to do with it)
mi = specialize_method(match)
return resolve_todo(mi, match, argtypes, info, flag, state; invokesig)
return resolve_todo(mi, match, argtypes, info, flag, state; invokesig, inferred_result)
end

function retrieve_ir_for_inlining(mi::MethodInstance, src::String)
Expand Down Expand Up @@ -1197,7 +1202,11 @@ function handle_invoke_call!(todo::Vector{Pair{Int,Any}},
return nothing
end
end
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars=false, invokesig)
if !(result === nothing || isa(result, InferredResult))
# TODO handle SemiConcreteResult
result = nothing
end
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars=false, invokesig, inferred_result=result)
end
handle_single_case!(todo, ir, idx, stmt, item, true)
return nothing
Expand Down Expand Up @@ -1336,8 +1345,8 @@ function handle_any_const_result!(cases::Vector{InliningCase},
if isa(result, ConstPropResult)
return handle_const_prop_result!(cases, result, argtypes, info, flag, state; allow_abstract, allow_typevars)
else
@assert result === nothing
return handle_match!(cases, match, argtypes, info, flag, state; allow_abstract, allow_typevars)
@assert result === nothing || result isa InferredResult
return handle_match!(cases, match, argtypes, info, flag, state; allow_abstract, allow_typevars, inferred_result = result)
end
end

Expand Down Expand Up @@ -1468,14 +1477,14 @@ end
function handle_match!(cases::Vector{InliningCase},
match::MethodMatch, argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32,
state::InliningState;
allow_abstract::Bool, allow_typevars::Bool)
allow_abstract::Bool, allow_typevars::Bool, inferred_result::Union{Nothing,InferredResult})
spec_types = match.spec_types
allow_abstract || isdispatchtuple(spec_types) || return false
# We may see duplicated dispatch signatures here when a signature gets widened
# during abstract interpretation: for the purpose of inlining, we can just skip
# processing this dispatch candidate (unless unmatched type parameters are present)
!allow_typevars && any(case::InliningCase->case.sig === spec_types, cases) && return true
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars)
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars, inferred_result)
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
return true
Expand Down Expand Up @@ -1580,7 +1589,8 @@ function handle_opaque_closure_call!(todo::Vector{Pair{Int,Any}},
if isa(result, SemiConcreteResult)
item = semiconcrete_result_item(result, info, flag, state)
else
item = analyze_method!(info.match, sig.argtypes, info, flag, state; allow_typevars=false)
@assert result === nothing || result isa InferredResult
item = analyze_method!(info.match, sig.argtypes, info, flag, state; allow_typevars=false, inferred_result=result)
end
end
handle_single_case!(todo, ir, idx, stmt, item)
Expand Down
4 changes: 4 additions & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ struct SemiConcreteResult <: ConstResult
effects::Effects
end

struct InferredResult <: ConstResult
inferred_src::CodeInfo
end

"""
info::ConstCallInfo <: CallInfo
Expand Down
9 changes: 6 additions & 3 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -798,10 +798,12 @@ struct EdgeCallResult
rt #::Type
edge::Union{Nothing,MethodInstance}
effects::Effects
inferred_src::Union{Nothing,CodeInfo}
function EdgeCallResult(@nospecialize(rt),
edge::Union{Nothing,MethodInstance},
effects::Effects)
return new(rt, edge, effects)
effects::Effects,
inferred_src::Union{Nothing,CodeInfo} = nothing)
return new(rt, edge, effects, inferred_src)
end
end

Expand Down Expand Up @@ -855,7 +857,8 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
isinferred = is_inferred(frame)
edge = isinferred ? mi : nothing
effects = isinferred ? frame.ipo_effects : adjust_effects(Effects(), method) # effects are adjusted already within `finish` for ipo_effects
return EdgeCallResult(frame.bestguess, edge, effects)
inferred_src = isinferred && is_inlineable(frame.src) ? frame.src : nothing
return EdgeCallResult(frame.bestguess, edge, effects, inferred_src)
elseif frame === true
# unresolvable cycle
return EdgeCallResult(Any, nothing, Effects())
Expand Down

0 comments on commit 35630bd

Please sign in to comment.