diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index 1612a422..7e66f5da 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -409,7 +409,7 @@ defmodule Bumblebee.Text do {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-large"}) {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-large"}) - serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer) + serving = Bumblebee.Text.text_embedding(model_info, tokenizer) text = "query: Cats are cute." Nx.Serving.run(serving, text) diff --git a/lib/bumblebee/text/clip_text.ex b/lib/bumblebee/text/clip_text.ex index 1fc47ce0..fbf7b2f4 100644 --- a/lib/bumblebee/text/clip_text.ex +++ b/lib/bumblebee/text/clip_text.ex @@ -35,6 +35,10 @@ defmodule Bumblebee.Text.ClipText do doc: "the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder" ], + projection_size: [ + default: 512, + doc: "the dimensionality of the projection layer" + ], activation: [ default: :quick_gelu, doc: "the activation function" @@ -62,6 +66,10 @@ defmodule Bumblebee.Text.ClipText do * `:base` - the base text model + * `:for_embedding` - the base model with a single projection layer + on top. The head returns a vector embedded in the joint text-image + CLIP space + ## Inputs * `"input_ids"` - `{batch_size, sequence_length}` @@ -95,7 +103,7 @@ defmodule Bumblebee.Text.ClipText do alias Bumblebee.Layers @impl true - def architectures(), do: [:base] + def architectures(), do: [:base, :for_embedding] @impl true def config(spec, opts \\ []) do @@ -120,6 +128,22 @@ defmodule Bumblebee.Text.ClipText do |> Layers.output() end + def model(%__MODULE__{architecture: :for_embedding} = spec) do + inputs = inputs() + + outputs = core(inputs, spec) + + embedding = + outputs.pooled_state + |> Axon.dense(spec.projection_size, use_bias: false, name: "embedding_head.output") + + Layers.output(%{ + embedding: embedding, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + defp inputs() do shape = {nil, nil} @@ -226,6 +250,7 @@ defmodule Bumblebee.Text.ClipText do num_blocks: {"num_hidden_layers", number()}, num_attention_heads: {"num_attention_heads", number()}, intermediate_size: {"intermediate_size", number()}, + projection_size: {"projection_dim", number()}, activation: {"hidden_act", atom()}, attention_dropout_rate: {"attention_dropout", number()}, layer_norm_epsilon: {"layer_norm_eps", number()} @@ -252,7 +277,8 @@ defmodule Bumblebee.Text.ClipText do "encoder.blocks.{n}.ffn.intermediate" => "text_model.encoder.layers.{n}.mlp.fc1", "encoder.blocks.{n}.ffn.output" => "text_model.encoder.layers.{n}.mlp.fc2", "encoder.blocks.{n}.output_norm" => "text_model.encoder.layers.{n}.layer_norm2", - "norm" => "text_model.final_layer_norm" + "norm" => "text_model.final_layer_norm", + "embedding_head.output" => "text_projection" } end end diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex index f3aefcd3..c1fd93ba 100644 --- a/lib/bumblebee/text/text_embedding.ex +++ b/lib/bumblebee/text/text_embedding.ex @@ -38,10 +38,19 @@ defmodule Bumblebee.Text.TextEmbedding do output = encoder.(params, inputs) output = - if is_map(output) do - output[output_attribute] - else - output + case output do + %{^output_attribute => output} -> + output + + %{} -> + keys = output |> Map.keys() |> Enum.sort() + + raise ArgumentError, + "key #{inspect(output_attribute)} not found in the output map," <> + " you may want to set :output_attribute to one of the map keys: #{inspect(keys)}" + + _ -> + output end output = diff --git a/lib/bumblebee/vision.ex b/lib/bumblebee/vision.ex index ac0ed0b9..da37dfa9 100644 --- a/lib/bumblebee/vision.ex +++ b/lib/bumblebee/vision.ex @@ -188,7 +188,7 @@ defmodule Bumblebee.Vision do module: Bumblebee.Vision.ClipVision ) {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/clip-vit-base-patch32"}) - serving = Bumblebee.Vision.ImageEmbedding.image_embedding(clip, featurizer) + serving = Bumblebee.Vision.image_embedding(clip, featurizer) image = StbImage.read_file!(path) Nx.Serving.run(serving, image) #=> %{ diff --git a/lib/bumblebee/vision/clip_vision.ex b/lib/bumblebee/vision/clip_vision.ex index f07035e5..bb7b2d62 100644 --- a/lib/bumblebee/vision/clip_vision.ex +++ b/lib/bumblebee/vision/clip_vision.ex @@ -32,6 +32,10 @@ defmodule Bumblebee.Vision.ClipVision do docs: "the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder" ], + projection_size: [ + default: 512, + doc: "the dimensionality of the projection layer" + ], activation: [ default: :quick_gelu, doc: "the activation function" @@ -57,6 +61,10 @@ defmodule Bumblebee.Vision.ClipVision do * `:base` - the base image model + * `:for_embedding` - the base model with a single projection layer + on top. The head returns a vector embedded in the joint text-image + CLIP space + ## Inputs * `"pixel_values"` - `{batch_size, image_size, image_size, num_channels}` @@ -78,7 +86,7 @@ defmodule Bumblebee.Vision.ClipVision do alias Bumblebee.Layers @impl true - def architectures(), do: [:base] + def architectures(), do: [:base, :for_embedding] @impl true def config(spec, opts \\ []) do @@ -102,6 +110,22 @@ defmodule Bumblebee.Vision.ClipVision do |> Layers.output() end + def model(%__MODULE__{architecture: :for_embedding} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + embedding = + outputs.pooled_state + |> Axon.dense(spec.projection_size, use_bias: false, name: "projection_head.output") + + Layers.output(%{ + embedding: embedding, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + defp inputs(spec) do shape = {nil, spec.image_size, spec.image_size, spec.num_channels} @@ -220,6 +244,7 @@ defmodule Bumblebee.Vision.ClipVision do num_blocks: {"num_hidden_layers", number()}, num_attention_heads: {"num_attention_heads", number()}, intermediate_size: {"intermediate_size", number()}, + projection_size: {"projection_dim", number()}, activation: {"hidden_act", atom()}, attention_dropout_rate: {"attention_dropout", number()}, layer_norm_epsilon: {"layer_norm_eps", number()} @@ -253,7 +278,8 @@ defmodule Bumblebee.Vision.ClipVision do "encoder.blocks.{n}.ffn.output" => "vision_model.encoder.layers.{n}.mlp.fc2", "encoder.blocks.{n}.output_norm" => "vision_model.encoder.layers.{n}.layer_norm2", "pre_norm" => "vision_model.pre_layrnorm", - "post_norm" => "vision_model.post_layernorm" + "post_norm" => "vision_model.post_layernorm", + "projection_head.output" => "visual_projection" } end end diff --git a/lib/bumblebee/vision/image_embedding.ex b/lib/bumblebee/vision/image_embedding.ex index f607257b..c51fb8bd 100644 --- a/lib/bumblebee/vision/image_embedding.ex +++ b/lib/bumblebee/vision/image_embedding.ex @@ -37,10 +37,19 @@ defmodule Bumblebee.Vision.ImageEmbedding do output = encoder.(params, inputs) output = - if is_map(output) do - output[output_attribute] - else - output + case output do + %{^output_attribute => output} -> + output + + %{} -> + keys = output |> Map.keys() |> Enum.sort() + + raise ArgumentError, + "key #{inspect(output_attribute)} not found in the output map," <> + " you may want to set :output_attribute to one of the map keys: #{inspect(keys)}" + + _ -> + output end output = diff --git a/test/bumblebee/text/clip_text_test.exs b/test/bumblebee/text/clip_text_test.exs index 73030957..fcf985e2 100644 --- a/test/bumblebee/text/clip_text_test.exs +++ b/test/bumblebee/text/clip_text_test.exs @@ -43,5 +43,33 @@ defmodule Bumblebee.Text.ClipTextTest do atol: 1.0e-4 ) end + + test "embedding model" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "openai/clip-vit-base-patch32"}, + module: Bumblebee.Text.ClipText, + architecture: :for_embedding + ) + + assert %Bumblebee.Text.ClipText{architecture: :for_embedding} = spec + + inputs = %{ + "input_ids" => + Nx.tensor([ + [49406, 320, 1125, 539, 320, 2368, 49407] + ]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.embedding) == {1, 512} + + assert_all_close( + outputs.embedding[[.., 1..3]], + Nx.tensor([[0.0733, -0.2448, -0.2212]]), + atol: 1.0e-4 + ) + end end end diff --git a/test/bumblebee/vision/clip_vision_test.exs b/test/bumblebee/vision/clip_vision_test.exs index dccdbfe9..03dd6f6c 100644 --- a/test/bumblebee/vision/clip_vision_test.exs +++ b/test/bumblebee/vision/clip_vision_test.exs @@ -37,5 +37,29 @@ defmodule Bumblebee.Vision.ClipVisionTest do atol: 1.0e-4 ) end + + test "embedding model" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "openai/clip-vit-base-patch32"}, + module: Bumblebee.Vision.ClipVision, + architecture: :for_embedding + ) + + assert %Bumblebee.Vision.ClipVision{architecture: :for_embedding} = spec + + inputs = %{ + "pixel_values" => Nx.broadcast(0.5, {1, 224, 224, 3}) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.embedding) == {1, 512} + + assert_all_close( + outputs.embedding[[.., 1..3]], + Nx.tensor([[-0.3381, -0.0196, -0.4053]]), + atol: 1.0e-4 + ) + end end end