Skip to content

Commit

Permalink
Add BLIP and text-to-image serving (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Mar 28, 2023
1 parent 135141e commit 48c26db
Show file tree
Hide file tree
Showing 21 changed files with 1,464 additions and 33 deletions.
17 changes: 17 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ defmodule Bumblebee do
"BertForTokenClassification" => {Bumblebee.Text.Bert, :for_token_classification},
"BertLMHeadModel" => {Bumblebee.Text.Bert, :for_causal_language_modeling},
"BertModel" => {Bumblebee.Text.Bert, :base},
"BlipForConditionalGeneration" => {Bumblebee.Multimodal.Blip, :for_conditional_generation},
# These models are just RoBERTa models, but the config will list them as CamemBERT
"CamembertModel" => {Bumblebee.Text.Roberta, :base},
"CamembertForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
Expand Down Expand Up @@ -138,6 +139,10 @@ defmodule Bumblebee do
"WhisperFeatureExtractor" => Bumblebee.Audio.WhisperFeaturizer
}

@transformers_image_processor_type_to_featurizer %{
"BlipImageProcessor" => Bumblebee.Vision.BlipFeaturizer
}

@model_type_to_featurizer %{
"convnext" => Bumblebee.Vision.ConvNextFeaturizer,
"deit" => Bumblebee.Vision.DeitFeaturizer,
Expand All @@ -150,6 +155,7 @@ defmodule Bumblebee do
"albert" => Bumblebee.Text.AlbertTokenizer,
"bart" => Bumblebee.Text.BartTokenizer,
"bert" => Bumblebee.Text.BertTokenizer,
"blip" => Bumblebee.Text.BertTokenizer,
"distilbert" => Bumblebee.Text.DistilbertTokenizer,
"camembert" => Bumblebee.Text.CamembertTokenizer,
"clip" => Bumblebee.Text.ClipTokenizer,
Expand Down Expand Up @@ -519,6 +525,17 @@ defmodule Bumblebee do
end
end

defp infer_featurizer_type(%{"image_processor_type" => class_name}, _repository) do
case @transformers_image_processor_type_to_featurizer[class_name] do
nil ->
{:error,
"could not match the class name #{inspect(class_name)} to any of the supported featurizers"}

module ->
{:ok, module}
end
end

defp infer_featurizer_type(_featurizer_data, repository) do
with {:ok, path} <- download(repository, @config_filename),
{:ok, featurizer_data} <- decode_config(path) do
Expand Down
236 changes: 236 additions & 0 deletions lib/bumblebee/multimodal/blip.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
defmodule Bumblebee.Multimodal.Blip do
alias Bumblebee.Shared

options =
[
text_spec: [
default: nil,
doc: "the specification of the text model. See `Bumblebee.Text.BlipText` for details"
],
vision_spec: [
default: nil,
doc:
"the specification of the vision model. See `Bumblebee.Vision.BlipVision` for details"
],
projection_size: [
default: 512,
doc: "the dimensionality of text and vision projection layers"
],
logit_scale_initial_value: [
default: 2.6592,
doc: "the initial value for the scaling layer used to scale similarity logits"
]
] ++
Shared.token_options(
pad_token_id: 0,
bos_token_id: 30522,
# During generation SEP token is used as the EOS token
eos_token_id: 102,
sep_token_id: 102
)

@moduledoc """
The BLIP model for text-image similarity.
## Architectures
* `:for_conditional_generation` - BLIP model with a language
modeling head
## Inputs
* `"pixel_values"` - `{batch_size, image_size, image_size, num_channels}`
Featurized image pixel values.
* `"decoder_input_ids"` - `{batch_size, target_sequence_length}`
Indices of decoder input sequence tokens in the vocabulary. If not
present and `"input_ids"` is, it will be generated by shifting
each token in `"input_ids"` to the right once.
* `"decoder_attention_mask"` - `{batch_size, target_sequence_length}`
Mask indicating which decoder tokens to attend to. This is used
to ignore padding tokens, which are added when processing a batch
of sequences with different length.
* `"decoder_position_ids"` - `{batch_size, target_sequence_length}`
Indices of positions of each decoder input sequence tokens in
the position embeddings.
* `"cache"`
A container with cached layer results used to speed up sequential
decoding (autoregression). With cache, certain hidden states are
taken from the cache, rather than recomputed on every decoding
pass. The cache should be treated as opaque and initialized with
`Bumblebee.Text.Generation.init_cache/4`.
## Configuration
#{Shared.options_doc(options)}
## References
* [BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation](https://arxiv.org/abs/2201.12086)
"""

defstruct [architecture: :for_conditional_generation] ++ Shared.option_defaults(options)

@behaviour Bumblebee.ModelSpec
@behaviour Bumblebee.Configurable
@behaviour Bumblebee.Text.Generation

alias Bumblebee.Layers

@impl true
def architectures(), do: [:for_conditional_generation]

@impl true
def config(spec, opts \\ []) do
Shared.put_config_attrs(spec, opts)
end

@impl true
def input_template(%{vision_spec: vision_spec}) do
vision_shape = {1, vision_spec.image_size, vision_spec.image_size, vision_spec.num_channels}

%{
"pixel_values" => Nx.template(vision_shape, :f32),
"decoder_input_ids" => Nx.template({1, 1}, :s64)
}
end

@impl true
def model(%__MODULE__{architecture: :for_conditional_generation} = spec) do
%{vision_spec: vision_spec, text_spec: text_spec} = spec

vision_shape = {nil, vision_spec.image_size, vision_spec.image_size, vision_spec.num_channels}
text_shape = {nil, nil}

inputs =
Bumblebee.Utils.Model.inputs_to_map([
Axon.input("pixel_values", shape: vision_shape),
Axon.input("decoder_input_ids", optional: true, shape: text_shape),
Axon.input("decoder_attention_mask", optional: true, shape: text_shape),
Axon.input("decoder_position_ids", optional: true, shape: text_shape),
Axon.input("cache", optional: true)
])

vision_model =
vision_spec
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("vision_model.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
"pixel_values" => inputs["pixel_values"]
})

text_decoder =
text_spec
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("text_decoder.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
"input_ids" => inputs["decoder_input_ids"],
"attention_mask" => inputs["decoder_attention_mask"],
"position_ids" => inputs["decoder_position_ids"],
"encoder_hidden_state" => Axon.nx(vision_model, & &1.hidden_state),
"cache" => inputs["cache"]
})

Layers.output(%{
logits: Axon.nx(text_decoder, & &1.logits),
decoder_hidden_states: Axon.nx(text_decoder, & &1.hidden_states),
decoder_attentions: Axon.nx(text_decoder, & &1.attentions),
cross_attentions: Axon.nx(text_decoder, & &1.cross_attentions),
encoder_hidden_state: Axon.nx(vision_model, & &1.hidden_state),
encoder_hidden_states: Axon.nx(vision_model, & &1.hidden_states),
encoder_attentions: Axon.nx(vision_model, & &1.attentions),
cache: Axon.nx(text_decoder, & &1.cache)
})
end

@impl true
def init_cache(
%{vision_spec: vision_spec, text_spec: text_spec},
batch_size,
max_length,
inputs
) do
num_patches = div(vision_spec.image_size, vision_spec.patch_size) ** 2
encoder_sequence_length = num_patches + 1
encoder_shape = {batch_size, encoder_sequence_length, text_spec.hidden_size}

inputs =
%{
"input_ids" => inputs["decoder_input_ids"],
"attention_mask" => inputs["decoder_attention_mask"],
"position_ids" => inputs["decoder_position_ids"],
"encoder_hidden_state" => Nx.template(encoder_shape, :f32)
}
|> Map.reject(&match?({_, nil}, &1))

text_spec.__struct__.init_cache(text_spec, batch_size, max_length, inputs)
end

defimpl Bumblebee.HuggingFace.Transformers.Config do
def load(spec, data) do
import Shared.Converters

{text_data, data} = Map.pop(data, "text_config", %{})
{vision_data, data} = Map.pop(data, "vision_config", %{})

text_spec =
Bumblebee.Text.BlipText
|> Bumblebee.configure(architecture: :for_causal_language_modeling)
|> Bumblebee.HuggingFace.Transformers.Config.load(text_data)

vision_spec =
Bumblebee.Vision.BlipVision
|> Bumblebee.configure()
|> Bumblebee.HuggingFace.Transformers.Config.load(vision_data)

opts =
convert!(data,
projection_size: {"projection_dim", number()},
logit_scale_initial_value: {"logit_scale_init_value", number()}
) ++ Shared.common_options_from_transformers(data, spec)

opts =
case Keyword.fetch(opts, :sep_token_id) do
{:ok, sep_token_id} -> Keyword.put(opts, :eos_token_id, sep_token_id)
:error -> opts
end

@for.config(spec, opts ++ [text_spec: text_spec, vision_spec: vision_spec])
end
end

defimpl Bumblebee.HuggingFace.Transformers.Model do
alias Bumblebee.HuggingFace.Transformers

def params_mapping(spec) do
text_mapping =
spec.text_spec
|> Transformers.Model.params_mapping()
|> Transformers.Utils.prefix_params_mapping("text_decoder", nil)

vision_mapping =
spec.vision_spec
|> Transformers.Model.params_mapping()
|> Transformers.Utils.prefix_params_mapping("vision_model", nil)

%{
"text_projection" => "text_projection",
"visual_projection" => "visual_projection",
"scale" => %{
"scale" => {[{"scale", "logit_scale"}], fn [scale] -> scale end}
}
}
|> Map.merge(text_mapping)
|> Map.merge(vision_mapping)
end
end
end
4 changes: 1 addition & 3 deletions lib/bumblebee/multimodal/clip.ex
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,11 @@ defmodule Bumblebee.Multimodal.Clip do
padding tokens, which are added when processing a batch of sequences
with different length.
* `"position_ids"` - `{batch_size, sequence_length}`
Indices of positions of each input sequence tokens in the position
embeddings.
* `"pixel_values"` - `{batch_size, image_size, image_size, num_channels}`
Featurized image pixel values.
Expand Down Expand Up @@ -83,7 +81,7 @@ defmodule Bumblebee.Multimodal.Clip do

%{
"input_ids" => Nx.template({1, 1}, :s64),
"pixel_values" => Nx.template(vision_shape, :s64)
"pixel_values" => Nx.template(vision_shape, :f32)
}
end

Expand Down
10 changes: 5 additions & 5 deletions lib/bumblebee/multimodal/layout_lm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -526,15 +526,15 @@ defmodule Bumblebee.Multimodal.LayoutLm do
"encoder.blocks.{n}.self_attention_norm" =>
"layoutlm.encoder.layer.{n}.attention.output.LayerNorm",
"encoder.blocks.{n}.cross_attention.query" =>
"layoutlm.encoder.layer.{n}.attention.self.query",
"layoutlm.encoder.layer.{n}.crossattention.self.query",
"encoder.blocks.{n}.cross_attention.key" =>
"layoutlm.encoder.layer.{n}.attention.self.key",
"layoutlm.encoder.layer.{n}.crossattention.self.key",
"encoder.blocks.{n}.cross_attention.value" =>
"layoutlm.encoder.layer.{n}.attention.self.value",
"layoutlm.encoder.layer.{n}.crossattention.self.value",
"encoder.blocks.{n}.cross_attention.output" =>
"layoutlm.encoder.layer.{n}.attention.output.dense",
"layoutlm.encoder.layer.{n}.crossattention.output.dense",
"encoder.blocks.{n}.cross_attention_norm" =>
"layoutlm.encoder.layer.{n}.attention.output.LayerNorm",
"layoutlm.encoder.layer.{n}.crossattention.output.LayerNorm",
"encoder.blocks.{n}.ffn.intermediate" => "layoutlm.encoder.layer.{n}.intermediate.dense",
"encoder.blocks.{n}.ffn.output" => "layoutlm.encoder.layer.{n}.output.dense",
"encoder.blocks.{n}.output_norm" => "layoutlm.encoder.layer.{n}.output.LayerNorm",
Expand Down
11 changes: 6 additions & 5 deletions lib/bumblebee/text/bert.ex
Original file line number Diff line number Diff line change
Expand Up @@ -636,14 +636,15 @@ defmodule Bumblebee.Text.Bert do
"encoder.blocks.{n}.self_attention_norm" =>
"bert.encoder.layer.{n}.attention.output.LayerNorm",
"encoder.blocks.{n}.cross_attention.query" =>
"bert.encoder.layer.{n}.attention.self.query",
"encoder.blocks.{n}.cross_attention.key" => "bert.encoder.layer.{n}.attention.self.key",
"bert.encoder.layer.{n}.crossattention.self.query",
"encoder.blocks.{n}.cross_attention.key" =>
"bert.encoder.layer.{n}.crossattention.self.key",
"encoder.blocks.{n}.cross_attention.value" =>
"bert.encoder.layer.{n}.attention.self.value",
"bert.encoder.layer.{n}.crossattention.self.value",
"encoder.blocks.{n}.cross_attention.output" =>
"bert.encoder.layer.{n}.attention.output.dense",
"bert.encoder.layer.{n}.crossattention.output.dense",
"encoder.blocks.{n}.cross_attention_norm" =>
"bert.encoder.layer.{n}.attention.output.LayerNorm",
"bert.encoder.layer.{n}.crossattention.output.LayerNorm",
"encoder.blocks.{n}.ffn.intermediate" => "bert.encoder.layer.{n}.intermediate.dense",
"encoder.blocks.{n}.ffn.output" => "bert.encoder.layer.{n}.output.dense",
"encoder.blocks.{n}.output_norm" => "bert.encoder.layer.{n}.output.LayerNorm",
Expand Down
Loading

0 comments on commit 48c26db

Please sign in to comment.