-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix vipllava for generation #29874
Fix vipllava for generation #29874
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Do you know if our slow CI caught this? Otherwise let's add a test!
Yes, the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit surprised by the other model working as expected because other llava based model do use [-1] which is the head_dim
no?
Yeah, that's because other models index only the first head dim in this line , while VipLlava indexes the first head. I found why we need this head indexing hack and tried generating with padded inputs, I did not see any difference between indexing first head or first head dimension. |
Made Llava model code consistent and ran all tests (+slow). Some that were failing do not have anything to do with the current changes, I confirmed they were failing long time before |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
@@ -403,7 +403,7 @@ def test_small_model_integration_test_llama(self): | |||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) | |||
|
|||
output = model.generate(**inputs, max_new_tokens=900, do_sample=False) | |||
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the presence of wildlife, such as birds or fish, and avoid disturbing their natural habitats. Lastly, be aware of any local regulations or guidelines for the use of the pier, as some areas may be restricted or prohibited for certain activities." # fmt: skip | |||
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water, there are a few things to be cautious about. First, be aware of the surroundings and potential hazards, such as slippery surfaces, uneven ground, or any obstacles in the water. Second, be mindful of the weather conditions, as sudden changes in weather can make the dock or pier unsafe to use. Third, be cautious of the water depth and any underwater hazards, such as rocks or debris, that could pose a risk to your safety. Lastly, be respectful of the environment and other visitors, and follow any rules or guidelines posted at the dock or pier." # fmt: skip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which hardware did you run this on ? should not be failing! but maybe it T4 vs A100
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related question: if llava
was not touched in this PR and our daily slow CI is not complaining, why was this test changed? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I was running the test on A100. If daily CI is passing, then prob it does not need change. Anyway, it's interesting for me that some tests are reliant on the hardware, which means a contributor might change failing tests without knowing that it's correct.
Should I just revert changes then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch
@@ -441,10 +441,10 @@ def forward( | |||
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: | |||
# Retrieve the first layer to inspect the logits and mask out the hidden states | |||
# that are set to 0 | |||
first_layer_past_key_value = past_key_values[0][0][:, 0, :, :] | |||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like an opportunity for # Copied from
(not to be fixed in this PR, but in the future)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted for the next PR :)
@@ -403,7 +403,7 @@ def test_small_model_integration_test_llama(self): | |||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) | |||
|
|||
output = model.generate(**inputs, max_new_tokens=900, do_sample=False) | |||
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the presence of wildlife, such as birds or fish, and avoid disturbing their natural habitats. Lastly, be aware of any local regulations or guidelines for the use of the pier, as some areas may be restricted or prohibited for certain activities." # fmt: skip | |||
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water, there are a few things to be cautious about. First, be aware of the surroundings and potential hazards, such as slippery surfaces, uneven ground, or any obstacles in the water. Second, be mindful of the weather conditions, as sudden changes in weather can make the dock or pier unsafe to use. Third, be cautious of the water depth and any underwater hazards, such as rocks or debris, that could pose a risk to your safety. Lastly, be respectful of the environment and other visitors, and follow any rules or guidelines posted at the dock or pier." # fmt: skip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related question: if llava
was not touched in this PR and our daily slow CI is not complaining, why was this test changed? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍
Let's revert the changes in tests/models/llava/test_modeling_llava.py
, then we can merge :)
reverted changes back and rebased main, can be merged now |
* fix vipllava generation * consistent llava code * revert llava tests changes
* fix vipllava generation * consistent llava code * revert llava tests changes
What does this PR do?
When working on this PR, it was found that VipLlava fails when generating with kv cache, because of incorrectly indexing
past_kv_length
. This PR fixes it by indexing the past length as-2
. Other Llava models work correctly.