1010from torch import nn
1111
1212from ...activations import ACT2FN
13- from ...cache_utils import Cache , DynamicCache , StaticCache
13+ from ...cache_utils import Cache , DynamicCache , SlidingWindowCache , StaticCache
1414from ...generation import GenerationMixin
1515from ...modeling_attn_mask_utils import AttentionMaskConverter
1616from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -616,7 +616,15 @@ def _update_causal_mask(
616616 output_attentions : bool ,
617617 ):
618618 if self .config ._attn_implementation == "flash_attention_2" :
619- if attention_mask is not None and (attention_mask == 0.0 ).any ():
619+ if attention_mask is not None and past_key_values is not None :
620+ is_padding_right = attention_mask [:, - 1 ].sum ().item () != input_tensor .size ()[0 ]
621+ if is_padding_right :
622+ raise ValueError (
623+ "You are attempting to perform batched generation with padding_side='right'"
624+ " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
625+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
626+ )
627+ if attention_mask is not None and 0.0 in attention_mask :
620628 return attention_mask
621629 return None
622630
@@ -625,21 +633,30 @@ def _update_causal_mask(
625633 # to infer the attention mask.
626634 past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
627635 using_static_cache = isinstance (past_key_values , StaticCache )
636+ using_sliding_window_cache = isinstance (past_key_values , SlidingWindowCache )
628637
629638 # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
630- if self .config ._attn_implementation == "sdpa" and not using_static_cache and not output_attentions :
639+ if (
640+ self .config ._attn_implementation == "sdpa"
641+ and not (using_static_cache or using_sliding_window_cache )
642+ and not output_attentions
643+ ):
631644 if AttentionMaskConverter ._ignore_causal_mask_sdpa (
632645 attention_mask ,
633646 inputs_embeds = input_tensor ,
634647 past_key_values_length = past_seen_tokens ,
648+ sliding_window = self .config .sliding_window ,
635649 is_training = self .training ,
636650 ):
637651 return None
638652
639653 dtype , device = input_tensor .dtype , input_tensor .device
654+ min_dtype = torch .finfo (dtype ).min
640655 sequence_length = input_tensor .shape [1 ]
641- if using_static_cache :
656+ # SlidingWindowCache or StaticCache
657+ if using_sliding_window_cache or using_static_cache :
642658 target_length = past_key_values .get_max_cache_shape ()
659+ # DynamicCache or no cache
643660 else :
644661 target_length = (
645662 attention_mask .shape [- 1 ]
@@ -656,6 +673,8 @@ def _update_causal_mask(
656673 device = device ,
657674 cache_position = cache_position ,
658675 batch_size = input_tensor .shape [0 ],
676+ config = self .config ,
677+ past_key_values = past_key_values ,
659678 )
660679
661680 if (
@@ -667,7 +686,6 @@ def _update_causal_mask(
667686 # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
668687 # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
669688 # Details: https://github.com/pytorch/pytorch/issues/110213
670- min_dtype = torch .finfo (dtype ).min
671689 causal_mask = AttentionMaskConverter ._unmask_unattended (causal_mask , min_dtype )
672690
673691 return causal_mask
@@ -681,21 +699,20 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
681699 device : torch .device ,
682700 cache_position : torch .Tensor ,
683701 batch_size : int ,
684- ** kwargs ,
702+ config : Qwen2Config ,
703+ past_key_values : Cache ,
685704 ):
686705 """
687706 Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
688707 `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
689708
690709 Args:
691710 attention_mask (`torch.Tensor`):
692- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
693- `(batch_size, 1, query_length, key_value_length)`.
711+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
694712 sequence_length (`int`):
695713 The sequence length being processed.
696714 target_length (`int`):
697- The target length: when generating with static cache, the mask should be as long as the static cache,
698- to account for the 0 padding, the part of the cache that is not filled yet.
715+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
699716 dtype (`torch.dtype`):
700717 The dtype to use for the 4D attention mask.
701718 device (`torch.device`):
@@ -704,6 +721,10 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
704721 Indices depicting the position of the input sequence tokens in the sequence.
705722 batch_size (`torch.Tensor`):
706723 Batch size.
724+ config (`Qwen2Config`):
725+ The model's configuration class
726+ past_key_values (`Cache`):
727+ The cache class that is being used currently to generate
707728 """
708729 if attention_mask is not None and attention_mask .dim () == 4 :
709730 # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
@@ -713,12 +734,21 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
713734 causal_mask = torch .full (
714735 (sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device
715736 )
716- if sequence_length != 1 :
717- causal_mask = torch .triu (causal_mask , diagonal = 1 )
718- causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
737+ diagonal_attend_mask = torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
738+ if config .sliding_window is not None :
739+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
740+ # the check is needed to verify is current checkpoint was trained with sliding window or not
741+ if not isinstance (past_key_values , SlidingWindowCache ) or sequence_length > target_length :
742+ sliding_attend_mask = torch .arange (target_length , device = device ) <= (
743+ cache_position .reshape (- 1 , 1 ) - config .sliding_window
744+ )
745+ diagonal_attend_mask .bitwise_or_ (sliding_attend_mask )
746+ causal_mask *= diagonal_attend_mask
719747 causal_mask = causal_mask [None , None , :, :].expand (batch_size , 1 , - 1 , - 1 )
720748 if attention_mask is not None :
721749 causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
750+ if attention_mask .shape [- 1 ] > target_length :
751+ attention_mask = attention_mask [:, :target_length ]
722752 mask_length = attention_mask .shape [- 1 ]
723753 padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :].to (
724754 causal_mask .device
@@ -727,7 +757,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
727757 causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
728758 padding_mask , min_dtype
729759 )
730-
731760 return causal_mask
732761
733762
0 commit comments