From 2bd31a0b2ac23165d3dcda569e04bdd32283c3bd Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Thu, 17 Sep 2020 15:49:54 -0400 Subject: [PATCH] Refactor inlining to allow re-use in more sophisticated inlining passes (#37027) The inlining transform basically has three parts: 1. Analysis (What needs to be inlined and are we allowed to do that?) 2. Policy (Should we inline this?) 3. Mechanism (Stuff the bits from one function into the other) At the moment, we already separate this out into two passes: Analysis/Policy (assemble_inline_todo!) and Mechanism (batch_inline!). For our needs in base, the policy bits are quite simple (how large is the optimized version of this function), but that policy is insufficient for some more sophisticated inlining needs I have in an external compiler pass (where I want to interleave inlining with different transforms as well as potentially run inlining multiple times). To facilitate such use cases, this commit optionally splits out the policy part, but lets the analysis and mechanism parts be re-used by a more sophisticated inlining pass. It also refactors the optimization state to more clearly delineate the different independent parts (edge tracking, inference catches, method table), as well as making the different parts optional (where not required). We were already essentially supporting optimization without edge tracking (for testing purposes), so this is just a bit more explicit about it (which is useful for me, since the different inlining passes in my pipeline may need different settings). For base itself, nothing should functionally change, though hopefully things are factored a bit cleaner. --- base/compiler/optimize.jl | 82 ++++--- base/compiler/ssair/driver.jl | 2 +- base/compiler/ssair/inlining.jl | 397 +++++++++++++++++--------------- base/compiler/ssair/ir.jl | 20 +- base/compiler/typeinfer.jl | 4 +- base/essentials.jl | 2 + test/compiler/inference.jl | 2 - 7 files changed, 283 insertions(+), 226 deletions(-) diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 87e0c4c30f3b4..d53b8193e639a 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -4,21 +4,45 @@ # OptimizationState # ##################### -mutable struct OptimizationState +struct EdgeTracker + edges::Vector{Any} + valid_worlds::RefValue{WorldRange} + EdgeTracker(edges::Vector{Any}, range::WorldRange) = + new(edges, RefValue{WorldRange}(range)) +end +EdgeTracker() = EdgeTracker(Any[], 0:typemax(UInt)) + +intersect!(et::EdgeTracker, range::WorldRange) = + et.valid_worlds[] = intersect(et.valid_worlds[], range) + +push!(et::EdgeTracker, mi::MethodInstance) = push!(et.edges, mi) +function push!(et::EdgeTracker, ci::CodeInstance) + intersect!(et, WorldRange(min_world(li), max_world(li))) + push!(et, ci.def) +end + +struct InferenceCaches{T, S} + inf_cache::T + mi_cache::S +end + +struct InliningState{S <: Union{EdgeTracker, Nothing}, T <: Union{InferenceCaches, Nothing}, V <: Union{Nothing, MethodTableView}} params::OptimizationParams + et::S + caches::T + method_table::V +end + +mutable struct OptimizationState linfo::MethodInstance - calledges::Vector{Any} src::CodeInfo stmt_info::Vector{Any} mod::Module nargs::Int - world::UInt - valid_worlds::WorldRange sptypes::Vector{Any} # static parameters slottypes::Vector{Any} const_api::Bool - # TODO: This will be eliminated once optimization no longer needs to do method lookups - interp::AbstractInterpreter + inlining::InliningState function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter) s_edges = frame.stmt_edges[1] if s_edges === nothing @@ -26,12 +50,16 @@ mutable struct OptimizationState frame.stmt_edges[1] = s_edges end src = frame.src - return new(params, frame.linfo, - s_edges::Vector{Any}, + inlining = InliningState(params, + EdgeTracker(s_edges::Vector{Any}, frame.valid_worlds), + InferenceCaches( + get_inference_cache(interp), + WorldView(code_cache(interp), frame.world)), + method_table(interp)) + return new(frame.linfo, src, frame.stmt_info, frame.mod, frame.nargs, - frame.world, frame.valid_worlds, frame.sptypes, frame.slottypes, false, - interp) + inlining) end function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter) # prepare src for running optimization passes @@ -45,7 +73,6 @@ mutable struct OptimizationState if slottypes === nothing slottypes = Any[ Any for i = 1:nslots ] end - s_edges = [] stmt_info = Any[nothing for i = 1:nssavalues] # cache some useful state computations toplevel = !isa(linfo.def, Method) @@ -57,12 +84,18 @@ mutable struct OptimizationState inmodule = linfo.def::Module nargs = 0 end - return new(params, linfo, - s_edges::Vector{Any}, + # Allow using the global MI cache, but don't track edges. + # This method is mostly used for unit testing the optimizer + inlining = InliningState(params, + nothing, + InferenceCaches( + get_inference_cache(interp), + WorldView(code_cache(interp), get_world_counter())), + method_table(interp)) + return new(linfo, src, stmt_info, inmodule, nargs, - get_world_counter(), WorldRange(UInt(1), get_world_counter()), sptypes_from_meth_instance(linfo), slottypes, false, - interp) + inlining) end end @@ -106,25 +139,6 @@ const TOP_TUPLE = GlobalRef(Core, :tuple) _topmod(sv::OptimizationState) = _topmod(sv.mod) -function update_valid_age!(sv::OptimizationState, valid_worlds::WorldRange) - sv.valid_worlds = intersect(sv.valid_worlds, valid_worlds) - @assert(sv.world in sv.valid_worlds, "invalid age range update") - nothing -end - -function add_backedge!(li::MethodInstance, caller::OptimizationState) - #TODO: deprecate this? - isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs - push!(caller.calledges, li) - nothing -end - -function add_backedge!(li::CodeInstance, caller::OptimizationState) - update_valid_age!(caller, WorldRange(min_world(li), max_world(li))) - add_backedge!(li.def, caller) - nothing -end - function isinlineable(m::Method, me::OptimizationState, params::OptimizationParams, union_penalties::Bool, bonus::Int=0) # compute the cost (size) of inlining this code inlineable = false diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index cb8e9ce01d902..465102e82e155 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -124,7 +124,7 @@ function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState) #@Base.show ("after_construct", ir) # TODO: Domsorting can produce an updated domtree - no need to recompute here @timeit "compact 1" ir = compact!(ir) - @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv) + @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) #@timeit "verify 2" verify_ir(ir) ir = compact!(ir) #@Base.show ("before_sroa", ir) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 655073c7d3664..4a8e5f5e0a622 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -19,61 +19,56 @@ struct Signature end with_atype(sig::Signature) = Signature(sig.f, sig.ft, sig.atypes, argtypes_to_type(sig.atypes)) -struct InliningTodo - idx::Int # The statement to replace - # Properties of the call - these determine how arguments - # need to be rewritten. - isva::Bool - isinvoke::Bool - na::Int - method::Method # The method being inlined - sparams::Vector{Any} # The static parameters we computed for this call site - metharg # ::Type +struct ResolvedInliningSpec # The LineTable and IR of the inlinee ir::IRCode # If the function being inlined is a single basic block we can use a # simpler inlining algorithm. This flag determines whether that's allowed linear_inline_eligible::Bool end -isinvoke(inl::InliningTodo) = inl.isinvoke -struct ConstantCase - val::Any - method::Method - sparams::Vector{Any} - metharg::Any - ConstantCase(val, method::Method, sparams::Vector{Any}, metharg) = - new(val, method, sparams, metharg) +""" + Represents a callsite that our analysis has determined is legal to inline, + but did not resolve during the analysis step to allow the outer inlining + pass to apply its own inlining policy decisions. +""" +struct DelayedInliningSpec + match::MethodMatch + atypes::Vector{Any} + stmttype::Any end -struct DynamicCase - method::Method - sparams::Vector{Any} - metharg::Any - DynamicCase(method::Method, sparams::Vector{Any}, metharg) = - new(method, sparams, metharg) +struct InliningTodo + # The MethodInstance to be inlined + mi::MethodInstance + spec::Union{ResolvedInliningSpec, DelayedInliningSpec} +end + +InliningTodo(mi::MethodInstance, match::MethodMatch, atypes::Vector{Any}, @nospecialize(stmttype)) = InliningTodo(mi, DelayedInliningSpec(match, atypes, stmttype)) + +struct ConstantCase + val::Any + ConstantCase(val) = new(val) end struct UnionSplit - idx::Int # The statement to replace fully_covered::Bool atype # ::Type cases::Vector{Pair{Any, Any}} bbs::Vector{Int} - UnionSplit(idx::Int, fully_covered::Bool, atype, cases::Vector{Pair{Any, Any}}) = - new(idx, fully_covered, atype, cases, Int[]) + UnionSplit(fully_covered::Bool, atype, cases::Vector{Pair{Any, Any}}) = + new(fully_covered, atype, cases, Int[]) end -isinvoke(inl::UnionSplit) = false @specialize -function ssa_inlining_pass!(ir::IRCode, linetable::Vector{LineInfoNode}, sv::OptimizationState) +function ssa_inlining_pass!(ir::IRCode, linetable::Vector{LineInfoNode}, state::InliningState, propagate_inbounds::Bool) # Go through the function, performing simple ininlingin (e.g. replacing call by constants # and analyzing legality of inlining). - @timeit "analysis" todo = assemble_inline_todo!(ir, sv) + @timeit "analysis" todo = assemble_inline_todo!(ir, state) isempty(todo) && return ir # Do the actual inlining for every call we identified - @timeit "execution" ir = batch_inline!(todo, ir, linetable, sv.src.propagate_inbounds) + @timeit "execution" ir = batch_inline!(todo, ir, linetable, propagate_inbounds) return ir end @@ -117,12 +112,12 @@ function inline_into_block!(state::CFGInliningState, block::Int) return end -function cfg_inline_item!(item::InliningTodo, state::CFGInliningState, from_unionsplit::Bool=false) - inlinee_cfg = item.ir.cfg +function cfg_inline_item!(idx::Int, spec::ResolvedInliningSpec, state::CFGInliningState, from_unionsplit::Bool=false) + inlinee_cfg = spec.ir.cfg # Figure out if we need to split the BB need_split_before = false need_split = true - block = block_for_inst(state.cfg, item.idx) + block = block_for_inst(state.cfg, idx) inline_into_block!(state, block) if !isempty(inlinee_cfg.blocks[1].preds) @@ -195,7 +190,7 @@ function cfg_inline_item!(item::InliningTodo, state::CFGInliningState, from_unio for (old_block, new_block) in enumerate(bb_rename_range) if (length(state.new_cfg_blocks[new_block].succs) == 0) terminator_idx = last(inlinee_cfg.blocks[old_block].stmts) - terminator = item.ir[SSAValue(terminator_idx)] + terminator = spec.ir[SSAValue(terminator_idx)] if isa(terminator, ReturnNode) && isdefined(terminator, :val) any_edges = true push!(state.new_cfg_blocks[new_block].succs, post_bb_id) @@ -211,8 +206,8 @@ function cfg_inline_item!(item::InliningTodo, state::CFGInliningState, from_unio end end -function cfg_inline_unionsplit!(item::UnionSplit, state::CFGInliningState) - block = block_for_inst(state.cfg, item.idx) +function cfg_inline_unionsplit!(idx::Int, item::UnionSplit, state::CFGInliningState) + block = block_for_inst(state.cfg, idx) inline_into_block!(state, block) from_bbs = Int[] delete!(state.split_targets, length(state.new_cfg_blocks)) @@ -221,12 +216,15 @@ function cfg_inline_unionsplit!(item::UnionSplit, state::CFGInliningState) for (i, (_, case)) in enumerate(item.cases) # The condition gets sunk into the previous block # Add a block for the union-split body - push!(state.new_cfg_blocks, BasicBlock(StmtRange(item.idx, item.idx))) + push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx))) cond_bb = length(state.new_cfg_blocks)-1 push!(state.new_cfg_blocks[end].preds, cond_bb) push!(state.new_cfg_blocks[cond_bb].succs, cond_bb+1) - if isa(case, InliningTodo) && !case.linear_inline_eligible - cfg_inline_item!(case, state, true) + if isa(case, InliningTodo) + spec = case.spec::ResolvedInliningSpec + if !spec.linear_inline_eligible + cfg_inline_item!(idx, spec, state, true) + end end bb = length(state.new_cfg_blocks) push!(from_bbs, bb) @@ -234,7 +232,7 @@ function cfg_inline_unionsplit!(item::UnionSplit, state::CFGInliningState) # in case of subtyping errors - This is probably unnecessary. if true # i != length(item.cases) || !item.fully_covered # This block will have the next condition or the final else case - push!(state.new_cfg_blocks, BasicBlock(StmtRange(item.idx, item.idx))) + push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx))) push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks)) push!(state.new_cfg_blocks[end].preds, cond_bb) push!(item.bbs, length(state.new_cfg_blocks)) @@ -245,7 +243,7 @@ function cfg_inline_unionsplit!(item::UnionSplit, state::CFGInliningState) push!(from_bbs, length(state.new_cfg_blocks)) end # This block will be the block everyone returns to - push!(state.new_cfg_blocks, BasicBlock(StmtRange(item.idx, item.idx), from_bbs, orig_succs)) + push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx), from_bbs, orig_succs)) join_bb = length(state.new_cfg_blocks) push!(state.split_targets, join_bb) push!(item.bbs, join_bb) @@ -301,18 +299,21 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector linetable::Vector{LineInfoNode}, item::InliningTodo, boundscheck::Symbol, todo_bbs::Vector{Tuple{Int, Int}}) # Ok, do the inlining here - inline_cfg = item.ir.cfg + spec = item.spec::ResolvedInliningSpec + inline_cfg = spec.ir.cfg stmt = compact.result[idx][:inst] linetable_offset::Int32 = length(linetable) # Append the linetable of the inlined function to our line table inlined_at = Int(compact.result[idx][:line]) - for entry in item.ir.linetable + for entry in spec.ir.linetable push!(linetable, LineInfoNode(entry.module, entry.method, entry.file, entry.line, (entry.inlined_at > 0 ? entry.inlined_at + linetable_offset : inlined_at))) end - if item.isva - vararg = mk_tuplecall!(compact, argexprs[item.na:end], compact.result[idx][:line]) - argexprs = Any[argexprs[1:(item.na - 1)]..., vararg] + nargs_def = item.mi.def.nargs + isva = nargs_def > 0 && item.mi.def.isva + if isva + vararg = mk_tuplecall!(compact, argexprs[nargs_def:end], compact.result[idx][:line]) + argexprs = Any[argexprs[1:(nargs_def - 1)]..., vararg] end flag = compact.result[idx][:flag] boundscheck_idx = boundscheck @@ -325,16 +326,16 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector # temporarily re-open in again. local return_value # Special case inlining that maintains the current basic block if there's only one BB in the target - if item.linear_inline_eligible - terminator = item.ir[SSAValue(last(inline_cfg.blocks[1].stmts))] + if spec.linear_inline_eligible + terminator = spec.ir[SSAValue(last(inline_cfg.blocks[1].stmts))] #compact[idx] = nothing - inline_compact = IncrementalCompact(compact, item.ir, compact.result_idx) + inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx) for ((_, idx′), stmt′) in inline_compact # This dance is done to maintain accurate usage counts in the # face of rename_arguments! mutating in place - should figure out # something better eventually. inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, item.method.sig, item.sparams, linetable_offset, boundscheck_idx, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, item.mi.def.sig, item.mi.sparam_vals, linetable_offset, boundscheck_idx, compact) if isa(stmt′, ReturnNode) isa(stmt′.val, SSAValue) && (compact.used_ssas[stmt′.val.id] += 1) return_value = SSAValue(idx′) @@ -352,16 +353,16 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector else bb_offset, post_bb_id = popfirst!(todo_bbs) # This implements the need_split_before flag above - need_split_before = !isempty(item.ir.cfg.blocks[1].preds) + need_split_before = !isempty(spec.ir.cfg.blocks[1].preds) if need_split_before finish_current_bb!(compact, 0) end pn = PhiNode() #compact[idx] = nothing - inline_compact = IncrementalCompact(compact, item.ir, compact.result_idx) + inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx) for ((_, idx′), stmt′) in inline_compact inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, item.method.sig, item.sparams, linetable_offset, boundscheck_idx, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, item.mi.def.sig, item.mi.sparam_vals, linetable_offset, boundscheck_idx, compact) if isa(stmt′, ReturnNode) if isdefined(stmt′, :val) val = stmt′.val @@ -495,17 +496,18 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, nothing end -function batch_inline!(todo::Vector{Any}, ir::IRCode, linetable::Vector{LineInfoNode}, propagate_inbounds::Bool) +function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vector{LineInfoNode}, propagate_inbounds::Bool) # Compute the new CFG first (modulo statement ranges, which will be computed below) state = CFGInliningState(ir) - for item in todo + for (idx, item) in todo if isa(item, UnionSplit) - cfg_inline_unionsplit!(item::UnionSplit, state) + cfg_inline_unionsplit!(idx, item::UnionSplit, state) else item = item::InliningTodo + spec = item.spec::ResolvedInliningSpec # A linear inline does not modify the CFG - item.linear_inline_eligible && continue - cfg_inline_item!(item, state) + spec.linear_inline_eligible && continue + cfg_inline_item!(idx, spec, state, false) end end finish_cfg_inline!(state) @@ -519,15 +521,15 @@ function batch_inline!(todo::Vector{Any}, ir::IRCode, linetable::Vector{LineInfo compact.result_bbs = state.new_cfg_blocks # This needs to be a minimum and is more of a size hint nn = 0 - for item in todo + for (_, item) in todo if isa(item, InliningTodo) - nn += (length(item.ir.stmts) + length(item.ir.new_nodes)) + spec = item.spec::ResolvedInliningSpec + nn += (length(spec.ir.stmts) + length(spec.ir.new_nodes)) end end nnewnodes = length(compact.result) + nn resize!(compact, nnewnodes) - item = popfirst!(todo) - inline_idx = item.idx + (inline_idx, item) = popfirst!(todo) for ((old_idx, idx), stmt) in compact if old_idx == inline_idx argexprs = copy(stmt.args) @@ -543,11 +545,6 @@ function batch_inline!(todo::Vector{Any}, ir::IRCode, linetable::Vector{LineInfo argexprs[aidx] = insert_node_here!(compact, aexpr, compact_exprtype(compact, aexpr), compact.result[idx][:line]) end end - if isinvoke(item) - argexprs = rewrite_invoke_exprargs!(argexprs) do node, typ - insert_node_here!(compact, node, typ, compact.result[idx][:line]) - end - end if isa(item, InliningTodo) compact.ssa_rename[old_idx] = ir_inline_item!(compact, idx, argexprs, linetable, item, boundscheck, state.todo_bbs) elseif isa(item, UnionSplit) @@ -556,8 +553,7 @@ function batch_inline!(todo::Vector{Any}, ir::IRCode, linetable::Vector{LineInfo compact[idx] = nothing refinish && finish_current_bb!(compact, 0) if !isempty(todo) - item = popfirst!(todo) - inline_idx = item.idx + (inline_idx, item) = popfirst!(todo) else inline_idx = -1 end @@ -578,7 +574,11 @@ function batch_inline!(todo::Vector{Any}, ir::IRCode, linetable::Vector{LineInfo end # This assumes the caller has verified that all arguments to the _apply call are Tuples. -function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Any}, idx::Int, argexprs::Vector{Any}, atypes::Vector{Any}, arginfos::Vector{Any}, arg_start::Int, sv::OptimizationState) +function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, + argexprs::Vector{Any}, atypes::Vector{Any}, arginfos::Vector{Any}, + arg_start::Int, et::Union{EdgeTracker, Nothing}, caches::Union{InferenceCaches, Nothing}, + params::OptimizationParams) + new_argexprs = Any[argexprs[arg_start]] new_atypes = Any[atypes[arg_start]] # loop over original arguments and flatten any known iterators @@ -636,7 +636,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Any}, idx::Int, argexp MethodMatchInfo[call.info] : call.info.matches # See if we can inline this call to `iterate` analyze_single_call!(ir, todo, state1.id, new_stmt, - new_sig, call.rt, info, sv) + new_sig, call.rt, info, et, caches, params) end if i != length(thisarginfo.each) valT = getfield_tfunc(call.rt, Const(1)) @@ -654,7 +654,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Any}, idx::Int, argexp return new_argexprs, new_atypes end -function rewrite_invoke_exprargs!(inserter, argexprs::Vector{Any}) +function rewrite_invoke_exprargs!(argexprs::Vector{Any}) argexpr0 = argexprs[2] argexprs = argexprs[4:end] pushfirst!(argexprs, argexpr0) @@ -670,15 +670,57 @@ function singleton_type(@nospecialize(ft)) return nothing end -function compileable_specialization(match::MethodMatch, sv::OptimizationState) +function compileable_specialization(et::Union{EdgeTracker, Nothing}, match::MethodMatch) mi = specialize_method(match, false, true) - mi !== nothing && add_backedge!(mi::MethodInstance, sv) + mi !== nothing && et !== nothing && push!(et, mi::MethodInstance) return mi end -function analyze_method!(idx::Int, atypes::Vector{Any}, match::MethodMatch, - stmt::Expr, sv::OptimizationState, - isinvoke::Bool, @nospecialize(stmttyp)) +function resolve_todo(todo::InliningTodo, et::Union{EdgeTracker, Nothing}, caches::InferenceCaches) + spec = todo.spec::DelayedInliningSpec + isconst, src = find_inferred(todo.mi, spec.atypes, caches, spec.stmttype) + + if isconst + push!(et, todo.mi) + return ConstantCase(src) + end + + if src === nothing + return compileable_specialization(et, spec.match) + end + + if isa(src, CodeInfo) || isa(src, Vector{UInt8}) + src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src) + src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), src) + + if !(src_inferred && src_inlineable) + return compileable_specialization(et, spec.match) + end + elseif isa(src, IRCode) + src = copy(src) + end + + et !== nothing && push!(et, todo.mi) + return InliningTodo(todo.mi, src) +end + +function resolve_todo(todo::UnionSplit, et::Union{EdgeTracker, Nothing}, caches::InferenceCaches) + UnionSplit(todo.fully_covered, todo.atype, + Pair{Any,Any}[sig=>resolve_todo(item, et, caches) for (sig, item) in todo.cases]) +end + +function resolve_todo!(todo::Vector{Pair{Int, Any}}, et::Union{EdgeTracker, Nothing}, caches::InferenceCaches) + for i = 1:length(todo) + idx, item = todo[i] + todo[i] = idx=>resolve_todo(item, et, caches) + end + todo +end + +function analyze_method!(match::MethodMatch, atypes::Vector{Any}, + et::Union{EdgeTracker, Nothing}, + caches::Union{InferenceCaches, Nothing}, + params::OptimizationParams, @nospecialize(stmttyp)) method = match.method methsig = method.sig @@ -699,58 +741,35 @@ function analyze_method!(idx::Int, atypes::Vector{Any}, match::MethodMatch, isa(match.sparams[i], TypeVar) && return nothing end - if !sv.params.inlining - return compileable_specialization(match, sv) + if !params.inlining + return compileable_specialization(et, match) end # See if there exists a specialization for this method signature mi = specialize_method(match, true) # Union{Nothing, MethodInstance} if !isa(mi, MethodInstance) - return compileable_specialization(match, sv) - end - - isconst, src = find_inferred(mi, atypes, sv, stmttyp) - if isconst - add_backedge!(mi, sv) - return ConstantCase(src, method, Any[match.sparams...], match.spec_types) - end - if src === nothing - return compileable_specialization(match, sv) + return compileable_specialization(et, match) end - src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src) - src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), src) - - if !(src_inferred && src_inlineable) - return compileable_specialization(match, sv) - end + todo = InliningTodo(mi, match, atypes, stmttyp) + # If we don't have caches here, delay resolving this MethodInstance + # until the batch inlining step (or an external post-processing pass) + caches === nothing && return todo + return resolve_todo(todo, et, caches) +end - # At this point we're committed to performing the inlining, add the backedge - add_backedge!(mi, sv) +function InliningTodo(mi::MethodInstance, ir::IRCode) + return InliningTodo(mi, ResolvedInliningSpec(ir, linear_inline_eligible(ir))) +end +function InliningTodo(mi::MethodInstance, src::Union{CodeInfo, Array{UInt8, 1}}) if !isa(src, CodeInfo) - src = ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any), method, C_NULL, src::Vector{UInt8})::CodeInfo + src = ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any), mi.def, C_NULL, src::Vector{UInt8})::CodeInfo end @timeit "inline IR inflation" begin - ir2 = inflate_ir(src, mi) - # #optional: prepare inlining linetable with method instance information - # inline_linetable = ir2.linetable - # for i = 1:length(inline_linetable) - # entry = inline_linetable[i] - # if entry.inlined_at === 0 && entry.method === method - # entry = LineInfoNode(entry.module, mi, entry.file, entry.line, entry.inlined_at) - # inline_linetable[i] = entry - # end - # end - end - #verify_ir(ir2) - - return InliningTodo(idx, - na > 0 && method.isva, - isinvoke, na, - method, Any[match.sparams...], match.spec_types, - ir2, linear_inline_eligible(ir2)) + return InliningTodo(mi, inflate_ir(src, mi)::IRCode) + end end # Neither the product iterator not CartesianIndices are available @@ -798,21 +817,22 @@ function iterate(split::UnionSplitSignature, state::Vector{Int}...) return (sig, state) end -function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case), isinvoke::Bool, todo::Vector{Any}) +function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case), isinvoke::Bool, todo::Vector{Pair{Int, Any}}) if isa(case, ConstantCase) ir[SSAValue(idx)] = case.val elseif isa(case, MethodInstance) if isinvoke - stmt.args = rewrite_invoke_exprargs!( - (node, typ)->insert_node!(ir, idx, typ, node), - stmt.args) + stmt.args = rewrite_invoke_exprargs!(stmt.args) end stmt.head = :invoke pushfirst!(stmt.args, case) elseif case === nothing # Do, well, nothing else - push!(todo, case::InliningTodo) + if isinvoke + stmt.args = rewrite_invoke_exprargs!(stmt.args) + end + push!(todo, idx=>(case::InliningTodo)) end nothing end @@ -886,8 +906,8 @@ function call_sig(ir::IRCode, stmt::Expr) Signature(f, ft, atypes) end -function inline_apply!(ir::IRCode, todo::Vector{Any}, idx::Int, sig::Signature, - params::OptimizationParams, sv::OptimizationState) +function inline_apply!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sig::Signature, + et, caches, params::OptimizationParams) stmt = ir.stmts[idx][:inst] while sig.f === Core._apply || sig.f === Core._apply_iterate info = ir.stmts[idx][:info] @@ -944,7 +964,7 @@ function inline_apply!(ir::IRCode, todo::Vector{Any}, idx::Int, sig::Signature, end # Independent of whether we can inline, the above analysis allows us to rewrite # this apply call to a regular call - stmt.args, atypes = rewrite_apply_exprargs!(ir, todo, idx, stmt.args, atypes, infos, arg_start, sv) + stmt.args, atypes = rewrite_apply_exprargs!(ir, todo, idx, stmt.args, atypes, infos, arg_start, et, caches, params) ir.stmts[idx][:info] = new_info has_free_typevars(ft) && return nothing f = singleton_type(ft) @@ -960,7 +980,7 @@ is_builtin(s::Signature) = isa(s.f, Builtin) || s.ft ⊑ Builtin -function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::InvokeData, sv::OptimizationState, todo::Vector{Any}) +function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::InvokeData, state::InliningState, todo::Vector{Pair{Int, Any}}) stmt = ir.stmts[idx][:inst] calltype = ir.stmts[idx][:type] method = invoke_data.entry @@ -968,16 +988,16 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::Invok sig.atype, method.sig)::SimpleVector methsp = methsp::SimpleVector match = MethodMatch(metharg, methsp, method, true) - result = analyze_method!(idx, sig.atypes, match, stmt, sv, true, calltype) + result = analyze_method!(match, sig.atypes, state.et, state.caches, state.params, calltype) handle_single_case!(ir, stmt, idx, result, true, todo) - update_valid_age!(sv, WorldRange(invoke_data.min_valid, invoke_data.max_valid)) + intersect!(state.et, WorldRange(invoke_data.min_valid, invoke_data.max_valid)) return nothing end # Handles all analysis and inlining of intrinsics and builtins. In particular, # this method does not access the method table or otherwise process generic # functions. -function process_simple!(ir::IRCode, todo, idx::Int, params::OptimizationParams, world::UInt, sv) +function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, state::InliningState) stmt = ir.stmts[idx][:inst] stmt isa Expr || return nothing if stmt.head === :splatnew @@ -991,12 +1011,12 @@ function process_simple!(ir::IRCode, todo, idx::Int, params::OptimizationParams, sig === nothing && return nothing # Handle _apply - sig = inline_apply!(ir, todo, idx, sig, params, sv) + sig = inline_apply!(ir, todo, idx, sig, state.et, state.caches, state.params) sig === nothing && return nothing # Check if we match any of the early inliners calltype = ir.stmts[idx][:type] - res = early_inline_special_case(ir, sig, stmt, params, calltype) + res = early_inline_special_case(ir, sig, stmt, state.params, calltype) if res !== nothing ir.stmts[idx][:inst] = res return nothing @@ -1005,7 +1025,7 @@ function process_simple!(ir::IRCode, todo, idx::Int, params::OptimizationParams, # Handle invoke invoke_data = nothing if sig.f === Core.invoke && length(sig.atypes) >= 3 - res = compute_invoke_data(sig.atypes, world) + res = compute_invoke_data(sig.atypes, state.method_table) res === nothing && return nothing (sig, invoke_data) = res elseif is_builtin(sig) @@ -1020,7 +1040,7 @@ function process_simple!(ir::IRCode, todo, idx::Int, params::OptimizationParams, (invoke_data === nothing || sig.atype <: invoke_data.types0) || return nothing # Special case inliners for regular functions - if late_inline_special_case!(ir, sig, idx, stmt, params) || is_return_type(sig.f) + if late_inline_special_case!(ir, sig, idx, stmt, state.params) || is_return_type(sig.f) return nothing end return (sig, invoke_data) @@ -1029,20 +1049,21 @@ end # This is not currently called in the regular course, but may be needed # if we ever want to re-run inlining again later in the pass pipeline after # additional type information was discovered. -function recompute_method_matches(@nospecialize(atype), sv::OptimizationState) +function recompute_method_matches(@nospecialize(atype), params::OptimizationParams, et::EdgeTracker, method_table::MethodTableView) # Regular case: Retrieve matching methods from cache (or compute them) # World age does not need to be taken into account in the cache # because it is forwarded from type inference through `sv.params` # in the case that the cache is nonempty, so it should be unchanged # The max number of methods should be the same as in inference most # of the time, and should not affect correctness otherwise. - results = findall(atype, InternalMethodTable(sv.world); limit=sv.params.MAX_METHODS) - results !== missing && update_valid_age!(sv, results.valid_worlds) + results = findall(atype, method_table; limit=params.MAX_METHODS) + results !== missing && intersect!(et, results.valid_worlds) MethodMatchInfo(results) end -function analyze_single_call!(ir::IRCode, todo::Vector{Any}, idx::Int, @nospecialize(stmt), - sig::Signature, @nospecialize(calltype), infos::Vector{MethodMatchInfo}, sv::OptimizationState) +function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt), + sig::Signature, @nospecialize(calltype), infos::Vector{MethodMatchInfo}, + et, caches, params) cases = Pair{Any, Any}[] signature_union = Union{} only_method = nothing # keep track of whether there is one matching method @@ -1075,8 +1096,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Any}, idx::Int, @nospecia fully_covered = false continue end - case = analyze_method!(idx, sig.atypes, match, - stmt, sv, false, calltype) + case = analyze_method!(match, sig.atypes, et, caches, params, calltype) if case === nothing fully_covered = false continue @@ -1102,7 +1122,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Any}, idx::Int, @nospecia match = meth[1] end fully_covered = true - case = analyze_method!(idx, sig.atypes, match, stmt, sv, false, calltype) + case = analyze_method!(match, sig.atypes, et, caches, params, calltype) case === nothing && return push!(cases, Pair{Any,Any}(match.spec_types, case)) end @@ -1118,19 +1138,19 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Any}, idx::Int, @nospecia return end length(cases) == 0 && return - push!(todo, UnionSplit(idx, fully_covered, sig.atype, cases)) + push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases)) return nothing end -function assemble_inline_todo!(ir::IRCode, sv::OptimizationState) +function assemble_inline_todo!(ir::IRCode, state::InliningState) # todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie) - todo = Any[] - if sv.params.unoptimize_throw_blocks + todo = Pair{Int, Any}[] + if state.params.unoptimize_throw_blocks skip = find_throw_blocks(ir.stmts.inst, RefValue(ir)) end for idx in 1:length(ir.stmts) - sv.params.unoptimize_throw_blocks && idx in skip && continue - r = process_simple!(ir, todo, idx, sv.params, sv.world, sv) + state.params.unoptimize_throw_blocks && idx in skip && continue + r = process_simple!(ir, todo, idx, state) r === nothing && continue stmt = ir.stmts[idx][:inst] @@ -1153,28 +1173,34 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState) # Ok, now figure out what method to call if invoke_data !== nothing - inline_invoke!(ir, idx, sig, invoke_data, sv, todo) + inline_invoke!(ir, idx, sig, invoke_data, state, todo) continue end nu = countunionsplit(sig.atypes) - if nu == 1 || nu > sv.params.MAX_UNION_SPLITTING + if nu == 1 || nu > state.params.MAX_UNION_SPLITTING if !isa(info, MethodMatchInfo) - info = recompute_method_matches(sig.atype, sv) + if state.method_table === nothing + continue + end + info = recompute_method_matches(sig.atype, state.params, state.et, state.method_table) end infos = MethodMatchInfo[info] else if !isa(info, UnionSplitInfo) + if state.method_table === nothing + continue + end infos = MethodMatchInfo[] for union_sig in UnionSplitSignature(sig.atypes) - push!(infos, recompute_method_matches(argtypes_to_type(union_sig), sv)) + push!(infos, recompute_method_matches(argtypes_to_type(union_sig), state.params, state.et, state.method_table)) end else infos = info.matches end end - analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, sv) + analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state.et, state.caches, state.params) end todo end @@ -1193,7 +1219,7 @@ function linear_inline_eligible(ir::IRCode) return true end -function compute_invoke_data(@nospecialize(atypes), world::UInt) +function compute_invoke_data(@nospecialize(atypes), method_table) ft = widenconst(atypes[2]) if !isdispatchelem(ft) || has_free_typevars(ft) || (ft <: Builtin) # TODO: this can be rather aggressive at preventing inlining of closures @@ -1208,8 +1234,13 @@ function compute_invoke_data(@nospecialize(atypes), world::UInt) if !(isa(unwrap_unionall(invoke_tt), DataType) && invoke_tt <: Tuple) return nothing end + if method_table === nothing + # TODO: These should be forwarded in stmt_info, just like regular + # method lookup results + return nothing + end invoke_types = rewrap_unionall(Tuple{ft, unwrap_unionall(invoke_tt).parameters...}, invoke_tt) - invoke_entry = findsup(invoke_types, InternalMethodTable(world)) + invoke_entry = findsup(invoke_types, method_table) invoke_entry === nothing && return nothing method, valid_worlds = invoke_entry invoke_data = InvokeData(method, invoke_types, first(valid_worlds), last(valid_worlds)) @@ -1301,7 +1332,7 @@ function late_inline_special_case!(ir::IRCode, sig::Signature, idx::Int, stmt::E end function ssa_substitute!(idx::Int, @nospecialize(val), arg_replacements::Vector{Any}, - @nospecialize(spsig), spvals::Vector{Any}, + @nospecialize(spsig), spvals::SimpleVector, linetable_offset::Int32, boundscheck::Symbol, compact::IncrementalCompact) compact.result[idx][:flag] &= ~IR_FLAG_INBOUNDS compact.result[idx][:line] += linetable_offset @@ -1309,7 +1340,7 @@ function ssa_substitute!(idx::Int, @nospecialize(val), arg_replacements::Vector{ end function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, - @nospecialize(spsig), spvals::Vector{Any}, boundscheck::Symbol) + @nospecialize(spsig), spvals::SimpleVector, boundscheck::Symbol) if isa(val, Argument) return arg_replacements[val.n] end @@ -1355,41 +1386,45 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, return urs[] end -function find_inferred(mi::MethodInstance, atypes::Vector{Any}, sv::OptimizationState, @nospecialize(rettype)) - # see if the method has a InferenceResult in the current cache - # or an existing inferred code info store in `.inferred` - haveconst = false - for i in 1:length(atypes) - if has_nontrivial_const_info(atypes[i]) - # have new information from argtypes that wasn't available from the signature - haveconst = true - break - end - end - if haveconst || improvable_via_constant_propagation(rettype) - inf_result = cache_lookup(mi, atypes, get_inference_cache(sv.interp)) # Union{Nothing, InferenceResult} - else - inf_result = nothing - end - #XXX: update_valid_age!(min_valid[1], max_valid[1], sv) - if isa(inf_result, InferenceResult) - let inferred_src = inf_result.src - if isa(inferred_src, CodeInfo) - return svec(false, inferred_src) +function find_inferred(mi::MethodInstance, atypes::Vector{Any}, caches::InferenceCaches, @nospecialize(rettype)) + if caches.inf_cache !== nothing + # see if the method has a InferenceResult in the current cache + # or an existing inferred code info store in `.inferred` + haveconst = false + for i in 1:length(atypes) + if has_nontrivial_const_info(atypes[i]) + # have new information from argtypes that wasn't available from the signature + haveconst = true + break end - if isa(inferred_src, Const) && is_inlineable_constant(inferred_src.val) - return svec(true, quoted(inferred_src.val),) + end + if haveconst || improvable_via_constant_propagation(rettype) + inf_result = cache_lookup(mi, atypes, caches.inf_cache) # Union{Nothing, InferenceResult} + else + inf_result = nothing + end + #XXX: update_valid_age!(min_valid[1], max_valid[1], sv) + if isa(inf_result, InferenceResult) + let inferred_src = inf_result.src + if isa(inferred_src, CodeInfo) + return svec(false, inferred_src) + end + if isa(inferred_src, Const) && is_inlineable_constant(inferred_src.val) + return svec(true, quoted(inferred_src.val),) + end end end end - linfo = get(WorldView(code_cache(sv.interp), sv.world), mi, nothing) + linfo = get(caches.mi_cache, mi, nothing) if linfo isa CodeInstance if invoke_api(linfo) == 2 # in this case function can be inlined to a constant return svec(true, quoted(linfo.rettype_const)) end return svec(false, linfo.inferred) + else + # `linfo` may be `nothing` or an IRCode here + return svec(false, linfo) end - return svec(false, nothing) end diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index ae58fcf1f22d4..e7003473e1cbd 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -35,6 +35,7 @@ struct CFG index::Vector{Int} # map from instruction => basic-block number # TODO: make this O(1) instead of O(log(n_blocks))? end +copy(c::CFG) = CFG(BasicBlock[copy(b) for b in c.blocks], copy(c.index)) function block_for_inst(index::Vector{Int}, inst::Int) return searchsortedfirst(index, inst, lt=(<=)) @@ -180,13 +181,14 @@ function add!(is::InstructionStream) resize!(is, ninst) return ninst end -#function copy(is::InstructionStream) # unused -# return InstructionStream( -# copy_exprargs(is.insts), -# copy(is.types), -# copy(is.lines), -# copy(is.flags)) -#end +function copy(is::InstructionStream) + return InstructionStream( + copy_exprargs(is.inst), + copy(is.type), + copy(is.info), + copy(is.line), + copy(is.flag)) +end function resize!(stmts::InstructionStream, len) old_length = length(stmts) resize!(stmts.inst, len) @@ -248,6 +250,7 @@ function add!(new::NewNodeStream, pos::Int, attach_after::Bool) push!(new.info, NewNodeInfo(pos, attach_after)) return Instruction(new.stmts) end +copy(nns::NewNodeStream) = NewNodeStream(copy(nns.stmts), copy(nns.info)) struct IRCode stmts::InstructionStream @@ -264,6 +267,9 @@ struct IRCode function IRCode(ir::IRCode, stmts::InstructionStream, cfg::CFG, new_nodes::NewNodeStream) return new(stmts, ir.argtypes, ir.sptypes, ir.linetable, cfg, new_nodes, ir.meta) end + global copy + copy(ir::IRCode) = new(copy(ir.stmts), copy(ir.argtypes), copy(ir.sptypes), + copy(ir.linetable), copy(ir.cfg), copy(ir.new_nodes), copy(ir.meta)) end function getindex(x::IRCode, s::SSAValue) diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 1f53aa1b552bd..2cd89d0442fdb 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -49,7 +49,9 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState) caller.src = nothing end end - valid_worlds = intersect(valid_worlds, opt.valid_worlds) + # As a hack the et reuses frame_edges[1] to push any optimization + # edges into, so we don't need to handle them specially here + valid_worlds = intersect(valid_worlds, opt.inlining.et.valid_worlds[]) end end end diff --git a/base/essentials.jl b/base/essentials.jl index 8673765f0b8c5..fb360ea6482db 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -617,6 +617,8 @@ map(f, v::SimpleVector) = Any[ f(v[i]) for i = 1:length(v) ] getindex(v::SimpleVector, I::AbstractArray) = Core.svec(Any[ v[i] for i in I ]...) +unsafe_convert(::Type{Ptr{Any}}, sv::SimpleVector) = convert(Ptr{Any},pointer_from_objref(sv)) + sizeof(Ptr) + """ isassigned(array, i) -> Bool diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 9ce543c88a26b..50eb5e9734b35 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1492,8 +1492,6 @@ let linfo = get_linfo(Base.convert, Tuple{Type{Int64}, Int32}), @test opt.src.ssavaluetypes isa Vector{Any} @test !opt.src.inferred @test opt.mod === Base - @test opt.valid_worlds.max_world === Core.Compiler.get_world_counter() - @test opt.valid_worlds.min_world === Core.Compiler.min_world(opt.src) === UInt(1) @test opt.nargs == 3 end