Skip to content

Commit

Permalink
Compute BLIP image embeddings only once during generation (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Nov 14, 2023
1 parent 4e0e178 commit 83c5af6
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions lib/bumblebee/multimodal/blip.ex
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ defmodule Bumblebee.Multimodal.Blip do
Indices of positions of each decoder input sequence tokens in
the position embeddings.
* `"encoder_hidden_state"` - `{batch_size, sequence_length, hidden_size}`
Last hidden state output from the encoder. This hidden state is
used in cross-attention blocks in the decoder. If specified, the
model will skip the image encoding process and use this value
directly for cross-attentions in the text decoder.
* `"cache"`
A container with cached layer results used to speed up sequential
Expand Down Expand Up @@ -107,13 +114,15 @@ defmodule Bumblebee.Multimodal.Blip do

vision_shape = {nil, vision_spec.image_size, vision_spec.image_size, vision_spec.num_channels}
text_shape = {nil, nil}
vision_hidden_shape = {nil, nil, vision_spec.hidden_size}

inputs =
Bumblebee.Utils.Model.inputs_to_map([
Axon.input("pixel_values", shape: vision_shape),
Axon.input("decoder_input_ids", optional: true, shape: text_shape),
Axon.input("decoder_attention_mask", optional: true, shape: text_shape),
Axon.input("decoder_position_ids", optional: true, shape: text_shape),
Axon.input("encoder_hidden_state", optional: true, shape: vision_hidden_shape),
Axon.input("cache", optional: true)
])

Expand All @@ -129,6 +138,21 @@ defmodule Bumblebee.Multimodal.Blip do
"pixel_values" => inputs["pixel_values"]
})

vision_model_outputs =
Layers.if_present inputs["encoder_hidden_state"] do
%{
hidden_state: inputs["encoder_hidden_state"],
hidden_states: Layers.none(),
attentions: Layers.none()
}
else
%{
hidden_state: Axon.nx(vision_model, & &1.hidden_state),
hidden_states: Axon.nx(vision_model, & &1.hidden_states),
attentions: Axon.nx(vision_model, & &1.attentions)
}
end

text_decoder =
text_spec
|> Bumblebee.configure(
Expand All @@ -141,7 +165,7 @@ defmodule Bumblebee.Multimodal.Blip do
"input_ids" => inputs["decoder_input_ids"],
"attention_mask" => inputs["decoder_attention_mask"],
"position_ids" => inputs["decoder_position_ids"],
"encoder_hidden_state" => Axon.nx(vision_model, & &1.hidden_state),
"encoder_hidden_state" => vision_model_outputs.hidden_state,
"cache" => inputs["cache"]
})

Expand All @@ -150,9 +174,9 @@ defmodule Bumblebee.Multimodal.Blip do
decoder_hidden_states: Axon.nx(text_decoder, & &1.hidden_states),
decoder_attentions: Axon.nx(text_decoder, & &1.attentions),
cross_attentions: Axon.nx(text_decoder, & &1.cross_attentions),
encoder_hidden_state: Axon.nx(vision_model, & &1.hidden_state),
encoder_hidden_states: Axon.nx(vision_model, & &1.hidden_states),
encoder_attentions: Axon.nx(vision_model, & &1.attentions),
encoder_hidden_state: vision_model_outputs.hidden_state,
encoder_hidden_states: vision_model_outputs.hidden_states,
encoder_attentions: vision_model_outputs.attentions,
cache: Axon.nx(text_decoder, & &1.cache)
})
end
Expand Down

0 comments on commit 83c5af6

Please sign in to comment.