Skip to content

Commit

Permalink
Correct llava mask & fix missing setter for vocab_size (#29389)
Browse files Browse the repository at this point in the history
* correct llava mask

* fix vipllava as wlel

* mask out embedding for padding tokens

* add test

* fix style

* add setter

* fix test on suggestion
  • Loading branch information
fxmarty authored Mar 22, 2024
1 parent aa17cf9 commit 13b2370
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 10 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/llava/configuration_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def vocab_size(self):
)
return self._vocab_size

@vocab_size.setter
def vocab_size(self, value):
self._vocab_size = value

def to_dict(self):
output = super().to_dict()
output.pop("_vocab_size", None)
Expand Down
13 changes: 10 additions & 3 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0

if labels is None:
final_labels = None

Expand Down Expand Up @@ -449,10 +455,11 @@ def forward(
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_seqlen = first_layer_past_key_value.shape[-1] + 1
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]

extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
Expand All @@ -467,7 +474,7 @@ def forward(
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1

outputs = self.language_model(
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape

left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
Expand Down Expand Up @@ -418,6 +417,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0

if labels is None:
final_labels = None

Expand Down
14 changes: 10 additions & 4 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0

if labels is None:
final_labels = None

Expand Down Expand Up @@ -442,11 +448,11 @@ def forward(
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0)

# Get the target length
target_seqlen = first_layer_past_key_value.shape[-2] + 1
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]

extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
Expand All @@ -461,7 +467,7 @@ def forward(
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1

outputs = self.language_model(
Expand Down
46 changes: 44 additions & 2 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
Expand Down Expand Up @@ -470,10 +477,45 @@ def test_small_model_integration_test_llama_batched_regression(self):

output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip

self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
@require_torch
@require_vision
def test_batched_generation(self):
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf").to(torch_device)

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

prompt1 = "<image>\n<image>\nUSER: What's the the difference of two images?\nASSISTANT:"
prompt2 = "<image>\nUSER: Describe the image.\nASSISTANT:"
prompt3 = "<image>\nUSER: Describe the image.\nASSISTANT:"
url1 = "https://images.unsplash.com/photo-1552053831-71594a27632d?q=80&w=3062&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
url2 = "https://images.unsplash.com/photo-1617258683320-61900b281ced?q=80&w=3087&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
image1 = Image.open(requests.get(url1, stream=True).raw)
image2 = Image.open(requests.get(url2, stream=True).raw)

inputs = processor(
text=[prompt1, prompt2, prompt3],
images=[image1, image2, image1, image2],
return_tensors="pt",
padding=True,
).to(torch_device)

model = model.eval()

EXPECTED_OUTPUT = [
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog holding a flower in one",
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding",
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama",
]

generate_ids = model.generate(**inputs, max_new_tokens=20)
outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertEqual(outputs, EXPECTED_OUTPUT)

@slow
@require_bitsandbytes
def test_llava_index_error_bug(self):
Expand Down

0 comments on commit 13b2370

Please sign in to comment.