-
Notifications
You must be signed in to change notification settings - Fork 5
/
broadcast.jl
169 lines (152 loc) · 6.06 KB
/
broadcast.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import Base.Broadcast: BroadcastStyle
using Base.Broadcast: AbstractArrayStyle, Broadcasted, DefaultArrayStyle, _broadcast_getindex
# combine_sizes moved from StaticArrays after https://github.com/JuliaArrays/StaticArrays.jl/pull/1008
# see also https://github.com/JuliaArrays/HybridArrays.jl/issues/50
@generated function combine_sizes(s::Tuple{Vararg{Size}})
sizes = [sz.parameters[1] for sz ∈ s.parameters]
ndims = 0
for i = 1:length(sizes)
ndims = max(ndims, length(sizes[i]))
end
newsize = StaticArrays.StaticDimension[Dynamic() for _ = 1 : ndims]
for i = 1:length(sizes)
s = sizes[i]
for j = 1:length(s)
if s[j] isa Dynamic
continue
elseif newsize[j] isa Dynamic || newsize[j] == 1
newsize[j] = s[j]
elseif newsize[j] ≠ s[j] && s[j] ≠ 1
throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes"))
end
end
end
quote
Base.@_inline_meta
Size($(tuple(newsize...)))
end
end
function broadcasted_index(oldsize, newindex)
index = ones(Int, length(oldsize))
for i = 1:length(oldsize)
if oldsize[i] != 1
index[i] = newindex[i]
end
end
return LinearIndices(oldsize)[index...]
end
scalar_getindex(x) = x
scalar_getindex(x::Ref) = x[]
# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
# A constructor that changes the style parameter N (array dimension) is also required
struct HybridArrayStyle{N} <: AbstractArrayStyle{N} end
HybridArrayStyle{M}(::Val{N}) where {M,N} = HybridArrayStyle{N}()
BroadcastStyle(::Type{<:HybridArray{<:Tuple, <:Any, N}}) where {N} = HybridArrayStyle{N}()
# Precedence rules
BroadcastStyle(::HybridArray{M}, ::DefaultArrayStyle{N}) where {M,N} =
DefaultArrayStyle(Val(max(M, N)))
BroadcastStyle(::HybridArray{M}, ::DefaultArrayStyle{0}) where {M} =
HybridArrayStyle{M}()
BroadcastStyle(::HybridArray{M}, ::StaticArrays.StaticArrayStyle{N}) where {M,N} =
StaticArrays.Hybrid(Val(max(M, N)))
BroadcastStyle(::HybridArray{M}, ::StaticArrays.StaticArrayStyle{0}) where {M} =
HybridArrayStyle{M}()
# copy overload
@inline function Base.copy(B::Broadcasted{HybridArrayStyle{M}}) where M
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
argsizes = StaticArrays.broadcast_sizes(as...)
destsize = combine_sizes(argsizes)
if Length(destsize) === Length{StaticArrays.Dynamic()}()
# destination dimension cannot be determined statically; fall back to generic broadcast
return HybridArray{StaticArrays.size_tuple(destsize)}(copy(convert(Broadcasted{DefaultArrayStyle{M}}, B)))
end
_broadcast(f, destsize, argsizes, as...)
end
# copyto! overloads
@inline Base.copyto!(dest, B::Broadcasted{<:HybridArrayStyle}) = _copyto!(dest, B)
@inline Base.copyto!(dest::AbstractArray, B::Broadcasted{<:HybridArrayStyle}) = _copyto!(dest, B)
@inline function _copyto!(dest, B::Broadcasted{HybridArrayStyle{M}}) where M
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
argsizes = StaticArrays.broadcast_sizes(as...)
destsize = combine_sizes((Size(dest), argsizes...))
if Length(destsize) === Length{StaticArrays.Dynamic()}()
# destination dimension cannot be determined statically; fall back to generic broadcast!
return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B))
end
_s_broadcast!(f, destsize, dest, argsizes, as...)
end
broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I))
broadcast_getindex(::Tuple{Dynamic}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I))
function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex)
li = LinearIndices(oldsize)
ind = _broadcast_getindex(li, newindex)
return :(a[$i][$ind])
end
@generated function _s_broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, a...) where {newsize}
sizes = [sz.parameters[1] for sz in s.parameters]
indices = CartesianIndices(newsize)
exprs = similar(indices, Expr)
for (j, current_ind) ∈ enumerate(indices)
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
exprs[j] = :(dest[$j] = f($(exprs_vals...)))
end
return quote
Base.@_inline_meta
@inbounds $(Expr(:block, exprs...))
return dest
end
end
@generated function _broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
first_staticarray = 0
for i = 1:length(a)
if a[i] <: StaticArray
first_staticarray = a[i]
break
end
end
if first_staticarray == 0
for i = 1:length(a)
if a[i] <: HybridArray
first_staticarray = a[i]
break
end
end
end
exprs = Array{Expr}(undef, newsize)
more = prod(newsize) > 0
current_ind = ones(Int, length(newsize))
sizes = [sz.parameters[1] for sz ∈ s.parameters]
make_expr(i) = begin
if !(a[i] <: AbstractArray)
return :(scalar_getindex(a[$i]))
elseif hasdynamic(Tuple{sizes[i]...})
return :(a[$i][$(current_ind...)])
else
:(a[$i][$(broadcasted_index(sizes[i], current_ind))])
end
end
while more
exprs_vals = [make_expr(i) for i = 1:length(sizes)]
exprs[current_ind...] = :(f($(exprs_vals...)))
# increment current_ind (maybe use CartesianIndices?)
current_ind[1] += 1
for i ∈ 1:length(newsize)
if current_ind[i] > newsize[i]
if i == length(newsize)
more = false
break
else
current_ind[i] = 1
current_ind[i+1] += 1
end
else
break
end
end
end
return quote
Base.@_inline_meta
@inbounds elements = tuple($(exprs...))
@inbounds return similar_type($first_staticarray, eltype(elements), Size(newsize))(elements)
end
end