Skip to content

Commit 96f01a3

Browse files
authored
Revert qwen2 breaking changes related to attention refactor (#36162)
* dito * add a test * upsate * test needs fa2 * update test and configuration * test requires fa2 * style
1 parent cb586a3 commit 96f01a3

File tree

4 files changed

+468
-18
lines changed

4 files changed

+468
-18
lines changed

src/transformers/models/qwen2/configuration_qwen2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __init__(
174174
self.num_hidden_layers = num_hidden_layers
175175
self.num_attention_heads = num_attention_heads
176176
self.use_sliding_window = use_sliding_window
177-
self.sliding_window = sliding_window if use_sliding_window else None
177+
self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code
178178
self.max_window_layers = max_window_layers
179179

180180
# for backward compatibility

src/transformers/models/qwen2/modeling_qwen2.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch import nn
1111

1212
from ...activations import ACT2FN
13-
from ...cache_utils import Cache, DynamicCache, StaticCache
13+
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
1414
from ...generation import GenerationMixin
1515
from ...modeling_attn_mask_utils import AttentionMaskConverter
1616
from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -616,7 +616,15 @@ def _update_causal_mask(
616616
output_attentions: bool,
617617
):
618618
if self.config._attn_implementation == "flash_attention_2":
619-
if attention_mask is not None and (attention_mask == 0.0).any():
619+
if attention_mask is not None and past_key_values is not None:
620+
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
621+
if is_padding_right:
622+
raise ValueError(
623+
"You are attempting to perform batched generation with padding_side='right'"
624+
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
625+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
626+
)
627+
if attention_mask is not None and 0.0 in attention_mask:
620628
return attention_mask
621629
return None
622630

@@ -625,21 +633,30 @@ def _update_causal_mask(
625633
# to infer the attention mask.
626634
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
627635
using_static_cache = isinstance(past_key_values, StaticCache)
636+
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
628637

629638
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
630-
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
639+
if (
640+
self.config._attn_implementation == "sdpa"
641+
and not (using_static_cache or using_sliding_window_cache)
642+
and not output_attentions
643+
):
631644
if AttentionMaskConverter._ignore_causal_mask_sdpa(
632645
attention_mask,
633646
inputs_embeds=input_tensor,
634647
past_key_values_length=past_seen_tokens,
648+
sliding_window=self.config.sliding_window,
635649
is_training=self.training,
636650
):
637651
return None
638652

639653
dtype, device = input_tensor.dtype, input_tensor.device
654+
min_dtype = torch.finfo(dtype).min
640655
sequence_length = input_tensor.shape[1]
641-
if using_static_cache:
656+
# SlidingWindowCache or StaticCache
657+
if using_sliding_window_cache or using_static_cache:
642658
target_length = past_key_values.get_max_cache_shape()
659+
# DynamicCache or no cache
643660
else:
644661
target_length = (
645662
attention_mask.shape[-1]
@@ -656,6 +673,8 @@ def _update_causal_mask(
656673
device=device,
657674
cache_position=cache_position,
658675
batch_size=input_tensor.shape[0],
676+
config=self.config,
677+
past_key_values=past_key_values,
659678
)
660679

661680
if (
@@ -667,7 +686,6 @@ def _update_causal_mask(
667686
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
668687
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
669688
# Details: https://github.com/pytorch/pytorch/issues/110213
670-
min_dtype = torch.finfo(dtype).min
671689
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
672690

673691
return causal_mask
@@ -681,21 +699,20 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
681699
device: torch.device,
682700
cache_position: torch.Tensor,
683701
batch_size: int,
684-
**kwargs,
702+
config: Qwen2Config,
703+
past_key_values: Cache,
685704
):
686705
"""
687706
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
688707
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
689708
690709
Args:
691710
attention_mask (`torch.Tensor`):
692-
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
693-
`(batch_size, 1, query_length, key_value_length)`.
711+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
694712
sequence_length (`int`):
695713
The sequence length being processed.
696714
target_length (`int`):
697-
The target length: when generating with static cache, the mask should be as long as the static cache,
698-
to account for the 0 padding, the part of the cache that is not filled yet.
715+
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
699716
dtype (`torch.dtype`):
700717
The dtype to use for the 4D attention mask.
701718
device (`torch.device`):
@@ -704,6 +721,10 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
704721
Indices depicting the position of the input sequence tokens in the sequence.
705722
batch_size (`torch.Tensor`):
706723
Batch size.
724+
config (`Qwen2Config`):
725+
The model's configuration class
726+
past_key_values (`Cache`):
727+
The cache class that is being used currently to generate
707728
"""
708729
if attention_mask is not None and attention_mask.dim() == 4:
709730
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
@@ -713,12 +734,21 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
713734
causal_mask = torch.full(
714735
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
715736
)
716-
if sequence_length != 1:
717-
causal_mask = torch.triu(causal_mask, diagonal=1)
718-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
737+
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
738+
if config.sliding_window is not None:
739+
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
740+
# the check is needed to verify is current checkpoint was trained with sliding window or not
741+
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
742+
sliding_attend_mask = torch.arange(target_length, device=device) <= (
743+
cache_position.reshape(-1, 1) - config.sliding_window
744+
)
745+
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
746+
causal_mask *= diagonal_attend_mask
719747
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
720748
if attention_mask is not None:
721749
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
750+
if attention_mask.shape[-1] > target_length:
751+
attention_mask = attention_mask[:, :target_length]
722752
mask_length = attention_mask.shape[-1]
723753
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
724754
causal_mask.device
@@ -727,7 +757,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
727757
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
728758
padding_mask, min_dtype
729759
)
730-
731760
return causal_mask
732761

733762

src/transformers/models/qwen2/modular_qwen2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
LlamaForSequenceClassification,
1818
LlamaForTokenClassification,
1919
LlamaMLP,
20-
LlamaModel,
2120
apply_rotary_pos_emb,
2221
eager_attention_forward,
2322
)
23+
from ..mistral.modeling_mistral import MistralModel
2424
from .configuration_qwen2 import Qwen2Config
2525

2626

@@ -114,7 +114,7 @@ def __init__(self, config: Qwen2Config, layer_idx: int):
114114
)
115115

116116

117-
class Qwen2Model(LlamaModel):
117+
class Qwen2Model(MistralModel):
118118
pass
119119

120120

0 commit comments

Comments
 (0)