-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correct llava mask & fix missing setter for vocab_size
#29389
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Could you add a test / update test values? Or has this no impact on the current tests?
@fxmarty We do that by default (in this function) 🤔 @zucchini-nlp since you're interested in multimodal models: after this PR gets merged, can you inspect why llava (and related models) need this custom attention mask handling at generation time? |
@gante , I've been affected by this also while trying to make speculative decoding work with VLMs. I'll briefly outline what I found. The custom mask is needed in Llava because it concatenated image embeds and text, which is further used as Comparing to other models, I have counted only 3 soft prompt VLMs in We can either fix Llava to call It's up to you to decide if we need any changes, I have no idea how this can affect library-wide 😄 |
@zucchini-nlp I'd say to go with "call text_model.generate() for consistency", if that is feasible and results in a clean interface. |
@@ -307,10 +307,50 @@ def test_small_model_integration_test_llama_batched_regression(self): | |||
|
|||
output = model.generate(**inputs, max_new_tokens=20) | |||
|
|||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip | |||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test itself was wrong. The input_ids is tensor([[32001, 32001, 32001, 1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 526, 278, 2712, 306, 881, 367, 274, 1300, 2738, 1048, 746, 306, 6493, 445, 2058, 29973, 1724, 881, 306, 6963, 411, 592, 29973, 13, 22933, 9047, 13566, 29901], [ 1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 338, 445, 29973, 13, 22933, 9047, 13566, 29901, 7803, 274, 1446, 19214, 373, 263, 6592, 29991, 13, 11889, 29901, 29871, 32000, 29871, 13, 2855, 445, 29973, 13, 22933, 9047, 13566, 29901]], device='cuda:0')
, with 32000
being the image tokens. The first sequences visibly see only the first image, so the output should not be different compared to
transformers/tests/models/llava/test_modeling_llava.py
Lines 245 to 265 in 00c1d87
def test_small_model_integration_test_llama_batched(self): | |
# Let' s make sure we test the preprocessing to replace what is used | |
model_id = "llava-hf/llava-1.5-7b-hf" | |
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True) | |
processor = AutoProcessor.from_pretrained(model_id) | |
prompts = [ | |
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", | |
"USER: <image>\nWhat is this?\nASSISTANT:", | |
] | |
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) | |
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) | |
inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True) | |
output = model.generate(**inputs, max_new_tokens=20) | |
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip | |
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) |
https://huggingface.co/llava-hf/llava-1.5-7b-hf/blob/main/config.json#L6
|
||
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) | ||
|
||
@slow | ||
def test_batched_generation(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test does not pass on main (trash is generated instead)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. | ||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) | ||
indices_to_mask = new_token_positions[batch_indices, pad_indices] | ||
|
||
final_embedding[batch_indices, indices_to_mask] = 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As later batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
is used, this is necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why - could you explain a little bit more? The previous code didn't modify first_layer_past_key_value
Llava slow tests all pass (apart from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good to me 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the fix and the deep investigation ! LGTM since the slow tests pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, might be a simpler way to fix this
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) | ||
indices_to_mask = new_token_positions[batch_indices, pad_indices] | ||
|
||
final_embedding[batch_indices, indices_to_mask] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) | |
indices_to_mask = new_token_positions[batch_indices, pad_indices] | |
final_embedding[batch_indices, indices_to_mask] = 0 | |
new_token_positions = new_token_positions * (input_ids != self.pad_token_id) |
should this not work as well? given that we create the new_token_positions
with a cumsum we could also do it before (right after the cumsum)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure which one will end up being faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe
vocab_size
I added a setter to fix the regression in #29586 (comment) ( I wonder if llava is working on the release version? maybe not |
@vocab_size.setter | ||
def vocab_size(self, value): | ||
self._vocab_size = value | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this good @NielsRogge @amyeroberts ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we remove the @property
annotator above? That fixed #29789 for me. Isn't property for immutable things (whereas the vocabulary size can change? cc @amyeroberts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Properties aren't for immutable things (this is the value of a setter). The property enables us to emit a deprecation warning if the property is accessed as an attribute.
Removing the property annotation would mean that previous usage of config.vocab_size
would break, and people would have to use config.vocab_size()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for digging into this and fixing!
Overall the changes look OK, I'm just a bit unclear on some of the reasoning for the logic changes
@@ -344,6 +344,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in | |||
final_attention_mask |= image_to_overwrite | |||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) | |||
|
|||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I quite don't understand this comment, why do we need to mask out here because of using the past_key_values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit later in the code (specifically here:
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) |
past_key_values
value to find unattended tokens. This was not correct before because this added step 6.
was missing.
Overall, this would be worth a larger refactor that would avoid regenerating full masks at every forward step in the generate. This PR is just a hotfix.
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. | ||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) | ||
indices_to_mask = new_token_positions[batch_indices, pad_indices] | ||
|
||
final_embedding[batch_indices, indices_to_mask] = 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why - could you explain a little bit more? The previous code didn't modify first_layer_past_key_value
@amyeroberts We need to set
to do its job! Up to now it did not do what was expected. What I meant in
is that this (i.e. the attention_mask reconstruction in the decode) could be avoided altogether if the mask was rather handled in generate, which is currently not the case in the implementation edit: slow tests for llava pass |
* correct llava mask * fix vipllava as wlel * mask out embedding for padding tokens * add test * fix style * add setter * fix test on suggestion
* correct llava mask * fix vipllava as wlel * mask out embedding for padding tokens * add test * fix style * add setter * fix test on suggestion
This PR fixes llava mask in the generation case.
torch.cat((attention_mask, extended_mask), dim=-1)
was very wrong and it is a miracle we were getting meaningful generation for batch generation. The issue is that this disregards the attention_mask used during the first forward pass, that is custom. Wrong past key values are used.It is also not clear to me why this custom masking is handled in the forward and not in GenerationMixin cc @gante (we are doing unnecessary operations at each forward to retrieve the correct attention_mask, it could just be an output of the model & updated in GenerationMixin)
Fixes #28184