-
-
Notifications
You must be signed in to change notification settings - Fork 122
/
gather.jl
87 lines (67 loc) · 2.58 KB
/
gather.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
NNlib.gather!(dst, src, idx)
Reverse operation of [`scatter!`](@ref). Gathers data from source `src`
and writes it in destination `dst` according to the index array `idx`.
For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to
dst[:, ... , k] .= src[:, ... , idx[k]...]
Notice that if `idx` is a vector containing integers,
and both `dst` and `src` are matrices, previous expression simplifies to
dst[:, k] .= src[:, idx[k]]
and `k` will run over `1:length(idx)`.
The elements of `idx` can be integers or integer tuples and may be repeated.
A single `src` column can end up being copied into zero, one,
or multiple `dst` columns.
See [`gather`](@ref) for an allocating version.
"""
function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
dims = scatter_dims(src, dst, idx)
colons = ntuple(i -> Colon(), dims)
for k in CartesianIndices(idx)
_view(dst, colons, k) .= _view(src, colons, idx[k])
end
return dst
end
"""
NNlib.gather(src, idx) -> dst
Reverse operation of [`scatter`](@ref). Gathers data from source `src`
and writes it in a destination `dst` according to the index
array `idx`.
For each `k` in `CartesianIndices(idx)`, assign values to `dst`
according to
dst[:, ... , k] .= src[:, ... , idx[k]...]
Notice that if `idx` is a vector containing integers
and `src` is a matrix, previous expression simplifies to
dst[:, k] .= src[:, idx[k]]
and `k` will run over `1:length(idx)`.
The elements of `idx` can be integers or integer tuples and may be repeated.
A single `src` column can end up being copied into zero, one,
or multiple `dst` columns.
See [`gather!`](@ref) for an in-place version.
# Examples
```jldoctest
julia> NNlib.gather([1,20,300,4000], [2,4,2])
3-element Vector{Int64}:
20
4000
20
julia> NNlib.gather([1 2 3; 4 5 6], [1,3,1,3,1])
2×5 Matrix{Int64}:
1 3 1 3 1
4 6 4 6 4
```
"""
function gather(src::AbstractArray{Tsrc, Nsrc},
idx::AbstractArray{Tidx, Nidx}) where
{Tsrc, Nsrc, Nidx, Tidx}
M = typelength(Tidx)
dstsize = (size(src)[1:Nsrc-M]..., size(idx)...)
dst = similar(src, Tsrc, dstsize)
return gather!(dst, src, idx)
end
∇gather_src(Δ, src_size, idx) = scatter!(+, fill!(similar(Δ, eltype(Δ), src_size), 0), Δ, idx)
function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
y = gather!(dst, src, idx)
src_size = size(src)
gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent())
y, gather!_pullback
end