Skip to content

Commit

Permalink
Make Ref behave as a scalar wrapper for broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Oct 16, 2016
1 parent 6c9f5af commit 09a9f03
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
5 changes: 4 additions & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module Broadcast

using Base.Cartesian
using Base: promote_eltype_op, @get!, _msk_end, unsafe_bitgetindex, linearindices, tail, OneTo, to_shape
using Base: promote_eltype_op, @get!, _msk_end, unsafe_bitgetindex, linearindices, tail, OneTo, to_shape, RefValue
import Base: .+, .-, .*, ./, .\, .//, .==, .<, .!=, .<=, , .%, .<<, .>>, .^
import Base: broadcast
export broadcast!, bitbroadcast, dotview
Expand All @@ -30,6 +30,7 @@ end
containertype(x) = containertype(typeof(x))
containertype(::Type) = Any
containertype{T<:Tuple}(::Type{T}) = Tuple
containertype{T<:RefValue}(::Type{T}) = Array
containertype{T<:AbstractArray}(::Type{T}) = Array
containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2))
@inline containertype(ct1, ct2, cts...) = promote_containertype(containertype(ct1), containertype(ct2, cts...))
Expand All @@ -48,6 +49,7 @@ broadcast_indices(A) = broadcast_indices(containertype(A), A)
broadcast_indices(::Type{Any}, A) = ()
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
broadcast_indices(::Type{Array}, A) = indices(A)
broadcast_indices(::Type{Array}, A::RefValue) = ()
@inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...)
# shape (i.e., tuple-of-indices) inputs
broadcast_shape(shape::Tuple) = shape
Expand Down Expand Up @@ -127,6 +129,7 @@ dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) =
Base.copy_to_bitarray_chunks!(Bc, ((bind - 1) << 6) + 1, C, 1, min(bitcache_size, (length(Bc)-bind+1) << 6))

@inline _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
@inline _broadcast_getindex(::Type{Array}, A::RefValue, I) = A[]
@inline _broadcast_getindex(::Type{Any}, A, I) = A
@inline _broadcast_getindex(::Any, A, I) = A[I]

Expand Down
6 changes: 6 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,9 @@ end
@test broadcast(+, 1.0, (0, -2.0)) == (1.0,-1.0)
@test broadcast(+, 1.0, (0, -2.0), [1]) == [2.0, 0.0]
@test broadcast(*, ["Hello"], ", ", ["World"], "!") == ["Hello, World!"]

# Issue #18379
@test (+).(1, Ref(2)) == fill(3)
@test (+).(Ref(1), Ref(2)) == fill(3)
@test (+).([[0,2], [1,3]], [1,-1]) == [[1,3], [0,2]]
@test (+).([[0,2], [1,3]], Base.RefValue([1,-1])) == [[1,1], [2,2]]

0 comments on commit 09a9f03

Please sign in to comment.