Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: avoid inferring unreachable code methods #51317

Merged
merged 2 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,10 @@ function abstract_call_method(interp::AbstractInterpreter,
return MethodCallResult(Any, false, false, nothing, Effects())
end
sigtuple = unwrap_unionall(sig)
sigtuple isa DataType || return MethodCallResult(Any, false, false, nothing, Effects())
sigtuple isa DataType ||
return MethodCallResult(Any, false, false, nothing, Effects())
all(@nospecialize(x) -> valid_as_lattice(unwrapva(x), true), sigtuple.parameters) ||
return MethodCallResult(Union{}, false, false, nothing, EFFECTS_THROWS) # catch bad type intersections early

if is_nospecializeinfer(method)
sig = get_nospecializeinfer_sig(method, sig, sparams)
Expand Down Expand Up @@ -1365,25 +1368,35 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
end
if isa(tti, Union)
utis = uniontypes(tti)
if any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
return AbstractIterationResult(Any[Vararg{Any}], nothing, Effects())
end
ltp = length((utis[1]::DataType).parameters)
for t in utis
if length((t::DataType).parameters) != ltp
return AbstractIterationResult(Any[Vararg{Any}], nothing)
# refine the Union to remove elements that are not valid tags for objects
filter!(@nospecialize(x) -> valid_as_lattice(x, true), utis)
if length(utis) == 0
return AbstractIterationResult(Any[], nothing) # oops, this statement was actually unreachable
elseif length(utis) == 1
tti = utis[1]
tti0 = rewrap_unionall(tti, tti0)
else
if any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
return AbstractIterationResult(Any[Vararg{Any}], nothing, Effects())
end
end
result = Any[ Union{} for _ in 1:ltp ]
for t in utis
tps = (t::DataType).parameters
_all(valid_as_lattice, tps) || continue
for j in 1:ltp
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
ltp = length((utis[1]::DataType).parameters)
for t in utis
if length((t::DataType).parameters) != ltp
return AbstractIterationResult(Any[Vararg{Any}], nothing)
end
end
result = Any[ Union{} for _ in 1:ltp ]
for t in utis
tps = (t::DataType).parameters
for j in 1:ltp
@assert valid_as_lattice(tps[j], true)
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
end
end
return AbstractIterationResult(result, nothing)
end
return AbstractIterationResult(result, nothing)
elseif tti0 <: Tuple
end
if tti0 <: Tuple
if isa(tti0, DataType)
return AbstractIterationResult(Any[ p for p in tti0.parameters ], nothing)
elseif !isa(tti, DataType)
Expand Down Expand Up @@ -1647,7 +1660,7 @@ end
return isa_condition(xt, ty, max_union_splitting)
end
@inline function isa_condition(@nospecialize(xt), @nospecialize(ty), max_union_splitting::Int)
tty_ub, isexact_tty = instanceof_tfunc(ty)
tty_ub, isexact_tty = instanceof_tfunc(ty, true)
tty = widenconst(xt)
if isexact_tty && !isa(tty_ub, TypeVar)
tty_lb = tty_ub # TODO: this would be wrong if !isexact_tty, but instanceof_tfunc doesn't preserve this info
Expand All @@ -1657,7 +1670,7 @@ end
# `typeintersect` may be unable narrow down `Type`-type
thentype = tty_ub
end
valid_as_lattice(thentype) || (thentype = Bottom)
valid_as_lattice(thentype, true) || (thentype = Bottom)
elsetype = typesubtract(tty, tty_lb, max_union_splitting)
return ConditionalTypes(thentype, elsetype)
end
Expand Down Expand Up @@ -1903,7 +1916,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
ft === Bottom && return CallMeta(Bottom, EFFECTS_THROWS, NoCallInfo())
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3))
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3), false)
isexact || return CallMeta(Any, Effects(), NoCallInfo())
unwrapped = unwrap_unionall(types)
if types === Bottom || !(unwrapped isa DataType) || unwrapped.name !== Tuple.name
Expand Down Expand Up @@ -2322,7 +2335,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp
(; rt, effects) = abstract_eval_call(interp, e, vtypes, sv)
t = rt
elseif ehead === :new
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv), true)
ut = unwrap_unionall(t)
consistent = noub = ALWAYS_FALSE
nothrow = false
Expand Down Expand Up @@ -2387,7 +2400,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp
end
effects = Effects(EFFECTS_TOTAL; consistent, nothrow, noub)
elseif ehead === :splatnew
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv), true)
nothrow = false # TODO: More precision
if length(e.args) == 2 && isconcretedispatch(t) && !ismutabletype(t)
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
Expand Down
8 changes: 5 additions & 3 deletions base/compiler/abstractlattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,18 @@ is_valid_lattice_norec(::InferenceLattice, @nospecialize(elem)) = isa(elem, Limi
"""
tmeet(𝕃::AbstractLattice, a, b::Type)

Compute the lattice meet of lattice elements `a` and `b` over the lattice `𝕃`.
If `𝕃` is `JLTypeLattice`, this is equivalent to type intersection.
Compute the lattice meet of lattice elements `a` and `b` over the lattice `𝕃`,
dropping any results that will not be inhabited at runtime.
If `𝕃` is `JLTypeLattice`, this is equivalent to type intersection plus the
elimination of results that have no concrete subtypes.
Note that currently `b` is restricted to being a type
(interpreted as a lattice element in the `JLTypeLattice` sub-lattice of `𝕃`).
"""
function tmeet end

function tmeet(::JLTypeLattice, @nospecialize(a::Type), @nospecialize(b::Type))
ti = typeintersect(a, b)
valid_as_lattice(ti) || return Bottom
valid_as_lattice(ti, true) || return Bottom
return ti
end

Expand Down
4 changes: 2 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ function new_expr_effect_flags(𝕃ₒ::AbstractLattice, args::Vector{Any}, src:
Targ = args[1]
atyp = argextype(Targ, src)
# `Expr(:new)` of unknown type could raise arbitrary TypeError.
typ, isexact = instanceof_tfunc(atyp)
typ, isexact = instanceof_tfunc(atyp, true)
if !isexact
atyp = unwrap_unionall(widenconst(atyp))
if isType(atyp) && isTypeDataType(atyp.parameters[1])
Expand Down Expand Up @@ -335,7 +335,7 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
elseif head === :new_opaque_closure
length(args) < 4 && return (false, false, false)
typ = argextype(args[1], src)
typ, isexact = instanceof_tfunc(typ)
typ, isexact = instanceof_tfunc(typ, true)
isexact || return (false, false, false)
⊑(𝕃ₒ, typ, Tuple) || return (false, false, false)
rt_lb = argextype(args[2], src)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ function handle_invoke_call!(todo::Vector{Pair{Int,Any}},
end

function invoke_signature(argtypes::Vector{Any})
ft, argtyps = widenconst(argtypes[2]), instanceof_tfunc(widenconst(argtypes[3]))[1]
ft, argtyps = widenconst(argtypes[2]), instanceof_tfunc(widenconst(argtypes[3]), false)[1]
return rewrap_unionall(Tuple{ft, unwrap_unionall(argtyps).parameters...}, argtyps)
end

Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1756,7 +1756,7 @@ function adce_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
else
if is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3
# nullify safe `typeassert` calls
ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact))
ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact), true)
if isexact && ⊑(𝕃ₒ, argextype(stmt.args[2], compact), ty)
compact[idx] = nothing
continue
Expand Down
Loading