Skip to content

Commit 886421d

Browse files
committed
pad_input and unpad_input use same implementation as fa2
1 parent c77cefd commit 886421d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/modeling_flash_attention_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _lazy_imports(implementation: Optional[str]):
8383
if implementation == "flash_attention_2" and is_torch_npu_available():
8484
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
8585
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
86-
from .integrations.npu_flash_attention import pad_input, unpad_input
86+
pad_input, unpad_input = _pad_input, _unpad_input
8787
elif implementation == "flash_attention_2" or (implementation is None and is_fa2 and not is_fa3):
8888
from flash_attn import flash_attn_func, flash_attn_varlen_func
8989
from flash_attn.bert_padding import pad_input, unpad_input

0 commit comments

Comments
 (0)