Skip to content

Commit

Permalink
inference: stop re-converging worlds after optimization (#38820)
Browse files Browse the repository at this point in the history
The validity did not change, so we should not need to update it. This
also ensures we copy over all result information earlier, so we can
destroy the InferenceState slightly sooner, and slightly cleaner data flow.

(cherry picked from commit 8c01444)
  • Loading branch information
vtjnash authored and staticfloat committed Dec 22, 2022
1 parent 951d1b3 commit 47130c5
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 45 deletions.
3 changes: 1 addition & 2 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,10 @@ using .Sort
# compiler #
############

include("compiler/cicache.jl")
include("compiler/types.jl")
include("compiler/utilities.jl")
include("compiler/validation.jl")

include("compiler/cicache.jl")
include("compiler/methodtable.jl")

include("compiler/inferenceresult.jl")
Expand Down
11 changes: 3 additions & 8 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,15 @@ mutable struct OptimizationState
const_api::Bool
inlining::InliningState
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
s_edges = frame.stmt_edges[1]
if s_edges === nothing
s_edges = []
frame.stmt_edges[1] = s_edges
end
src = frame.src
s_edges = frame.stmt_edges[1]::Vector{Any}
inlining = InliningState(params,
EdgeTracker(s_edges::Vector{Any}, frame.valid_worlds),
EdgeTracker(s_edges, frame.valid_worlds),
InferenceCaches(
get_inference_cache(interp),
WorldView(code_cache(interp), frame.world)),
method_table(interp))
return new(frame.linfo,
src, frame.stmt_info, frame.mod, frame.nargs,
frame.src, frame.stmt_info, frame.mod, frame.nargs,
frame.sptypes, frame.slottypes, false,
inlining)
end
Expand Down
83 changes: 49 additions & 34 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,21 +217,29 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
# with no active ip's, frame is done
frames = frame.callers_in_cycle
isempty(frames) && push!(frames, frame)
valid_worlds = WorldRange()
for caller in frames
@assert !(caller.dont_work_on_me)
caller.dont_work_on_me = true
# might might not fully intersect these earlier, so do that now
valid_worlds = intersect(caller.valid_worlds, valid_worlds)
end
for caller in frames
caller.valid_worlds = valid_worlds
finish(caller, interp)
# finalize and record the linfo result
caller.inferred = true
end
# collect results for the new expanded frame
results = Tuple{InferenceResult, Bool}[ ( frames[i].result,
frames[i].cached || frames[i].parent !== nothing ) for i in 1:length(frames) ]
# empty!(frames)
valid_worlds = frame.valid_worlds
results = Tuple{InferenceResult, Vector{Any}, Bool}[
( frames[i].result,
frames[i].stmt_edges[1],
frames[i].cached || frames[i].parent !== nothing )
for i in 1:length(frames) ]
empty!(frames)
cached = frame.cached
if cached || frame.parent !== nothing
for (caller, doopt) in results
for (caller, _, doopt) in results
opt = caller.src
if opt isa OptimizationState
run_optimizer = doopt && may_optimize(interp)
Expand All @@ -253,31 +261,24 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
caller.src = nothing
end
end
# As a hack the et reuses frame_edges[1] to push any optimization
# edges into, so we don't need to handle them specially here
valid_worlds = intersect(valid_worlds, opt.inlining.et.valid_worlds[])
caller.valid_worlds = opt.inlining.et.valid_worlds[]
end
end
end
if last(valid_worlds) == get_world_counter()
valid_worlds = WorldRange(first(valid_worlds), typemax(UInt))
end
for caller in frames
for (caller, edges, doopt) in results
valid_worlds = caller.valid_worlds
if last(valid_worlds) == get_world_counter()
valid_worlds = WorldRange(first(valid_worlds), typemax(UInt))
end
caller.valid_worlds = valid_worlds
caller.src.min_world = first(valid_worlds)
caller.src.max_world = last(valid_worlds)
if cached
cache_result!(interp, caller.result, valid_worlds)
cache_result!(interp, caller)
end
if last(valid_worlds) == typemax(UInt)
if doopt && last(valid_worlds) == typemax(UInt)
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
for caller in frames
store_backedges(caller)
end
store_backedges(caller, edges)
end
# finalize and record the linfo result
caller.inferred = true
end
return true
end
Expand Down Expand Up @@ -343,14 +344,16 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta
end

function transform_result_for_cache(interp::AbstractInterpreter, linfo::MethodInstance,
@nospecialize(inferred_result))
valid_worlds::WorldRange, @nospecialize(inferred_result))
local const_flags::Int32
# If we decided not to optimize, drop the OptimizationState now.
# External interpreters can override as necessary to cache additional information
if inferred_result isa OptimizationState
inferred_result = inferred_result.src
end
if inferred_result isa CodeInfo
inferred_result.min_world = first(valid_worlds)
inferred_result.max_world = last(valid_worlds)
inferred_result = maybe_compress_codeinfo(interp, linfo, inferred_result)
end
# The global cache can only handle objects that codegen understands
Expand All @@ -360,7 +363,8 @@ function transform_result_for_cache(interp::AbstractInterpreter, linfo::MethodIn
return inferred_result
end

function cache_result!(interp::AbstractInterpreter, result::InferenceResult, valid_worlds::WorldRange)
function cache_result!(interp::AbstractInterpreter, result::InferenceResult)
valid_worlds = result.valid_worlds
# check if the existing linfo metadata is also sufficient to describe the current inference result
# to decide if it is worth caching this
already_inferred = already_inferred_quick_test(interp, result.linfo)
Expand All @@ -370,7 +374,7 @@ function cache_result!(interp::AbstractInterpreter, result::InferenceResult, val

# TODO: also don't store inferred code if we've previously decided to interpret this function
if !already_inferred
inferred_result = transform_result_for_cache(interp, result.linfo, result.src)
inferred_result = transform_result_for_cache(interp, result.linfo, valid_worlds, result.src)
code_cache(interp)[result.linfo] = CodeInstance(result, inferred_result, valid_worlds)
end
unlock_mi_inference(interp, result.linfo)
Expand All @@ -381,6 +385,21 @@ end
# update the MethodInstance
function finish(me::InferenceState, interp::AbstractInterpreter)
# prepare to run optimization passes on fulltree
s_edges = me.stmt_edges[1]
if s_edges === nothing
s_edges = []
me.stmt_edges[1] = s_edges
end
for edges in me.stmt_edges
edges === nothing && continue
edges === s_edges && continue
append!(s_edges, edges)
empty!(edges)
end
if me.src.edges !== nothing
append!(s_edges, me.src.edges)
me.src.edges = nothing
end
if me.limited && me.cached && me.parent !== nothing
# a top parent will be cached still, but not this intermediate work
# we can throw everything else away now
Expand All @@ -392,6 +411,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
type_annotate!(me)
me.result.src = OptimizationState(me, OptimizationParams(interp), interp)
end
me.result.valid_worlds = me.valid_worlds
me.result.result = me.bestguess
nothing
end
Expand All @@ -404,20 +424,15 @@ function finish(src::CodeInfo, interp::AbstractInterpreter)
end

# record the backedges
function store_backedges(frame::InferenceState)
function store_backedges(frame::InferenceResult, edges::Vector{Any})
toplevel = !isa(frame.linfo.def, Method)
if !toplevel && (frame.cached || frame.parent !== nothing)
caller = frame.result.linfo
for edges in frame.stmt_edges
store_backedges(caller, edges)
end
store_backedges(caller, frame.src.edges)
frame.src.edges = nothing
if !toplevel
store_backedges(frame.linfo, edges)
end
nothing
end

store_backedges(caller, edges::Nothing) = nothing
function store_backedges(caller, edges::Vector)
function store_backedges(caller::MethodInstance, edges::Vector)
i = 1
while i <= length(edges)
to = edges[i]
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ mutable struct InferenceResult
overridden_by_const::BitVector
result # ::Type, or InferenceState if WIP
src #::Union{CodeInfo, OptimizationState, Nothing} # if inferred copy is available
valid_worlds::WorldRange # if inference and optimization is finished
function InferenceResult(linfo::MethodInstance, given_argtypes = nothing)
argtypes, overridden_by_const = matching_cache_argtypes(linfo, given_argtypes)
return new(linfo, argtypes, overridden_by_const, Any, nothing)
return new(linfo, argtypes, overridden_by_const, Any, nothing, WorldRange())
end
end

Expand Down

0 comments on commit 47130c5

Please sign in to comment.