Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EmbeddingBag #2031

Merged
merged 23 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Parallel
Flux.Bilinear
Flux.Scale
Flux.Embedding
Flux.EmbeddingBag
```

## Normalisation & Regularisation
Expand Down
81 changes: 81 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -692,3 +692,84 @@ end
function Base.show(io::IO, m::Embedding)
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end

"""
EmbeddingBag(in => out, reduction=Statistics.mean; init=randn)
mcognetta marked this conversation as resolved.
Show resolved Hide resolved

A lookup table that stores embeddings of dimension `out` for a vocabulary of size
`in`. Similar to [`Embedding`](@ref) but can take multiple inputs in a "bag". The
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
embeddings of these are then reduced to a single embedding based on `reduction`.
Typically, `reduction` is `Statistics.mean`, `sum`, or `maximum`.

This layer is often used to store word embeddings and retrieve them using indices.
The inputs can take several forms:
- A scalar := single bag with a single item
- A vector := single bag with multiple items
- A matrix := multiple bags with multiple items (each column is a bag)
- A vector of vectors: multiple mags with multiple items (each vector is a bag)
mcognetta marked this conversation as resolved.
Show resolved Hide resolved
- An input vector and offset vector: Explained below

The `input`/`offset` input type is similar to PyTorch's implementation. `input` should be
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
a vector of class indices and `offset` should be a vector representing offsets from the
first index of `input`. The first element of `offsets` must be `0`, and `offsets` should
be monotonically increasing, but the second condition is not checked.

For example, the `input`/`offset` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[0, 4, 5, 7]`
is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]`

# Examples
```jldoctest
julia> vocab_size, embed_size = 1000, 4;

julia> model = Flux.EmbeddingBag(vocab_size => embed_size)
EmbeddingBag(1000 => 4) # 4_000 parameters

julia> bags = [[1, 200, 25, 789], [2, 5, 10, 999]];

julia> bags_mtx = [1 2; 200 5; 25 10; 789 999];

julia> model(bags) |> summary
"4×2 Matrix{Float32}"
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

julia> model(bags) ≈ model(bags_mtx)
true
```
"""
struct EmbeddingBag{F, W}
mcognetta marked this conversation as resolved.
Show resolved Hide resolved
weight::W
reduction::F
end

@functor EmbeddingBag

EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = Statistics.mean; init = randn32) = EmbeddingBag(init(out, in), reduction)
EmbeddingBag(weight) = EmbeddingBag(weight, Statistics.mean)

function (m::EmbeddingBag)(inputs::AbstractVector, offsets::AbstractVector)
offsets[1] == 0 || throw(ArgumentError("`offsets` must begin with 0."))
out = zeros(eltype(m.weight), size(m.weight, 1), length(offsets))
start = firstindex(inputs)
for i in eachindex(offsets[1:end-1])
out[:, i] = m(inputs[start:offsets[i+1]])
start = offsets[i+1]+1
end
out[:, end] = m(inputs[offsets[end]+1:end])
out
mcognetta marked this conversation as resolved.
Show resolved Hide resolved
end
mcognetta marked this conversation as resolved.
Show resolved Hide resolved
(m::EmbeddingBag)(idx::Integer) = m.weight[:, idx]
(m::EmbeddingBag)(bag::AbstractVector) = vec(m.reduction(NNlib.gather(m.weight, bag), dims=2))
mcognetta marked this conversation as resolved.
Show resolved Hide resolved
(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
(m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags)))
mcognetta marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading the PyTorch docstring, it seems the main advantage of this layer is memory efficiency. So, shouldn't these be mapreduce instead of a broadcast to achieve the same feature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, mapreduce(f, hcat, collection) is not optimized. But yes, I agree. I will add a todo for when specialized mapreduce functions are added. See: https://discourse.julialang.org/t/different-performance-between-reduce-map-and-mapreduce/85149 and JuliaLang/julia#31137.

julia> (m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
julia> (m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags)))

julia> test(m::EmbeddingBag, bags::AbstractVector{<:AbstractVector})  = mapreduce(m, hcat, bags)
julia> test(m::EmbeddingBag, bags::AbstractMatrix) = mapreduce(m, hcat, eachcol(bags))
julia> e = Flux.EmbeddingBag(100=>64)
julia> bags = [[rand(1:100) for _ in 1:3] for _ in 1:1000]
julia> @btime e(bags);
  709.630 μs (14004 allocations: 2.16 MiB)

julia> @btime test(e, bags);
  14.700 ms (15935 allocations: 124.18 MiB)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, mapreduce(f, hcat, collection) is not optimized

If this is the hurdle, then stack(f, collection) might be the solution, assuming f returns vectors. Needs using Compat, which is certainly already loaded downstream.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The really big memory cost is going to be the gradient of gather. For every column / vector, ∇gather_src is going to allocate like a copy of the weights.

https://github.com/FluxML/NNlib.jl/blob/6f74fad0a2a24e3594fc5229cc515fa25e80f877/src/gather.jl#L80

One could write a more efficient combined rule for this. Or add some thunks to the one in NNlib & wait for AD to learn to exploit them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done after this PR, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I just mean these concerns will dwarf the hcat cost. (Even on the forward pass, the thing you make to call mean on it will also be much larger.)


function (m::EmbeddingBag)(x::OneHotVector{T,L}) where {T,L}
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(onecold(x))
end
function (m::EmbeddingBag)(x::OneHotMatrix{T,L}) where {T,L}
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(LinearAlgebra.Transpose(onecold(x)))
end
mcognetta marked this conversation as resolved.
Show resolved Hide resolved

function Base.show(io::IO, m::EmbeddingBag)
print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end
2 changes: 1 addition & 1 deletion src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ _show_children(p::Parallel) = (p.connection, p.layers...)
_show_children(f::PairwiseFusion) = (f.connection, f.layers...)

for T in [
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, :EmbeddingBag,
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
Expand Down
57 changes: 57 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,63 @@ import Flux: activations
@test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3]
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
end

@testset "EmbeddingBag" begin
for reduction in [sum, Statistics.mean, maximum]
vocab_size, embed_size = 10, 4
emb_bag = Flux.EmbeddingBag(vocab_size => embed_size, reduction)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
emb = Flux.Embedding(emb_bag.weight)
@test size(emb_bag.weight) == (embed_size, vocab_size)

# scalar bag
@test emb_bag(2) ≈ emb_bag.weight[:,2]
@test emb_bag(3) ≈ emb(3)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

# single bag (input as a vector)
x = rand(1:vocab_size, 3)
y = emb_bag(x)
z = vec(reduction(emb(x), dims=2))
@test y isa Vector{Float32}
@test y ≈ z

# PyTorch style `input`/`offset` bagging
@test emb_bag([1,3,2,4,5,7], [0,2,4]) ≈ emb_bag([[1,3], [2,4], [5,7]])
@test emb_bag([1,3,2,4,5,7], [0,2,4]) ≈ emb_bag([1 2 5; 3 4 7])
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2,4])
@test_throws BoundsError emb_bag([1,2,3,4,5,6], [0,12])

# docstring example
@test emb_bag([1,2,3,4,5,6,7,8,9,10], [0,4,5,7]) ≈ emb_bag([[1,2,3,4], [5], [6,7], [8,9,10]])

# multiple bags (input as a vector of vectors)
x = [rand(1:vocab_size, 3) for _ in 1:4]
y = emb_bag(x)
z = reduce(hcat, reduction.(emb.(x), dims=2))
@test y isa Matrix{Float32}
mcognetta marked this conversation as resolved.
Show resolved Hide resolved
@test y ≈ z

# multiple bags (input as a matrix)
x = rand(1:vocab_size, (3, 5))
xvec = collect(eachcol(x))
y = emb_bag(x)
z = reduce(hcat, reduction.(emb.(xvec), dims=2))
@test y ≈ emb_bag(xvec)
@test y ≈ z

# one hot bags. should be identical to Embedding, since the bags
# are of size 1.
@test emb_bag(Flux.OneHotVector(3, vocab_size)) ≈ emb_bag.weight[:,3]
@test emb_bag(Flux.OneHotVector(4, vocab_size)) ≈ emb(Flux.OneHotVector(4, vocab_size))
@test_throws DimensionMismatch emb_bag(Flux.OneHotVector(3, 1000))
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

x2 = Flux.OneHotMatrix(rand(1:vocab_size, 3), vocab_size)
y2 = emb_bag(x2)
z2 = emb(x2)
@test y2 isa Matrix{Float32}
@test y2 ≈ z2
@test_throws DimensionMismatch emb_bag(Flux.OneHotMatrix(1:5, 1000))
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
end
end
end

@testset "second derivatives" begin
Expand Down