We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c77cefd commit 886421dCopy full SHA for 886421d
src/transformers/modeling_flash_attention_utils.py
@@ -83,7 +83,7 @@ def _lazy_imports(implementation: Optional[str]):
83
if implementation == "flash_attention_2" and is_torch_npu_available():
84
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
85
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
86
- from .integrations.npu_flash_attention import pad_input, unpad_input
+ pad_input, unpad_input = _pad_input, _unpad_input
87
elif implementation == "flash_attention_2" or (implementation is None and is_fa2 and not is_fa3):
88
from flash_attn import flash_attn_func, flash_attn_varlen_func
89
from flash_attn.bert_padding import pad_input, unpad_input
0 commit comments