Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BLIP2 Model Fails with Version 4.46.3 (Shape Mismatch Error) #34990

Closed
4 tasks
movchan74 opened this issue Nov 28, 2024 · 4 comments
Closed
4 tasks

BLIP2 Model Fails with Version 4.46.3 (Shape Mismatch Error) #34990

movchan74 opened this issue Nov 28, 2024 · 4 comments
Labels

Comments

@movchan74
Copy link

movchan74 commented Nov 28, 2024

System Info

  • transformers version: 4.46.3
  • Platform: Linux-5.15.0-89-generic-x86_64-with-glibc2.35
  • Python version: 3.10.15
  • Huggingface_hub version: 0.26.2
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 2080 Ti

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import requests
from PIL import Image
import torch
from transformers import Blip2ForConditionalGeneration, Blip2Processor

# Load the model and processor for BLIP2 from HuggingFace
model_id = "Salesforce/blip2-opt-2.7b"
torch_dtype = torch.float16
load_in_8bit = False

model = Blip2ForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch_dtype, load_in_8bit=load_in_8bit, device_map="cuda"
)
model = torch.compile(model)
model.eval()
processor = Blip2Processor.from_pretrained(model_id)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

inputs = processor(raw_image, return_tensors="pt").to(device)
generated_ids = model.generate(**inputs)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
generated_texts = [generated_text.strip() for generated_text in generated_texts]
print (generated_texts)

Error message:

RuntimeError                              Traceback (most recent call last)
Cell In[2], line 26
     23 raw_image = Image.open(requests.get(img_url, stream=True).raw).convert(\"RGB\")
     25 inputs = processor(raw_image, return_tensors=\"pt\").to(device)
---> 26 generated_ids = model.generate(**inputs)
     27 generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
     28 generated_texts = [generated_text.strip() for generated_text in generated_texts]

File ~/.cache/pypoetry/virtualenvs/aana-vIr3-B0u-py3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/aana-vIr3-B0u-py3.10/lib/python3.10/site-packages/transformers/models/blip_2/modeling_blip_2.py:2316, in Blip2ForConditionalGeneration.generate(self, pixel_values, input_ids, attention_mask, interpolate_pos_encoding, **generate_kwargs)
   2314 if getattr(self.config, \"image_token_index\", None) is not None:
   2315     special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
-> 2316     inputs_embeds[special_image_mask] = language_model_inputs.flatten()
   2317 else:
   2318     logger.warning_once(
   2319         \"Expanding inputs for image tokens in BLIP-2 should be done in processing. \"
   2320         \"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. \"
   2321         \"Using processors without these attributes in the config is deprecated and will throw an error in v4.47.\"
   2322     )

RuntimeError: shape mismatch: value tensor of shape [81920] cannot be broadcast to indexing result of shape [0]"

Expected behavior

The BLIP2 model does not work with transformers==4.46.3

The model should not fail.

@movchan74 movchan74 added the bug label Nov 28, 2024
@movchan74
Copy link
Author

Seems like shape mismatch error is quite a common issue with new releases. I myself reported such an error before for the Idefics 2 model: #33752

Are there any tests to ensure the model works fine with a new release?

@zucchini-nlp
Copy link
Member

Hey everyone! The issue was fixed on main branch and will be included in the next release around next week. For now please install transformers from source as !pip install --upgrade git+https://github.com/huggingface/transformers.git

The error is caused by updating configs on the hub, we'll make sure to run more tests before updating/merging hub related modifications to prevent this type of issues.

@movchan74
Copy link
Author

I can confirm that it works in the main branch.

@prasadseemakurthi
Copy link

I can confirm it too that it worked for me on local as well on Google Colab

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants