Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EmbeddingBag #2031

Merged
merged 23 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ These layers accept an index, and return a vector (or several indices, and sever

```@docs
Flux.Embedding
Flux.EmbeddingBag
```

## [Dataflow Layers, or Containers](@id man-dataflow-layers)
Expand Down
148 changes: 148 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -716,3 +716,151 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
function Base.show(io::IO, m::Embedding)
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end


"""
_splitat(data::AbstractVector, at::AbstractVector{Int})

Partitions `data` into a vector of views.

Each index `i in at` specifies that a view starts with `data[i]`.
These indices must be strictly increasing, and start at `1`.
The resulting views do not overlap, and are never empty.
The last view always ends with `data[end]`.

### Example
```jldoctest
julia> Flux._splitat(collect('A':'Z'), [1, 3, 4, 13])
4-element Vector{SubArray{Char, 1, Vector{Char}, Tuple{UnitRange{Int64}}, true}}:
['A', 'B']
['C']
['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L']
['M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
```
"""
function _splitat(data::AbstractVector, at::AbstractVector{<:Integer})
at[begin] == firstindex(data) || throw(ArgumentError("The first element in `at` must be 1."))
at[end] <= lastindex(data) || throw(ArgumentError("The last element in `at` must be at most the length of `data`."))
issorted(at, lt = <=) || throw(ArgumentError("`at` must be monotonically increasing with no duplicates."))
iplus = vcat(at, lastindex(data)+1)
return [view(data, iplus[n]:(iplus[n+1]-1)) for n in eachindex(at)]
end

"""
EmbeddingBag(in => out, reduction=mean; init=Flux.randn32)

A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`.
Differs from [`Embedding`](@ref) in that, instead of acting on a single vocabulary index,
it always acts a vector of indices which it calls a "bag".
Their individual embedding vectors are reduced to one, using `mean` or some other function.

Instead of acting on one "bag", such as `x::Vector{Int}`, the layer can also act on several:

* Acting on a vector of "bags", it produces a matrix whose columns are the reduced vectors.
More generally on `x::Array{Vector{Int}}`, its output is of size `(out, size(x)...)`.

* Any higher-rank array of integers is interpreted as a collection of "bags" each along the first dimension.
Thus the output is `mapslices(e, x; dims=1)` when `e::EmbeddingBag` and `x::Array{Int,N}`.
This method is more efficient, but requires that all "bags" have the same length.

* A vector of "bags" may also be produced by splitting a vector of indices at specified points.
For this case the layer takes two inputs, both vectors of integers. See details below.

The "bag" may equivalently be represented as a `OneHotMatrix`. A collection of these,
or one higher-rank `OneHotArray`, again produce a stack of embeddings. See details below.

# Examples
```jldoctest
julia> vocab_size = 26; # embed into 3 dimensions, with non-random vectors:

julia> eb = EmbeddingBag(vocab_size => 3, init=Flux.identity_init(gain=100))
EmbeddingBag(26 => 3) # 78 parameters

julia> eb([2]) # one bag of 1 item
3-element Vector{Float32}:
0.0
100.0
0.0

julia> eb([3,3,1]) # one bag of 3 items, one mean embedding
3-element Vector{Float32}:
33.333332
0.0
66.666664

julia> eb([[3,1,3], [2,1]]) # two bags
3×2 Matrix{Float32}:
33.3333 50.0
0.0 50.0
66.6667 0.0

julia> eb([1 1 1 1; 1 2 3 4]) # 4 bags each of 2 items, eachcol([1 1 1 1; 1 2 3 4])
3×4 Matrix{Float32}:
100.0 50.0 50.0 50.0
0.0 50.0 0.0 0.0
0.0 0.0 50.0 0.0

julia> eb(rand(1:26, 10, 5, 5)) |> size # 25 bags each of 10 items
(3, 5, 5)
```

Another way to specify "many bags of many items" is to provide a vector `data` (each in `1:in`)
and a vector `at` stating where to split that up into "bags".
The first bag starts with `data[at[1]]`, the second at `data[at[2]]`, and so on,
with no overlaps and nothing left out (thus it requires `at[1]==1`).

```jldoctest
julia> data = [11, 1, 12, 2, 13, 3, 14];

julia> Flux._splitat(data, [1, 4]) |> println # internal function, makes data[1:3], data[4:end]
[[11, 1, 12], [2, 13, 3, 14]]

julia> eb(data, [1, 4]) # two bags, of 3 and 4 items
3×2 Matrix{Float32}:
33.3333 0.0
0.0 25.0
0.0 25.0
```

Finally, each bag may also be also be represented as a [`OneHotMatrix`](@ref OneHotArrays.onehotbatch).

```jldoctest
julia> eb(Flux.onehotbatch("bba", 'a':'z')) # same as [2,2,1], one bag of 3 items
3-element Vector{Float32}:
33.333332
66.666664
0.0

julia> eb([Flux.onehotbatch("bba", 'a':'z'), Flux.onehotbatch("cc", 'a':'z')]) # two bags
3×2 Matrix{Float32}:
33.3333 0.0
66.6667 0.0
0.0 100.0
```
"""
struct EmbeddingBag{F, W<:AbstractMatrix}
weight::W
reduction::F
end

@functor EmbeddingBag

EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean; init = randn32) = EmbeddingBag(init(out, in), reduction)
EmbeddingBag(weight::AbstractMatrix) = EmbeddingBag(weight, mean)

(m::EmbeddingBag)(data::AbstractVector, at::AbstractVector) = m(_splitat(data, at))
(m::EmbeddingBag)(inds::AbstractArray{<:Integer}) = dropdims(m.reduction(Embedding(m.weight)(inds), dims=2), dims=2)
(m::EmbeddingBag)(ind::Integer) = error("EmbeddingBag expects an array of indices, not just one")

(m::EmbeddingBag)(hot::AbstractArray{Bool}) = dropdims(m.reduction(Embedding(m.weight)(hot), dims=2), dims=2)
(m::EmbeddingBag)(hot::AbstractVector{Bool}) = error("EmbeddingBag not defined for a one-hot vector")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be too general of a type restriction. For example, I could define a MultiHot <: AbstractVector{Bool}, that succinctly encodes a bag with fixed k elements (in fact, this was one of my original use cases for EmbeddingBags), and then if index i is true, it should be included in the bag.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a possible encoding. Dispatch on such a type specifically is not forbidden by this method.

So far, I think every other use of one-hot arrays behaves identically if you collect it. This is why I think it makes sense to define these methods for AbstractArray{Bool}. Another boolean type with a different meaning cannot also have this property that collect doesn't change the result.


# These two could be stack(m, bags), but no AD support yet. (Gradient for weight quite inefficient here.)
(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
(m::EmbeddingBag)(bags::AbstractArray{<:AbstractVector}) = reshape(m(vec(bags)), :, size(bags)...)

(m::EmbeddingBag)(bags::AbstractArray{<:AbstractMatrix{Bool}}) = reshape(reduce(hcat, m.(vec(bags))), :, size(bags)...)

function Base.show(io::IO, m::EmbeddingBag)
print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end
2 changes: 1 addition & 1 deletion src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ _show_children(p::Parallel) = (p.connection, p.layers...)
_show_children(f::PairwiseFusion) = (f.connection, f.layers...)

for T in [
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, :EmbeddingBag,
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
Expand Down
75 changes: 75 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,81 @@ import Flux: activations
y3 = m(x3)
@test size(y3) == (embed_size, 3, 4)
end

@testset "EmbeddingBag" begin

# test _splitat
data = [1, 2, 3, 4, 5, 6, 7, 8, 9]
offsets_good = [1, 3, 6]
offsets_each = [1,2,3,4,5,6,7,8,9]
offsets_just_one = [1]
offsets_all_but_last = [1, 9]

@test Flux._splitat(data, offsets_good) == [[1, 2], [3, 4, 5], [6, 7, 8, 9]]
@test Flux._splitat(data, offsets_each) == [[1], [2], [3], [4], [5], [6], [7], [8], [9]]
@test Flux._splitat(data, offsets_just_one) == [[1,2,3,4,5,6,7,8,9]]
@test Flux._splitat(data, offsets_all_but_last) == [[1,2,3,4,5,6,7,8], [9]]

offsets_non_monotonic = [1, 2, 2, 5]
offsets_non_sorted = [1, 5, 2]
offsets_non_one = [2, 3, 5]
offsets_too_large = [1, 5, 11]

@test_throws ArgumentError Flux._splitat(data, offsets_non_monotonic)
@test_throws ArgumentError Flux._splitat(data, offsets_non_sorted)
@test_throws ArgumentError Flux._splitat(data, offsets_non_one)
@test_throws ArgumentError Flux._splitat(data, offsets_too_large)

@testset for reduction in [sum, Statistics.mean, maximum]
vocab_size, embed_size = 10, 4
emb_bag = Flux.EmbeddingBag(vocab_size => embed_size, reduction)
emb = Flux.Embedding(emb_bag.weight)
@test size(emb_bag.weight) == (embed_size, vocab_size)
@test_throws ErrorException emb_bag(2)

# single bag (input as a vector)
x = rand(1:vocab_size, 3)
y = emb_bag(x)
z = vec(reduction(emb(x), dims=2))
@test y isa Vector{Float32}
@test y ≈ z

# PyTorch style `input`/`offset` bagging
@test emb_bag([1,3,2,4,5,7], [1,3,5]) ≈ emb_bag([[1,3], [2,4], [5,7]])
@test emb_bag([1,3,2,4,5,7], [1,3,5]) ≈ emb_bag([1 2 5; 3 4 7])
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2, 4])
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [1, 12])

# docstring example
@test emb_bag([1,2,3,4,5,6,7,8,9,10], [1,5,6,8]) ≈ emb_bag([[1,2,3,4], [5], [6,7], [8,9,10]])

# multiple bags (input as a vector of vectors)
x = [rand(1:vocab_size, 3) for _ in 1:4]
y = emb_bag(x)
z = reduce(hcat, reduction.(emb.(x), dims=2))
@test y isa Matrix{Float32}
mcognetta marked this conversation as resolved.
Show resolved Hide resolved
@test y ≈ z

# multiple bags (input as a matrix)
x = rand(1:vocab_size, (3, 5))
xvec = collect(eachcol(x))
y = emb_bag(x)
z = reduce(hcat, reduction.(emb.(xvec), dims=2))
@test y ≈ emb_bag(xvec)
@test y ≈ z

# a one-hot matrix is a bag, but a one-hot vector is not.
@test_throws ErrorException emb_bag(Flux.OneHotVector(3, vocab_size))

i2 = rand(1:vocab_size, 3)
x2 = Flux.OneHotMatrix(i2, vocab_size)
y2 = emb_bag(x2)
z2 = emb(i2)
@test y2 isa Vector{Float32}
@test y2 ≈ vec(reduction(z2, dims=2))
@test_throws DimensionMismatch emb_bag(Flux.OneHotMatrix(1:5, 1000))
end
end
end

@testset "second derivatives" begin
Expand Down