From dc8b789894cec8b087a1e522b1150b37d775208f Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 22 Mar 2024 19:57:08 +0800 Subject: [PATCH] Correct llava mask & fix missing setter for `vocab_size` (#29389) * correct llava mask * fix vipllava as wlel * mask out embedding for padding tokens * add test * fix style * add setter * fix test on suggestion --- .../models/llava/configuration_llava.py | 4 ++ .../models/llava/modeling_llava.py | 13 ++++-- .../models/llava_next/modeling_llava_next.py | 7 ++- .../models/vipllava/modeling_vipllava.py | 14 ++++-- tests/models/llava/test_modeling_llava.py | 46 ++++++++++++++++++- 5 files changed, 74 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/llava/configuration_llava.py b/src/transformers/models/llava/configuration_llava.py index d70175c4738a..56b7974db0ad 100644 --- a/src/transformers/models/llava/configuration_llava.py +++ b/src/transformers/models/llava/configuration_llava.py @@ -147,6 +147,10 @@ def vocab_size(self): ) return self._vocab_size + @vocab_size.setter + def vocab_size(self, value): + self._vocab_size = value + def to_dict(self): output = super().to_dict() output.pop("_vocab_size", None) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 13b8375647e2..d3fc58eb3642 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -344,6 +344,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + if labels is None: final_labels = None @@ -449,10 +455,11 @@ def forward( batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) # Get the target length - target_seqlen = first_layer_past_key_value.shape[-1] + 1 + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] extended_attention_mask = torch.ones( - (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), + (attention_mask.shape[0], past_length), dtype=attention_mask.dtype, device=attention_mask.device, ) @@ -467,7 +474,7 @@ def forward( # Zero-out the places where we don't need to attend extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 outputs = self.language_model( diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 845269830c53..54ad4d5a5040 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -356,7 +356,6 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): num_images, num_image_patches, embed_dim = image_features.shape batch_size, sequence_length = input_ids.shape - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) # 1. Create a mask to know where special image tokens are special_image_token_mask = input_ids == self.config.image_token_index @@ -418,6 +417,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + if labels is None: final_labels = None diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 48e29179388c..34582a912a6e 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -347,6 +347,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + if labels is None: final_labels = None @@ -442,11 +448,11 @@ def forward( # Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0) - # Get the target length - target_seqlen = first_layer_past_key_value.shape[-2] + 1 + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] extended_attention_mask = torch.ones( - (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), + (attention_mask.shape[0], past_length), dtype=attention_mask.dtype, device=attention_mask.device, ) @@ -461,7 +467,7 @@ def forward( # Zero-out the places where we don't need to attend extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 outputs = self.language_model( diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 27c99dda1692..856044520a94 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -27,7 +27,14 @@ is_torch_available, is_vision_available, ) -from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import ( + require_bitsandbytes, + require_torch, + require_torch_gpu, + require_vision, + slow, + torch_device, +) from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -470,10 +477,45 @@ def test_small_model_integration_test_llama_batched_regression(self): output = model.generate(**inputs, max_new_tokens=20) - EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip + EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + @slow + @require_torch + @require_vision + def test_batched_generation(self): + model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf").to(torch_device) + + processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + + prompt1 = "\n\nUSER: What's the the difference of two images?\nASSISTANT:" + prompt2 = "\nUSER: Describe the image.\nASSISTANT:" + prompt3 = "\nUSER: Describe the image.\nASSISTANT:" + url1 = "https://images.unsplash.com/photo-1552053831-71594a27632d?q=80&w=3062&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + url2 = "https://images.unsplash.com/photo-1617258683320-61900b281ced?q=80&w=3087&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + image1 = Image.open(requests.get(url1, stream=True).raw) + image2 = Image.open(requests.get(url2, stream=True).raw) + + inputs = processor( + text=[prompt1, prompt2, prompt3], + images=[image1, image2, image1, image2], + return_tensors="pt", + padding=True, + ).to(torch_device) + + model = model.eval() + + EXPECTED_OUTPUT = [ + "\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog holding a flower in one", + "\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding", + "\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama", + ] + + generate_ids = model.generate(**inputs, max_new_tokens=20) + outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertEqual(outputs, EXPECTED_OUTPUT) + @slow @require_bitsandbytes def test_llava_index_error_bug(self):