From 06a688227149f0cfe0a37ceb190939c5002c0a54 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Wed, 1 Sep 2021 10:20:44 +0200 Subject: [PATCH] Make return type of map inferrable with heterogeneous arrays (#42046) Inference is not able to detect the element type automatically, but we can do it manually since we know promote_typejoin is used for widening. This is similar to the approach used for `broadcast` at #30485. (cherry picked from commit 49e3aecd5966a2af0b064c0314cd61c1338abc00) --- base/array.jl | 15 ++++++++----- base/broadcast.jl | 46 +-------------------------------------- base/promotion.jl | 44 +++++++++++++++++++++++++++++++++++++ test/broadcast.jl | 4 ---- test/generic_map_tests.jl | 22 +++++++++++++++++++ test/sets.jl | 1 + 6 files changed, 78 insertions(+), 54 deletions(-) diff --git a/base/array.jl b/base/array.jl index d629064777ce0d..7e30927a71b9c4 100644 --- a/base/array.jl +++ b/base/array.jl @@ -679,10 +679,11 @@ if isdefined(Core, :Compiler) I = esc(itr) return quote if $I isa Generator && ($I).f isa Type - ($I).f + T = ($I).f else - Core.Compiler.return_type(_iterator_upper_bound, Tuple{typeof($I)}) + T = Core.Compiler.return_type(_iterator_upper_bound, Tuple{typeof($I)}) end + promote_typejoin_union(T) end end else @@ -690,7 +691,7 @@ else I = esc(itr) return quote if $I isa Generator && ($I).f isa Type - ($I).f + promote_typejoin_union($I.f) else Any end @@ -715,8 +716,12 @@ function collect(itr::Generator) return _array_for(et, itr.iter, isz) end v1, st = y - arr = _array_for(typeof(v1), itr.iter, isz, shape) - return collect_to_with_first!(arr, v1, itr, st) + dest = _array_for(typeof(v1), itr.iter, isz, shape) + # The typeassert gives inference a helping hand on the element type and dimensionality + # (work-around for #28382) + et′ = et <: Type ? Type : et + RT = dest isa AbstractArray ? AbstractArray{<:et′, ndims(dest)} : Any + collect_to_with_first!(dest, v1, itr, st)::RT end end diff --git a/base/broadcast.jl b/base/broadcast.jl index b34a73041708b0..90479189ffee4d 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -8,7 +8,7 @@ Module containing the broadcasting implementation. module Broadcast using .Base.Cartesian -using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure, +using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, promote_typejoin_union, @pure, _msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias import .Base: copy, copyto!, axes export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, BroadcastFunction @@ -713,50 +713,6 @@ eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])} eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])} eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...} -function promote_typejoin_union(::Type{T}) where T - if T === Union{} - return Union{} - elseif T isa UnionAll - return Any # TODO: compute more precise bounds - elseif T isa Union - return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b)) - elseif T <: Tuple - return typejoin_union_tuple(T) - else - return T - end -end - -@pure function typejoin_union_tuple(T::Type) - u = Base.unwrap_unionall(T) - u isa Union && return typejoin( - typejoin_union_tuple(Base.rewrap_unionall(u.a, T)), - typejoin_union_tuple(Base.rewrap_unionall(u.b, T))) - p = (u::DataType).parameters - lr = length(p)::Int - if lr == 0 - return Tuple{} - end - c = Vector{Any}(undef, lr) - for i = 1:lr - pi = p[i] - U = Core.Compiler.unwrapva(pi) - if U === Union{} - ci = Union{} - elseif U isa Union - ci = typejoin(U.a, U.b) - else - ci = U - end - if i == lr && Core.Compiler.isvarargtype(pi) - c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci} - else - c[i] = ci - end - end - return Base.rewrap_unionall(Tuple{c...}, T) -end - # Inferred eltype of result of broadcast(f, args...) combine_eltypes(f, args::Tuple) = promote_typejoin_union(Base._return_type(f, eltypes(args))) diff --git a/base/promotion.jl b/base/promotion.jl index 43add6cbf5f628..ef29b273f60002 100644 --- a/base/promotion.jl +++ b/base/promotion.jl @@ -161,6 +161,50 @@ function promote_typejoin(@nospecialize(a), @nospecialize(b)) end _promote_typesubtract(@nospecialize(a)) = typesplit(a, Union{Nothing, Missing}) +function promote_typejoin_union(::Type{T}) where T + if T === Union{} + return Union{} + elseif T isa UnionAll + return Any # TODO: compute more precise bounds + elseif T isa Union + return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b)) + elseif T <: Tuple + return typejoin_union_tuple(T) + else + return T + end +end + +function typejoin_union_tuple(T::Type) + @_pure_meta + u = Base.unwrap_unionall(T) + u isa Union && return typejoin( + typejoin_union_tuple(Base.rewrap_unionall(u.a, T)), + typejoin_union_tuple(Base.rewrap_unionall(u.b, T))) + p = (u::DataType).parameters + lr = length(p)::Int + if lr == 0 + return Tuple{} + end + c = Vector{Any}(undef, lr) + for i = 1:lr + pi = p[i] + U = Core.Compiler.unwrapva(pi) + if U === Union{} + ci = Union{} + elseif U isa Union + ci = typejoin(U.a, U.b) + else + ci = U + end + if i == lr && Core.Compiler.isvarargtype(pi) + c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci} + else + c[i] = ci + end + end + return Base.rewrap_unionall(Tuple{c...}, T) +end # Returns length, isfixed function full_va_len(p) diff --git a/test/broadcast.jl b/test/broadcast.jl index 66c215aee92934..329bcc602206b4 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -991,10 +991,6 @@ end @test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int}, Vector{Union{Float64, Missing}}}) == Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}} - @test isequal([1, 2] + [3.0, missing], [4.0, missing]) - @test Core.Compiler.return_type(+, Tuple{Vector{Int}, - Vector{Union{Float64, Missing}}}) == - Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}} @test Core.Compiler.return_type(+, Tuple{Vector{Int}, Vector{Union{Float64, Missing}}}) == Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}} diff --git a/test/generic_map_tests.jl b/test/generic_map_tests.jl index 8e77533362fe39..abd9a31946a9ad 100644 --- a/test/generic_map_tests.jl +++ b/test/generic_map_tests.jl @@ -53,6 +53,28 @@ function generic_map_tests(mapf, inplace_mapf=nothing) @test A == map(x->x*x*x, Float64[1:10...]) @test A === B end + + # Issue #28382: inferrability of map with Union eltype + @test isequal(map(+, [1, 2], [3.0, missing]), [4.0, missing]) + @test Core.Compiler.return_type(map, Tuple{typeof(+), Vector{Int}, + Vector{Union{Float64, Missing}}}) == + Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}} + @test isequal(map(tuple, [1, 2], [3.0, missing]), [(1, 3.0), (2, missing)]) + @test Core.Compiler.return_type(map, Tuple{typeof(tuple), Vector{Int}, + Vector{Union{Float64, Missing}}}) == + Vector{<:Tuple{Int, Any}} + # Check that corner cases do not throw an error + @test isequal(map(x -> x === 1 ? nothing : x, [1, 2, missing]), + [nothing, 2, missing]) + @test isequal(map(x -> x === 1 ? nothing : x, Any[1, 2, 3.0, missing]), + [nothing, 2, 3, missing]) + @test map((x,y)->(x==1 ? 1.0 : x, y), [1, 2, 3], ["a", "b", "c"]) == + [(1.0, "a"), (2, "b"), (3, "c")] + @test map(typeof, [iszero, isdigit]) == [typeof(iszero), typeof(isdigit)] + @test map(typeof, [iszero, iszero]) == [typeof(iszero), typeof(iszero)] + @test isequal(map(identity, Vector{<:Union{Int, Missing}}[[1, 2],[missing, 1]]), + [[1, 2],[missing, 1]]) + @test map(x -> x < 0 ? false : x, Int[]) isa Vector{Integer} end function testmap_equivalence(mapf, f, c...) diff --git a/test/sets.jl b/test/sets.jl index 46854dae957c6d..5de38e96b9e319 100644 --- a/test/sets.jl +++ b/test/sets.jl @@ -22,6 +22,7 @@ using Dates @test isa(Set(sin(x) for x = 1:3), Set{Float64}) @test isa(Set(f17741(x) for x = 1:3), Set{Int}) @test isa(Set(f17741(x) for x = -1:1), Set{Integer}) + @test isa(Set(f17741(x) for x = 1:0), Set{Integer}) end let s1 = Set(["foo", "bar"]), s2 = Set(s1) @test s1 == s2