diff --git a/lib/bumblebee/multimodal/blip.ex b/lib/bumblebee/multimodal/blip.ex index 92f23cec..13c6b2a1 100644 --- a/lib/bumblebee/multimodal/blip.ex +++ b/lib/bumblebee/multimodal/blip.ex @@ -57,6 +57,13 @@ defmodule Bumblebee.Multimodal.Blip do Indices of positions of each decoder input sequence tokens in the position embeddings. + * `"encoder_hidden_state"` - `{batch_size, sequence_length, hidden_size}` + + Last hidden state output from the encoder. This hidden state is + used in cross-attention blocks in the decoder. If specified, the + model will skip the image encoding process and use this value + directly for cross-attentions in the text decoder. + * `"cache"` A container with cached layer results used to speed up sequential @@ -107,6 +114,7 @@ defmodule Bumblebee.Multimodal.Blip do vision_shape = {nil, vision_spec.image_size, vision_spec.image_size, vision_spec.num_channels} text_shape = {nil, nil} + vision_hidden_shape = {nil, nil, vision_spec.hidden_size} inputs = Bumblebee.Utils.Model.inputs_to_map([ @@ -114,6 +122,7 @@ defmodule Bumblebee.Multimodal.Blip do Axon.input("decoder_input_ids", optional: true, shape: text_shape), Axon.input("decoder_attention_mask", optional: true, shape: text_shape), Axon.input("decoder_position_ids", optional: true, shape: text_shape), + Axon.input("encoder_hidden_state", optional: true, shape: vision_hidden_shape), Axon.input("cache", optional: true) ]) @@ -129,6 +138,21 @@ defmodule Bumblebee.Multimodal.Blip do "pixel_values" => inputs["pixel_values"] }) + vision_model_outputs = + Layers.if_present inputs["encoder_hidden_state"] do + %{ + hidden_state: inputs["encoder_hidden_state"], + hidden_states: Layers.none(), + attentions: Layers.none() + } + else + %{ + hidden_state: Axon.nx(vision_model, & &1.hidden_state), + hidden_states: Axon.nx(vision_model, & &1.hidden_states), + attentions: Axon.nx(vision_model, & &1.attentions) + } + end + text_decoder = text_spec |> Bumblebee.configure( @@ -141,7 +165,7 @@ defmodule Bumblebee.Multimodal.Blip do "input_ids" => inputs["decoder_input_ids"], "attention_mask" => inputs["decoder_attention_mask"], "position_ids" => inputs["decoder_position_ids"], - "encoder_hidden_state" => Axon.nx(vision_model, & &1.hidden_state), + "encoder_hidden_state" => vision_model_outputs.hidden_state, "cache" => inputs["cache"] }) @@ -150,9 +174,9 @@ defmodule Bumblebee.Multimodal.Blip do decoder_hidden_states: Axon.nx(text_decoder, & &1.hidden_states), decoder_attentions: Axon.nx(text_decoder, & &1.attentions), cross_attentions: Axon.nx(text_decoder, & &1.cross_attentions), - encoder_hidden_state: Axon.nx(vision_model, & &1.hidden_state), - encoder_hidden_states: Axon.nx(vision_model, & &1.hidden_states), - encoder_attentions: Axon.nx(vision_model, & &1.attentions), + encoder_hidden_state: vision_model_outputs.hidden_state, + encoder_hidden_states: vision_model_outputs.hidden_states, + encoder_attentions: vision_model_outputs.attentions, cache: Axon.nx(text_decoder, & &1.cache) }) end