From 5e1b19b468f45e202a4f27831cc69b65d456febf Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Wed, 1 Nov 2023 01:59:36 +0900 Subject: [PATCH] revisit #47137: avoid round-trip of locally-cached inferred source Built on top of #51958, with the improved performance of `cfg_simplify!`, let's give another try on JuliaLang/julia#47137. Tha aim is to retain locally cached inferred source as `IRCode`, eliminating the need for the inlining algorithm to round-trip it through `CodeInfo` representation. --- base/compiler/abstractinterpretation.jl | 62 ++++++++++++++----------- base/compiler/optimize.jl | 6 +-- base/compiler/typeinfer.jl | 16 +++++-- base/compiler/types.jl | 3 +- 4 files changed, 52 insertions(+), 35 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 21ff1c951bf85..5adee06c3f008 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1176,44 +1176,54 @@ function semi_concrete_eval_call(interp::AbstractInterpreter, return nothing end +const_prop_result(inf_result::InferenceResult) = + ConstCallResults(inf_result.result, ConstPropResult(inf_result), inf_result.ipo_effects, inf_result.linfo) + function const_prop_call(interp::AbstractInterpreter, mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo, sv::AbsIntState, concrete_eval_result::Union{Nothing,ConstCallResults}=nothing) inf_cache = get_inference_cache(interp) ๐•ƒแตข = typeinf_lattice(interp) inf_result = cache_lookup(๐•ƒแตข, mi, arginfo.argtypes, inf_cache) - if inf_result === nothing - # fresh constant prop' - argtypes = has_conditional(๐•ƒแตข, sv) ? ConditionalArgtypes(arginfo, sv) : SimpleArgtypes(arginfo.argtypes) - inf_result = InferenceResult(mi, argtypes, typeinf_lattice(interp)) - if !any(inf_result.overridden_by_const) - add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes") - return nothing - end - frame = InferenceState(inf_result, #=cache_mode=#:local, interp) - if frame === nothing - add_remark!(interp, sv, "[constprop] Could not retrieve the source") - return nothing # this is probably a bad generated function (unsound), but just ignore it - end - frame.parent = sv - if !typeinf(interp, frame) - add_remark!(interp, sv, "[constprop] Fresh constant inference hit a cycle") - return nothing - end - @assert inf_result.result !== nothing - if concrete_eval_result !== nothing - # override return type and effects with concrete evaluation result if available - inf_result.result = concrete_eval_result.rt - inf_result.ipo_effects = concrete_eval_result.effects - end - else + cache_mode = CACHE_MODE_LOCAL + if inf_result isa InferenceResult # found the cache for this constant prop' if inf_result.result === nothing add_remark!(interp, sv, "[constprop] Found cached constant inference in a cycle") return nothing end + if inf_result.src === nothing && is_stmt_inline(get_curr_ssaflag(sv)) + cache_mode = CACHE_MODE_VOLATILE + else + return const_prop_result(inf_result) + end + elseif is_stmt_inline(get_curr_ssaflag(sv)) + cache_mode = CACHE_MODE_VOLATILE + end + # fresh constant prop' + argtypes = has_conditional(๐•ƒแตข, sv) ? ConditionalArgtypes(arginfo, sv) : SimpleArgtypes(arginfo.argtypes) + inf_result = InferenceResult(mi, argtypes, typeinf_lattice(interp)) + if !any(inf_result.overridden_by_const) + add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes") + return nothing + end + frame = InferenceState(inf_result, cache_mode, interp) + if frame === nothing + add_remark!(interp, sv, "[constprop] Could not retrieve the source") + return nothing # this is probably a bad generated function (unsound), but just ignore it + end + frame.parent = sv + if !typeinf(interp, frame) + add_remark!(interp, sv, "[constprop] Fresh constant inference hit a cycle") + return nothing + end + @assert inf_result.result !== nothing + if concrete_eval_result !== nothing + # override return type and effects with concrete evaluation result if available + inf_result.result = concrete_eval_result.rt + inf_result.ipo_effects = concrete_eval_result.effects end - return ConstCallResults(inf_result.result, ConstPropResult(inf_result), inf_result.ipo_effects, mi) + return const_prop_result(inf_result) end # TODO implement MustAlias forwarding diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index cbeb447e0c9aa..fa6a66ff1967f 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -81,6 +81,7 @@ function inlining_policy(interp::AbstractInterpreter, src_inlineable = is_stmt_inline(stmt_flag) || is_inlineable(src) return src_inlineable ? src : nothing elseif isa(src, IRCode) + # n.b. the inlineability was computed within `finish!` return src elseif isa(src, SemiConcreteResult) if is_declared_noinline(mi.def::Method) @@ -182,10 +183,9 @@ include("compiler/ssair/passes.jl") include("compiler/ssair/irinterp.jl") function ir_to_codeinf!(opt::OptimizationState) - (; linfo, src) = opt - src = ir_to_codeinf!(src, opt.ir::IRCode) + src = ir_to_codeinf!(opt.src, opt.ir::IRCode) opt.ir = nothing - validate_code_in_debug_mode(linfo, src, "optimized") + validate_code_in_debug_mode(opt.linfo, src, "optimized") return src end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 0c678651693b9..35d3a827463bf 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -227,8 +227,16 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState) store_backedges(result, caller.stmt_edges[1]) end opt = result.src - if opt isa OptimizationState && result.must_be_codeinf - result.src = opt = ir_to_codeinf!(opt) + if opt isa OptimizationState + if !iszero(caller.cache_mode & CACHE_MODE_GLOBAL) + result.src = opt = ir_to_codeinf!(opt) + elseif !iszero(caller.cache_mode & CACHE_MODE_VOLATILE) + result.src = opt = cfg_simplify!(opt.ir::IRCode) + elseif !iszero(caller.cache_mode & CACHE_MODE_LOCAL) && is_inlineable(opt.src) + result.src = opt = cfg_simplify!(opt.ir::IRCode) + else + result.src = opt = nothing + end end if opt isa CodeInfo opt.min_world = first(valid_worlds) @@ -236,8 +244,8 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState) caller.src = opt else # In this case caller.src is invalid for clients (such as typeinf_ext) to use - # but that is what !must_be_codeinf permits - # This is hopefully unreachable when must_be_codeinf is true + # but that is what cache_mode != :global permits + # This is hopefully unreachable when cache_mode != :global end return nothing end diff --git a/base/compiler/types.jl b/base/compiler/types.jl index ba05131d4e3a3..5c7445aec0386 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -76,13 +76,12 @@ mutable struct InferenceResult ipo_effects::Effects # if inference is finished effects::Effects # if optimization is finished argescapes # ::ArgEscapeCache if optimized, nothing otherwise - must_be_codeinf::Bool # if this must come out as CodeInfo or leaving it as IRCode is ok function InferenceResult(linfo::MethodInstance, cache_argtypes::Vector{Any}, overridden_by_const::BitVector) # def = linfo.def # nargs = def isa Method ? Int(def.nargs) : 0 # @assert length(cache_argtypes) == nargs return new(linfo, cache_argtypes, overridden_by_const, nothing, nothing, - WorldRange(), Effects(), Effects(), nothing, true) + WorldRange(), Effects(), Effects(), nothing) end end InferenceResult(linfo::MethodInstance, ๐•ƒ::AbstractLattice=fallback_lattice) =