File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed
src/transformers/models/llama4 Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -730,6 +730,7 @@ def forward(
730730 )
731731 return output if return_dict else output .to_tuple ()
732732
733+ @torch .compiler .disable # the operations in this method are not compilable
733734 def _update_causal_mask (
734735 self ,
735736 attention_mask : torch .Tensor ,
@@ -767,7 +768,7 @@ def _update_causal_mask(
767768 )
768769
769770 if past_key_values is not None and past_key_values .is_compileable :
770- target_length = past_key_values .get_max_cache_shape
771+ target_length = past_key_values .get_max_cache_shape ()
771772 else :
772773 target_length = attention_mask .shape [- 1 ] if attention_mask is not None else sequence_length
773774
@@ -780,7 +781,7 @@ def _update_causal_mask(
780781 attention_mask = make_flex_block_causal_mask (
781782 attention_mask ,
782783 query_length = sequence_length ,
783- key_length = past_key_values . get_max_cache_shape () ,
784+ key_length = target_length ,
784785 offsets = None if sequence_length != 1 else (first_cache_position , 0 ),
785786 )
786787 return attention_mask , chunked_attention_mask
You can’t perform that action at this time.
0 commit comments