From 0964c22c6443f8c29320ad9449dfbd28dd458cad Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Thu, 23 Jul 2015 22:31:24 -0500 Subject: [PATCH] Add promote_op, which uses inference to determine types (fixes #8027) promote_op(F, R, S) computes the output type of F(r, s), where r::R and s::S. This fixes the problem that arises from relying on promote_type(R, S) when the result type depends on F. --- base/arraymath.jl | 22 ++++++++++++---------- base/bitarray.jl | 37 ++++++++++++++++++++----------------- base/complex.jl | 2 +- base/functors.jl | 5 +++++ base/promotion.jl | 7 +++++++ 5 files changed, 45 insertions(+), 28 deletions(-) diff --git a/base/arraymath.jl b/base/arraymath.jl index efd6b3b768d99..1b2506c185543 100644 --- a/base/arraymath.jl +++ b/base/arraymath.jl @@ -38,10 +38,10 @@ end ## Binary arithmetic operators ## -promote_array_type{Scalar, Arry}(::Type{Scalar}, ::Type{Arry}) = promote_type(Scalar, Arry) -promote_array_type{S<:Real, A<:FloatingPoint}(::Type{S}, ::Type{A}) = A -promote_array_type{S<:Integer, A<:Integer}(::Type{S}, ::Type{A}) = A -promote_array_type{S<:Integer}(::Type{S}, ::Type{Bool}) = S +promote_array_type{Scalar, Arry}(F, ::Type{Scalar}, ::Type{Arry}) = promote_op(F, Scalar, Arry) +promote_array_type{S<:Real, A<:FloatingPoint}(F, ::Type{S}, ::Type{A}) = A +promote_array_type{S<:Integer, A<:Integer}(F, ::Type{S}, ::Type{A}) = A +promote_array_type{S<:Integer}(F, ::Type{S}, ::Type{Bool}) = S # Handle operations that return different types ./(x::Number, Y::AbstractArray) = @@ -58,9 +58,10 @@ promote_array_type{S<:Integer}(::Type{S}, ::Type{Bool}) = S reshape([ x ^ y for x in X ], size(X)) for f in (:+, :-, :div, :mod, :&, :|, :$) + F = GenericNFunc{f,2}() @eval begin function ($f){S,T}(A::Range{S}, B::Range{T}) - F = similar(A, promote_type(S,T), promote_shape(size(A),size(B))) + F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B))) i = 1 for (a,b) in zip(A,B) @inbounds F[i] = ($f)(a, b) @@ -69,7 +70,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$) return F end function ($f){S,T}(A::AbstractArray{S}, B::Range{T}) - F = similar(A, promote_type(S,T), promote_shape(size(A),size(B))) + F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B))) i = 1 for b in B @inbounds F[i] = ($f)(A[i], b) @@ -78,7 +79,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$) return F end function ($f){S,T}(A::Range{S}, B::AbstractArray{T}) - F = similar(B, promote_type(S,T), promote_shape(size(A),size(B))) + F = similar(B, promote_op($F,S,T), promote_shape(size(A),size(B))) i = 1 for a in A @inbounds F[i] = ($f)(a, B[i]) @@ -87,7 +88,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$) return F end function ($f){S,T}(A::AbstractArray{S}, B::AbstractArray{T}) - F = similar(A, promote_type(S,T), promote_shape(size(A),size(B))) + F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B))) for i in eachindex(A,B) @inbounds F[i] = ($f)(A[i], B[i]) end @@ -96,16 +97,17 @@ for f in (:+, :-, :div, :mod, :&, :|, :$) end end for f in (:.+, :.-, :.*, :.%, :.<<, :.>>, :div, :mod, :rem, :&, :|, :$) + F = GenericNFunc{f,2}() @eval begin function ($f){T}(A::Number, B::AbstractArray{T}) - F = similar(B, promote_array_type(typeof(A),T)) + F = similar(B, promote_array_type($F,typeof(A),T)) for i in eachindex(B) @inbounds F[i] = ($f)(A, B[i]) end return F end function ($f){T}(A::AbstractArray{T}, B::Number) - F = similar(A, promote_array_type(typeof(B),T)) + F = similar(A, promote_array_type($F,typeof(B),T)) for i in eachindex(A) @inbounds F[i] = ($f)(A[i], B) end diff --git a/base/bitarray.jl b/base/bitarray.jl index b22e8a97a8942..d20a7a10c58de 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -855,21 +855,23 @@ for f in (:+, :-) return r end end -for f in (:.+, :.-), - (arg1, arg2, T, fargs) in ((:(B::BitArray), :(x::Bool) , Int , :(b, x)), - (:(B::BitArray), :(x::Number) , :(promote_array_type(typeof(x), Bool)), :(b, x)), - (:(x::Bool) , :(B::BitArray), Int , :(x, b)), - (:(x::Number) , :(B::BitArray), :(promote_array_type(typeof(x), Bool)), :(x, b))) - @eval function ($f)($arg1, $arg2) - r = Array($T, size(B)) - bi = start(B) - ri = 1 - while !done(B, bi) - b, bi = next(B, bi) - @inbounds r[ri] = ($f)($fargs...) - ri += 1 +for f in (:.+, :.-) + fq = Expr(:quote, f) + for (arg1, arg2, T, fargs) in ((:(B::BitArray), :(x::Bool) , Int , :(b, x)), + (:(B::BitArray), :(x::Number) , :(promote_array_type(GenericNFunc{$fq,2}(),typeof(x), Bool)), :(b, x)), + (:(x::Bool) , :(B::BitArray), Int , :(x, b)), + (:(x::Number) , :(B::BitArray), :(promote_array_type(GenericNFunc{$fq,2}(),typeof(x), Bool)), :(x, b))) + @eval function ($f)($arg1, $arg2) + r = Array($T, size(B)) + bi = start(B) + ri = 1 + while !done(B, bi) + b, bi = next(B, bi) + @inbounds r[ri] = ($f)($fargs...) + ri += 1 + end + return r end - return r end end @@ -897,7 +899,7 @@ function div(x::Bool, B::BitArray) end function div(x::Number, B::BitArray) all(B) || throw(DivideError()) - pt = promote_array_type(typeof(x), Bool) + pt = promote_array_type(GenericNFunc{:div,2}(),typeof(x), Bool) y = div(x, true) reshape(pt[ y for i = 1:length(B) ], size(B)) end @@ -918,15 +920,16 @@ function mod(x::Bool, B::BitArray) end function mod(x::Number, B::BitArray) all(B) || throw(DivideError()) - pt = promote_array_type(typeof(x), Bool) + pt = promote_array_type(GenericNFunc{:mod,2}(), typeof(x), Bool) y = mod(x, true) reshape(pt[ y for i = 1:length(B) ], size(B)) end for f in (:div, :mod) + F = GenericNFunc{f,2}() @eval begin function ($f)(B::BitArray, x::Number) - F = Array(promote_array_type(typeof(x), Bool), size(B)) + F = Array(promote_array_type($F, typeof(x), Bool), size(B)) for i = 1:length(F) F[i] = ($f)(B[i], x) end diff --git a/base/complex.jl b/base/complex.jl index 339a40d7849d0..eda1a60f82894 100644 --- a/base/complex.jl +++ b/base/complex.jl @@ -747,7 +747,7 @@ big{T<:FloatingPoint,N}(A::AbstractArray{Complex{T},N}) = convert(AbstractArray{ ## promotion to complex ## -promote_array_type{S<:Union{Complex, Real}, AT<:FloatingPoint}(::Type{S}, ::Type{Complex{AT}}) = Complex{AT} +promote_array_type{S<:Union{Complex, Real}, AT<:FloatingPoint}(F, ::Type{S}, ::Type{Complex{AT}}) = Complex{AT} function complex{S<:Real,T<:Real}(A::Array{S}, B::Array{T}) if size(A) != size(B); throw(DimensionMismatch()); end diff --git a/base/functors.jl b/base/functors.jl index 8f239ea661c40..54478a427dee7 100644 --- a/base/functors.jl +++ b/base/functors.jl @@ -67,6 +67,11 @@ call(::LessFun, x, y) = x < y immutable MoreFun <: Func{2} end call(::MoreFun, x, y) = x > y +immutable GenericNFunc{F,N} <: Func{N} end +@generated function call{F}(::GenericNFunc{F,2}, x, y) + :($F(x, y)) +end + # a fallback unspecialized function object that allows code using # function objects to not care whether they were able to specialize on # the function value or not diff --git a/base/promotion.jl b/base/promotion.jl index 56c82062f54ea..2729acdfaa934 100644 --- a/base/promotion.jl +++ b/base/promotion.jl @@ -199,6 +199,13 @@ checked_add(x::Integer, y::Integer) = checked_add(promote(x,y)...) checked_sub(x::Integer, y::Integer) = checked_sub(promote(x,y)...) checked_mul(x::Integer, y::Integer) = checked_mul(promote(x,y)...) +@generated function promote_op{R,S}(F, ::Type{R}, ::Type{S}) + ret = Base.return_types(F(), Tuple{R,S}) + length(ret) == 1 || error("Strange result from Base.return_types: ", ret) + T = ret[1] + :($T) +end + ## catch-alls to prevent infinite recursion when definitions are missing ## no_op_err(name, T) = error(name," not defined for ",T)