Skip to content

Commit

Permalink
Refactor inlining to allow re-use in more sophisticated inlining pass…
Browse files Browse the repository at this point in the history
…es (#37027)

The inlining transform basically has three parts:
1. Analysis (What needs to be inlined and are we allowed to do that?)
2. Policy (Should we inline this?)
3. Mechanism (Stuff the bits from one function into the other)

At the moment, we already separate this out into two passes:
Analysis/Policy (assemble_inline_todo!) and Mechanism (batch_inline!).
For our needs in base, the policy bits are quite simple (how large
is the optimized version of this function), but that policy is
insufficient for some more sophisticated inlining needs I have
in an external compiler pass (where I want to interleave inlining
with different transforms as well as potentially run inlining multiple
times). To facilitate such use cases, this commit optionally splits
out the policy part, but lets the analysis and mechanism parts be
re-used by a more sophisticated inlining pass. It also refactors
the optimization state to more clearly delineate the different
independent parts (edge tracking, inference catches, method table),
as well as making the different parts optional (where not required).
We were already essentially supporting optimization without edge
tracking (for testing purposes), so this is just a bit more
explicit about it (which is useful for me, since the different
inlining passes in my pipeline may need different settings).

For base itself, nothing should functionally change, though
hopefully things are factored a bit cleaner.
  • Loading branch information
Keno authored Sep 17, 2020
1 parent 840e2fc commit 2bd31a0
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 226 deletions.
82 changes: 48 additions & 34 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,62 @@
# OptimizationState #
#####################

mutable struct OptimizationState
struct EdgeTracker
edges::Vector{Any}
valid_worlds::RefValue{WorldRange}
EdgeTracker(edges::Vector{Any}, range::WorldRange) =
new(edges, RefValue{WorldRange}(range))
end
EdgeTracker() = EdgeTracker(Any[], 0:typemax(UInt))

intersect!(et::EdgeTracker, range::WorldRange) =
et.valid_worlds[] = intersect(et.valid_worlds[], range)

push!(et::EdgeTracker, mi::MethodInstance) = push!(et.edges, mi)
function push!(et::EdgeTracker, ci::CodeInstance)
intersect!(et, WorldRange(min_world(li), max_world(li)))
push!(et, ci.def)
end

struct InferenceCaches{T, S}
inf_cache::T
mi_cache::S
end

struct InliningState{S <: Union{EdgeTracker, Nothing}, T <: Union{InferenceCaches, Nothing}, V <: Union{Nothing, MethodTableView}}
params::OptimizationParams
et::S
caches::T
method_table::V
end

mutable struct OptimizationState
linfo::MethodInstance
calledges::Vector{Any}
src::CodeInfo
stmt_info::Vector{Any}
mod::Module
nargs::Int
world::UInt
valid_worlds::WorldRange
sptypes::Vector{Any} # static parameters
slottypes::Vector{Any}
const_api::Bool
# TODO: This will be eliminated once optimization no longer needs to do method lookups
interp::AbstractInterpreter
inlining::InliningState
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
s_edges = frame.stmt_edges[1]
if s_edges === nothing
s_edges = []
frame.stmt_edges[1] = s_edges
end
src = frame.src
return new(params, frame.linfo,
s_edges::Vector{Any},
inlining = InliningState(params,
EdgeTracker(s_edges::Vector{Any}, frame.valid_worlds),
InferenceCaches(
get_inference_cache(interp),
WorldView(code_cache(interp), frame.world)),
method_table(interp))
return new(frame.linfo,
src, frame.stmt_info, frame.mod, frame.nargs,
frame.world, frame.valid_worlds,
frame.sptypes, frame.slottypes, false,
interp)
inlining)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
# prepare src for running optimization passes
Expand All @@ -45,7 +73,6 @@ mutable struct OptimizationState
if slottypes === nothing
slottypes = Any[ Any for i = 1:nslots ]
end
s_edges = []
stmt_info = Any[nothing for i = 1:nssavalues]
# cache some useful state computations
toplevel = !isa(linfo.def, Method)
Expand All @@ -57,12 +84,18 @@ mutable struct OptimizationState
inmodule = linfo.def::Module
nargs = 0
end
return new(params, linfo,
s_edges::Vector{Any},
# Allow using the global MI cache, but don't track edges.
# This method is mostly used for unit testing the optimizer
inlining = InliningState(params,
nothing,
InferenceCaches(
get_inference_cache(interp),
WorldView(code_cache(interp), get_world_counter())),
method_table(interp))
return new(linfo,
src, stmt_info, inmodule, nargs,
get_world_counter(), WorldRange(UInt(1), get_world_counter()),
sptypes_from_meth_instance(linfo), slottypes, false,
interp)
inlining)
end
end

Expand Down Expand Up @@ -106,25 +139,6 @@ const TOP_TUPLE = GlobalRef(Core, :tuple)

_topmod(sv::OptimizationState) = _topmod(sv.mod)

function update_valid_age!(sv::OptimizationState, valid_worlds::WorldRange)
sv.valid_worlds = intersect(sv.valid_worlds, valid_worlds)
@assert(sv.world in sv.valid_worlds, "invalid age range update")
nothing
end

function add_backedge!(li::MethodInstance, caller::OptimizationState)
#TODO: deprecate this?
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
push!(caller.calledges, li)
nothing
end

function add_backedge!(li::CodeInstance, caller::OptimizationState)
update_valid_age!(caller, WorldRange(min_world(li), max_world(li)))
add_backedge!(li.def, caller)
nothing
end

function isinlineable(m::Method, me::OptimizationState, params::OptimizationParams, union_penalties::Bool, bonus::Int=0)
# compute the cost (size) of inlining this code
inlineable = false
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState)
#@Base.show ("after_construct", ir)
# TODO: Domsorting can produce an updated domtree - no need to recompute here
@timeit "compact 1" ir = compact!(ir)
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv)
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
#@timeit "verify 2" verify_ir(ir)
ir = compact!(ir)
#@Base.show ("before_sroa", ir)
Expand Down
Loading

0 comments on commit 2bd31a0

Please sign in to comment.