-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correct llava mask & fix missing setter for vocab_size
#29389
Changes from 4 commits
4252a21
1203982
beb7946
6aad22f
b993bad
eff874c
e967bc0
38b2929
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -340,6 +340,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in | |||||||||||
final_attention_mask |= image_to_overwrite | ||||||||||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) | ||||||||||||
|
||||||||||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. | ||||||||||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) | ||||||||||||
indices_to_mask = new_token_positions[batch_indices, pad_indices] | ||||||||||||
|
||||||||||||
final_embedding[batch_indices, indices_to_mask] = 0 | ||||||||||||
Comment on lines
+348
to
+351
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
should this not work as well? given that we create the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure which one will end up being faster There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe |
||||||||||||
|
||||||||||||
Comment on lines
+347
to
+352
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As later There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see why - could you explain a little bit more? The previous code didn't modify |
||||||||||||
if labels is None: | ||||||||||||
final_labels = None | ||||||||||||
|
||||||||||||
|
@@ -444,11 +450,11 @@ def forward( | |||||||||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 | ||||||||||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) | ||||||||||||
|
||||||||||||
# Get the target length | ||||||||||||
target_seqlen = first_layer_past_key_value.shape[-1] + 1 | ||||||||||||
target_length = input_ids.shape[1] | ||||||||||||
past_length = first_layer_past_key_value.shape[-1] | ||||||||||||
|
||||||||||||
extended_attention_mask = torch.ones( | ||||||||||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), | ||||||||||||
(attention_mask.shape[0], past_length), | ||||||||||||
dtype=attention_mask.dtype, | ||||||||||||
device=attention_mask.device, | ||||||||||||
) | ||||||||||||
|
@@ -463,7 +469,7 @@ def forward( | |||||||||||
# Zero-out the places where we don't need to attend | ||||||||||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 | ||||||||||||
|
||||||||||||
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) | ||||||||||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) | ||||||||||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 | ||||||||||||
|
||||||||||||
outputs = self.language_model( | ||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -307,10 +307,50 @@ def test_small_model_integration_test_llama_batched_regression(self): | |||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
output = model.generate(**inputs, max_new_tokens=20) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip | ||||||||||||||||||||||||||||||||||||||||||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip | ||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test itself was wrong. The input_ids is transformers/tests/models/llava/test_modeling_llava.py Lines 245 to 265 in 00c1d87
https://huggingface.co/llava-hf/llava-1.5-7b-hf/blob/main/config.json#L6 |
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
@slow | ||||||||||||||||||||||||||||||||||||||||||||
def test_batched_generation(self): | ||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test does not pass on main (trash is generated instead) |
||||||||||||||||||||||||||||||||||||||||||||
import requests | ||||||||||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||||||||||
from PIL import Image | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
from transformers import AutoProcessor, LlavaForConditionalGeneration | ||||||||||||||||||||||||||||||||||||||||||||
fxmarty marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
with torch.device(torch_device): | ||||||||||||||||||||||||||||||||||||||||||||
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") | ||||||||||||||||||||||||||||||||||||||||||||
fxmarty marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
prompt1 = "<image>\n<image>\nUSER: What's the the difference of two images?\nASSISTANT:" | ||||||||||||||||||||||||||||||||||||||||||||
prompt2 = "<image>\nUSER: Describe the image.\nASSISTANT:" | ||||||||||||||||||||||||||||||||||||||||||||
prompt3 = "<image>\nUSER: Describe the image.\nASSISTANT:" | ||||||||||||||||||||||||||||||||||||||||||||
url1 = "https://images.unsplash.com/photo-1552053831-71594a27632d?q=80&w=3062&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" | ||||||||||||||||||||||||||||||||||||||||||||
url2 = "https://images.unsplash.com/photo-1617258683320-61900b281ced?q=80&w=3087&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" | ||||||||||||||||||||||||||||||||||||||||||||
image1 = Image.open(requests.get(url1, stream=True).raw) | ||||||||||||||||||||||||||||||||||||||||||||
image2 = Image.open(requests.get(url2, stream=True).raw) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
inputs = processor( | ||||||||||||||||||||||||||||||||||||||||||||
text=[prompt1, prompt2, prompt3], | ||||||||||||||||||||||||||||||||||||||||||||
images=[image1, image2, image1, image2], | ||||||||||||||||||||||||||||||||||||||||||||
return_tensors="pt", | ||||||||||||||||||||||||||||||||||||||||||||
padding=True, | ||||||||||||||||||||||||||||||||||||||||||||
).to(torch_device) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
model = model.eval() | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
EXPECTED_OUTPUT = [ | ||||||||||||||||||||||||||||||||||||||||||||
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog holding a flower in one", | ||||||||||||||||||||||||||||||||||||||||||||
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding", | ||||||||||||||||||||||||||||||||||||||||||||
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama", | ||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
generate_ids = model.generate(**inputs, max_new_tokens=20) | ||||||||||||||||||||||||||||||||||||||||||||
outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | ||||||||||||||||||||||||||||||||||||||||||||
self.assertEqual(outputs, EXPECTED_OUTPUT) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
@slow | ||||||||||||||||||||||||||||||||||||||||||||
@require_bitsandbytes | ||||||||||||||||||||||||||||||||||||||||||||
def test_llava_index_error_bug(self): | ||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I quite don't understand this comment, why do we need to mask out here because of using the past_key_values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit later in the code (specifically here:
transformers/src/transformers/models/llava/modeling_llava.py
Line 449 in aa17cf9
past_key_values
value to find unattended tokens. This was not correct before because this added step6.
was missing.Overall, this would be worth a larger refactor that would avoid regenerating full masks at every forward step in the generate. This PR is just a hotfix.