Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix _merge_input_ids_with_image_features for llava model #28333

Merged
merged 9 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,7 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
self.vocab_size = model_embeds.num_embeddings
return model_embeds

def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids, attention_mask, position_ids
):
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
Expand Down Expand Up @@ -307,6 +305,10 @@ def _merge_input_ids_with_image_features(
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
Expand All @@ -321,6 +323,8 @@ def _merge_input_ids_with_image_features(
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
Expand All @@ -335,7 +339,11 @@ def _merge_input_ids_with_image_features(
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
return final_embedding, final_attention_mask, position_ids

if labels is None:
final_labels = None

return final_embedding, final_attention_mask, final_labels, position_ids

@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -420,8 +428,8 @@ def forward(
)

image_features = self.multi_modal_projector(selected_image_feature)
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, position_ids
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
if labels is None:
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
Expand Down
20 changes: 14 additions & 6 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,7 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
self.vocab_size = model_embeds.num_embeddings
return model_embeds

def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids, attention_mask, position_ids
):
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
Expand Down Expand Up @@ -315,6 +313,10 @@ def _merge_input_ids_with_image_features(
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
Expand All @@ -329,6 +331,8 @@ def _merge_input_ids_with_image_features(
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
Expand All @@ -343,7 +347,11 @@ def _merge_input_ids_with_image_features(
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
return final_embedding, final_attention_mask, position_ids

if labels is None:
final_labels = None

return final_embedding, final_attention_mask, final_labels, position_ids

@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -419,8 +427,8 @@ def forward(
image_features = torch.cat(image_features, dim=-1)

image_features = self.multi_modal_projector(image_features)
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, position_ids
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
if labels is None:
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
Expand Down
40 changes: 39 additions & 1 deletion tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
Expand Down Expand Up @@ -332,3 +332,41 @@ def test_llava_index_error_bug(self):

# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)

@slow
@require_torch_gpu
def test_llava_merge_inputs_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
).to(torch_device)

# Simulate some user inputs
pixel_values = torch.randn(
(2, 3, 336, 336),
dtype=torch.float,
device=torch_device,
)
input_ids = torch.tensor(
[
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
],
dtype=torch.long,
device=torch_device,
)
attention_mask = torch.tensor(
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.long,
device=torch_device,
)

# Make sure that the loss is properly computed
loss = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
).loss
loss.backward()
40 changes: 39 additions & 1 deletion tests/models/vipllava/test_modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
Expand Down Expand Up @@ -214,3 +214,41 @@ def test_small_model_integration_test(self):

EXPECTED_OUTPUT = "USER: <image> \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on"
self.assertEqual(processor.decode(outputs[0], skip_special_tokens=True), EXPECTED_OUTPUT)

@slow
@require_torch_gpu
def test_vipllava_merge_inputs_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
model_id = "llava-hf/vip-llava-7b-hf"
model = VipLlavaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
).to(torch_device)

# Simulate some user inputs
pixel_values = torch.randn(
(2, 3, 336, 336),
dtype=torch.float,
device=torch_device,
)
input_ids = torch.tensor(
[
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
],
dtype=torch.long,
device=torch_device,
)
attention_mask = torch.tensor(
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.long,
device=torch_device,
)

# Make sure that the loss is properly computed
loss = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
).loss
loss.backward()
Loading