diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 82aca46f3e7edd..caf86db2243e40 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -59,7 +59,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), splitsigs = switchtupleunion(sig) for sig_n in splitsigs result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, si, sv) - (; rt, edge, effects) = result + (; rt, edge, effects, inferred_src) = result this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i] this_arginfo = ArgInfo(fargs, this_argtypes) const_call_result = abstract_call_method_with_const_args(interp, @@ -90,7 +90,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), this_rt = widenwrappedconditional(this_rt) else result = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, si, sv) - (; rt, edge, effects) = result + (; rt, edge, effects, inferred_src) = result this_conditional = ignorelimited(rt) this_rt = widenwrappedconditional(rt) # try constant propagation with argtypes for this match @@ -119,6 +119,11 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), const_results = fill!(Vector{Union{Nothing,ConstResult}}(undef, napplicable), nothing) end const_results[i] = const_result + elseif inferred_src !== nothing + if const_results === nothing + const_results = fill!(Vector{Union{Nothing,ConstResult}}(undef, napplicable), nothing) + end + const_results[i] = InferredResult(inferred_src) end edge === nothing || push!(edges, edge) end @@ -621,7 +626,7 @@ function abstract_call_method(interp::AbstractInterpreter, sparams = recomputed[2]::SimpleVector end - (; rt, edge, effects) = typeinf_edge(interp, method, sig, sparams, sv) + (; rt, edge, effects, inferred_src) = typeinf_edge(interp, method, sig, sparams, sv) if edge === nothing edgecycle = edgelimited = true @@ -645,7 +650,7 @@ function abstract_call_method(interp::AbstractInterpreter, end end - return MethodCallResult(rt, edgecycle, edgelimited, edge, effects) + return MethodCallResult(rt, edgecycle, edgelimited, edge, effects, inferred_src) end function edge_matches_sv(interp::AbstractInterpreter, frame::AbsIntState, @@ -748,12 +753,14 @@ struct MethodCallResult edgelimited::Bool edge::Union{Nothing,MethodInstance} effects::Effects + inferred_src::Union{Nothing,CodeInfo} function MethodCallResult(@nospecialize(rt), edgecycle::Bool, edgelimited::Bool, edge::Union{Nothing,MethodInstance}, - effects::Effects) - return new(rt, edgecycle, edgelimited, edge, effects) + effects::Effects, + inferred_src::Union{Nothing,CodeInfo}=nothing) + return new(rt, edgecycle, edgelimited, edge, effects, inferred_src) end end @@ -1945,7 +1952,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn tienv = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector ti = tienv[1]; env = tienv[2]::SimpleVector result = abstract_call_method(interp, method, ti, env, false, si, sv) - (; rt, edge, effects) = result + (; rt, edge, effects, inferred_src) = result match = MethodMatch(ti, env, method, argtype <: method.sig) res = nothing sig = match.spec_types @@ -1968,6 +1975,9 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn (; rt, effects, const_result, edge) = const_call_result end end + if const_result === nothing && inferred_src !== nothing + const_result = InferredResult(inferred_src) + end rt = from_interprocedural!(interp, rt, sv, arginfo, sig) info = InvokeCallInfo(match, const_result) edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge) @@ -2091,7 +2101,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, check::Bool=true) sig = argtypes_to_type(arginfo.argtypes) result = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, si, sv) - (; rt, edge, effects) = result + (; rt, edge, effects, inferred_src) = result tt = closure.typ sigT = (unwrap_unionall(tt)::DataType).parameters[1] match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt)) @@ -2115,6 +2125,9 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, effects = Effects(effects; nothrow=false) end end + if const_result === nothing && inferred_src !== nothing + const_result = InferredResult(inferred_src) + end rt = from_interprocedural!(interp, rt, sv, arginfo, match.spec_types) info = OpaqueClosureCallInfo(match, const_result) edge !== nothing && add_backedge!(sv, edge) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 75eb4677762362..45333452aab568 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -864,8 +864,9 @@ end # the general resolver for usual and const-prop'ed calls function resolve_todo(mi::MethodInstance, result::Union{MethodMatch,InferenceResult}, - argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32, - state::InliningState; invokesig::Union{Nothing,Vector{Any}}=nothing) + argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32, state::InliningState; + invokesig::Union{Nothing,Vector{Any}}=nothing, + inferred_result::Union{Nothing,InferredResult}=nothing) et = InliningEdgeTracker(state, invokesig) if isa(result, InferenceResult) @@ -900,6 +901,9 @@ function resolve_todo(mi::MethodInstance, result::Union{MethodMatch,InferenceRes compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes) add_inlining_backedge!(et, mi) + if src isa String && inferred_result !== nothing + src = inferred_result.inferred_src + end return InliningTodo(mi, retrieve_ir_for_inlining(mi, src), effects) end @@ -944,7 +948,8 @@ end function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32, state::InliningState; - allow_typevars::Bool, invokesig::Union{Nothing,Vector{Any}}=nothing) + allow_typevars::Bool, invokesig::Union{Nothing,Vector{Any}}=nothing, + inferred_result::Union{Nothing,InferredResult}=nothing) method = match.method spec_types = match.spec_types @@ -973,7 +978,7 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, # Get the specialization for this method signature # (later we will decide what to do with it) mi = specialize_method(match) - return resolve_todo(mi, match, argtypes, info, flag, state; invokesig) + return resolve_todo(mi, match, argtypes, info, flag, state; invokesig, inferred_result) end function retrieve_ir_for_inlining(mi::MethodInstance, src::String) @@ -1197,7 +1202,11 @@ function handle_invoke_call!(todo::Vector{Pair{Int,Any}}, return nothing end end - item = analyze_method!(match, argtypes, info, flag, state; allow_typevars=false, invokesig) + if !(result === nothing || isa(result, InferredResult)) + # TODO handle SemiConcreteResult + result = nothing + end + item = analyze_method!(match, argtypes, info, flag, state; allow_typevars=false, invokesig, inferred_result=result) end handle_single_case!(todo, ir, idx, stmt, item, true) return nothing @@ -1336,8 +1345,8 @@ function handle_any_const_result!(cases::Vector{InliningCase}, if isa(result, ConstPropResult) return handle_const_prop_result!(cases, result, argtypes, info, flag, state; allow_abstract, allow_typevars) else - @assert result === nothing - return handle_match!(cases, match, argtypes, info, flag, state; allow_abstract, allow_typevars) + @assert result === nothing || result isa InferredResult + return handle_match!(cases, match, argtypes, info, flag, state; allow_abstract, allow_typevars, inferred_result = result) end end @@ -1468,14 +1477,14 @@ end function handle_match!(cases::Vector{InliningCase}, match::MethodMatch, argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32, state::InliningState; - allow_abstract::Bool, allow_typevars::Bool) + allow_abstract::Bool, allow_typevars::Bool, inferred_result::Union{Nothing,InferredResult}) spec_types = match.spec_types allow_abstract || isdispatchtuple(spec_types) || return false # We may see duplicated dispatch signatures here when a signature gets widened # during abstract interpretation: for the purpose of inlining, we can just skip # processing this dispatch candidate (unless unmatched type parameters are present) !allow_typevars && any(case::InliningCase->case.sig === spec_types, cases) && return true - item = analyze_method!(match, argtypes, info, flag, state; allow_typevars) + item = analyze_method!(match, argtypes, info, flag, state; allow_typevars, inferred_result) item === nothing && return false push!(cases, InliningCase(spec_types, item)) return true @@ -1580,7 +1589,8 @@ function handle_opaque_closure_call!(todo::Vector{Pair{Int,Any}}, if isa(result, SemiConcreteResult) item = semiconcrete_result_item(result, info, flag, state) else - item = analyze_method!(info.match, sig.argtypes, info, flag, state; allow_typevars=false) + @assert result === nothing || result isa InferredResult + item = analyze_method!(info.match, sig.argtypes, info, flag, state; allow_typevars=false, inferred_result=result) end end handle_single_case!(todo, ir, idx, stmt, item) diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index 9f55d56181838e..c36f8b05836cfb 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -76,6 +76,10 @@ struct SemiConcreteResult <: ConstResult effects::Effects end +struct InferredResult <: ConstResult + inferred_src::CodeInfo +end + """ info::ConstCallInfo <: CallInfo diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index d1f635d3be704d..4543786bffe1ba 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -798,10 +798,12 @@ struct EdgeCallResult rt #::Type edge::Union{Nothing,MethodInstance} effects::Effects + inferred_src::Union{Nothing,CodeInfo} function EdgeCallResult(@nospecialize(rt), edge::Union{Nothing,MethodInstance}, - effects::Effects) - return new(rt, edge, effects) + effects::Effects, + inferred_src::Union{Nothing,CodeInfo} = nothing) + return new(rt, edge, effects, inferred_src) end end @@ -855,7 +857,8 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize isinferred = is_inferred(frame) edge = isinferred ? mi : nothing effects = isinferred ? frame.ipo_effects : adjust_effects(Effects(), method) # effects are adjusted already within `finish` for ipo_effects - return EdgeCallResult(frame.bestguess, edge, effects) + inferred_src = isinferred && is_inlineable(frame.src) ? frame.src : nothing + return EdgeCallResult(frame.bestguess, edge, effects, inferred_src) elseif frame === true # unresolvable cycle return EdgeCallResult(Any, nothing, Effects())