Skip to content

Commit

Permalink
Add text embedding serving (#214)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>
  • Loading branch information
coderrg and jonatanklosko authored Jun 2, 2023
1 parent b9acb4a commit 23de64b
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 1 deletion.
65 changes: 65 additions & 0 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,71 @@ defmodule Bumblebee.Text do
defdelegate text_classification(model_info, tokenizer, opts \\ []),
to: Bumblebee.Text.TextClassification

@type text_embedding_input :: String.t()
@type text_embedding_output :: %{embedding: Nx.Tensor.t()}

@doc """
Builds serving for text embeddings.
The serving accepts `t:text_embedding_input/0` and returns
`t:text_embedding_output/0`. A list of inputs is also supported.
## Options
* `:output_attribute` - the attribute of the model output map to
retrieve. When the output is a single tensor (rather than a map),
this option is ignored. Defaults to `:pooled_state`
* `:output_pool` - pooling to apply on top of the model output, in case
it is not already a pooled embedding. Supported values: `:mean`. By
default no pooling is applied
* `:embedding_processor` - a post-processing step to apply to the
embedding. Supported values: `:l2_norm`. By default the output is
returned as is
* `:compile` - compiles all computations for predefined input shapes
during serving initialization. Should be a keyword list with the
following keys:
* `:batch_size` - the maximum batch size of the input. Inputs
are optionally padded to always match this batch size
* `:sequence_length` - the maximum input sequence length. Input
sequences are always padded/truncated to match that length
It is advised to set this option in production and also configure
a defn compiler using `:defn_options` to maximally reduce inference
time.
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
## Examples
{:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-large"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-large"})
serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer)
text = "query: Cats are cute."
Nx.Serving.run(serving, text)
#=> %{
#=> embedding: #Nx.Tensor<
#=> f32[1024]
#=> EXLA.Backend<host:0, 0.124908262.1234305056.185360>
#=> [-0.9789889454841614, -0.9814645051956177, -0.5015208125114441, 0.9867952466011047, 0.9917466640472412, -0.5557178258895874, -0.18618212640285492, 0.797040581703186, 0.8922086954116821, 0.7599573135375977, -0.16524426639080048, -0.8740050792694092, 0.9433475732803345, 0.7217797636985779, 0.9437620639801025, 0.4694959223270416, 0.40594056248664856, -0.20143413543701172, 0.7144518494606018, -0.8689796924591064, 0.94001305103302, 0.17163503170013428, -0.9896315932273865, 0.4455447494983673, 0.41139301657676697, 0.01911175064742565, -0.11275406181812286, -0.734498143196106, -0.6410953402519226, -0.628239095211029, -0.2570168673992157, 0.475137323141098, -0.7534396052360535, -0.9492156505584717, -0.17271563410758972, 0.9081271886825562, -0.4851466119289398, -0.9440935254096985, -0.20976334810256958, -0.684502899646759, -0.11581139266490936, 0.17509342730045319, 0.05547652021050453, 0.31042391061782837, 0.955132007598877, -0.35595986247062683, 0.016105204820632935, -0.3154579997062683, 0.9630348682403564, ...]
#=> >
#=> }
"""
@spec text_embedding(
Bumblebee.model_info(),
Bumblebee.Tokenizer.t(),
keyword()
) :: Nx.Serving.t()
defdelegate text_embedding(model_info, tokenizer, opts \\ []),
to: Bumblebee.Text.TextEmbedding

@type fill_mask_input :: String.t()
@type fill_mask_output :: %{predictions: list(fill_mask_prediction())}
@type fill_mask_prediction :: %{score: number(), token: String.t()}
Expand Down
116 changes: 116 additions & 0 deletions lib/bumblebee/text/text_embedding.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
defmodule Bumblebee.Text.TextEmbedding do
@moduledoc false

alias Bumblebee.Shared

def text_embedding(model_info, tokenizer, opts \\ []) do
%{model: model, params: params, spec: _spec} = model_info

opts =
Keyword.validate!(opts, [
:compile,
output_attribute: :pooled_state,
output_pool: nil,
embedding_processor: nil,
defn_options: []
])

output_attribute = opts[:output_attribute]
output_pool = opts[:output_pool]
embedding_processor = opts[:embedding_processor]
compile = opts[:compile]
defn_options = opts[:defn_options]

batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]

if compile != nil and (batch_size == nil or sequence_length == nil) do
raise ArgumentError,
"expected :compile to be a keyword list specifying :batch_size and :sequence_length, got: #{inspect(compile)}"
end

{_init_fun, encoder} = Axon.build(model)

embedding_fun = fn params, inputs ->
output = encoder.(params, inputs)

output =
if is_map(output) do
output[output_attribute]
else
output
end

output =
case output_pool do
nil ->
output

:mean_pooling ->
input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1)

output
|> Nx.multiply(input_mask_expanded)
|> Nx.sum(axes: [1])
|> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))

other ->
raise ArgumentError,
"expected :output_pool to be one of nil or :mean_pooling, got: #{inspect(other)}"
end

output =
case embedding_processor do
nil ->
output

:l2_norm ->
Bumblebee.Utils.Nx.normalize(output)

other ->
raise ArgumentError,
"expected :embedding_processor to be one of nil or :l2_norm, got: #{inspect(other)}"
end

output
end

Nx.Serving.new(
fn defn_options ->
embedding_fun =
Shared.compile_or_jit(embedding_fun, defn_options, compile != nil, fn ->
inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
}

[params, inputs]
end)

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
embedding_fun.(params, inputs)
end
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
|> Nx.Serving.client_preprocessing(fn input ->
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

inputs =
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
return_token_type_ids: false
)

{Nx.Batch.concatenate([inputs]), multi?}
end)
|> Nx.Serving.client_postprocessing(fn embeddings, _metadata, multi? ->
for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do
%{embedding: embedding}
end
|> Shared.normalize_output(multi?)
end)
end
end
7 changes: 6 additions & 1 deletion lib/bumblebee/utils/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,18 @@ defmodule Bumblebee.Utils.Nx do
Nx.dot(x, [-1], batch_axes, y, [-1], batch_axes)
end

defnp normalize(tensor) do
@doc """
Applies L2 normalization to the last dimension of the given tensor.
"""
defn normalize(tensor) do
norm =
tensor
|> Nx.pow(2)
|> Nx.sum(axes: [-1], keep_axes: true)
|> Nx.sqrt()

norm = Nx.select(norm == 0.0, 1.0, norm)

tensor / norm
end

Expand Down
51 changes: 51 additions & 0 deletions test/bumblebee/text/text_embedding_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
defmodule Bumblebee.Text.TextEmbeddingTest do
use ExUnit.Case, async: false

import Bumblebee.TestHelpers

@moduletag model_test_tags()

describe "integration" do
test "returns E5 embedding for a piece of text" do
{:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-large"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-large"})

serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer)

text = "query: Cats are cute."

assert %{embedding: %Nx.Tensor{} = embedding} = Nx.Serving.run(serving, text)

assert Nx.shape(embedding) == {1024}

assert_all_close(
embedding[1..3],
Nx.tensor([-0.9815, -0.5015, 0.9868]),
atol: 1.0e-4
)
end

test "returns normalized E5 embedding for a piece of text" do
{:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-large"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-large"})

options = [embedding_processor: :l2_norm]

serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer, options)

text = "query: Cats are cute."

assert %{embedding: %Nx.Tensor{} = embedding} = Nx.Serving.run(serving, text)

assert Nx.shape(embedding) == {1024}

assert_all_close(
embedding[1..3],
Nx.tensor([-0.0459, -0.0234, 0.0461]),
atol: 1.0e-4
)

assert_equal(Nx.sum(Nx.pow(embedding, 2)), Nx.tensor(1.0))
end
end
end

0 comments on commit 23de64b

Please sign in to comment.