Skip to content

Commit

Permalink
Merge pull request #37714 from JuliaLang/jn/35600-again
Browse files Browse the repository at this point in the history
Improve typesubtract for tuples (repeat #35600)
  • Loading branch information
vtjnash authored Sep 24, 2020
2 parents b55e250 + 47d1f62 commit 3b55dae
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 16 deletions.
8 changes: 4 additions & 4 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
while valtype !== Any
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt
stateordonet = widenconst(stateordonet)
nounion = typesubtract(stateordonet, Nothing)
nounion = typesubtract(stateordonet, Nothing, 0)
if !isa(nounion, DataType) || !(nounion <: Tuple) || isvatuple(nounion) || length(nounion.parameters) != 2
valtype = Any
break
Expand Down Expand Up @@ -814,7 +814,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
tty_lb = tty_ub # TODO: this would be wrong if !isexact_tty, but instanceof_tfunc doesn't preserve this info
if !has_free_typevars(tty_lb) && !has_free_typevars(tty_ub)
ifty = typeintersect(aty, tty_ub)
elty = typesubtract(aty, tty_lb)
elty = typesubtract(aty, tty_lb, InferenceParams(interp).MAX_UNION_SPLITTING)
return Conditional(a, ifty, elty)
end
end
Expand All @@ -831,7 +831,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
elseif rt === Const(true)
bty = Union{}
elseif bty isa Type && isdefined(typeof(aty.val), :instance) # can only widen a if it is a singleton
bty = typesubtract(bty, typeof(aty.val))
bty = typesubtract(bty, typeof(aty.val), InferenceParams(interp).MAX_UNION_SPLITTING)
end
return Conditional(b, aty, bty)
end
Expand All @@ -841,7 +841,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
elseif rt === Const(true)
aty = Union{}
elseif aty isa Type && isdefined(typeof(bty.val), :instance) # same for b
aty = typesubtract(aty, typeof(bty.val))
aty = typesubtract(aty, typeof(bty.val), InferenceParams(interp).MAX_UNION_SPLITTING)
end
return Conditional(a, bty, aty)
end
Expand Down
40 changes: 35 additions & 5 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end
# some of these queries, this check can be used to somewhat protect against making incorrect
# decisions based on incorrect subtyping. Note that this check, itself, is broken for
# certain combinations of `a` and `b` where one/both isa/are `Union`/`UnionAll` type(s)s.
isnotbrokensubtype(@nospecialize(a), @nospecialize(b)) = (!iskindtype(b) || !isType(a) || hasuniquerep(a.parameters[1]))
isnotbrokensubtype(@nospecialize(a), @nospecialize(b)) = (!iskindtype(b) || !isType(a) || hasuniquerep(a.parameters[1]) || b <: a)

argtypes_to_type(argtypes::Array{Any,1}) = Tuple{anymap(widenconst, argtypes)...}

Expand All @@ -63,13 +63,43 @@ end

# return an upper-bound on type `a` with type `b` removed
# such that `return <: a` && `Union{return, b} == Union{a, b}`
function typesubtract(@nospecialize(a), @nospecialize(b))
function typesubtract(@nospecialize(a), @nospecialize(b), MAX_UNION_SPLITTING::Int)
if a <: b && isnotbrokensubtype(a, b)
return Bottom
end
if isa(a, Union)
return Union{typesubtract(a.a, b),
typesubtract(a.b, b)}
ua = unwrap_unionall(a)
if isa(ua, Union)
return Union{typesubtract(rewrap_unionall(ua.a, a), b, MAX_UNION_SPLITTING),
typesubtract(rewrap_unionall(ua.b, a), b, MAX_UNION_SPLITTING)}
elseif a isa DataType
ub = unwrap_unionall(b)
if ub isa DataType
if a.name === ub.name === Tuple.name &&
length(a.parameters) == length(ub.parameters)
if 1 < unionsplitcost(a.parameters) <= MAX_UNION_SPLITTING
ta = switchtupleunion(a)
return typesubtract(Union{ta...}, b, 0)
elseif b isa DataType
# if exactly one element is not bottom after calling typesubtract
# then the result is all of the elements as normal except that one
notbottom = fill(false, length(a.parameters))
for i = 1:length(notbottom)
ap = a.parameters[i]
bp = b.parameters[i]
notbottom[i] = !(ap <: bp && isnotbrokensubtype(ap, bp))
end
let i = findfirst(notbottom)
if i !== nothing && findnext(notbottom, i + 1) === nothing
ta = collect(a.parameters)
ap = a.parameters[i]
bp = b.parameters[i]
ta[i] = typesubtract(ap, bp, min(2, MAX_UNION_SPLITTING))
return Tuple{ta...}
end
end
end
end
end
end
return a # TODO: improve this bound?
end
Expand Down
2 changes: 1 addition & 1 deletion base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,7 @@ else
# fixpoint.
approx_iter_type(itrT::Type) = _approx_iter_type(itrT, Base._return_type(iterate, Tuple{itrT}))
# Not actually called, just passed to return type to avoid
# having to typesubtract
# having to typesplit on Nothing
function doiterate(itr, valstate::Union{Nothing, Tuple{Any, Any}})
valstate === nothing && return nothing
val, st = valstate
Expand Down
2 changes: 1 addition & 1 deletion base/missing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Any
!!! compat "Julia 1.3"
This function is exported as of Julia 1.3.
"""
nonmissingtype(::Type{T}) where {T} = Core.Compiler.typesubtract(T, Missing)
nonmissingtype(::Type{T}) where {T} = typesplit(T, Missing)

function nonmissingtype_checked(T::Type)
R = nonmissingtype(T)
Expand Down
19 changes: 18 additions & 1 deletion base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,23 @@ function typejoin(@nospecialize(a), @nospecialize(b))
return Any
end

# return an upper-bound on type `a` with type `b` removed
# such that `return <: a` && `Union{return, b} == Union{a, b}`
# WARNING: this is wrong for some objects for which subtyping is broken
# (Core.Compiler.isnotbrokensubtype), use only simple types for `b`
function typesplit(@nospecialize(a), @nospecialize(b))
@_pure_meta
if a <: b
return Bottom
end
if isa(a, Union)
return Union{typesplit(a.a, b),
typesplit(a.b, b)}
end
return a
end


"""
promote_typejoin(T, S)
Expand All @@ -132,7 +149,7 @@ function promote_typejoin(@nospecialize(a), @nospecialize(b))
c = typejoin(_promote_typesubtract(a), _promote_typesubtract(b))
return Union{a, b, c}::Type
end
_promote_typesubtract(@nospecialize(a)) = Core.Compiler.typesubtract(a, Union{Nothing, Missing})
_promote_typesubtract(@nospecialize(a)) = typesplit(a, Union{Nothing, Missing})


# Returns length, isfixed
Expand Down
2 changes: 1 addition & 1 deletion base/set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ promote_valuetype(x::Pair{K, V}, y::Pair...) where {K, V} =
# Subtract singleton types which are going to be replaced
function subtract_singletontype(::Type{T}, x::Pair{K}) where {T, K}
if issingletontype(K)
Core.Compiler.typesubtract(T, K)
typesplit(T, K)
else
T
end
Expand Down
2 changes: 1 addition & 1 deletion base/some.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Some(::Type{T}) where {T} = Some{Type{T}}(T)

promote_rule(::Type{Some{T}}, ::Type{Some{S}}) where {T, S<:T} = Some{T}

nonnothingtype(::Type{T}) where {T} = Core.Compiler.typesubtract(T, Nothing)
nonnothingtype(::Type{T}) where {T} = typesplit(T, Nothing)
promote_rule(T::Type{Nothing}, S::Type) = Union{S, Nothing}
function promote_rule(T::Type{>:Nothing}, S::Type)
R = nonnothingtype(T)
Expand Down
4 changes: 2 additions & 2 deletions stdlib/Test/src/Test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import Distributed: myid
using Random
using Random: AbstractRNG, default_rng
using InteractiveUtils: gen_call_with_extracted_types
using Core.Compiler: typesubtract
using Base: typesplit

const DISPLAY_FAILED = (
:isequal,
Expand Down Expand Up @@ -1393,7 +1393,7 @@ function _inferred(ex, mod, allow = :(Union{}))
end)
@assert length(inftypes) == 1
rettype = result isa Type ? Type{result} : typeof(result)
rettype <: allow || rettype == typesubtract(inftypes[1], allow) || error("return type $rettype does not match inferred return type $(inftypes[1])")
rettype <: allow || rettype == typesplit(inftypes[1], allow) || error("return type $rettype does not match inferred return type $(inftypes[1])")
result
end
end)
Expand Down
54 changes: 54 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2592,6 +2592,60 @@ end

@test map(>:, [Int], [Int]) == [true]

# issue 35566
module Issue35566
function step(acc, x)
xs, = acc
y = x > 0.0 ? x : missing
if y isa eltype(xs)
ys = push!(xs, y)
else
ys = vcat(xs, [y])
end
return (ys,)
end

function probe(y)
if y isa Tuple{Vector{Missing}}
return Val(:missing)
else
return Val(:expected)
end
end

function _foldl_iter(rf, val::T, iter, state) where {T}
while true
ret = iterate(iter, state)
ret === nothing && break
x, state = ret
y = rf(val, x)
if y isa T
val = y
else
return probe(y)
end
end
return Val(:expected)
end

f() = _foldl_iter(step, (Missing[],), [0.0], 1)
end
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 0) == Tuple{Int}
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 1) == Tuple{Int}
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 2) == Tuple{Int}
@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, Tuple{Char, Any, Any}, 0) ==
Tuple{Int, Union{Char, Int}, Union{Char, Int}}
@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, Tuple{Char, Any, Any}, 10) ==
Union{Tuple{Int, Char, Char}, Tuple{Int, Char, Int}, Tuple{Int, Int, Char}, Tuple{Int, Int, Int}}
@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, NTuple{3, Char}, 0) ==
NTuple{3, Union{Int, Char}}
@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, NTuple{3, Char}, 10) ==
Union{Tuple{Char, Char, Int}, Tuple{Char, Int, Char}, Tuple{Char, Int, Int}, Tuple{Int, Char, Char},
Tuple{Int, Char, Int}, Tuple{Int, Int, Char}, Tuple{Int, Int, Int}}


@test Base.return_types(Issue35566.f) == [Val{:expected}]

# constant prop through keyword arguments
_unstable_kw(;x=1,y=2) = x == 1 ? 0 : ""
_use_unstable_kw_1() = _unstable_kw(x = 2)
Expand Down

0 comments on commit 3b55dae

Please sign in to comment.