Skip to content

Commit

Permalink
Fix vipllava for generation (#29874)
Browse files Browse the repository at this point in the history
* fix vipllava generation

* consistent llava code

* revert llava tests changes
  • Loading branch information
zucchini-nlp authored and Ita Zaporozhets committed May 14, 2024
1 parent eb964d2 commit bb9c64f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
7 changes: 4 additions & 3 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,10 +569,11 @@ def forward(
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_seqlen = first_layer_past_key_value.shape[-1] + 1
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]

extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
Expand All @@ -587,7 +588,7 @@ def forward(
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1

outputs = self.language_model(
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,10 @@ def forward(
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, 0, :, :]
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]

# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0)
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion tests/models/llava_next/test_modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def test_small_model_integration_test(self):
output = model(**inputs)

expected_slice = torch.tensor(
[[-4.7695, -4.5664, -0.2786], [-10.6172, -10.8906, -2.5234], [-6.7344, -7.2422, -0.6758]],
[[-4.7695, -4.5664, -0.2786], [-10.6250, -10.8906, -2.5254], [-6.7383, -7.2461, -0.6787]],
dtype=torch.float32,
device=torch_device,
)
Expand Down

0 comments on commit bb9c64f

Please sign in to comment.