Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make return type of broadcast inferrable with heterogeneous arrays #30485

Merged
merged 14 commits into from
Oct 27, 2020
54 changes: 51 additions & 3 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Module containing the broadcasting implementation.
module Broadcast

using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin,
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d
Expand Down Expand Up @@ -691,8 +691,52 @@ eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])}
eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])}
eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...}

function promote_typejoin_union(::Type{T}) where T
if T === Union{}
return Union{}
nalimilan marked this conversation as resolved.
Show resolved Hide resolved
elseif T isa Union
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
elseif T <: Tuple
return typejoin_union_tuple(T)
else
return T
end
end

@pure function typejoin_union_tuple(T::Type)
u = Base.unwrap_unionall(T)
u isa Union && return typejoin(
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
p = (u::DataType).parameters
lr = length(p)::Int
if lr == 0
return Tuple{}
end
c = Vector{Any}(undef, lr)
for i = 1:lr
pi = p[i]
U = Core.Compiler.unwrapva(pi)
if U === Union{}
ci = Union{}
elseif U isa Union
ci = typejoin(Base.rewrap_unionall(U.a, T), Base.rewrap_unionall(U.b, T))
nalimilan marked this conversation as resolved.
Show resolved Hide resolved
else
ci = U
end
if i == lr && Core.Compiler.isvarargtype(pi)
N = (Base.unwrap_unionall(pi)::DataType).parameters[2]
c[i] = Vararg{ci, N}
nalimilan marked this conversation as resolved.
Show resolved Hide resolved
else
c[i] = ci
end
end
return Base.rewrap_unionall(Tuple{c...}, T)
end

# Inferred eltype of result of broadcast(f, args...)
combine_eltypes(f, args::Tuple) = Base._return_type(f, eltypes(args))
combine_eltypes(f, args::Tuple) =
promote_typejoin_union(Base._return_type(f, eltypes(args)))

## Broadcasting core

Expand Down Expand Up @@ -877,7 +921,11 @@ const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict}
dest = similar(bc′, typeof(val))
@inbounds dest[I] = val
# Now handle the remaining values
return copyto_nonleaf!(dest, bc′, iter, state, 1)
# The typeassert gives inference a helping hand on the element type and dimensionality
# (work-around for #28382)
ElType′ = ElType <: Type ? DataType : ElType
nalimilan marked this conversation as resolved.
Show resolved Hide resolved
RT = dest isa AbstractArray ? AbstractArray{<:ElType′, ndims(dest)} : Any
return copyto_nonleaf!(dest, bc′, iter, state, 1)::RT
end

## general `copyto!` methods
Expand Down
40 changes: 39 additions & 1 deletion test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ end
let f17314 = x -> x < 0 ? false : x
@test eltype(broadcast(f17314, 1:3)) === Int
@test eltype(broadcast(f17314, -1:1)) === Integer
@test eltype(broadcast(f17314, Int[])) == Union{Bool,Int}
@test eltype(broadcast(f17314, Int[])) === Integer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a minor change which I think can be considered as a bug fix, in the sense that before this PR the element type when the input is empty will never be observed when the array isn't empty (we can only ever observe Int, Bool or Integer). I could change the PR to preserve the existing behavior if we want (e.g. for backports).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's interesting. Yeah, here's the current behaviors:

julia> eltype(broadcast(f17314, Int[]))
Union{Bool, Int64}

julia> eltype(broadcast(f17314, Int[1]))
Int64

julia> eltype(broadcast(f17314, Int[-1]))
Bool

julia> eltype(broadcast(f17314, Int[1,-1]))
Integer

The reason for using inference here in the first place is to preserve the performance in the non-empty case. Adding a fourth possible return type defeats such a purpose, so I'm in support of this change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly - we either need to change the last to match the first, or the first to match the last.

This PR seems the least breaking, and suitable for v1.x. If we ever wanted to consider the other way around maybe that should be a v2.0 change?

end
let io = IOBuffer()
broadcast(x->print(io,x), 1:5) # broadcast with side effects
Expand Down Expand Up @@ -944,3 +944,41 @@ p = rand(4,4); r = rand(2,4);
p0 = copy(p)
@views @. p[1:2, :] += r
@test p[1:2, :] ≈ p0[1:2, :] + r

@testset "Issue #28382: inferrability of broadcast with Union eltype" begin
@test isequal([1, 2] .+ [3.0, missing], [4.0, missing])
@test_broken Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Union{Float64, Missing}}
@test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
AbstractVector{<:Union{Float64, Missing}}
@test isequal([1, 2] + [3.0, missing], [4.0, missing])
@test_broken Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Union{Float64, Missing}}
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
AbstractVector{<:Union{Float64, Missing}}
@test_broken Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Union{Float64, Missing}}
@test isequal(tuple.([1, 2], [3.0, missing]), [(1, 3.0), (2, missing)])
@test_broken Core.Compiler.return_type(broadcast, Tuple{typeof(tuple), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Tuple{Int, Any}}
@test Core.Compiler.return_type(broadcast, Tuple{typeof(tuple), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
AbstractVector{<:Tuple{Int, Any}}
Comment on lines +956 to +978
Copy link
Member

@simeonschaub simeonschaub Feb 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#39618 seems to fix these tests. @nalimilan Could that mean that this workaround might not even be needed anymore? Or does it maybe just fix this particular example?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah then it's great! Actually all the work I did in this PR had not effect for now due to this inference failure (a regression due to changes to reduce compile time introduced since 1.5). So if you can replace these @test_broken with @test then it's perfect!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, cool. Thanks for checking!

# Check that corner cases do not throw an error
@test isequal(broadcast(x -> x === 1 ? nothing : x, [1, 2, missing]),
[nothing, 2, missing])
@test isequal(broadcast(x -> x === 1 ? nothing : x, Any[1, 2, 3.0, missing]),
[nothing, 2, 3, missing])
@test broadcast((x,y)->(x==1 ? 1.0 : x, y), [1 2 3], ["a", "b", "c"]) ==
[(1.0, "a") (2, "a") (3, "a")
(1.0, "b") (2, "b") (3, "b")
(1.0, "c") (2, "c") (3, "c")]
@test typeof.([iszero, isdigit]) == [typeof(iszero), typeof(isdigit)]
@test typeof.([iszero, iszero]) == [typeof(iszero), typeof(iszero)]
end