-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Llava for 0-embeddings #30473
Fix Llava for 0-embeddings #30473
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice handling - thanks for digging into this and adding a test! 🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean, thanks for looking into this! cc @ArthurZucker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good call thanks
It seems that the problem still exists for llava when using the new code of #30473 . I upgraded 'transformers' to version The error still occurs: My code is:
Is it because I use quantization? I followed the tutorial from @NielsRogge , and I couldn't run throught the code. |
@DingYX0731 hey! I couldn't reproduce the error on the latest version of Can you check if it works for you in Colab and if yes, the problem might be in your local setup/hardware? For ex in #30294 the problem was in using "mps" as device |
The tutorial code works fine in Colab for me as well (with transformers==4.42.0.dev0, also well with 4.41.0 which is the same version as my local setting). The problem is very likely caused by my local setup, which is ubuntu20.04 and RTX 4090. But I am still confused... I have also tried older version like transformers==4.37.2, which is also work in Colab...not locally... |
Hmm, then it would hard for me to help you locate the bug. Let's try the following
import transformers
print(transformers.__version__)
inputs = processor(prompts, images=[image1, image2], padding=True, return_tensors="pt").to(model.device)
print(inputs.input_ids)
print(model.condig.image_token_index)
outputs = model(**inputs, use_cache=True) |
Sorry to bother you so much @zucchini-nlp For old version of transformer (4.37.2): inputs = processor(prompts, images=[image1, image2], padding=True, return_tensors="pt").to(torch_device)
print(inputs.input_ids)
print(model.config.image_token_index) the output is:
There appears lots of zeros. Does them indicate the existence of special tokens? And the code: outputs = model(**inputs, use_cache=True) still has problem:
For newer version of transformers (4.41.0),
The modified code indeed exists in the new package: Note that when I upgrade transformer==4.41.0, such messages occur:
But the llava-torch version is already the latest. Could this leads to the problem? |
I just found out the problem and solved it! |
In transformers version4.41. I stilll encounter the error belows: |
Hey @hxhcreate! What do you mean by |
I mean the data text itself wrongly contain several Could the correct images nums be infered from the input images themself, not the text? |
@hxhcreate Ah I see, unfortunately we can't infer that in processing. While it is doable, I think it will cause more errors in the future and we should better delegate to users to prepare their text and images correctly. You can preprocess you dataset manually by checking how many images you have and replacing extra |
I see, that's easy to do, thanks for your help |
What does this PR do?
Fixes #29835. In llava-next the embedding weights of some tokens are rounded to 0 when cast to fp-16, which results in incorrect calculation for image_positions. This PR fixes it by getting image positions as "anything that was not in text_positions", so that we do not rely on values.
All llava tests (+slow) are passing locally and I added one test for the < unk > token in llva-next. "Unk" is one of the tokens that get cast to 0, but there are around 200 such tokens in llava-mistral-7b version/