From 76d506e2bea863ee73aa83861551912528ef476e Mon Sep 17 00:00:00 2001 From: Chris Foster Date: Thu, 26 Sep 2019 22:58:38 +1000 Subject: [PATCH] Avoid eltype degrading to Union{} for empty map/broadcast This reverts to using Core.Compiler.return_type for map/broadcast, but only in the very restricted case that the output container is completely empty. This is consistent with the way that return_type is used in Base for collect and broadcast for empty collections only. --- src/broadcast.jl | 14 +++++++++----- src/mapreduce.jl | 22 ++++++++++++++++++---- test/broadcast.jl | 10 +++++++--- test/linalg.jl | 6 +++++- test/lu.jl | 6 +----- test/mapreduce.jl | 5 +++++ 6 files changed, 45 insertions(+), 18 deletions(-) diff --git a/src/broadcast.jl b/src/broadcast.jl index db74f89e..dd4be227 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -97,11 +97,15 @@ scalar_getindex(x::Ref) = x[] scalar_getindex(x::Tuple{<: Any}) = x[1] @generated function _broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize - first_staticarray = 0 - for i = 1:length(a) - if a[i] <: StaticArray - first_staticarray = a[i] - break + first_staticarray = a[findfirst(ai -> ai <: StaticArray, a)] + + if prod(newsize) == 0 + # Use inference to get eltype in empty case (see also comments in _map) + eltys = [:(eltype(a[$i])) for i ∈ 1:length(a)] + return quote + @_inline_meta + T = Core.Compiler.return_type(f, Tuple{$(eltys...)}) + @inbounds return similar_type($first_staticarray, T, Size(newsize))() end end diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 7806bba3..6b5c80f1 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -18,17 +18,31 @@ end end @generated function _map(f, a::AbstractArray...) - i = findfirst(ai -> ai <: StaticArray, a) - if i === nothing + first_staticarray = findfirst(ai -> ai <: StaticArray, a) + if first_staticarray === nothing return :(throw(ArgumentError("No StaticArray found in argument list"))) end # Passing the Size as an argument to _map leads to inference issues when # recursively mapping over nested StaticArrays (see issue #593). Calling - # Size in the generator here is valid because a[i] is known to be a + # Size in the generator here is valid because a[first_staticarray] is known to be a # StaticArray for which the default Size method is correct. If wrapped # StaticArrays (with a custom Size method) are to be supported, this will # no longer be valid. - S = Size(a[i]) + S = Size(a[first_staticarray]) + + if prod(S) == 0 + # In the empty case only, use inference to try figuring out a sensible + # eltype, as is done in Base.collect and broadcast. + # See https://github.com/JuliaArrays/StaticArrays.jl/issues/528 + eltys = [:(eltype(a[$i])) for i ∈ 1:length(a)] + return quote + @_inline_meta + S = same_size(a...) + T = Core.Compiler.return_type(f, Tuple{$(eltys...)}) + @inbounds return similar_type(a[$first_staticarray], T, S)() + end + end + exprs = Vector{Expr}(undef, prod(S)) for i ∈ 1:prod(S) tmp = [:(a[$j][$i]) for j ∈ 1:length(a)] diff --git a/test/broadcast.jl b/test/broadcast.jl index a6209d5a..eebce977 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -109,9 +109,6 @@ end @test @inferred(v2 .- v1) === SVector(0, 2) @test @inferred(v1 .^ v2) === SVector(1, 16) @test @inferred(v2 .^ v1) === SVector(1, 16) - # Issue #199: broadcast with empty SArray - @test @inferred(SVector(1) .+ SVector{0,Int}()) === SVector{0,Union{}}() - @test @inferred(SVector{0,Int}() .+ SVector(1)) === SVector{0,Union{}}() # Issue #200: broadcast with Adjoint @test @inferred(v1 .+ v2') === @SMatrix [2 5; 3 6] @test @inferred(v1 .+ transpose(v2)) === @SMatrix [2 5; 3 6] @@ -142,6 +139,13 @@ end @test @inferred(zeros(SVector{0}) .+ zeros(SMatrix{0,2})) === zeros(SMatrix{0,2}) m = zeros(MMatrix{0,2}) @test @inferred(broadcast!(+, m, m, zeros(SVector{0}))) == zeros(SMatrix{0,2}) + # Issue #199: broadcast with empty SArray + @test @inferred(SVector(1) .+ SVector{0,Int}()) === SVector{0,Int}() + @test @inferred(SVector{0,Int}() .+ SVector(1.0)) === SVector{0,Float64}() + # Issue #528 + @test @inferred(isapprox(SMatrix{3,0,Float64}(), SMatrix{3,0,Float64}())) + @test @inferred(broadcast(length, SVector{0,String}())) === SVector{0,Int}() + @test @inferred(broadcast(join, SVector{0,String}(), SVector{0,String}(), SVector{0,String}())) === SVector{0,String}() end @testset "Mutating broadcast!" begin diff --git a/test/linalg.jl b/test/linalg.jl index ad7288ad..6b070827 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -2,7 +2,7 @@ using StaticArrays, Test, LinearAlgebra @testset "Linear algebra" begin - @testset "SVector as a (mathematical) vector space" begin + @testset "SArray as a (mathematical) vector space" begin c = 2 v1 = @SVector [2,4,6,8] v2 = @SVector [4,3,2,1] @@ -14,6 +14,10 @@ using StaticArrays, Test, LinearAlgebra @test @inferred(v1 + v2) === @SVector [6, 7, 8, 9] @test @inferred(v1 - v2) === @SVector [-2, 1, 4, 7] + # #528 eltype with empty addition + zm = zeros(SMatrix{3, 0, Float64}) + @test @inferred(zm + zm) === zm + # TODO Decide what to do about this stuff: #v3 = [2,4,6,8] #v4 = [4,3,2,1] diff --git a/test/lu.jl b/test/lu.jl index a1375698..6c009cac 100644 --- a/test/lu.jl +++ b/test/lu.jl @@ -36,9 +36,5 @@ using StaticArrays, Test, LinearAlgebra # decomposition is correct l_u = l*u - if length(l_u) > 0 # Union{} element type breaks norm - @test l*u ≈ a[p,:] - else - @test_broken l*u ≈ a[p,:] - end + @test l*u ≈ a[p,:] end diff --git a/test/mapreduce.jl b/test/mapreduce.jl index e76e6f82..084a3ee7 100644 --- a/test/mapreduce.jl +++ b/test/mapreduce.jl @@ -24,6 +24,11 @@ using Statistics: mean v3 = @SVector [1, 2, 3, 4] map!(+, mv3, v1, v2, v3) @test mv3 == @MVector [7, 9, 11, 13] + + # Output eltype for empty cases #528 + @test @inferred(map(/, SVector{0,Int}(), SVector{0,Int}())) === SVector{0,Float64}() + @test @inferred(map(+, SVector{0,Int}(), SVector{0,Float32}())) === SVector{0,Float32}() + @test @inferred(map(length, SVector{0,String}())) === SVector{0,Int}() end @testset "[map]reduce and [map]reducedim" begin