Skip to content

Commit

Permalink
Add promote_op, which uses inference to determine types (fixes #8027)
Browse files Browse the repository at this point in the history
promote_op(F, R, S) computes the output type of F(r, s), where
r::R and s::S. This fixes the problem that arises from relying
on promote_type(R, S) when the result type depends on F.
  • Loading branch information
timholy committed Jul 29, 2015
1 parent 9b14a7c commit 0964c22
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 28 deletions.
22 changes: 12 additions & 10 deletions base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ end

## Binary arithmetic operators ##

promote_array_type{Scalar, Arry}(::Type{Scalar}, ::Type{Arry}) = promote_type(Scalar, Arry)
promote_array_type{S<:Real, A<:FloatingPoint}(::Type{S}, ::Type{A}) = A
promote_array_type{S<:Integer, A<:Integer}(::Type{S}, ::Type{A}) = A
promote_array_type{S<:Integer}(::Type{S}, ::Type{Bool}) = S
promote_array_type{Scalar, Arry}(F, ::Type{Scalar}, ::Type{Arry}) = promote_op(F, Scalar, Arry)
promote_array_type{S<:Real, A<:FloatingPoint}(F, ::Type{S}, ::Type{A}) = A
promote_array_type{S<:Integer, A<:Integer}(F, ::Type{S}, ::Type{A}) = A
promote_array_type{S<:Integer}(F, ::Type{S}, ::Type{Bool}) = S

# Handle operations that return different types
./(x::Number, Y::AbstractArray) =
Expand All @@ -58,9 +58,10 @@ promote_array_type{S<:Integer}(::Type{S}, ::Type{Bool}) = S
reshape([ x ^ y for x in X ], size(X))

for f in (:+, :-, :div, :mod, :&, :|, :$)
F = GenericNFunc{f,2}()
@eval begin
function ($f){S,T}(A::Range{S}, B::Range{T})
F = similar(A, promote_type(S,T), promote_shape(size(A),size(B)))
F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B)))
i = 1
for (a,b) in zip(A,B)
@inbounds F[i] = ($f)(a, b)
Expand All @@ -69,7 +70,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$)
return F
end
function ($f){S,T}(A::AbstractArray{S}, B::Range{T})
F = similar(A, promote_type(S,T), promote_shape(size(A),size(B)))
F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B)))
i = 1
for b in B
@inbounds F[i] = ($f)(A[i], b)
Expand All @@ -78,7 +79,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$)
return F
end
function ($f){S,T}(A::Range{S}, B::AbstractArray{T})
F = similar(B, promote_type(S,T), promote_shape(size(A),size(B)))
F = similar(B, promote_op($F,S,T), promote_shape(size(A),size(B)))
i = 1
for a in A
@inbounds F[i] = ($f)(a, B[i])
Expand All @@ -87,7 +88,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$)
return F
end
function ($f){S,T}(A::AbstractArray{S}, B::AbstractArray{T})
F = similar(A, promote_type(S,T), promote_shape(size(A),size(B)))
F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B)))
for i in eachindex(A,B)
@inbounds F[i] = ($f)(A[i], B[i])
end
Expand All @@ -96,16 +97,17 @@ for f in (:+, :-, :div, :mod, :&, :|, :$)
end
end
for f in (:.+, :.-, :.*, :.%, :.<<, :.>>, :div, :mod, :rem, :&, :|, :$)
F = GenericNFunc{f,2}()
@eval begin
function ($f){T}(A::Number, B::AbstractArray{T})
F = similar(B, promote_array_type(typeof(A),T))
F = similar(B, promote_array_type($F,typeof(A),T))
for i in eachindex(B)
@inbounds F[i] = ($f)(A, B[i])
end
return F
end
function ($f){T}(A::AbstractArray{T}, B::Number)
F = similar(A, promote_array_type(typeof(B),T))
F = similar(A, promote_array_type($F,typeof(B),T))
for i in eachindex(A)
@inbounds F[i] = ($f)(A[i], B)
end
Expand Down
37 changes: 20 additions & 17 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -855,21 +855,23 @@ for f in (:+, :-)
return r
end
end
for f in (:.+, :.-),
(arg1, arg2, T, fargs) in ((:(B::BitArray), :(x::Bool) , Int , :(b, x)),
(:(B::BitArray), :(x::Number) , :(promote_array_type(typeof(x), Bool)), :(b, x)),
(:(x::Bool) , :(B::BitArray), Int , :(x, b)),
(:(x::Number) , :(B::BitArray), :(promote_array_type(typeof(x), Bool)), :(x, b)))
@eval function ($f)($arg1, $arg2)
r = Array($T, size(B))
bi = start(B)
ri = 1
while !done(B, bi)
b, bi = next(B, bi)
@inbounds r[ri] = ($f)($fargs...)
ri += 1
for f in (:.+, :.-)
fq = Expr(:quote, f)
for (arg1, arg2, T, fargs) in ((:(B::BitArray), :(x::Bool) , Int , :(b, x)),
(:(B::BitArray), :(x::Number) , :(promote_array_type(GenericNFunc{$fq,2}(),typeof(x), Bool)), :(b, x)),
(:(x::Bool) , :(B::BitArray), Int , :(x, b)),
(:(x::Number) , :(B::BitArray), :(promote_array_type(GenericNFunc{$fq,2}(),typeof(x), Bool)), :(x, b)))
@eval function ($f)($arg1, $arg2)
r = Array($T, size(B))
bi = start(B)
ri = 1
while !done(B, bi)
b, bi = next(B, bi)
@inbounds r[ri] = ($f)($fargs...)
ri += 1
end
return r
end
return r
end
end

Expand Down Expand Up @@ -897,7 +899,7 @@ function div(x::Bool, B::BitArray)
end
function div(x::Number, B::BitArray)
all(B) || throw(DivideError())
pt = promote_array_type(typeof(x), Bool)
pt = promote_array_type(GenericNFunc{:div,2}(),typeof(x), Bool)
y = div(x, true)
reshape(pt[ y for i = 1:length(B) ], size(B))
end
Expand All @@ -918,15 +920,16 @@ function mod(x::Bool, B::BitArray)
end
function mod(x::Number, B::BitArray)
all(B) || throw(DivideError())
pt = promote_array_type(typeof(x), Bool)
pt = promote_array_type(GenericNFunc{:mod,2}(), typeof(x), Bool)
y = mod(x, true)
reshape(pt[ y for i = 1:length(B) ], size(B))
end

for f in (:div, :mod)
F = GenericNFunc{f,2}()
@eval begin
function ($f)(B::BitArray, x::Number)
F = Array(promote_array_type(typeof(x), Bool), size(B))
F = Array(promote_array_type($F, typeof(x), Bool), size(B))
for i = 1:length(F)
F[i] = ($f)(B[i], x)
end
Expand Down
2 changes: 1 addition & 1 deletion base/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ big{T<:FloatingPoint,N}(A::AbstractArray{Complex{T},N}) = convert(AbstractArray{

## promotion to complex ##

promote_array_type{S<:Union{Complex, Real}, AT<:FloatingPoint}(::Type{S}, ::Type{Complex{AT}}) = Complex{AT}
promote_array_type{S<:Union{Complex, Real}, AT<:FloatingPoint}(F, ::Type{S}, ::Type{Complex{AT}}) = Complex{AT}

function complex{S<:Real,T<:Real}(A::Array{S}, B::Array{T})
if size(A) != size(B); throw(DimensionMismatch()); end
Expand Down
5 changes: 5 additions & 0 deletions base/functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ call(::LessFun, x, y) = x < y
immutable MoreFun <: Func{2} end
call(::MoreFun, x, y) = x > y

immutable GenericNFunc{F,N} <: Func{N} end
@generated function call{F}(::GenericNFunc{F,2}, x, y)
:($F(x, y))
end

# a fallback unspecialized function object that allows code using
# function objects to not care whether they were able to specialize on
# the function value or not
Expand Down
7 changes: 7 additions & 0 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ checked_add(x::Integer, y::Integer) = checked_add(promote(x,y)...)
checked_sub(x::Integer, y::Integer) = checked_sub(promote(x,y)...)
checked_mul(x::Integer, y::Integer) = checked_mul(promote(x,y)...)

@generated function promote_op{R,S}(F, ::Type{R}, ::Type{S})
ret = Base.return_types(F(), Tuple{R,S})
length(ret) == 1 || error("Strange result from Base.return_types: ", ret)
T = ret[1]
:($T)
end

## catch-alls to prevent infinite recursion when definitions are missing ##

no_op_err(name, T) = error(name," not defined for ",T)
Expand Down

0 comments on commit 0964c22

Please sign in to comment.