Skip to content

Commit

Permalink
Make Ref behave as a scalar wrapper for broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Oct 22, 2016
1 parent 0d8a738 commit 246ff49
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
3 changes: 3 additions & 0 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ end
# logic for deciding the resulting container type
containertype(x) = containertype(typeof(x))
containertype(::Type) = Any
containertype{T<:Ref}(::Type{T}) = Array
containertype{T<:Tuple}(::Type{T}) = Tuple
containertype{T<:AbstractArray}(::Type{T}) = Array
containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2))
Expand All @@ -49,6 +50,7 @@ broadcast_indices(A) = broadcast_indices(containertype(A), A)
broadcast_indices(::Type{Any}, A) = ()
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
broadcast_indices(::Type{Array}, A) = indices(A)
broadcast_indices(::Type{Array}, A::Ref) = ()
@inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...)
# shape (i.e., tuple-of-indices) inputs
broadcast_shape(shape::Tuple) = shape
Expand Down Expand Up @@ -121,6 +123,7 @@ map_newindexer(shape, ::Tuple{}) = (), ()
end

@inline _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
@inline _broadcast_getindex(::Type{Array}, A::Ref, I) = A[]
@inline _broadcast_getindex(::Type{Any}, A, I) = A
@inline _broadcast_getindex(::Any, A, I) = A[I]

Expand Down
6 changes: 6 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,9 @@ end
@test broadcast(+, 1.0, (0, -2.0)) == (1.0,-1.0)
@test broadcast(+, 1.0, (0, -2.0), [1]) == [2.0, 0.0]
@test broadcast(*, ["Hello"], ", ", ["World"], "!") == ["Hello, World!"]

# Ref as 0-dimensional array for broadcast
@test (+).(1, Ref(2)) == fill(3)
@test (+).(Ref(1), Ref(2)) == fill(3)
@test (+).([[0,2], [1,3]], [1,-1]) == [[1,3], [0,2]]
@test (+).([[0,2], [1,3]], Ref{Vector{Int}}([-1,1])) == [[1,1], [2,2]]

0 comments on commit 246ff49

Please sign in to comment.