Skip to content

Commit

Permalink
inference: avoid inferring unreachable code methods (#51317)
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash authored Sep 27, 2023
1 parent 8de80bd commit 0a82b71
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 94 deletions.
57 changes: 35 additions & 22 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,10 @@ function abstract_call_method(interp::AbstractInterpreter,
return MethodCallResult(Any, false, false, nothing, Effects())
end
sigtuple = unwrap_unionall(sig)
sigtuple isa DataType || return MethodCallResult(Any, false, false, nothing, Effects())
sigtuple isa DataType ||
return MethodCallResult(Any, false, false, nothing, Effects())
all(@nospecialize(x) -> valid_as_lattice(unwrapva(x), true), sigtuple.parameters) ||
return MethodCallResult(Union{}, false, false, nothing, EFFECTS_THROWS) # catch bad type intersections early

if is_nospecializeinfer(method)
sig = get_nospecializeinfer_sig(method, sig, sparams)
Expand Down Expand Up @@ -1365,25 +1368,35 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
end
if isa(tti, Union)
utis = uniontypes(tti)
if any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
return AbstractIterationResult(Any[Vararg{Any}], nothing, Effects())
end
ltp = length((utis[1]::DataType).parameters)
for t in utis
if length((t::DataType).parameters) != ltp
return AbstractIterationResult(Any[Vararg{Any}], nothing)
# refine the Union to remove elements that are not valid tags for objects
filter!(@nospecialize(x) -> valid_as_lattice(x, true), utis)
if length(utis) == 0
return AbstractIterationResult(Any[], nothing) # oops, this statement was actually unreachable
elseif length(utis) == 1
tti = utis[1]
tti0 = rewrap_unionall(tti, tti0)
else
if any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
return AbstractIterationResult(Any[Vararg{Any}], nothing, Effects())
end
end
result = Any[ Union{} for _ in 1:ltp ]
for t in utis
tps = (t::DataType).parameters
_all(valid_as_lattice, tps) || continue
for j in 1:ltp
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
ltp = length((utis[1]::DataType).parameters)
for t in utis
if length((t::DataType).parameters) != ltp
return AbstractIterationResult(Any[Vararg{Any}], nothing)
end
end
result = Any[ Union{} for _ in 1:ltp ]
for t in utis
tps = (t::DataType).parameters
for j in 1:ltp
@assert valid_as_lattice(tps[j], true)
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
end
end
return AbstractIterationResult(result, nothing)
end
return AbstractIterationResult(result, nothing)
elseif tti0 <: Tuple
end
if tti0 <: Tuple
if isa(tti0, DataType)
return AbstractIterationResult(Any[ p for p in tti0.parameters ], nothing)
elseif !isa(tti, DataType)
Expand Down Expand Up @@ -1647,7 +1660,7 @@ end
return isa_condition(xt, ty, max_union_splitting)
end
@inline function isa_condition(@nospecialize(xt), @nospecialize(ty), max_union_splitting::Int)
tty_ub, isexact_tty = instanceof_tfunc(ty)
tty_ub, isexact_tty = instanceof_tfunc(ty, true)
tty = widenconst(xt)
if isexact_tty && !isa(tty_ub, TypeVar)
tty_lb = tty_ub # TODO: this would be wrong if !isexact_tty, but instanceof_tfunc doesn't preserve this info
Expand All @@ -1657,7 +1670,7 @@ end
# `typeintersect` may be unable narrow down `Type`-type
thentype = tty_ub
end
valid_as_lattice(thentype) || (thentype = Bottom)
valid_as_lattice(thentype, true) || (thentype = Bottom)
elsetype = typesubtract(tty, tty_lb, max_union_splitting)
return ConditionalTypes(thentype, elsetype)
end
Expand Down Expand Up @@ -1903,7 +1916,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
ft === Bottom && return CallMeta(Bottom, EFFECTS_THROWS, NoCallInfo())
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3))
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3), false)
isexact || return CallMeta(Any, Effects(), NoCallInfo())
unwrapped = unwrap_unionall(types)
if types === Bottom || !(unwrapped isa DataType) || unwrapped.name !== Tuple.name
Expand Down Expand Up @@ -2322,7 +2335,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp
(; rt, effects) = abstract_eval_call(interp, e, vtypes, sv)
t = rt
elseif ehead === :new
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv), true)
ut = unwrap_unionall(t)
consistent = noub = ALWAYS_FALSE
nothrow = false
Expand Down Expand Up @@ -2387,7 +2400,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp
end
effects = Effects(EFFECTS_TOTAL; consistent, nothrow, noub)
elseif ehead === :splatnew
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv), true)
nothrow = false # TODO: More precision
if length(e.args) == 2 && isconcretedispatch(t) && !ismutabletype(t)
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
Expand Down
8 changes: 5 additions & 3 deletions base/compiler/abstractlattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,18 @@ is_valid_lattice_norec(::InferenceLattice, @nospecialize(elem)) = isa(elem, Limi
"""
tmeet(𝕃::AbstractLattice, a, b::Type)
Compute the lattice meet of lattice elements `a` and `b` over the lattice `𝕃`.
If `𝕃` is `JLTypeLattice`, this is equivalent to type intersection.
Compute the lattice meet of lattice elements `a` and `b` over the lattice `𝕃`,
dropping any results that will not be inhabited at runtime.
If `𝕃` is `JLTypeLattice`, this is equivalent to type intersection plus the
elimination of results that have no concrete subtypes.
Note that currently `b` is restricted to being a type
(interpreted as a lattice element in the `JLTypeLattice` sub-lattice of `𝕃`).
"""
function tmeet end

function tmeet(::JLTypeLattice, @nospecialize(a::Type), @nospecialize(b::Type))
ti = typeintersect(a, b)
valid_as_lattice(ti) || return Bottom
valid_as_lattice(ti, true) || return Bottom
return ti
end

Expand Down
4 changes: 2 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ function new_expr_effect_flags(𝕃ₒ::AbstractLattice, args::Vector{Any}, src:
Targ = args[1]
atyp = argextype(Targ, src)
# `Expr(:new)` of unknown type could raise arbitrary TypeError.
typ, isexact = instanceof_tfunc(atyp)
typ, isexact = instanceof_tfunc(atyp, true)
if !isexact
atyp = unwrap_unionall(widenconst(atyp))
if isType(atyp) && isTypeDataType(atyp.parameters[1])
Expand Down Expand Up @@ -335,7 +335,7 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
elseif head === :new_opaque_closure
length(args) < 4 && return (false, false, false)
typ = argextype(args[1], src)
typ, isexact = instanceof_tfunc(typ)
typ, isexact = instanceof_tfunc(typ, true)
isexact || return (false, false, false)
(𝕃ₒ, typ, Tuple) || return (false, false, false)
rt_lb = argextype(args[2], src)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ function handle_invoke_call!(todo::Vector{Pair{Int,Any}},
end

function invoke_signature(argtypes::Vector{Any})
ft, argtyps = widenconst(argtypes[2]), instanceof_tfunc(widenconst(argtypes[3]))[1]
ft, argtyps = widenconst(argtypes[2]), instanceof_tfunc(widenconst(argtypes[3]), false)[1]
return rewrap_unionall(Tuple{ft, unwrap_unionall(argtyps).parameters...}, argtyps)
end

Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1756,7 +1756,7 @@ function adce_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
else
if is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3
# nullify safe `typeassert` calls
ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact))
ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact), true)
if isexact && (𝕃ₒ, argextype(stmt.args[2], compact), ty)
compact[idx] = nothing
continue
Expand Down
Loading

0 comments on commit 0a82b71

Please sign in to comment.