diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index 9afdd102..368a70db 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -305,6 +305,71 @@ defmodule Bumblebee.Text do defdelegate text_classification(model_info, tokenizer, opts \\ []), to: Bumblebee.Text.TextClassification + @type text_embedding_input :: String.t() + @type text_embedding_output :: %{embedding: Nx.Tensor.t()} + + @doc """ + Builds serving for text embeddings. + + The serving accepts `t:text_embedding_input/0` and returns + `t:text_embedding_output/0`. A list of inputs is also supported. + + ## Options + + * `:output_attribute` - the attribute of the model output map to + retrieve. When the output is a single tensor (rather than a map), + this option is ignored. Defaults to `:pooled_state` + + * `:output_pool` - pooling to apply on top of the model output, in case + it is not already a pooled embedding. Supported values: `:mean`. By + default no pooling is applied + + * `:embedding_processor` - a post-processing step to apply to the + embedding. Supported values: `:l2_norm`. By default the output is + returned as is + + * `:compile` - compiles all computations for predefined input shapes + during serving initialization. Should be a keyword list with the + following keys: + + * `:batch_size` - the maximum batch size of the input. Inputs + are optionally padded to always match this batch size + + * `:sequence_length` - the maximum input sequence length. Input + sequences are always padded/truncated to match that length + + It is advised to set this option in production and also configure + a defn compiler using `:defn_options` to maximally reduce inference + time. + + * `:defn_options` - the options for JIT compilation. Defaults to `[]` + + ## Examples + + {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-large"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-large"}) + + serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer) + + text = "query: Cats are cute." + Nx.Serving.run(serving, text) + + #=> %{ + #=> embedding: #Nx.Tensor< + #=> f32[1024] + #=> EXLA.Backend + #=> [-0.9789889454841614, -0.9814645051956177, -0.5015208125114441, 0.9867952466011047, 0.9917466640472412, -0.5557178258895874, -0.18618212640285492, 0.797040581703186, 0.8922086954116821, 0.7599573135375977, -0.16524426639080048, -0.8740050792694092, 0.9433475732803345, 0.7217797636985779, 0.9437620639801025, 0.4694959223270416, 0.40594056248664856, -0.20143413543701172, 0.7144518494606018, -0.8689796924591064, 0.94001305103302, 0.17163503170013428, -0.9896315932273865, 0.4455447494983673, 0.41139301657676697, 0.01911175064742565, -0.11275406181812286, -0.734498143196106, -0.6410953402519226, -0.628239095211029, -0.2570168673992157, 0.475137323141098, -0.7534396052360535, -0.9492156505584717, -0.17271563410758972, 0.9081271886825562, -0.4851466119289398, -0.9440935254096985, -0.20976334810256958, -0.684502899646759, -0.11581139266490936, 0.17509342730045319, 0.05547652021050453, 0.31042391061782837, 0.955132007598877, -0.35595986247062683, 0.016105204820632935, -0.3154579997062683, 0.9630348682403564, ...] + #=> > + #=> } + """ + @spec text_embedding( + Bumblebee.model_info(), + Bumblebee.Tokenizer.t(), + keyword() + ) :: Nx.Serving.t() + defdelegate text_embedding(model_info, tokenizer, opts \\ []), + to: Bumblebee.Text.TextEmbedding + @type fill_mask_input :: String.t() @type fill_mask_output :: %{predictions: list(fill_mask_prediction())} @type fill_mask_prediction :: %{score: number(), token: String.t()} diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex new file mode 100644 index 00000000..1dd97073 --- /dev/null +++ b/lib/bumblebee/text/text_embedding.ex @@ -0,0 +1,116 @@ +defmodule Bumblebee.Text.TextEmbedding do + @moduledoc false + + alias Bumblebee.Shared + + def text_embedding(model_info, tokenizer, opts \\ []) do + %{model: model, params: params, spec: _spec} = model_info + + opts = + Keyword.validate!(opts, [ + :compile, + output_attribute: :pooled_state, + output_pool: nil, + embedding_processor: nil, + defn_options: [] + ]) + + output_attribute = opts[:output_attribute] + output_pool = opts[:output_pool] + embedding_processor = opts[:embedding_processor] + compile = opts[:compile] + defn_options = opts[:defn_options] + + batch_size = compile[:batch_size] + sequence_length = compile[:sequence_length] + + if compile != nil and (batch_size == nil or sequence_length == nil) do + raise ArgumentError, + "expected :compile to be a keyword list specifying :batch_size and :sequence_length, got: #{inspect(compile)}" + end + + {_init_fun, encoder} = Axon.build(model) + + embedding_fun = fn params, inputs -> + output = encoder.(params, inputs) + + output = + if is_map(output) do + output[output_attribute] + else + output + end + + output = + case output_pool do + nil -> + output + + :mean_pooling -> + input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1) + + output + |> Nx.multiply(input_mask_expanded) + |> Nx.sum(axes: [1]) + |> Nx.divide(Nx.sum(input_mask_expanded, axes: [1])) + + other -> + raise ArgumentError, + "expected :output_pool to be one of nil or :mean_pooling, got: #{inspect(other)}" + end + + output = + case embedding_processor do + nil -> + output + + :l2_norm -> + Bumblebee.Utils.Nx.normalize(output) + + other -> + raise ArgumentError, + "expected :embedding_processor to be one of nil or :l2_norm, got: #{inspect(other)}" + end + + output + end + + Nx.Serving.new( + fn defn_options -> + embedding_fun = + Shared.compile_or_jit(embedding_fun, defn_options, compile != nil, fn -> + inputs = %{ + "input_ids" => Nx.template({batch_size, sequence_length}, :u32), + "attention_mask" => Nx.template({batch_size, sequence_length}, :u32) + } + + [params, inputs] + end) + + fn inputs -> + inputs = Shared.maybe_pad(inputs, batch_size) + embedding_fun.(params, inputs) + end + end, + defn_options + ) + |> Nx.Serving.process_options(batch_size: batch_size) + |> Nx.Serving.client_preprocessing(fn input -> + {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) + + inputs = + Bumblebee.apply_tokenizer(tokenizer, texts, + length: sequence_length, + return_token_type_ids: false + ) + + {Nx.Batch.concatenate([inputs]), multi?} + end) + |> Nx.Serving.client_postprocessing(fn embeddings, _metadata, multi? -> + for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do + %{embedding: embedding} + end + |> Shared.normalize_output(multi?) + end) + end +end diff --git a/lib/bumblebee/utils/nx.ex b/lib/bumblebee/utils/nx.ex index a9fa1882..44da7bdd 100644 --- a/lib/bumblebee/utils/nx.ex +++ b/lib/bumblebee/utils/nx.ex @@ -325,13 +325,18 @@ defmodule Bumblebee.Utils.Nx do Nx.dot(x, [-1], batch_axes, y, [-1], batch_axes) end - defnp normalize(tensor) do + @doc """ + Applies L2 normalization to the last dimension of the given tensor. + """ + defn normalize(tensor) do norm = tensor |> Nx.pow(2) |> Nx.sum(axes: [-1], keep_axes: true) |> Nx.sqrt() + norm = Nx.select(norm == 0.0, 1.0, norm) + tensor / norm end diff --git a/test/bumblebee/text/text_embedding_test.exs b/test/bumblebee/text/text_embedding_test.exs new file mode 100644 index 00000000..c71c1768 --- /dev/null +++ b/test/bumblebee/text/text_embedding_test.exs @@ -0,0 +1,51 @@ +defmodule Bumblebee.Text.TextEmbeddingTest do + use ExUnit.Case, async: false + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + describe "integration" do + test "returns E5 embedding for a piece of text" do + {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-large"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-large"}) + + serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer) + + text = "query: Cats are cute." + + assert %{embedding: %Nx.Tensor{} = embedding} = Nx.Serving.run(serving, text) + + assert Nx.shape(embedding) == {1024} + + assert_all_close( + embedding[1..3], + Nx.tensor([-0.9815, -0.5015, 0.9868]), + atol: 1.0e-4 + ) + end + + test "returns normalized E5 embedding for a piece of text" do + {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-large"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-large"}) + + options = [embedding_processor: :l2_norm] + + serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer, options) + + text = "query: Cats are cute." + + assert %{embedding: %Nx.Tensor{} = embedding} = Nx.Serving.run(serving, text) + + assert Nx.shape(embedding) == {1024} + + assert_all_close( + embedding[1..3], + Nx.tensor([-0.0459, -0.0234, 0.0461]), + atol: 1.0e-4 + ) + + assert_equal(Nx.sum(Nx.pow(embedding, 2)), Nx.tensor(1.0)) + end + end +end