diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index 546f4e90bd478..ebf148f428383 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -75,10 +75,29 @@ function typesubtract(@nospecialize(a), @nospecialize(b), MAX_UNION_SPLITTING::I ub = unwrap_unionall(b) if ub isa DataType if a.name === ub.name === Tuple.name && - length(a.parameters) == length(ub.parameters) && - 1 < unionsplitcost(a.parameters) <= MAX_UNION_SPLITTING - ta = switchtupleunion(a) - return typesubtract(Union{ta...}, b, 0) + length(a.parameters) == length(ub.parameters) + if 1 < unionsplitcost(a.parameters) <= MAX_UNION_SPLITTING + ta = switchtupleunion(a) + return typesubtract(Union{ta...}, b, 0) + elseif b isa DataType + # if exactly one element is not bottom after calling typesubtract + # then the result is all of the elements as normal except that one + notbottom = fill(false, length(a.parameters)) + for i = 1:length(notbottom) + ap = a.parameters[i] + bp = b.parameters[i] + notbottom[i] = !(ap <: bp && isnotbrokensubtype(ap, bp)) + end + let i = findfirst(notbottom) + if i !== nothing && findnext(notbottom, i + 1) === nothing + ta = collect(a.parameters) + ap = a.parameters[i] + bp = b.parameters[i] + ta[i] = typesubtract(ap, bp, min(2, MAX_UNION_SPLITTING)) + return Tuple{ta...} + end + end + end end end end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 586af7551ff45..9f45de723acb2 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -2630,8 +2630,20 @@ end f() = _foldl_iter(step, (Missing[],), [0.0], 1) end -@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 1) == Tuple{Union{Int,Char}} +@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 0) == Tuple{Int} +@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 1) == Tuple{Int} @test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 2) == Tuple{Int} +@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, Tuple{Char, Any, Any}, 0) == + Tuple{Int, Union{Char, Int}, Union{Char, Int}} +@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, Tuple{Char, Any, Any}, 10) == + Union{Tuple{Int, Char, Char}, Tuple{Int, Char, Int}, Tuple{Int, Int, Char}, Tuple{Int, Int, Int}} +@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, NTuple{3, Char}, 0) == + NTuple{3, Union{Int, Char}} +@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, NTuple{3, Char}, 10) == + Union{Tuple{Char, Char, Int}, Tuple{Char, Int, Char}, Tuple{Char, Int, Int}, Tuple{Int, Char, Char}, + Tuple{Int, Char, Int}, Tuple{Int, Int, Char}, Tuple{Int, Int, Int}} + + @test Base.return_types(Issue35566.f) == [Val{:expected}] # constant prop through keyword arguments