Skip to content

Commit

Permalink
inference: apply a limit to permitting typesubtract for tuples (from #…
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash committed Sep 24, 2020
1 parent 811b3a3 commit 5a56ecd
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 16 deletions.
8 changes: 4 additions & 4 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
while valtype !== Any
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt
stateordonet = widenconst(stateordonet)
nounion = typesubtract(stateordonet, Nothing)
nounion = typesubtract(stateordonet, Nothing, 0)
if !isa(nounion, DataType) || !(nounion <: Tuple) || isvatuple(nounion) || length(nounion.parameters) != 2
valtype = Any
break
Expand Down Expand Up @@ -814,7 +814,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
tty_lb = tty_ub # TODO: this would be wrong if !isexact_tty, but instanceof_tfunc doesn't preserve this info
if !has_free_typevars(tty_lb) && !has_free_typevars(tty_ub)
ifty = typeintersect(aty, tty_ub)
elty = typesubtract(aty, tty_lb)
elty = typesubtract(aty, tty_lb, InferenceParams(interp).MAX_UNION_SPLITTING)
return Conditional(a, ifty, elty)
end
end
Expand All @@ -831,7 +831,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
elseif rt === Const(true)
bty = Union{}
elseif bty isa Type && isdefined(typeof(aty.val), :instance) # can only widen a if it is a singleton
bty = typesubtract(bty, typeof(aty.val))
bty = typesubtract(bty, typeof(aty.val), InferenceParams(interp).MAX_UNION_SPLITTING)
end
return Conditional(b, aty, bty)
end
Expand All @@ -841,7 +841,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
elseif rt === Const(true)
aty = Union{}
elseif aty isa Type && isdefined(typeof(bty.val), :instance) # same for b
aty = typesubtract(aty, typeof(bty.val))
aty = typesubtract(aty, typeof(bty.val), InferenceParams(interp).MAX_UNION_SPLITTING)
end
return Conditional(a, bty, aty)
end
Expand Down
20 changes: 11 additions & 9 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,22 @@ end

# return an upper-bound on type `a` with type `b` removed
# such that `return <: a` && `Union{return, b} == Union{a, b}`
function typesubtract(@nospecialize(a), @nospecialize(b))
function typesubtract(@nospecialize(a), @nospecialize(b), MAX_UNION_SPLITTING::Int)
if a <: b && isnotbrokensubtype(a, b)
return Bottom
end
if isa(a, Union)
return Union{typesubtract(a.a, b),
typesubtract(a.b, b)}
ua = unwrap_unionall(a)
if isa(ua, Union)
return Union{typesubtract(rewrap_unionall(ua.a, a), b, MAX_UNION_SPLITTING),
typesubtract(rewrap_unionall(ua.b, a), b, MAX_UNION_SPLITTING)}
elseif a isa DataType
if b isa DataType
if a.name === b.name === Tuple.name && length(a.types) == length(b.types)
ub = unwrap_unionall(b)
if ub isa DataType
if a.name === ub.name === Tuple.name &&
length(a.parameters) == length(ub.parameters) &&
1 < unionsplitcost(a.parameters) <= MAX_UNION_SPLITTING
ta = switchtupleunion(a)
if length(ta) > 1
return typesubtract(Union{ta...}, b)
end
return typesubtract(Union{ta...}, b, 0)
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions stdlib/Test/src/Test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import Distributed: myid
using Random
using Random: AbstractRNG, default_rng
using InteractiveUtils: gen_call_with_extracted_types
using Core.Compiler: typesubtract
using Base: typesplit

const DISPLAY_FAILED = (
:isequal,
Expand Down Expand Up @@ -1393,7 +1393,7 @@ function _inferred(ex, mod, allow = :(Union{}))
end)
@assert length(inftypes) == 1
rettype = result isa Type ? Type{result} : typeof(result)
rettype <: allow || rettype == typesubtract(inftypes[1], allow) || error("return type $rettype does not match inferred return type $(inftypes[1])")
rettype <: allow || rettype == typesplit(inftypes[1], allow) || error("return type $rettype does not match inferred return type $(inftypes[1])")
result
end
end)
Expand Down
3 changes: 2 additions & 1 deletion test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2630,7 +2630,8 @@ end

f() = _foldl_iter(step, (Missing[],), [0.0], 1)
end
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}) == Tuple{Int}
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 1) == Tuple{Union{Int,Char}}
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 2) == Tuple{Int}
@test Base.return_types(Issue35566.f) == [Val{:expected}]

# constant prop through keyword arguments
Expand Down

0 comments on commit 5a56ecd

Please sign in to comment.