-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Blip2ForImageTextRetrieval for multimodal feature extraction #25612
Add Blip2ForImageTextRetrieval for multimodal feature extraction #25612
Conversation
|
cc @amyeroberts and @rafaelpadilla ! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
@jpizarrom Thanks for opening this PR! Let us know when it's ready for review :) |
c471b58
to
dae1d4e
Compare
Hi @amyeroberts
Do you know if it could be related to my PR, some kind of side effect, or is this error not related to my PR? |
It looks the issue is not related with my changes, i just tried using the main branch
the test https://github.com/huggingface/transformers/blob/main/tests/utils/test_hub_utils.py#L131 is trying to get the README.md, and expect an exception |
Hi @jpizarrom, yes there was a recent issue that resulted in some of the hub tests failing unfortunately. Rest assured they are not related to your PR :) For the moment you can ignore these tests. When you rebase on main they should be resolved. |
eb82bf1
to
6a9021a
Compare
Hi @amyeroberts and @rafaelpadilla , could you please help me? :) This PR is working, I still need to add more tests, but I would love to get your feedback about whether is it fine that the methods get_text_features and get_image_features were added to the proposed new class Blip2ForImageTextRetrieval. Or the logic should be added to Blip2Model, and extend Blip2Model to support also the feature extraction of models without t5/opt language models, but with text and vision protections. At the moment there is no huggingface model similar to the original Blip2Qformer/blip2 model with lang and visual projections, the current huggingface Blip2Model seems to be more related to the original Blip2OPT/Blip2T5 |
@jpizarrom My suggestion would be to extend Blip2Model as it already has |
42bf954
to
ca8e50d
Compare
@jpizarrom From next week I'm going to be away for a few weeks. If you have any questions, please ask @rafaelpadilla |
Hi @rafaelpadilla, I would appreciate to receive your feedback about this PR, The clip model has two classes for the cases with and without projections, What do you think should be the strategy to follow in this PR?
Thanks |
Hi @jpizarrom, Thank you for your contribution! :) I have just taken a look at your PR and it seems to me that the best strategy would be your first suggestion. IMO, your if/elses may not be a problem. @ArthurZucker , what do you think the best strategy would be? |
d2c874b
to
0c20b66
Compare
Hi @rafaelpadilla @ArthurZucker may you please review this PR?
@NielsRogge wdyt about this PR? Thanks |
6b502f1
to
d2128cc
Compare
|
||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | ||
# any TensorFlow checkpoint file | ||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall the layernorm and dropout be moved to Blip2QFormerModel
, to be more similar to the other Blip2 model implementations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's fine for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend moving them out
I was comparing the structure of the code with CLIP and noticed that: Here there's only one ModelWithProjection class To keep consistency with other models, would it be possible to break |
@rafaelpadilla thanks for the feedback, it could be possible to break the maybe |
docs/source/en/model_doc/blip-2.md
Outdated
@@ -61,6 +61,10 @@ If you're interested in submitting a resource to be included here, please feel f | |||
|
|||
[[autodoc]] Blip2QFormerConfig | |||
|
|||
## Blip2ModelWithoutLMConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be great to not add this config class. I'm ok with having separate Blip2ForImageTextRetrieval
and Blip2TextModelWithProjection
classes, but having a separate Blip2ModelWithoutLMConfig
is a bit weird. is there a reason there's a need for a separate config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback, the reason is that these new models have only qformer and no other text model.
I will remove Blip2ModelWithoutLMConfig and will try to update Blip2Config
to support text_config=None, it was also recommended by @younesbelkada in #25612 (comment)
c0cd3e4
to
253e067
Compare
Hi @younesbelkada, this PR has been updated following your advice, and is ready for a review. Thanks
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your great contribution! Think looks much cleaner now! Thanks also for pushing some converted model weights on the Hub!
Would you be able to run all blip2 slow tests and confirm they pass? RUN_SLOW=1 pytest tests/models/blip_2/
Let's also add a logits tests in the testing suite you added!
@@ -209,6 +209,7 @@ def __init__( | |||
position_embedding_type="absolute", | |||
cross_attention_frequency=2, | |||
encoder_hidden_size=1408, | |||
qformer_text_input=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a line in the docstring above to explain the aim of this arg? 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"] | ||
else: | ||
assert unexpected_keys == ["qformer.embeddings.position_ids"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually we want to avoid having assert
in the codebase but since it was already there that's fine I would say
|
||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | ||
# any TensorFlow checkpoint file | ||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's fine for now
@@ -1149,7 +1326,7 @@ def forward( | |||
if type(encoder_attention_mask) == list: | |||
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] | |||
elif encoder_attention_mask is None: | |||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) | |||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device, dtype=torch.long) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you elaborate on why specifying torch.long is now needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I reverted the usage of dtype=torch.long
.
I had some issues with fp16 when I started to work on this PR, and I believed that fixed the issue, now I am not able to replicate the it, even without dtype=torch.long
, so it is better to remove it.
for model_name in ["jpizarrom/blip2-itm-vit-g"]: | ||
model = Blip2VisionModelWithProjection.from_pretrained(model_name) | ||
self.assertIsNotNone(model) | ||
self.assertTrue(hasattr(model, "vision_proj")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a logits tests here to make sure future PRs will not break anything - what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a verification of the shape of the outputs to _test_model_from_pretrained_
and added this new tests:
test_inference_itm
test_inference_itm_fp16
test_inference_vision_with_projection_fp16
test_inference_text_with_projection_fp16
|
||
@slow | ||
def test_model_from_pretrained(self): | ||
for model_name in ["jpizarrom/blip2-itm-vit-g"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's also make that simpler:
model_name = "jpizarrom/blip2-itm-vit-g"
model = ...
Hi, I have made the recommended changes, answered directly in the comments of your reviews. slow tests passed, new doctest passed too
|
cc @amyeroberts if you could review (@younesbelkada is off!) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
Mostly comments about the structure and making sure all tests are correctly handled and run.
) | ||
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): | ||
main_input_name = "pixel_values" | ||
config_class = Blip2Config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the same as Blip2PreTrainedModel
config_class = Blip2Config |
|
||
|
||
class Blip2VisionModelWithProjection(Blip2PreTrainedModel): | ||
config_class = Blip2Config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as parent class
config_class = Blip2Config |
@@ -1521,6 +1696,183 @@ def forward( | |||
) | |||
|
|||
|
|||
class Blip2TextModelWithProjection(Blip2PreTrainedModel): | |||
config_class = Blip2Config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config_class = Blip2Config |
@@ -22,6 +22,7 @@ | |||
import torch.utils.checkpoint | |||
from torch import nn | |||
from torch.nn import CrossEntropyLoss | |||
from torch.nn.functional import normalize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn.functional.normalize
should be used instead to be consistent with the rest of the repo, and to make explicit which normalize
functionality is being used.
qformer_text_input (`bool`, *optional*, defaults to `False`): | ||
Whether to use BERT-style embeddings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we change this to a name which indicates it's a flag e.g. use_qformer_text_input
, use_text_embeddings
etc.? The name indicates that the value would be a text input itself
@unittest.skip(reason="Retain_grad is tested in individual model tests") | ||
def test_retain_grad_hidden_states_attentions(self): | ||
pass | ||
|
||
@unittest.skip(reason="Blip2Model does not have input/output embeddings") | ||
def test_model_common_attributes(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here (common_attributes only for the vision model AFAICT)
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" | ||
) | ||
def test_training_gradient_checkpointing_use_reentrant(self): | ||
pass | ||
|
||
@unittest.skip( | ||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has GC been tested on these models?
# TODO | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't be raising not implemented errors for tests
class Blip2VisionModelWithProjectionTest(ModelTesterMixin, unittest.TestCase): | ||
all_model_classes = (Blip2VisionModelWithProjection,) if is_torch_available() else () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this class necessary? I'd expect to be able to test the vision models together e.g. like here for CLIP
test_head_masking = False | ||
|
||
test_resize_embeddings = False | ||
test_attention_outputs = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be true - the model takes output_attentions
as an input
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi @amyeroberts, Now I would like to continue working on the comments received in December, how can I reopen this PR or should i create a new PR? |
Hey, @jpizarrom, very nice work! I'm also interested in using BLIP 2 for image-text retrieval, specifically finding relevant images for a text query. (I'm just a regular user, not from HF.) I understand that this PR is WIP, but it seems that it's in its final stages, so I want to give my feedback to help with testing. When I try to pass multiple images to the model, I get an error. Is this a valid use case, or is the class intended for single image + multiple labels usage? Code: import requests
from PIL import Image
from transformers import Blip2ForImageTextRetrieval, Blip2Processor
from transformers.testing_utils import torch_device
def prepare_img():
url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image
model_name = "jpizarrom/blip2-itm-vit-g"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForImageTextRetrieval.from_pretrained(model_name).to(torch_device)
images = [prepare_img(), prepare_img()]
text = "A woman and her dog sitting in a beach"
inputs = processor(images=images, text=text, return_tensors="pt").to(torch_device)
out = model(**inputs) Error:
|
Hi @gleb-akhmerov # %%
import requests
import torch
from PIL import Image
from transformers import Blip2ForImageTextRetrieval, Blip2Processor
from transformers.testing_utils import torch_device
# %%
def prepare_img():
url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image
# %%
model_name = "jpizarrom/blip2-itm-vit-g"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForImageTextRetrieval.from_pretrained(model_name).to(torch_device)
# %%
images = [prepare_img(), prepare_img()]
text = "A woman and her dog sitting in a beach"
text_other = "A woman and her dog in a beach"
# %%
inputs = processor(images=images, text=[text,text_other], return_tensors="pt", padding=True).to(torch_device)
# %%
itm_out = model(**inputs, use_itm_head=True)
itm_scores = torch.nn.functional.softmax(itm_out.itm_score, dim=1)
print(f'The image and text are matched with a probability of {itm_scores[:, 1].tolist()}')
# %%
itc_out = model(**inputs, use_itm_head=False)
print(f'The image feature and text feature has a cosine similarity of {itc_out.itm_score.tolist()}') or you can get image and text projections, then compare all images with the text # %%
import requests
import torch
from PIL import Image
from transformers import Blip2TextModelWithProjection, Blip2VisionModelWithProjection, AutoProcessor
from transformers.testing_utils import torch_device
# %%
device = "cuda" if torch.cuda.is_available() else "cpu"
# %%
def prepare_img():
url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image
# %%
model_name = "jpizarrom/blip2-itm-vit-g"
processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g")
vision_model = Blip2VisionModelWithProjection.from_pretrained(model_name).to(device)
text_model = Blip2TextModelWithProjection.from_pretrained(model_name).to(device)
# %%
images = [prepare_img(), prepare_img()]
text = "A woman and her dog sitting in a beach"
# %%
vision_inputs = processor(images=images, return_tensors="pt").to(torch_device)
vision_out = vision_model(**vision_inputs)
# out
# %%
text_inputs = processor(text=text, return_tensors="pt").to(torch_device)
text_out = text_model(**text_inputs)
# %%
print(vision_out.image_embeds.shape, text_out.text_embeds.shape)
# %%
max_scores, max_classes = (vision_out.image_embeds @ text_out.text_embeds[:,0,:].t()).max(dim=1)
# %%
print(max_scores) |
Hi @jpizarrom, I can't reopen this PR as something has happened upstream since closing: either the branch has had a force push or it's been recreated. You can open a new PR and link to this one for reference. |
Hi @jpizarrom has this PR been committed to HF transformers? I cannot import Blip2ForImageTextRetrieval somehow :( |
Hi @snpushpi this PR was not finished, i was able to do a working version of |
What does this PR do?
Add Blip2ForImageTextRetrieval model to be able to extract text,image,multimodal. similar to extract_features method in the original implementation.
Fixes part of #25300 #25245
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @NielsRogge
TODOs:
Blip2TextRetrievalModelTest
Blip2ModelWithProjection
Blip2TextModelWithProjection
andBlip2VisionModelWithProjection
Blip2Config
, removeBlip2ModelWithoutLMConfig