From 09a9f03fe3bec14f1550cb9916332c69d384eda7 Mon Sep 17 00:00:00 2001 From: pabloferz Date: Sat, 15 Oct 2016 14:28:10 -0500 Subject: [PATCH] Make Ref behave as a scalar wrapper for broadcast --- base/broadcast.jl | 5 ++++- test/broadcast.jl | 6 ++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index 917a9be9f3a5a6..e757690735f904 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -3,7 +3,7 @@ module Broadcast using Base.Cartesian -using Base: promote_eltype_op, @get!, _msk_end, unsafe_bitgetindex, linearindices, tail, OneTo, to_shape +using Base: promote_eltype_op, @get!, _msk_end, unsafe_bitgetindex, linearindices, tail, OneTo, to_shape, RefValue import Base: .+, .-, .*, ./, .\, .//, .==, .<, .!=, .<=, .รท, .%, .<<, .>>, .^ import Base: broadcast export broadcast!, bitbroadcast, dotview @@ -30,6 +30,7 @@ end containertype(x) = containertype(typeof(x)) containertype(::Type) = Any containertype{T<:Tuple}(::Type{T}) = Tuple +containertype{T<:RefValue}(::Type{T}) = Array containertype{T<:AbstractArray}(::Type{T}) = Array containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2)) @inline containertype(ct1, ct2, cts...) = promote_containertype(containertype(ct1), containertype(ct2, cts...)) @@ -48,6 +49,7 @@ broadcast_indices(A) = broadcast_indices(containertype(A), A) broadcast_indices(::Type{Any}, A) = () broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),) broadcast_indices(::Type{Array}, A) = indices(A) +broadcast_indices(::Type{Array}, A::RefValue) = () @inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...) # shape (i.e., tuple-of-indices) inputs broadcast_shape(shape::Tuple) = shape @@ -127,6 +129,7 @@ dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) = Base.copy_to_bitarray_chunks!(Bc, ((bind - 1) << 6) + 1, C, 1, min(bitcache_size, (length(Bc)-bind+1) << 6)) @inline _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I) +@inline _broadcast_getindex(::Type{Array}, A::RefValue, I) = A[] @inline _broadcast_getindex(::Type{Any}, A, I) = A @inline _broadcast_getindex(::Any, A, I) = A[I] diff --git a/test/broadcast.jl b/test/broadcast.jl index 7dca2095bff7e2..b38909127332c2 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -337,3 +337,9 @@ end @test broadcast(+, 1.0, (0, -2.0)) == (1.0,-1.0) @test broadcast(+, 1.0, (0, -2.0), [1]) == [2.0, 0.0] @test broadcast(*, ["Hello"], ", ", ["World"], "!") == ["Hello, World!"] + +# Issue #18379 +@test (+).(1, Ref(2)) == fill(3) +@test (+).(Ref(1), Ref(2)) == fill(3) +@test (+).([[0,2], [1,3]], [1,-1]) == [[1,3], [0,2]] +@test (+).([[0,2], [1,3]], Base.RefValue([1,-1])) == [[1,1], [2,2]]