Skip to content

Commit

Permalink
inference: make Limited tracking part of the type lattice (JuliaLang#…
Browse files Browse the repository at this point in the history
…39116)

This helps refine our knowledge of the `[limited]` flag setting, which
previously would always exclude a result from the cache when hitting a
cycle. However, we really only need to exclude a result if the result
might be dependent on that flag setting. That makes this formally part
of the lattice, though can be annoying to work with yet another wrapper,
so we try to add/remove it late/early to propagate it when necessary.
  • Loading branch information
vtjnash authored and antoine-levitt committed May 9, 2021
1 parent 5ab0a43 commit 7d5544d
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 106 deletions.
59 changes: 51 additions & 8 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ const _REF_NAME = Ref.body.name
# logic #
#########

# see if the inference result might affect the final answer
call_result_unused(frame::InferenceState, pc::LineNum=frame.currpc) =
isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[pc])
# See if the inference result of the current statement's result value might affect
# the final answer for the method (aside from optimization potential and exceptions).
# To do that, we need to check both for slot assignment and SSA usage.
call_result_unused(frame::InferenceState) =
isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[frame.currpc])

# check if this return type is improvable (i.e. whether it's possible that with
# more information, we might get a more precise type)
Expand Down Expand Up @@ -192,6 +194,16 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
end
#print("=> ", rettype, "\n")
if rettype isa LimitedAccuracy
union!(sv.pclimitations, rettype.causes)
rettype = rettype.typ
end
if !isempty(sv.pclimitations) # remove self, if present
delete!(sv.pclimitations, sv)
for caller in sv.callers_in_cycle
delete!(sv.pclimitations, caller)
end
end
return CallMeta(rettype, info)
end

Expand Down Expand Up @@ -313,7 +325,6 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
inf_result = InferenceResult(mi, argtypes)
frame = InferenceState(inf_result, #=cache=#false, interp)
frame === nothing && return Any # this is probably a bad generated function (unsound), but just ignore it
frame.limited = true
frame.parent = sv
push!(inf_cache, inf_result)
typeinf(interp, frame) || return Any
Expand Down Expand Up @@ -394,7 +405,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
parent = parent::InferenceState
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
if (parent.cached || parent.limited) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
if (parent.cached || parent.parent !== nothing) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
topmost = infstate
edgecycle = true
end
Expand Down Expand Up @@ -443,7 +454,8 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
return Any, true, nothing
end
poison_callstack(sv, topmost::InferenceState, true)
topmost = topmost::InferenceState
poison_callstack(sv, topmost.parent === nothing ? topmost : topmost.parent)
sig = newsig
sparams = svec()
end
Expand Down Expand Up @@ -1124,7 +1136,12 @@ function abstract_eval_value(interp::AbstractInterpreter, @nospecialize(e), vtyp
if isa(e, Expr)
return abstract_eval_value_expr(interp, e, vtypes, sv)
else
return abstract_eval_special_value(interp, e, vtypes, sv)
typ = abstract_eval_special_value(interp, e, vtypes, sv)
if typ isa LimitedAccuracy
union!(sv.pclimitations, typ.causes)
typ = typ.typ
end
return typ
end
end

Expand Down Expand Up @@ -1247,13 +1264,21 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
end
end
else
return abstract_eval_value_expr(interp, e, vtypes, sv)
t = abstract_eval_value_expr(interp, e, vtypes, sv)
end
@assert !isa(t, TypeVar)
if isa(t, DataType) && isdefined(t, :instance)
# replace singleton types with their equivalent Const object
t = Const(t.instance)
end
if !isempty(sv.pclimitations)
if t isa Const || t === Union{}
empty!(sv.pclimitations)
else
t = LimitedAccuracy(t, sv.pclimitations)
sv.pclimitations = IdSet{InferenceState}()
end
end
return t
end

Expand Down Expand Up @@ -1308,10 +1333,18 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
elseif isa(stmt, GotoIfNot)
condt = abstract_eval_value(interp, stmt.cond, s[pc], frame)
if condt === Bottom
empty!(frame.pclimitations)
break
end
condval = maybe_extract_const_bool(condt)
l = stmt.dest::Int
if !isempty(frame.pclimitations)
# we can't model the possible effect of control
# dependencies on the return value, so we propagate it
# directly to all the return values (unless we error first)
condval isa Bool || union!(frame.limitations, frame.pclimitations)
empty!(frame.pclimitations)
end
# constant conditions
if condval === true
elseif condval === false
Expand Down Expand Up @@ -1346,6 +1379,14 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# and is valid inter-procedurally
rt = widenconst(rt)
end
# copy limitations to return value
if !isempty(frame.pclimitations)
union!(frame.limitations, frame.pclimitations)
empty!(frame.pclimitations)
end
if !isempty(frame.limitations)
rt = LimitedAccuracy(rt, copy(frame.limitations))
end
if tchanged(rt, frame.bestguess)
# new (wider) return type for frame
frame.bestguess = tmerge(frame.bestguess, rt)
Expand Down Expand Up @@ -1420,6 +1461,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
end

@assert isempty(frame.pclimitations) "unhandled LimitedAccuracy"

if t === nothing
# mark other reached expressions as `Any` to indicate they don't throw
frame.src.ssavaluetypes[pc] = Any
Expand Down
30 changes: 4 additions & 26 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ mutable struct InferenceState
slottypes::Vector{Any}
mod::Module
currpc::LineNum
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return

# info on the state of inference and the linfo
src::CodeInfo
Expand Down Expand Up @@ -39,7 +41,6 @@ mutable struct InferenceState

# TODO: move these to InferenceResult / Params?
cached::Bool
limited::Bool
inferred::Bool
dont_work_on_me::Bool

Expand Down Expand Up @@ -105,6 +106,7 @@ mutable struct InferenceState
frame = new(
InferenceParams(interp), result, linfo,
sp, slottypes, inmodule, 0,
IdSet{InferenceState}(), IdSet{InferenceState}(),
src, get_world_counter(interp), valid_worlds,
nargs, s_types, s_edges, stmt_info,
Union{}, W, 1, n,
Expand All @@ -113,7 +115,7 @@ mutable struct InferenceState
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
Vector{InferenceState}(), # callers_in_cycle
#=parent=#nothing,
cached, false, false, false,
cached, false, false,
CachedMethodTable(method_table(interp)),
interp)
result.result = frame
Expand Down Expand Up @@ -265,37 +267,13 @@ function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::Infe
nothing
end

function poison_callstack(infstate::InferenceState, topmost::InferenceState, poison_topmost::Bool)
poison_topmost && (topmost = topmost.parent)
while !(infstate === topmost)
if call_result_unused(infstate)
# If we won't propagate the result any further (since it's typically unused),
# it's OK that we keep and cache the "limited" result in the parents
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
# TODO: we might be able to halt progress much more strongly here,
# since now we know we won't be able to keep anything much that we learned.
# We were mainly only here to compute the calling convention return type,
# but in most situations now, we are unlikely to be able to use that information.
break
end
infstate.limited = true
for infstate_cycle in infstate.callers_in_cycle
infstate_cycle.limited = true
end
infstate = infstate.parent
infstate === nothing && return
end
end

function print_callstack(sv::InferenceState)
while sv !== nothing
print(sv.linfo)
sv.limited && print(" [limited]")
!sv.cached && print(" [uncached]")
println()
for cycle in sv.callers_in_cycle
print(' ', cycle.linfo)
cycle.limited && print(" [limited]")
println()
end
sv = sv.parent
Expand Down
4 changes: 4 additions & 0 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1620,10 +1620,14 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
# output was computed to be constant
return Const(typeof(rt.val))
else
inaccurate = nothing
rt isa LimitedAccuracy && (inaccurate = rt.causes; rt = rt.typ)
rt = widenconst(rt)
if hasuniquerep(rt) || rt === Bottom
# output type was known for certain
return Const(rt)
elseif inaccurate !== nothing
return LimitedAccuracy(Type{<:rt}, inaccurate)
elseif (isa(tt, Const) || isconstType(tt)) &&
(isa(aft, Const) || isconstType(aft))
# input arguments were known for certain
Expand Down
Loading

0 comments on commit 7d5544d

Please sign in to comment.