-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correct llava mask & fix missing setter for vocab_size
#29389
Changes from all commits
4252a21
1203982
beb7946
6aad22f
b993bad
eff874c
e967bc0
38b2929
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -344,6 +344,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in | |||||||||||
final_attention_mask |= image_to_overwrite | ||||||||||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) | ||||||||||||
|
||||||||||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I quite don't understand this comment, why do we need to mask out here because of using the past_key_values? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit later in the code (specifically here:
past_key_values value to find unattended tokens. This was not correct before because this added step 6. was missing.
Overall, this would be worth a larger refactor that would avoid regenerating full masks at every forward step in the generate. This PR is just a hotfix. |
||||||||||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) | ||||||||||||
indices_to_mask = new_token_positions[batch_indices, pad_indices] | ||||||||||||
|
||||||||||||
final_embedding[batch_indices, indices_to_mask] = 0 | ||||||||||||
Comment on lines
+348
to
+351
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
should this not work as well? given that we create the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure which one will end up being faster There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe |
||||||||||||
|
||||||||||||
Comment on lines
+347
to
+352
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As later There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see why - could you explain a little bit more? The previous code didn't modify |
||||||||||||
if labels is None: | ||||||||||||
final_labels = None | ||||||||||||
|
||||||||||||
|
@@ -449,10 +455,11 @@ def forward( | |||||||||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) | ||||||||||||
|
||||||||||||
# Get the target length | ||||||||||||
target_seqlen = first_layer_past_key_value.shape[-1] + 1 | ||||||||||||
target_length = input_ids.shape[1] | ||||||||||||
past_length = first_layer_past_key_value.shape[-1] | ||||||||||||
|
||||||||||||
extended_attention_mask = torch.ones( | ||||||||||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), | ||||||||||||
(attention_mask.shape[0], past_length), | ||||||||||||
dtype=attention_mask.dtype, | ||||||||||||
device=attention_mask.device, | ||||||||||||
) | ||||||||||||
|
@@ -467,7 +474,7 @@ def forward( | |||||||||||
# Zero-out the places where we don't need to attend | ||||||||||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 | ||||||||||||
|
||||||||||||
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) | ||||||||||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) | ||||||||||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 | ||||||||||||
|
||||||||||||
outputs = self.language_model( | ||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -27,7 +27,14 @@ | |||||||||||||||||||||||||||||||||||||||||||
is_torch_available, | ||||||||||||||||||||||||||||||||||||||||||||
is_vision_available, | ||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device | ||||||||||||||||||||||||||||||||||||||||||||
from transformers.testing_utils import ( | ||||||||||||||||||||||||||||||||||||||||||||
require_bitsandbytes, | ||||||||||||||||||||||||||||||||||||||||||||
require_torch, | ||||||||||||||||||||||||||||||||||||||||||||
require_torch_gpu, | ||||||||||||||||||||||||||||||||||||||||||||
require_vision, | ||||||||||||||||||||||||||||||||||||||||||||
slow, | ||||||||||||||||||||||||||||||||||||||||||||
torch_device, | ||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
from ...test_configuration_common import ConfigTester | ||||||||||||||||||||||||||||||||||||||||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor | ||||||||||||||||||||||||||||||||||||||||||||
|
@@ -470,10 +477,45 @@ def test_small_model_integration_test_llama_batched_regression(self): | |||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
output = model.generate(**inputs, max_new_tokens=20) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip | ||||||||||||||||||||||||||||||||||||||||||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip | ||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test itself was wrong. The input_ids is transformers/tests/models/llava/test_modeling_llava.py Lines 245 to 265 in 00c1d87
https://huggingface.co/llava-hf/llava-1.5-7b-hf/blob/main/config.json#L6 |
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
@slow | ||||||||||||||||||||||||||||||||||||||||||||
@require_torch | ||||||||||||||||||||||||||||||||||||||||||||
@require_vision | ||||||||||||||||||||||||||||||||||||||||||||
def test_batched_generation(self): | ||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test does not pass on main (trash is generated instead) |
||||||||||||||||||||||||||||||||||||||||||||
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf").to(torch_device) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
prompt1 = "<image>\n<image>\nUSER: What's the the difference of two images?\nASSISTANT:" | ||||||||||||||||||||||||||||||||||||||||||||
prompt2 = "<image>\nUSER: Describe the image.\nASSISTANT:" | ||||||||||||||||||||||||||||||||||||||||||||
prompt3 = "<image>\nUSER: Describe the image.\nASSISTANT:" | ||||||||||||||||||||||||||||||||||||||||||||
url1 = "https://images.unsplash.com/photo-1552053831-71594a27632d?q=80&w=3062&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" | ||||||||||||||||||||||||||||||||||||||||||||
url2 = "https://images.unsplash.com/photo-1617258683320-61900b281ced?q=80&w=3087&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" | ||||||||||||||||||||||||||||||||||||||||||||
image1 = Image.open(requests.get(url1, stream=True).raw) | ||||||||||||||||||||||||||||||||||||||||||||
image2 = Image.open(requests.get(url2, stream=True).raw) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
inputs = processor( | ||||||||||||||||||||||||||||||||||||||||||||
text=[prompt1, prompt2, prompt3], | ||||||||||||||||||||||||||||||||||||||||||||
images=[image1, image2, image1, image2], | ||||||||||||||||||||||||||||||||||||||||||||
return_tensors="pt", | ||||||||||||||||||||||||||||||||||||||||||||
padding=True, | ||||||||||||||||||||||||||||||||||||||||||||
).to(torch_device) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
model = model.eval() | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
EXPECTED_OUTPUT = [ | ||||||||||||||||||||||||||||||||||||||||||||
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog holding a flower in one", | ||||||||||||||||||||||||||||||||||||||||||||
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding", | ||||||||||||||||||||||||||||||||||||||||||||
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama", | ||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
generate_ids = model.generate(**inputs, max_new_tokens=20) | ||||||||||||||||||||||||||||||||||||||||||||
outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | ||||||||||||||||||||||||||||||||||||||||||||
self.assertEqual(outputs, EXPECTED_OUTPUT) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
@slow | ||||||||||||||||||||||||||||||||||||||||||||
@require_bitsandbytes | ||||||||||||||||||||||||||||||||||||||||||||
def test_llava_index_error_bug(self): | ||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this good @NielsRogge @amyeroberts ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we remove the
@property
annotator above? That fixed #29789 for me. Isn't property for immutable things (whereas the vocabulary size can change? cc @amyerobertsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Properties aren't for immutable things (this is the value of a setter). The property enables us to emit a deprecation warning if the property is accessed as an attribute.
Removing the property annotation would mean that previous usage of
config.vocab_size
would break, and people would have to useconfig.vocab_size()