Skip to content

Commit d8f0695

Browse files
winglianArthurZucker
authored andcommitted
more fixes for post-training llama4 (#37329)
* more fixes for post-training llama4 * use target_length instead of guearded past_key_values
1 parent d27c8c3 commit d8f0695

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/transformers/models/llama4/modeling_llama4.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,7 @@ def forward(
730730
)
731731
return output if return_dict else output.to_tuple()
732732

733+
@torch.compiler.disable # the operations in this method are not compilable
733734
def _update_causal_mask(
734735
self,
735736
attention_mask: torch.Tensor,
@@ -767,7 +768,7 @@ def _update_causal_mask(
767768
)
768769

769770
if past_key_values is not None and past_key_values.is_compileable:
770-
target_length = past_key_values.get_max_cache_shape
771+
target_length = past_key_values.get_max_cache_shape()
771772
else:
772773
target_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
773774

@@ -780,7 +781,7 @@ def _update_causal_mask(
780781
attention_mask = make_flex_block_causal_mask(
781782
attention_mask,
782783
query_length=sequence_length,
783-
key_length=past_key_values.get_max_cache_shape(),
784+
key_length=target_length,
784785
offsets=None if sequence_length != 1 else (first_cache_position, 0),
785786
)
786787
return attention_mask, chunked_attention_mask

0 commit comments

Comments
 (0)