Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EmbeddingBag #2031

Merged
merged 23 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Parallel
Flux.Bilinear
Flux.Scale
Flux.Embedding
Flux.EmbeddingBag
```

## Normalisation & Regularisation
Expand Down
105 changes: 105 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -692,3 +692,108 @@ end
function Base.show(io::IO, m::Embedding)
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end


"""
_splitat(data::AbstractVector, offsets::AbstractVector{Int})

Splits a vector of data into a vector of vectors based on offsets. Each offset
specifies the next sub-vectors starting index in the `data` vector. In otherwords,
the `data` vector is chuncked into vectors from `offsets[1]` to `offsets[2]` (not including the element at `offsets[2]`), `offsets[2]` to `offsets[3]`, etc.
The last offset specifies a bag that contains everything to the right of it.

The `offsets` vector must begin with `1` and be monotonically increasing. The last element of `offsets` must be at most `length(data)`.
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
"""
function _splitat(data::AbstractVector, offsets::AbstractVector{Int})
offsets[firstindex(offsets)] == 1 || throw(ArgumentError("`offsets` must begin with 1."))
offsets[end] <= length(data) || throw(ArgumentError("The last element in `offsets` must be at most the length of `data`."))
issorted(offsets, lt = <=) || throw(ArgumentError("`offsets` must be monotonically increasing with no duplicates."))
newoffsets = vcat(offsets, [lastindex(data)])
return [data[offsets[i]:(i+1 > lastindex(offsets) ? end : offsets[i+1]-1)] for i in eachindex(offsets)]
end
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

"""
EmbeddingBag(in => out, reduction=mean; init=Flux.randn32)

A lookup table that stores embeddings of dimension `out` for a vocabulary of size
`in`. Similar to [`Embedding`](@ref) but can take multiple inputs in a "bag", and the reduce each bag's embeddings to a single embedding based on `reduction`.
Typically, `reduction` is `mean`, `sum`, or `maximum`.

This layer is often used to store word embeddings and retrieve them using indices.
The inputs can take several forms:
- A scalar := single bag with a single item
- A vector := single bag with multiple items
- A matrix := multiple bags with multiple items (each column is a bag)
- A vector of vectors := multiple bags with multiple items (each inner vector is a bag)
- A "data" vector and an "offsets" vector := Explained below.

The `data`/`offsets` input type is similar to PyTorch's implementation. `data` should be
a vector of class indices and `offsets` should be a vector representing the starting index of a bag in the `inputs` vector. The first element of `offsets` must be `1`, and `offsets` must be monotonically increasing with no duplicates.

This format is useful for dealing with flattened representations of "ragged" tensors. E.g., if you have a flat vector of class labels that need to be grouped in a non-uniform way. However, under the hood, it is just syntactic sugar for the vector-of-vectors input style.

For example, the `data`/`offsets` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[1, 5, 6, 8]`
is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]`, since the first bag starts at index `1` and goes up to index `5`, non-inclusive. The next bag starts at index `5` and goes up to index `6`, non-inclusive, etc.

# Examples

```jldoctest
julia> vocab_size, embed_size = 10, 8;

julia> model = Flux.EmbeddingBag(vocab_size => embed_size)
EmbeddingBag(10 => 8) # 80 parameters

julia> model(5) |> summary # a single bag of one item
"8-element Vector{Float32}"

julia> model([1, 2, 2, 4]) |> summary # one bag several items
"8-element Vector{Float32}"

julia> model([1 2 3; 4 5 6]) |> summary # 2 bags each with 3 items
"8×3 Matrix{Float32}"

julia> model([[1, 2], [3], [4], [5, 6, 7]]) |> summary # 4 bags with different number of items
"8×4 Matrix{Float32}"

julia> data = [1, 4, 5, 2, 3];

julia> offsets = [1, 3, 4]; # 3 bags of sizes [2, 1, 2]

julia> model(data, offsets) |> summary
"8×3 Matrix{Float32}"

julia> model(Flux.OneHotVector(2, vocab_size)) |> summary # single bag with one item
"8-element Vector{Float32}"

julia> model(Flux.OneHotMatrix([2, 3, 5, 7], vocab_size)) |> summary # 4 bags, each with one item
"8×4 Matrix{Float32}"
```
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
"""
struct EmbeddingBag{F, W<:AbstractMatrix}
weight::W
reduction::F
end

@functor EmbeddingBag

EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean; init = randn32) = EmbeddingBag(init(out, in), reduction)
EmbeddingBag(weight) = EmbeddingBag(weight, mean)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

function (m::EmbeddingBag)(data::AbstractVector, offsets::AbstractVector)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
return m(_splitat(data, offsets))
end
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
(m::EmbeddingBag)(inds::AbstractArray{<:Integer}) = dropdims(m.reduction(Embedding(m.weight)(inds), dims=2), dims=2)
(m::EmbeddingBag)(ind::Integer) = error("EmbeddingBag expects an array of indices, not just one")

(m::EmbeddingBag)(hot::AbstractArray{Bool}) = dropdims(m.reduction(Embedding(m.weight)(hot), dims=2), dims=2)
(m::EmbeddingBag)(hot::AbstractVector{Bool}) = error("EmbeddingBag not defined for a one-hot vector")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be too general of a type restriction. For example, I could define a MultiHot <: AbstractVector{Bool}, that succinctly encodes a bag with fixed k elements (in fact, this was one of my original use cases for EmbeddingBags), and then if index i is true, it should be included in the bag.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a possible encoding. Dispatch on such a type specifically is not forbidden by this method.

So far, I think every other use of one-hot arrays behaves identically if you collect it. This is why I think it makes sense to define these methods for AbstractArray{Bool}. Another boolean type with a different meaning cannot also have this property that collect doesn't change the result.


# These two could be stack(m, bags), but no AD support yet. (Gradient for weight quite inefficient here.)
(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
(m::EmbeddingBag)(bags::AbstractArray{<:AbstractVector}) = reshape(m(vec(bags)), :, size(bags)...)

(m::EmbeddingBag)(bags::AbstractArray{<:AbstractMatrix{Bool}}) = reshape(reduce(hcat, m.(vec(bags))), :, size(bags)...)

function Base.show(io::IO, m::EmbeddingBag)
print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end
2 changes: 1 addition & 1 deletion src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ _show_children(p::Parallel) = (p.connection, p.layers...)
_show_children(f::PairwiseFusion) = (f.connection, f.layers...)

for T in [
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, :EmbeddingBag,
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
Expand Down
75 changes: 75 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,81 @@ import Flux: activations
@test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3]
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
end

@testset "EmbeddingBag" begin

# test _splitat
data = [1, 2, 3, 4, 5, 6, 7, 8, 9]
offsets_good = [1, 3, 6]
offsets_each = [1,2,3,4,5,6,7,8,9]
offsets_just_one = [1]
offsets_all_but_last = [1, 9]

@test Flux._splitat(data, offsets_good) == [[1, 2], [3, 4, 5], [6, 7, 8, 9]]
@test Flux._splitat(data, offsets_each) == [[1], [2], [3], [4], [5], [6], [7], [8], [9]]
@test Flux._splitat(data, offsets_just_one) == [[1,2,3,4,5,6,7,8,9]]
@test Flux._splitat(data, offsets_all_but_last) == [[1,2,3,4,5,6,7,8], [9]]

offsets_non_monotonic = [1, 2, 2, 5]
offsets_non_sorted = [1, 5, 2]
offsets_non_one = [2, 3, 5]
offsets_too_large = [1, 5, 11]

@test_throws ArgumentError Flux._splitat(data, offsets_non_monotonic)
@test_throws ArgumentError Flux._splitat(data, offsets_non_sorted)
@test_throws ArgumentError Flux._splitat(data, offsets_non_one)
@test_throws ArgumentError Flux._splitat(data, offsets_too_large)

@testset for reduction in [sum, Statistics.mean, maximum]
vocab_size, embed_size = 10, 4
emb_bag = Flux.EmbeddingBag(vocab_size => embed_size, reduction)
emb = Flux.Embedding(emb_bag.weight)
@test size(emb_bag.weight) == (embed_size, vocab_size)
@test_throws ErrorException emb_bag(2)

# single bag (input as a vector)
x = rand(1:vocab_size, 3)
y = emb_bag(x)
z = vec(reduction(emb(x), dims=2))
@test y isa Vector{Float32}
@test y ≈ z

# PyTorch style `input`/`offset` bagging
@test emb_bag([1,3,2,4,5,7], [1,3,5]) ≈ emb_bag([[1,3], [2,4], [5,7]])
@test emb_bag([1,3,2,4,5,7], [1,3,5]) ≈ emb_bag([1 2 5; 3 4 7])
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2, 4])
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [1, 12])

# docstring example
@test emb_bag([1,2,3,4,5,6,7,8,9,10], [1,5,6,8]) ≈ emb_bag([[1,2,3,4], [5], [6,7], [8,9,10]])

# multiple bags (input as a vector of vectors)
x = [rand(1:vocab_size, 3) for _ in 1:4]
y = emb_bag(x)
z = reduce(hcat, reduction.(emb.(x), dims=2))
@test y isa Matrix{Float32}
mcognetta marked this conversation as resolved.
Show resolved Hide resolved
@test y ≈ z

# multiple bags (input as a matrix)
x = rand(1:vocab_size, (3, 5))
xvec = collect(eachcol(x))
y = emb_bag(x)
z = reduce(hcat, reduction.(emb.(xvec), dims=2))
@test y ≈ emb_bag(xvec)
@test y ≈ z

# a one-hot matrix is a bag, but a one-hot vector is not.
@test_throws ErrorException emb_bag(Flux.OneHotVector(3, vocab_size))

i2 = rand(1:vocab_size, 3)
x2 = Flux.OneHotMatrix(i2, vocab_size)
y2 = emb_bag(x2)
z2 = emb(i2)
@test y2 isa Vector{Float32}
@test y2 ≈ vec(mean(z2, dims=2))
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
@test_throws DimensionMismatch emb_bag(Flux.OneHotMatrix(1:5, 1000))
end
end
end

@testset "second derivatives" begin
Expand Down