Skip to content

Commit

Permalink
Update broadcast and its callers to the new indices behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Jul 7, 2016
1 parent 9f8be53 commit 8852f94
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 57 deletions.
44 changes: 23 additions & 21 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module Broadcast

using Base.Cartesian
using Base: promote_op, promote_eltype, promote_eltype_op, @get!, _msk_end, unsafe_bitgetindex, linearindices, allocate_for, tail, dimlength
using Base: promote_op, promote_eltype, promote_eltype_op, @get!, _msk_end, unsafe_bitgetindex, linearindices, to_shape, allocate_for, tail, dimlength, OneTo
import Base: .+, .-, .*, ./, .\, .//, .==, .<, .!=, .<=, , .%, .<<, .>>, .^
export broadcast, broadcast!, bitbroadcast
export broadcast_getindex, broadcast_setindex!
Expand Down Expand Up @@ -39,6 +39,8 @@ _bcsm(a::Number, b::Number) = a == b || b == 1
## Check that all arguments are broadcast compatible with shape
## Check that all arguments are broadcast compatible with shape
# comparing one input against a shape
check_broadcast_shape(::Tuple{}) = nothing
check_broadcast_shape(::Tuple{}, A::Union{AbstractArray,Number}) = check_broadcast_shape((), indices(A))
check_broadcast_shape(shp) = nothing
check_broadcast_shape(shp, A) = check_broadcast_shape(shp, indices(A))
check_broadcast_shape(::Tuple{}, ::Tuple{}) = nothing
Expand Down Expand Up @@ -217,13 +219,13 @@ end

@inline bitbroadcast(f, As...) = broadcast!(f, allocate_for(BitArray, As, broadcast_shape(As...)), As...)

broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(Array{eltype(src)}(broadcast_shape(I...)), src, I...)
broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(Array{eltype(src)}(to_shape(broadcast_shape(I...))), src, I...)
@generated function broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::AbstractArray...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
@nexprs $N d->(I_d = I[d])
check_broadcast_shape(size(dest), $(Isplat...)) # unnecessary if this function is never called directly
check_broadcast_shape(indices(dest), $(Isplat...)) # unnecessary if this function is never called directly
checkbounds(src, $(Isplat...))
@nloops $N i dest d->(@nexprs $N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs $N k->(@inbounds J_k = @nref $N I_k d->j_d_k)
Expand All @@ -240,22 +242,22 @@ end
@nexprs $N d->(I_d = I[d])
checkbounds(A, $(Isplat...))
shape = broadcast_shape($(Isplat...))
@nextract $N shape d->(length(shape) < d ? 1 : shape[d])
@nextract $N shape d->(length(shape) < d ? OneTo(1) : shape[d])
if !isa(x, AbstractArray)
@nloops $N i d->(1:shape_d) d->(@nexprs $N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
xA = convert(eltype(A), x)
@nloops $N i d->shape_d d->(@nexprs $N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs $N k->(@inbounds J_k = @nref $N I_k d->j_d_k)
@inbounds (@nref $N A J) = x
@inbounds (@nref $N A J) = xA
end
else
X = x
# To call setindex_shape_check, we need to create fake 1-d indexes of the proper size
@nexprs $N d->(fakeI_d = 1:shape_d)
@ncall $N Base.setindex_shape_check X shape
k = 1
@nloops $N i d->(1:shape_d) d->(@nexprs $N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs $N k->(@inbounds J_k = @nref $N I_k d->j_d_k)
@inbounds (@nref $N A J) = X[k]
k += 1
@nexprs $N d->(shapelen_d = dimlength(shape_d))
@ncall $N Base.setindex_shape_check X shapelen
Xstate = start(X)
@inbounds @nloops $N i d->shape_d d->(@nexprs $N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs $N k->(J_k = @nref $N I_k d->j_d_k)
x_el, Xstate = next(X, Xstate)
(@nref $N A J) = x_el
end
end
A
Expand All @@ -271,22 +273,22 @@ end

eltype_plus(As::AbstractArray...) = promote_eltype_op(+, As...)

.+(As::AbstractArray...) = broadcast!(+, Array{eltype_plus(As...)}(broadcast_shape(As...)), As...)
.+(As::AbstractArray...) = broadcast!(+, Array{eltype_plus(As...)}(to_shape(broadcast_shape(As...))), As...)

function .-(A::AbstractArray, B::AbstractArray)
broadcast!(-, Array{promote_op(-, eltype(A), eltype(B))}(broadcast_shape(A,B)), A, B)
broadcast!(-, Array{promote_op(-, eltype(A), eltype(B))}(to_shape(broadcast_shape(A,B))), A, B)
end

eltype_mul(As::AbstractArray...) = promote_eltype_op(*, As...)

.*(As::AbstractArray...) = broadcast!(*, Array{eltype_mul(As...)}(broadcast_shape(As...)), As...)
.*(As::AbstractArray...) = broadcast!(*, Array{eltype_mul(As...)}(to_shape(broadcast_shape(As...))), As...)

function ./(A::AbstractArray, B::AbstractArray)
broadcast!(/, Array{promote_op(/, eltype(A), eltype(B))}(broadcast_shape(A, B)), A, B)
broadcast!(/, Array{promote_op(/, eltype(A), eltype(B))}(to_shape(broadcast_shape(A, B))), A, B)
end

function .\(A::AbstractArray, B::AbstractArray)
broadcast!(\, Array{promote_op(\, eltype(A), eltype(B))}(broadcast_shape(A, B)), A, B)
broadcast!(\, Array{promote_op(\, eltype(A), eltype(B))}(to_shape(broadcast_shape(A, B))), A, B)
end

typealias RatIntT{T<:Integer} Union{Type{Rational{T}},Type{T}}
Expand All @@ -296,11 +298,11 @@ type_rdiv{T<:Integer,S<:Integer}(::RatIntT{T}, ::RatIntT{S}) =
type_rdiv{T<:Integer,S<:Integer}(::CRatIntT{T}, ::CRatIntT{S}) =
Complex{Rational{promote_type(T,S)}}
function .//(A::AbstractArray, B::AbstractArray)
broadcast!(//, Array{type_rdiv(eltype(A), eltype(B))}(broadcast_shape(A, B)), A, B)
broadcast!(//, Array{type_rdiv(eltype(A), eltype(B))}(to_shape(broadcast_shape(A, B))), A, B)
end

function .^(A::AbstractArray, B::AbstractArray)
broadcast!(^, Array{promote_op(^, eltype(A), eltype(B))}(broadcast_shape(A, B)), A, B)
broadcast!(^, Array{promote_op(^, eltype(A), eltype(B))}(to_shape(broadcast_shape(A, B))), A, B)
end

# ## element-wise comparison operators returning BitArray ##
Expand Down
6 changes: 6 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,12 @@ function first(::Colon)
1
end

# Not exported, but may be useful just in case
function Broadcast.check_broadcast_shape(sz::Dims, As::Union{AbstractArray,Number}...)
depwarn("check_broadcast_shape(size(A), B...) should be replaced with check_broadcast_shape(indices(A), B...)", :check_broadcast_shape)
Broadcast.check_broadcast_shape(map(OneTo, sz), As...)
end

@deprecate slice view
@deprecate sub view

Expand Down
10 changes: 5 additions & 5 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ end
# for permutations that leave array elements in the same linear order.
# those are the permutations that preserve the order of the non-singleton
# dimensions.
function setindex_shape_check(X::AbstractArray, I...)
function setindex_shape_check(X::AbstractArray, I::Integer...)
li = ndims(X)
lj = length(I)
i = j = 1
Expand Down Expand Up @@ -440,16 +440,16 @@ end
setindex_shape_check(X::AbstractArray) =
(length(X)==1 || throw_setindex_mismatch(X,()))

setindex_shape_check(X::AbstractArray, i) =
setindex_shape_check(X::AbstractArray, i::Integer) =
(length(X)==i || throw_setindex_mismatch(X, (i,)))

setindex_shape_check{T}(X::AbstractArray{T,1}, i) =
setindex_shape_check{T}(X::AbstractArray{T,1}, i::Integer) =
(length(X)==i || throw_setindex_mismatch(X, (i,)))

setindex_shape_check{T}(X::AbstractArray{T,1}, i, j) =
setindex_shape_check{T}(X::AbstractArray{T,1}, i::Integer, j::Integer) =
(length(X)==i*j || throw_setindex_mismatch(X, (i,j)))

function setindex_shape_check{T}(X::AbstractArray{T,2}, i, j)
function setindex_shape_check{T}(X::AbstractArray{T,2}, i::Integer, j::Integer)
if length(X) != i*j
throw_setindex_mismatch(X, (i,j))
end
Expand Down
2 changes: 1 addition & 1 deletion base/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

module SparseArrays

using Base: ReshapedArray, setindex_shape_check
using Base: ReshapedArray, setindex_shape_check, to_shape
using Base.Sort: Forward
using Base.LinAlg: AbstractTriangular, PosDefException

Expand Down
26 changes: 13 additions & 13 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1413,8 +1413,8 @@ end
function gen_broadcast_body_sparse(f::Function, is_first_sparse::Bool)
F = Expr(:quote, f)
quote
Base.Broadcast.check_broadcast_shape(size(B), A_1)
Base.Broadcast.check_broadcast_shape(size(B), A_2)
Base.Broadcast.check_broadcast_shape(indices(B), A_1)
Base.Broadcast.check_broadcast_shape(indices(B), A_2)

colptrB = B.colptr; rowvalB = B.rowval; nzvalB = B.nzval
colptr1 = A_1.colptr; rowval1 = A_1.rowval; nzval1 = A_1.nzval
Expand Down Expand Up @@ -1577,8 +1577,8 @@ function gen_broadcast_body_zpreserving(f::Function, is_first_sparse::Bool)
op2 = :(val1)
end
quote
Base.Broadcast.check_broadcast_shape(size(B), $A1)
Base.Broadcast.check_broadcast_shape(size(B), $A2)
Base.Broadcast.check_broadcast_shape(indices(B), $A1)
Base.Broadcast.check_broadcast_shape(indices(B), $A2)

nnzB = isempty(B) ? 0 :
nnz($A1) * div(B.n, ($A1).n) * div(B.m, ($A1).m)
Expand Down Expand Up @@ -1647,16 +1647,16 @@ end


broadcast{Tv1,Ti1,Tv2,Ti2}(f::Function, A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::SparseMatrixCSC{Tv2,Ti2}) =
broadcast!(f, spzeros(promote_type(Tv1, Tv2), promote_type(Ti1, Ti2), broadcast_shape(A_1, A_2)...), A_1, A_2)
broadcast!(f, spzeros(promote_type(Tv1, Tv2), promote_type(Ti1, Ti2), to_shape(broadcast_shape(A_1, A_2))...), A_1, A_2)

broadcast_zpreserving!(args...) = broadcast!(args...)
broadcast_zpreserving(args...) = broadcast(args...)
broadcast_zpreserving{Tv1,Ti1,Tv2,Ti2}(f::Function, A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::SparseMatrixCSC{Tv2,Ti2}) =
broadcast_zpreserving!(f, spzeros(promote_type(Tv1, Tv2), promote_type(Ti1, Ti2), broadcast_shape(A_1, A_2)...), A_1, A_2)
broadcast_zpreserving!(f, spzeros(promote_type(Tv1, Tv2), promote_type(Ti1, Ti2), to_shape(broadcast_shape(A_1, A_2))...), A_1, A_2)
broadcast_zpreserving{Tv,Ti}(f::Function, A_1::SparseMatrixCSC{Tv,Ti}, A_2::Union{Array,BitArray,Number}) =
broadcast_zpreserving!(f, spzeros(promote_eltype(A_1, A_2), Ti, broadcast_shape(A_1, A_2)...), A_1, A_2)
broadcast_zpreserving!(f, spzeros(promote_eltype(A_1, A_2), Ti, to_shape(broadcast_shape(A_1, A_2))...), A_1, A_2)
broadcast_zpreserving{Tv,Ti}(f::Function, A_1::Union{Array,BitArray,Number}, A_2::SparseMatrixCSC{Tv,Ti}) =
broadcast_zpreserving!(f, spzeros(promote_eltype(A_1, A_2), Ti, broadcast_shape(A_1, A_2)...), A_1, A_2)
broadcast_zpreserving!(f, spzeros(promote_eltype(A_1, A_2), Ti, to_shape(broadcast_shape(A_1, A_2))...), A_1, A_2)


## Binary arithmetic and boolean operators
Expand All @@ -1676,7 +1676,7 @@ for (op, pro) in ((+, :eltype_plus),
throw(DimensionMismatch(""))
end
Tv = ($pro)(A_1, A_2)
B = spzeros(Tv, promote_type(Ti1, Ti2), broadcast_shape(A_1, A_2)...)
B = spzeros(Tv, promote_type(Ti1, Ti2), to_shape(broadcast_shape(A_1, A_2))...)
$body
B
end
Expand Down Expand Up @@ -1718,15 +1718,15 @@ end # macro
(.^)(A::Array, B::SparseMatrixCSC) = (.^)(A, full(B))

.+{Tv1,Ti1,Tv2,Ti2}(A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::SparseMatrixCSC{Tv2,Ti2}) =
broadcast!(+, spzeros(eltype_plus(A_1, A_2), promote_type(Ti1, Ti2), broadcast_shape(A_1, A_2)...), A_1, A_2)
broadcast!(+, spzeros(eltype_plus(A_1, A_2), promote_type(Ti1, Ti2), to_shape(broadcast_shape(A_1, A_2))...), A_1, A_2)

function .-{Tva,Tia,Tvb,Tib}(A::SparseMatrixCSC{Tva,Tia}, B::SparseMatrixCSC{Tvb,Tib})
broadcast!(-, spzeros(eltype_plus(A, B), promote_type(Tia, Tib), broadcast_shape(A, B)...), A, B)
broadcast!(-, spzeros(eltype_plus(A, B), promote_type(Tia, Tib), to_shape(broadcast_shape(A, B))...), A, B)
end

## element-wise comparison operators returning SparseMatrixCSC ##
.<{Tv1,Ti1,Tv2,Ti2}(A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::SparseMatrixCSC{Tv2,Ti2}) = broadcast!(<, spzeros( Bool, promote_type(Ti1, Ti2), broadcast_shape(A_1, A_2)...), A_1, A_2)
.!={Tv1,Ti1,Tv2,Ti2}(A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::SparseMatrixCSC{Tv2,Ti2}) = broadcast!(!=, spzeros( Bool, promote_type(Ti1, Ti2), broadcast_shape(A_1, A_2)...), A_1, A_2)
.<{Tv1,Ti1,Tv2,Ti2}(A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::SparseMatrixCSC{Tv2,Ti2}) = broadcast!(<, spzeros( Bool, promote_type(Ti1, Ti2), to_shape(broadcast_shape(A_1, A_2))...), A_1, A_2)
.!={Tv1,Ti1,Tv2,Ti2}(A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::SparseMatrixCSC{Tv2,Ti2}) = broadcast!(!=, spzeros( Bool, promote_type(Ti1, Ti2), to_shape(broadcast_shape(A_1, A_2))...), A_1, A_2)

## full equality
function ==(A1::SparseMatrixCSC, A2::SparseMatrixCSC)
Expand Down
33 changes: 16 additions & 17 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module TestBroadcastInternals

using Base.Broadcast: broadcast_shape, check_broadcast_shape, newindex, _bcs, _bcsm
using Base.Test
using Base: Test, OneTo

@test @inferred(_bcs((), (3,5), (3,5))) == (3,5)
@test @inferred(_bcs((), (3,1), (3,5))) == (3,5)
Expand All @@ -18,29 +18,28 @@ using Base.Test
@test_throws DimensionMismatch _bcs((), (-1:1, 2:6), (-1:1, 2:5))
@test_throws DimensionMismatch _bcs((), (-1:1, 2:5), (2, 2:5))

@test @inferred(broadcast_shape(zeros(3,4), zeros(3,4))) == (3,4)
@test @inferred(broadcast_shape(zeros(3,4), zeros(3))) == (3,4)
@test @inferred(broadcast_shape(zeros(3), zeros(3,4))) == (3,4)
@test @inferred(broadcast_shape(zeros(3), zeros(1,4), zeros(1))) == (3,4)

check_broadcast_shape((3,5), zeros(3,5))
check_broadcast_shape((3,5), zeros(3,1))
check_broadcast_shape((3,5), zeros(3))
check_broadcast_shape((3,5), zeros(3,5), zeros(3))
check_broadcast_shape((3,5), zeros(3,5), 1)
check_broadcast_shape((3,5), 5, 2)
@test_throws DimensionMismatch check_broadcast_shape((3,5), zeros(2,5))
@test_throws DimensionMismatch check_broadcast_shape((3,5), zeros(3,4))
@test_throws DimensionMismatch check_broadcast_shape((3,5), zeros(3,4,2))
@test_throws DimensionMismatch check_broadcast_shape((3,5), zeros(3,5), zeros(2))
@test @inferred(broadcast_shape(zeros(3,4), zeros(3,4))) == (OneTo(3),OneTo(4))
@test @inferred(broadcast_shape(zeros(3,4), zeros(3))) == (OneTo(3),OneTo(4))
@test @inferred(broadcast_shape(zeros(3), zeros(3,4))) == (OneTo(3),OneTo(4))
@test @inferred(broadcast_shape(zeros(3), zeros(1,4), zeros(1))) == (OneTo(3),OneTo(4))

check_broadcast_shape((OneTo(3),OneTo(5)), zeros(3,5))
check_broadcast_shape((OneTo(3),OneTo(5)), zeros(3,1))
check_broadcast_shape((OneTo(3),OneTo(5)), zeros(3))
check_broadcast_shape((OneTo(3),OneTo(5)), zeros(3,5), zeros(3))
check_broadcast_shape((OneTo(3),OneTo(5)), zeros(3,5), 1)
check_broadcast_shape((OneTo(3),OneTo(5)), 5, 2)
@test_throws DimensionMismatch check_broadcast_shape((OneTo(3),OneTo(5)), zeros(2,5))
@test_throws DimensionMismatch check_broadcast_shape((OneTo(3),OneTo(5)), zeros(3,4))
@test_throws DimensionMismatch check_broadcast_shape((OneTo(3),OneTo(5)), zeros(3,4,2))
@test_throws DimensionMismatch check_broadcast_shape((OneTo(3),OneTo(5)), zeros(3,5), zeros(2))

check_broadcast_shape((-1:1, 6:9), (-1:1, 6:9))
check_broadcast_shape((-1:1, 6:9), (-1:1, 1))
check_broadcast_shape((-1:1, 6:9), (1, 6:9))
@test_throws DimensionMismatch check_broadcast_shape((-1:1, 6:9), (-1, 6:9))
@test_throws DimensionMismatch check_broadcast_shape((-1:1, 6:9), (-1:1, 6))
check_broadcast_shape((-1:1, 6:9), 1)
check_broadcast_shape((-1:1, 6:9), zeros(1,1))

ci(x) = CartesianIndex(x)
@test @inferred(newindex(ci((2,2)), (true, true))) == ci((2,2))
Expand Down

0 comments on commit 8852f94

Please sign in to comment.