Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Blip2ForImageTextRetrieval for multimodal feature extraction #25612

Conversation

jpizarrom
Copy link
Contributor

@jpizarrom jpizarrom commented Aug 19, 2023

What does this PR do?

Add Blip2ForImageTextRetrieval model to be able to extract text,image,multimodal. similar to extract_features method in the original implementation.

Fixes part of #25300 #25245

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @NielsRogge

TODOs:

  • Convert original weights from Blip2 ITM
  • New model should return the same feature vectors as the original model
  • Add forward method
  • Add extract feature methods (code was previously added in the forward method)
  • Add Blip2TextRetrievalModelTest
  • Refactor to try to add feature extractor logic into Blip2ModelWithProjection
  • use float16 tests
  • use float16 in doctest
  • add Blip2TextModelWithProjection and Blip2VisionModelWithProjection
  • add text_config=None support in Blip2Config, remove Blip2ModelWithoutLMConfig
  • change model name from jpizarrom/xxxx to Salesforce/xxx ?
  • remove Blip2TextModelWithProjection

@jpizarrom jpizarrom changed the title Add Blip2ForImageTextRetrieval Add Blip2ForImageTextRetrieval multimodal feature extraction Aug 19, 2023
@jpizarrom jpizarrom changed the title Add Blip2ForImageTextRetrieval multimodal feature extraction Add Blip2ForImageTextRetrieval for multimodal feature extraction Aug 19, 2023
@jpizarrom jpizarrom changed the title Add Blip2ForImageTextRetrieval for multimodal feature extraction [WIP] Add Blip2ForImageTextRetrieval for multimodal feature extraction Aug 19, 2023
@jpizarrom
Copy link
Contributor Author

jpizarrom commented Aug 19, 2023

@ArthurZucker
Copy link
Collaborator

cc @amyeroberts and @rafaelpadilla !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@amyeroberts
Copy link
Collaborator

@jpizarrom Thanks for opening this PR! Let us know when it's ready for review :)

@jpizarrom jpizarrom force-pushed the add_blip2_image_text_retrieval_model branch 2 times, most recently from c471b58 to dae1d4e Compare August 23, 2023 12:26
@jpizarrom
Copy link
Contributor Author

Hi @amyeroberts
Could you please help me :)
I am getting this error in ci/circleci.

FAILED tests/utils/test_hub_utils.py::GetFromCacheTests::test_get_file_gated_repo - AssertionError: OSError not raised
FAILED tests/utils/test_hub_utils.py::GetFromCacheTests::test_has_file_gated_repo - AssertionError: OSError not raised

Do you know if it could be related to my PR, some kind of side effect, or is this error not related to my PR?
Thanks

@jpizarrom
Copy link
Contributor Author

Hi @amyeroberts Could you please help me :) I am getting this error in ci/circleci.

FAILED tests/utils/test_hub_utils.py::GetFromCacheTests::test_get_file_gated_repo - AssertionError: OSError not raised
FAILED tests/utils/test_hub_utils.py::GetFromCacheTests::test_has_file_gated_repo - AssertionError: OSError not raised

Do you know if it could be related to my PR, some kind of side effect, or is this error not related to my PR? Thanks

It looks the issue is not related with my changes, i just tried using the main branch

from huggingface_hub import hf_hub_download
hf_hub_download("hf-internal-testing/dummy-gated-model", "README.md") # don't fail
hf_hub_download("hf-internal-testing/dummy-gated-model", "otherfile") # error: Cannot access gated repo for url...

the test https://github.com/huggingface/transformers/blob/main/tests/utils/test_hub_utils.py#L131 is trying to get the README.md, and expect an exception

@amyeroberts
Copy link
Collaborator

Hi @jpizarrom, yes there was a recent issue that resulted in some of the hub tests failing unfortunately. Rest assured they are not related to your PR :) For the moment you can ignore these tests. When you rebase on main they should be resolved.

@jpizarrom jpizarrom force-pushed the add_blip2_image_text_retrieval_model branch 2 times, most recently from eb82bf1 to 6a9021a Compare August 26, 2023 09:22
@jpizarrom
Copy link
Contributor Author

jpizarrom commented Aug 26, 2023

Hi @amyeroberts and @rafaelpadilla , could you please help me? :)

This PR is working, I still need to add more tests, but I would love to get your feedback about whether is it fine that the methods get_text_features and get_image_features were added to the proposed new class Blip2ForImageTextRetrieval. Or the logic should be added to Blip2Model, and extend Blip2Model to support also the feature extraction of models without t5/opt language models, but with text and vision protections.

At the moment there is no huggingface model similar to the original Blip2Qformer/blip2 model with lang and visual projections, the current huggingface Blip2Model seems to be more related to the original Blip2OPT/Blip2T5

@amyeroberts
Copy link
Collaborator

@jpizarrom My suggestion would be to extend Blip2Model as it already has get_text_features and get_image_features. Similarly, other retrieval models e.g. ViltForImageTextRetrieval don't have these methods implemented. I don't believe there's any reason why we couldn't also add these methods to Blip2ForImageTextRetrieval as well if you think it makes more sense - there's just a maintenance cost, as we can't guarantee and changes in the implementation in one class will be correctly updated in all places: adding tests to guard against this would be ideal.

@jpizarrom jpizarrom force-pushed the add_blip2_image_text_retrieval_model branch from 42bf954 to ca8e50d Compare September 13, 2023 18:04
@amyeroberts
Copy link
Collaborator

@jpizarrom From next week I'm going to be away for a few weeks. If you have any questions, please ask @rafaelpadilla

@jpizarrom
Copy link
Contributor Author

jpizarrom commented Sep 17, 2023

Hi @rafaelpadilla, I would appreciate to receive your feedback about this PR,
as recommended in #25612 (comment), I started to extend theget_text_features and get_image_features in Blip2Model to try to support when the models has qformer with vision_proj and text_proj, and not extra language_model (original Blip2Qformer/blip2, more context in #25612 (comment)), but my PR is adding many if/else in Blip2Model to check whether it correspond to the model with/without the language_model(opt/t5).

The clip model has two classes for the cases with and without projections, CLIPVisionModel and CLIPVisionModelWithProjection respectively.

What do you think should be the strategy to follow in this PR?

  • How is it currently done in this PR, extend Blip2Model to support both types of models, and do some refactoring to make it nicer?
  • add the get features methods to the new classes Blip2ForImageTextRetrieval, this way there will be get features methods in Blip2Model and also Blip2ForImageTextRetrieval.
  • maybe add the get features methods to another new class Blip2ModelWithProjection

Thanks

@rafaelpadilla
Copy link
Contributor

Hi @jpizarrom,

Thank you for your contribution! :) I have just taken a look at your PR and it seems to me that the best strategy would be your first suggestion. IMO, your if/elses may not be a problem.

@ArthurZucker , what do you think the best strategy would be?

@jpizarrom jpizarrom force-pushed the add_blip2_image_text_retrieval_model branch from d2c874b to 0c20b66 Compare September 23, 2023 07:46
@jpizarrom jpizarrom changed the title [WIP] Add Blip2ForImageTextRetrieval for multimodal feature extraction Add Blip2ForImageTextRetrieval for multimodal feature extraction Sep 24, 2023
@jpizarrom jpizarrom marked this pull request as ready for review September 24, 2023 10:44
@jpizarrom
Copy link
Contributor Author

jpizarrom commented Sep 24, 2023

Hi @rafaelpadilla @ArthurZucker may you please review this PR?

Blip2ModelWithProjection and Blip2ForImageTextRetrieval were added, more context in #25612 (comment)

@NielsRogge wdyt about this PR?

Thanks

@jpizarrom jpizarrom force-pushed the add_blip2_image_text_retrieval_model branch from 6b502f1 to d2128cc Compare September 24, 2023 17:02

# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall the layernorm and dropout be moved to Blip2QFormerModel, to be more similar to the other Blip2 model implementations?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's fine for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend moving them out

@rafaelpadilla
Copy link
Contributor

I was comparing the structure of the code with CLIP and noticed that:

Here there's only one ModelWithProjection class Blip2ModelWithProjection, which deals with embeddings of both text and image. However, for other models, and particularly for CLIP, there are CLIPTextModelWithProjection and CLIPVisionModelWithProjection.

To keep consistency with other models, would it be possible to break Blip2ModelWithProjection into Blip2TextModelWithProjections and Blip2VisionModelWithProkections?

@jpizarrom
Copy link
Contributor Author

jpizarrom commented Oct 6, 2023

@rafaelpadilla thanks for the feedback, it could be possible to break the Blip2ModelWithProjection get_text_features and get_image_features into Blip2TextModelWithProjections and Blip2VisionModelWithProkections to follow CLIP model structure, but I believe an implementation of Blip2Qformer.forward is still needed, that was the reason why i was trying to implement it in HF as Blip2ModelWithProjection with the methods get_text_features,get_image_features,forward following BlipModel

maybe Blip2ModelWithoutLM could be a better class name instead of Blip2ModelWithProjection,
Blip2Qformer.forward is used to do the pretraining stage 1

@jpizarrom jpizarrom marked this pull request as draft October 6, 2023 14:15
@@ -61,6 +61,10 @@ If you're interested in submitting a resource to be included here, please feel f

[[autodoc]] Blip2QFormerConfig

## Blip2ModelWithoutLMConfig
Copy link
Contributor

@NielsRogge NielsRogge Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to not add this config class. I'm ok with having separate Blip2ForImageTextRetrieval and Blip2TextModelWithProjection classes, but having a separate Blip2ModelWithoutLMConfig is a bit weird. is there a reason there's a need for a separate config?

Copy link
Contributor Author

@jpizarrom jpizarrom Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, the reason is that these new models have only qformer and no other text model.
I will remove Blip2ModelWithoutLMConfig and will try to update Blip2Config to support text_config=None, it was also recommended by @younesbelkada in #25612 (comment)

@jpizarrom jpizarrom marked this pull request as draft October 24, 2023 08:00
@jpizarrom jpizarrom force-pushed the add_blip2_image_text_retrieval_model branch from c0cd3e4 to 253e067 Compare October 27, 2023 17:33
@jpizarrom jpizarrom marked this pull request as ready for review October 30, 2023 11:49
@jpizarrom
Copy link
Contributor Author

jpizarrom commented Oct 30, 2023

Hi @younesbelkada, this PR has been updated following your advice, and is ready for a review. Thanks

  • Blip2ModelWithoutLMConfig was removed, now all the new modes are using Blip2Config
  • Blip2ModelWithProjection model was removed(it was added in a previous commit in this PR), could be added later in other PR, such a model could be used for pre-training.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your great contribution! Think looks much cleaner now! Thanks also for pushing some converted model weights on the Hub!
Would you be able to run all blip2 slow tests and confirm they pass? RUN_SLOW=1 pytest tests/models/blip_2/
Let's also add a logits tests in the testing suite you added!

@@ -209,6 +209,7 @@ def __init__(
position_embedding_type="absolute",
cross_attention_frequency=2,
encoder_hidden_size=1408,
qformer_text_input=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a line in the docstring above to explain the aim of this arg? 🙏

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +220 to +222
assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
else:
assert unexpected_keys == ["qformer.embeddings.position_ids"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually we want to avoid having assert in the codebase but since it was already there that's fine I would say


# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's fine for now

@@ -1149,7 +1326,7 @@ def forward(
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device, dtype=torch.long)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate on why specifying torch.long is now needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I reverted the usage of dtype=torch.long.

I had some issues with fp16 when I started to work on this PR, and I believed that fixed the issue, now I am not able to replicate the it, even without dtype=torch.long, so it is better to remove it.

for model_name in ["jpizarrom/blip2-itm-vit-g"]:
model = Blip2VisionModelWithProjection.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertTrue(hasattr(model, "vision_proj"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a logits tests here to make sure future PRs will not break anything - what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a verification of the shape of the outputs to _test_model_from_pretrained_

and added this new tests:

  • test_inference_itm
  • test_inference_itm_fp16
  • test_inference_vision_with_projection_fp16
  • test_inference_text_with_projection_fp16


@slow
def test_model_from_pretrained(self):
for model_name in ["jpizarrom/blip2-itm-vit-g"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also make that simpler:

model_name = "jpizarrom/blip2-itm-vit-g"
model = ...

@jpizarrom
Copy link
Contributor Author

Thanks for your great contribution! Think looks much cleaner now! Thanks also for pushing some converted model weights on the Hub! Would you be able to run all blip2 slow tests and confirm they pass? RUN_SLOW=1 pytest tests/models/blip_2/ Let's also add a logits tests in the testing suite you added!

Hi, I have made the recommended changes, answered directly in the comments of your reviews.

slow tests passed, RUN_SLOW=1 pytest tests/models/blip_2/

new doctest passed too
pytest --doctest-modules src/transformers/models/blip_2/modeling_blip_2.py::transformers.models.blip_2.modeling_blip_2.Blip2ForImageTextRetrieval.forward

pytest --doctest-modules src/transformers/models/blip_2/modeling_blip_2.py::transformers.models.blip_2.modeling_blip_2.Blip2TextModelWithProjection.forward

pytest --doctest-modules src/transformers/models/blip_2/modeling_blip_2.py::transformers.models.blip_2.modeling_blip_2.Blip2VisionModelWithProjection.forward

@huggingface huggingface deleted a comment from github-actions bot Nov 26, 2023
@ArthurZucker
Copy link
Collaborator

cc @amyeroberts if you could review (@younesbelkada is off!)
And sorry @jpizarrom for the wait

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Mostly comments about the structure and making sure all tests are correctly handled and run.

)
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
main_input_name = "pixel_values"
config_class = Blip2Config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as Blip2PreTrainedModel

Suggested change
config_class = Blip2Config



class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
config_class = Blip2Config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as parent class

Suggested change
config_class = Blip2Config

@@ -1521,6 +1696,183 @@ def forward(
)


class Blip2TextModelWithProjection(Blip2PreTrainedModel):
config_class = Blip2Config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
config_class = Blip2Config

@@ -22,6 +22,7 @@
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn.functional import normalize
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nn.functional.normalize should be used instead to be consistent with the rest of the repo, and to make explicit which normalize functionality is being used.

Comment on lines +179 to +180
qformer_text_input (`bool`, *optional*, defaults to `False`):
Whether to use BERT-style embeddings.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we change this to a name which indicates it's a flag e.g. use_qformer_text_input, use_text_embeddings etc.? The name indicates that the value would be a text input itself

Comment on lines +1267 to +1272
@unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass

@unittest.skip(reason="Blip2Model does not have input/output embeddings")
def test_model_common_attributes(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here (common_attributes only for the vision model AFAICT)

Comment on lines +1341 to +1347
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has GC been tested on these models?

Comment on lines +1285 to +1286
# TODO
raise NotImplementedError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't be raising not implemented errors for tests

Comment on lines +1083 to +1084
class Blip2VisionModelWithProjectionTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Blip2VisionModelWithProjection,) if is_torch_available() else ()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this class necessary? I'd expect to be able to test the vision models together e.g. like here for CLIP

test_head_masking = False

test_resize_embeddings = False
test_attention_outputs = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be true - the model takes output_attentions as an input

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jan 6, 2024
@jpizarrom
Copy link
Contributor Author

Hi @amyeroberts,

Now I would like to continue working on the comments received in December, how can I reopen this PR or should i create a new PR?
Thanks

@gleb-akhmerov
Copy link

gleb-akhmerov commented Feb 11, 2024

Hey, @jpizarrom, very nice work! I'm also interested in using BLIP 2 for image-text retrieval, specifically finding relevant images for a text query. (I'm just a regular user, not from HF.)

I understand that this PR is WIP, but it seems that it's in its final stages, so I want to give my feedback to help with testing.

When I try to pass multiple images to the model, I get an error. Is this a valid use case, or is the class intended for single image + multiple labels usage?

Code:

import requests
from PIL import Image
from transformers import Blip2ForImageTextRetrieval, Blip2Processor
from transformers.testing_utils import torch_device

def prepare_img():
    url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    return image


model_name = "jpizarrom/blip2-itm-vit-g"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForImageTextRetrieval.from_pretrained(model_name).to(torch_device)

images = [prepare_img(), prepare_img()]
text = "A woman and her dog sitting in a beach"
inputs = processor(images=images, text=text, return_tensors="pt").to(torch_device)

out = model(**inputs)

Error:

Traceback (most recent call last):
  File "/home/user/projects/transofmers-blip2-itm/test_multiple_images.py", line 22, in <module>
    out = model(**inputs)
          ^^^^^^^^^^^^^^^
  File "/home/user/projects/transofmers-blip2-itm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/projects/transofmers-blip2-itm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/projects/transofmers-blip2-itm/src/transformers/models/blip_2/modeling_blip_2.py", line 2363, in forward
    attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 2 but got size 1 for tensor number 1 in the list.

@jpizarrom
Copy link
Contributor Author

jpizarrom commented Feb 11, 2024

Hi @gleb-akhmerov
I think you could use Blip2ForImageTextRetrieval to match each text with each img, the length of both arrays should be the same

# %%
import requests
import torch
from PIL import Image
from transformers import Blip2ForImageTextRetrieval, Blip2Processor
from transformers.testing_utils import torch_device

# %%
def prepare_img():
    url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    return image

# %%
model_name = "jpizarrom/blip2-itm-vit-g"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForImageTextRetrieval.from_pretrained(model_name).to(torch_device)

# %%

images = [prepare_img(), prepare_img()]
text = "A woman and her dog sitting in a beach"
text_other = "A woman and her dog in a beach"


# %%
inputs = processor(images=images, text=[text,text_other], return_tensors="pt", padding=True).to(torch_device)

# %%
itm_out = model(**inputs, use_itm_head=True)
itm_scores = torch.nn.functional.softmax(itm_out.itm_score, dim=1)
print(f'The image and text are matched with a probability of {itm_scores[:, 1].tolist()}')

# %%
itc_out = model(**inputs, use_itm_head=False)
print(f'The image feature and text feature has a cosine similarity of {itc_out.itm_score.tolist()}')

or you can get image and text projections, then compare all images with the text

# %%
import requests
import torch
from PIL import Image
from transformers import Blip2TextModelWithProjection, Blip2VisionModelWithProjection, AutoProcessor
from transformers.testing_utils import torch_device

# %%
device = "cuda" if torch.cuda.is_available() else "cpu"

# %%
def prepare_img():
    url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    return image

# %%
model_name = "jpizarrom/blip2-itm-vit-g"
processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g")
vision_model = Blip2VisionModelWithProjection.from_pretrained(model_name).to(device) 
text_model = Blip2TextModelWithProjection.from_pretrained(model_name).to(device) 

# %%
images = [prepare_img(), prepare_img()]
text = "A woman and her dog sitting in a beach"

# %%
vision_inputs = processor(images=images, return_tensors="pt").to(torch_device)
vision_out = vision_model(**vision_inputs)
# out

# %%
text_inputs = processor(text=text, return_tensors="pt").to(torch_device)
text_out = text_model(**text_inputs)

# %%
print(vision_out.image_embeds.shape, text_out.text_embeds.shape)

# %%
max_scores, max_classes = (vision_out.image_embeds @ text_out.text_embeds[:,0,:].t()).max(dim=1)

# %%
print(max_scores)

@amyeroberts
Copy link
Collaborator

Hi @jpizarrom, I can't reopen this PR as something has happened upstream since closing: either the branch has had a force push or it's been recreated.

You can open a new PR and link to this one for reference.

@snpushpi
Copy link

Hi @jpizarrom has this PR been committed to HF transformers? I cannot import Blip2ForImageTextRetrieval somehow :(

@jpizarrom
Copy link
Contributor Author

jpizarrom commented Feb 22, 2024

Hi @snpushpi this PR was not finished, i was able to do a working version of Blip2ForImageTextRetrieval on this PR, but I need to work on the tests, and also apply some of the feedback I received some time ago, then I would like to open a new PR for revision.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants