From 5834f64224c427034d7472ee16a9f250711cf609 Mon Sep 17 00:00:00 2001 From: Marco Date: Mon, 17 Apr 2023 21:06:26 -0500 Subject: [PATCH] Add `EmbeddingBag` (#2031) * embedding bag * doc fix * Apply suggestions from code review Co-authored-by: Carlo Lucibello * Remove references to `Statistics` Statistics is imported by Flux so we can just call `mean` rather than `Statistics.mean`. * non mutating bag and onehot changes * better docs and todo * input/offset docs * doctest * Apply suggestions from code review Co-authored-by: Kyle Daruwalla Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> * reduce docs * broadcast to map * remove extra doc example line * add _splitat * rename input/offset * minor docs * Apply suggestions from code review * Update test/layers/basic.jl * Update test/layers/basic.jl * Update test/layers/basic.jl * typo * docstring * Apply suggestions from code review --------- Co-authored-by: Carlo Lucibello Co-authored-by: Kyle Daruwalla Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- docs/src/models/layers.md | 1 + src/layers/basic.jl | 148 ++++++++++++++++++++++++++++++++++++++ src/layers/show.jl | 2 +- test/layers/basic.jl | 75 +++++++++++++++++++ 4 files changed, 225 insertions(+), 1 deletion(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 6f5f3978cb..2799dd3bf0 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -91,6 +91,7 @@ These layers accept an index, and return a vector (or several indices, and sever ```@docs Flux.Embedding +Flux.EmbeddingBag ``` ## [Dataflow Layers, or Containers](@id man-dataflow-layers) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 17c77daa08..9f54d9c344 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -716,3 +716,151 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini function Base.show(io::IO, m::Embedding) print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")") end + + +""" + _splitat(data::AbstractVector, at::AbstractVector{Int}) + +Partitions `data` into a vector of views. + +Each index `i in at` specifies that a view starts with `data[i]`. +These indices must be strictly increasing, and start at `1`. +The resulting views do not overlap, and are never empty. +The last view always ends with `data[end]`. + +### Example +```jldoctest +julia> Flux._splitat(collect('A':'Z'), [1, 3, 4, 13]) +4-element Vector{SubArray{Char, 1, Vector{Char}, Tuple{UnitRange{Int64}}, true}}: + ['A', 'B'] + ['C'] + ['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L'] + ['M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] +``` +""" +function _splitat(data::AbstractVector, at::AbstractVector{<:Integer}) + at[begin] == firstindex(data) || throw(ArgumentError("The first element in `at` must be 1.")) + at[end] <= lastindex(data) || throw(ArgumentError("The last element in `at` must be at most the length of `data`.")) + issorted(at, lt = <=) || throw(ArgumentError("`at` must be monotonically increasing with no duplicates.")) + iplus = vcat(at, lastindex(data)+1) + return [view(data, iplus[n]:(iplus[n+1]-1)) for n in eachindex(at)] +end + +""" + EmbeddingBag(in => out, reduction=mean; init=Flux.randn32) + +A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`. +Differs from [`Embedding`](@ref) in that, instead of acting on a single vocabulary index, +it always acts a vector of indices which it calls a "bag". +Their individual embedding vectors are reduced to one, using `mean` or some other function. + +Instead of acting on one "bag", such as `x::Vector{Int}`, the layer can also act on several: + +* Acting on a vector of "bags", it produces a matrix whose columns are the reduced vectors. + More generally on `x::Array{Vector{Int}}`, its output is of size `(out, size(x)...)`. + +* Any higher-rank array of integers is interpreted as a collection of "bags" each along the first dimension. + Thus the output is `mapslices(e, x; dims=1)` when `e::EmbeddingBag` and `x::Array{Int,N}`. + This method is more efficient, but requires that all "bags" have the same length. + +* A vector of "bags" may also be produced by splitting a vector of indices at specified points. + For this case the layer takes two inputs, both vectors of integers. See details below. + +The "bag" may equivalently be represented as a `OneHotMatrix`. A collection of these, +or one higher-rank `OneHotArray`, again produce a stack of embeddings. See details below. + +# Examples +```jldoctest +julia> vocab_size = 26; # embed into 3 dimensions, with non-random vectors: + +julia> eb = EmbeddingBag(vocab_size => 3, init=Flux.identity_init(gain=100)) +EmbeddingBag(26 => 3) # 78 parameters + +julia> eb([2]) # one bag of 1 item +3-element Vector{Float32}: + 0.0 + 100.0 + 0.0 + +julia> eb([3,3,1]) # one bag of 3 items, one mean embedding +3-element Vector{Float32}: + 33.333332 + 0.0 + 66.666664 + +julia> eb([[3,1,3], [2,1]]) # two bags +3×2 Matrix{Float32}: + 33.3333 50.0 + 0.0 50.0 + 66.6667 0.0 + +julia> eb([1 1 1 1; 1 2 3 4]) # 4 bags each of 2 items, eachcol([1 1 1 1; 1 2 3 4]) +3×4 Matrix{Float32}: + 100.0 50.0 50.0 50.0 + 0.0 50.0 0.0 0.0 + 0.0 0.0 50.0 0.0 + +julia> eb(rand(1:26, 10, 5, 5)) |> size # 25 bags each of 10 items +(3, 5, 5) +``` + +Another way to specify "many bags of many items" is to provide a vector `data` (each in `1:in`) +and a vector `at` stating where to split that up into "bags". +The first bag starts with `data[at[1]]`, the second at `data[at[2]]`, and so on, +with no overlaps and nothing left out (thus it requires `at[1]==1`). + +```jldoctest +julia> data = [11, 1, 12, 2, 13, 3, 14]; + +julia> Flux._splitat(data, [1, 4]) |> println # internal function, makes data[1:3], data[4:end] +[[11, 1, 12], [2, 13, 3, 14]] + +julia> eb(data, [1, 4]) # two bags, of 3 and 4 items +3×2 Matrix{Float32}: + 33.3333 0.0 + 0.0 25.0 + 0.0 25.0 +``` + +Finally, each bag may also be also be represented as a [`OneHotMatrix`](@ref OneHotArrays.onehotbatch). + +```jldoctest +julia> eb(Flux.onehotbatch("bba", 'a':'z')) # same as [2,2,1], one bag of 3 items +3-element Vector{Float32}: + 33.333332 + 66.666664 + 0.0 + +julia> eb([Flux.onehotbatch("bba", 'a':'z'), Flux.onehotbatch("cc", 'a':'z')]) # two bags +3×2 Matrix{Float32}: + 33.3333 0.0 + 66.6667 0.0 + 0.0 100.0 +``` +""" +struct EmbeddingBag{F, W<:AbstractMatrix} + weight::W + reduction::F +end + +@functor EmbeddingBag + +EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean; init = randn32) = EmbeddingBag(init(out, in), reduction) +EmbeddingBag(weight::AbstractMatrix) = EmbeddingBag(weight, mean) + +(m::EmbeddingBag)(data::AbstractVector, at::AbstractVector) = m(_splitat(data, at)) +(m::EmbeddingBag)(inds::AbstractArray{<:Integer}) = dropdims(m.reduction(Embedding(m.weight)(inds), dims=2), dims=2) +(m::EmbeddingBag)(ind::Integer) = error("EmbeddingBag expects an array of indices, not just one") + +(m::EmbeddingBag)(hot::AbstractArray{Bool}) = dropdims(m.reduction(Embedding(m.weight)(hot), dims=2), dims=2) +(m::EmbeddingBag)(hot::AbstractVector{Bool}) = error("EmbeddingBag not defined for a one-hot vector") + +# These two could be stack(m, bags), but no AD support yet. (Gradient for weight quite inefficient here.) +(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags)) +(m::EmbeddingBag)(bags::AbstractArray{<:AbstractVector}) = reshape(m(vec(bags)), :, size(bags)...) + +(m::EmbeddingBag)(bags::AbstractArray{<:AbstractMatrix{Bool}}) = reshape(reduce(hcat, m.(vec(bags))), :, size(bags)...) + +function Base.show(io::IO, m::EmbeddingBag) + print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")") +end diff --git a/src/layers/show.jl b/src/layers/show.jl index aa9ccaf86f..0ae14dd9ee 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -59,7 +59,7 @@ _show_children(p::Parallel) = (p.connection, p.layers...) _show_children(f::PairwiseFusion) = (f.connection, f.layers...) for T in [ - :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, + :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, :EmbeddingBag, :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, ] @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 45e4750a6f..5215f59aca 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -338,6 +338,81 @@ import Flux: activations y3 = m(x3) @test size(y3) == (embed_size, 3, 4) end + + @testset "EmbeddingBag" begin + + # test _splitat + data = [1, 2, 3, 4, 5, 6, 7, 8, 9] + offsets_good = [1, 3, 6] + offsets_each = [1,2,3,4,5,6,7,8,9] + offsets_just_one = [1] + offsets_all_but_last = [1, 9] + + @test Flux._splitat(data, offsets_good) == [[1, 2], [3, 4, 5], [6, 7, 8, 9]] + @test Flux._splitat(data, offsets_each) == [[1], [2], [3], [4], [5], [6], [7], [8], [9]] + @test Flux._splitat(data, offsets_just_one) == [[1,2,3,4,5,6,7,8,9]] + @test Flux._splitat(data, offsets_all_but_last) == [[1,2,3,4,5,6,7,8], [9]] + + offsets_non_monotonic = [1, 2, 2, 5] + offsets_non_sorted = [1, 5, 2] + offsets_non_one = [2, 3, 5] + offsets_too_large = [1, 5, 11] + + @test_throws ArgumentError Flux._splitat(data, offsets_non_monotonic) + @test_throws ArgumentError Flux._splitat(data, offsets_non_sorted) + @test_throws ArgumentError Flux._splitat(data, offsets_non_one) + @test_throws ArgumentError Flux._splitat(data, offsets_too_large) + + @testset for reduction in [sum, Statistics.mean, maximum] + vocab_size, embed_size = 10, 4 + emb_bag = Flux.EmbeddingBag(vocab_size => embed_size, reduction) + emb = Flux.Embedding(emb_bag.weight) + @test size(emb_bag.weight) == (embed_size, vocab_size) + @test_throws ErrorException emb_bag(2) + + # single bag (input as a vector) + x = rand(1:vocab_size, 3) + y = emb_bag(x) + z = vec(reduction(emb(x), dims=2)) + @test y isa Vector{Float32} + @test y ≈ z + + # PyTorch style `input`/`offset` bagging + @test emb_bag([1,3,2,4,5,7], [1,3,5]) ≈ emb_bag([[1,3], [2,4], [5,7]]) + @test emb_bag([1,3,2,4,5,7], [1,3,5]) ≈ emb_bag([1 2 5; 3 4 7]) + @test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2, 4]) + @test_throws ArgumentError emb_bag([1,2,3,4,5,6], [1, 12]) + + # docstring example + @test emb_bag([1,2,3,4,5,6,7,8,9,10], [1,5,6,8]) ≈ emb_bag([[1,2,3,4], [5], [6,7], [8,9,10]]) + + # multiple bags (input as a vector of vectors) + x = [rand(1:vocab_size, 3) for _ in 1:4] + y = emb_bag(x) + z = reduce(hcat, reduction.(emb.(x), dims=2)) + @test y isa Matrix{Float32} + @test y ≈ z + + # multiple bags (input as a matrix) + x = rand(1:vocab_size, (3, 5)) + xvec = collect(eachcol(x)) + y = emb_bag(x) + z = reduce(hcat, reduction.(emb.(xvec), dims=2)) + @test y ≈ emb_bag(xvec) + @test y ≈ z + + # a one-hot matrix is a bag, but a one-hot vector is not. + @test_throws ErrorException emb_bag(Flux.OneHotVector(3, vocab_size)) + + i2 = rand(1:vocab_size, 3) + x2 = Flux.OneHotMatrix(i2, vocab_size) + y2 = emb_bag(x2) + z2 = emb(i2) + @test y2 isa Vector{Float32} + @test y2 ≈ vec(reduction(z2, dims=2)) + @test_throws DimensionMismatch emb_bag(Flux.OneHotMatrix(1:5, 1000)) + end + end end @testset "second derivatives" begin