Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Blip2ForImageTextRetrieval for multimodal feature extraction #25612

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
4023732
Add Blip2ForImageTextRetrieval
jpizarrom Aug 19, 2023
8eb718a
add Blip2ModelWithProjection
jpizarrom Sep 24, 2023
786da89
use gpu on Blip2ForImageTextRetrieval.forward doctest
jpizarrom Sep 24, 2023
188e3a7
use gpu on Blip2ModelWithProjection.forward doctest
jpizarrom Sep 24, 2023
d1cc037
use float16 on Blip2ForImageTextRetrieval.forward doctest
jpizarrom Sep 24, 2023
5f72231
add _tied_weights_keys to Blip2ForImageTextRetrieval
jpizarrom Sep 24, 2023
e099caa
add temp param to Blip2ForImageTextRetrieval
jpizarrom Sep 24, 2023
a1ab97f
add Blip2TextModelWithProjection and Blip2VisionModelWithProjection
jpizarrom Oct 6, 2023
18d5340
use cuda and float16 in doctest Blip2VisionModelWithProjection
jpizarrom Oct 6, 2023
0a227d0
rename Blip2ModelWithProjection to Blip2ModelWithoutLM
jpizarrom Oct 6, 2023
43fb263
add image_text_hidden_size to docstring
jpizarrom Oct 6, 2023
f8b0ed5
remove image_text_hidden_size from BlipConfig
jpizarrom Oct 6, 2023
401b8b8
use Blip2ModelWithoutLMConfig in convert script
jpizarrom Oct 6, 2023
a0f7142
remove not used text_model_tester
jpizarrom Oct 6, 2023
46adfd5
restore image_text_hidden_size in BlipConfig
jpizarrom Oct 6, 2023
a2c098e
rename Blip2ModelWithoutLMConfig.from_vision_qformer_configs
jpizarrom Oct 17, 2023
532f5ae
remove Blip2ModelWithoutLMConfig
jpizarrom Oct 26, 2023
ce86d4c
remove Blip2ModelWithProjection
jpizarrom Oct 27, 2023
253e067
remove _tied_weights_keys in Blip2ForImageTextRetrieval
jpizarrom Oct 27, 2023
81aea68
remove unused code: blip2_loss
jpizarrom Oct 30, 2023
04e2668
remove unused Blip2Output
jpizarrom Oct 30, 2023
3d2dfbd
remove Blip2ModelWithoutLM from check_repo
jpizarrom Oct 30, 2023
b9343ba
add qformer_text_input line in the docstring
jpizarrom Oct 30, 2023
6b65330
add tests for Blip2ForImageTextRetrieval and Blip2VisionModelWithProj…
jpizarrom Oct 30, 2023
afd66ca
Merge branch 'main' into add_blip2_image_text_retrieval_model
jpizarrom Oct 31, 2023
47acd93
add skip on test_training_gradient_checkpointing_use_reentrant
jpizarrom Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion docs/source/en/model_doc/blip-2.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,17 @@ If you're interested in submitting a resource to be included here, please feel f

[[autodoc]] Blip2ForConditionalGeneration
- forward
- generate
- generate

## Blip2ForImageTextRetrieval

[[autodoc]] Blip2ForImageTextRetrieval
- forward

## Blip2TextModelWithProjection

[[autodoc]] Blip2TextModelWithProjection

## Blip2VisionModelWithProjection

[[autodoc]] Blip2VisionModelWithProjection
6 changes: 6 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,10 +1363,13 @@
[
"BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Blip2ForConditionalGeneration",
"Blip2ForImageTextRetrieval",
"Blip2Model",
"Blip2PreTrainedModel",
"Blip2QFormerModel",
"Blip2TextModelWithProjection",
"Blip2VisionModel",
"Blip2VisionModelWithProjection",
]
)
_import_structure["models.bloom"].extend(
Expand Down Expand Up @@ -5438,10 +5441,13 @@
from .models.blip_2 import (
BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Blip2ForConditionalGeneration,
Blip2ForImageTextRetrieval,
Blip2Model,
Blip2PreTrainedModel,
Blip2QFormerModel,
Blip2TextModelWithProjection,
Blip2VisionModel,
Blip2VisionModelWithProjection,
)
from .models.bloom import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/blip_2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@
_import_structure["modeling_blip_2"] = [
"BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Blip2Model",
"Blip2VisionModelWithProjection",
"Blip2QFormerModel",
"Blip2PreTrainedModel",
"Blip2ForConditionalGeneration",
"Blip2ForImageTextRetrieval",
"Blip2VisionModel",
"Blip2TextModelWithProjection",
]

if TYPE_CHECKING:
Expand All @@ -59,10 +62,13 @@
from .modeling_blip_2 import (
BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Blip2ForConditionalGeneration,
Blip2ForImageTextRetrieval,
Blip2Model,
Blip2PreTrainedModel,
Blip2QFormerModel,
Blip2TextModelWithProjection,
Blip2VisionModel,
Blip2VisionModelWithProjection,
)

else:
Expand Down
38 changes: 36 additions & 2 deletions src/transformers/models/blip_2/configuration_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ class Blip2QFormerConfig(PretrainedConfig):
The frequency of adding cross-attention to the Transformer layers.
encoder_hidden_size (`int`, *optional*, defaults to 1408):
The hidden size of the hidden states for cross-attention.
qformer_text_input (`bool`, *optional*, defaults to `False`):
Whether to use BERT-style embeddings.
Comment on lines +179 to +180
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we change this to a name which indicates it's a flag e.g. use_qformer_text_input, use_text_embeddings etc.? The name indicates that the value would be a text input itself


Examples:

Expand Down Expand Up @@ -209,6 +211,7 @@ def __init__(
position_embedding_type="absolute",
cross_attention_frequency=2,
encoder_hidden_size=1408,
qformer_text_input=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a line in the docstring above to explain the aim of this arg? 🙏

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand All @@ -227,6 +230,7 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.cross_attention_frequency = cross_attention_frequency
self.encoder_hidden_size = encoder_hidden_size
self.qformer_text_input = qformer_text_input

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
Expand Down Expand Up @@ -266,7 +270,8 @@ class Blip2Config(PretrainedConfig):
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
num_query_tokens (`int`, *optional*, defaults to 32):
The number of query tokens passed through the Transformer.

image_text_hidden_size (`int`, *optional*, defaults to 256):
Dimentionality of the hidden state of the image-text fusion layer.
kwargs (*optional*):
Dictionary of keyword arguments.

Expand Down Expand Up @@ -302,7 +307,15 @@ class Blip2Config(PretrainedConfig):

model_type = "blip-2"

def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
def __init__(
self,
vision_config=None,
qformer_config=None,
text_config=None,
num_query_tokens=32,
image_text_hidden_size=256,
**kwargs,
):
super().__init__(**kwargs)

if vision_config is None:
Expand All @@ -326,6 +339,7 @@ def __init__(self, vision_config=None, qformer_config=None, text_config=None, nu
self.is_encoder_decoder = self.text_config.is_encoder_decoder

self.num_query_tokens = num_query_tokens
self.image_text_hidden_size = image_text_hidden_size
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
self.initializer_factor = 1.0
Expand Down Expand Up @@ -353,3 +367,23 @@ def from_vision_qformer_text_configs(
text_config=text_config.to_dict(),
**kwargs,
)

@classmethod
def from_vision_qformer_configs(
cls,
vision_config: Blip2VisionConfig,
qformer_config: Blip2QFormerConfig,
**kwargs,
):
r"""
Instantiate a [`Blip2Config`] (or a derived class) from a BLIP-2 vision and Q-Former model configurations.

Returns:
[`Blip2Config`]: An instance of a configuration object
"""

return cls(
vision_config=vision_config.to_dict(),
qformer_config=qformer_config.to_dict(),
**kwargs,
)
Loading