Skip to content

Commit eee0f00

Browse files
committed
[bugfix] fix flash-attention2 unavailable error for Ascend NPU
1 parent e446372 commit eee0f00

File tree

2 files changed

+129
-10
lines changed

2 files changed

+129
-10
lines changed

src/transformers/integrations/npu_flash_attention.py

Lines changed: 121 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,20 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13-
import math
1413
import os
1514

1615
import torch
16+
import torch.nn.functional as F
1717

1818
from ..utils.import_utils import is_torch_npu_available
1919

2020

2121
if is_torch_npu_available():
22-
from torch_npu import npu_fusion_attention, npu_rotary_mul
22+
import math
23+
24+
import torch_npu
25+
from einops import rearrange, repeat
26+
from torch_npu import npu_rotary_mul
2327

2428

2529
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
@@ -48,6 +52,117 @@ def is_npu_fa2_top_left_aligned_causal_mask():
4852
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
4953

5054

55+
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
56+
class IndexFirstAxis(torch.autograd.Function):
57+
@staticmethod
58+
def forward(ctx, input, indices):
59+
ctx.save_for_backward(indices)
60+
assert input.ndim >= 2
61+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
62+
second_dim = other_shape.numel()
63+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
64+
# return input[indices]
65+
return torch.gather(
66+
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
67+
).reshape(-1, *other_shape)
68+
69+
@staticmethod
70+
def backward(ctx, grad_output):
71+
(indices,) = ctx.saved_tensors
72+
assert grad_output.ndim >= 2
73+
other_shape = grad_output.shape[1:]
74+
grad_output = rearrange(grad_output, "b ... -> b (...)")
75+
grad_input = torch.zeros(
76+
[ctx.first_axis_dim, grad_output.shape[1]],
77+
device=grad_output.device,
78+
dtype=grad_output.dtype,
79+
)
80+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
81+
# grad_input[indices] = grad_output
82+
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
83+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
84+
85+
86+
index_first_axis = IndexFirstAxis.apply
87+
88+
89+
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
90+
class IndexPutFirstAxis(torch.autograd.Function):
91+
@staticmethod
92+
def forward(ctx, values, indices, first_axis_dim):
93+
ctx.save_for_backward(indices)
94+
assert indices.ndim == 1
95+
assert values.ndim >= 2
96+
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
97+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
98+
output[indices] = values
99+
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
100+
return output
101+
102+
@staticmethod
103+
def backward(ctx, grad_output):
104+
(indices,) = ctx.saved_tensors
105+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
106+
grad_values = grad_output[indices]
107+
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
108+
return grad_values, None, None
109+
110+
111+
index_put_first_axis = IndexPutFirstAxis.apply
112+
113+
114+
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
115+
def pad_input(hidden_states, indices, batch, seqlen):
116+
"""
117+
Arguments:
118+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
119+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
120+
batch: int, batch size for the padded sequence.
121+
seqlen: int, maximum sequence length for the padded sequence.
122+
Return:
123+
hidden_states: (batch, seqlen, ...)
124+
"""
125+
# dim = hidden_states.shape[-1]
126+
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
127+
# output[indices] = hidden_states
128+
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
129+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
130+
131+
132+
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
133+
def unpad_input(hidden_states, attention_mask, unused_mask=None):
134+
"""
135+
Arguments:
136+
hidden_states: (batch, seqlen, ...)
137+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
138+
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
139+
Return:
140+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
141+
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
142+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
143+
max_seqlen_in_batch: int
144+
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
145+
"""
146+
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
147+
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
148+
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
149+
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
150+
max_seqlen_in_batch = seqlens_in_batch.max().item()
151+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
152+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
153+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
154+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
155+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
156+
# so we write custom forward and backward to make it a bit faster.
157+
return (
158+
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
159+
indices,
160+
cu_seqlens,
161+
max_seqlen_in_batch,
162+
used_seqlens_in_batch,
163+
)
164+
165+
51166
def npu_flash_attn_func(
52167
q,
53168
k,
@@ -64,11 +179,11 @@ def npu_flash_attn_func(
64179

65180
if not causal:
66181
head_num = q.shape[2]
67-
output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
182+
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
68183
else:
69184
attn_mask_npu = get_attn_mask_npu(q.device)
70185
head_num = q.shape[2]
71-
output = npu_fusion_attention(
186+
output = torch_npu.npu_fusion_attention(
72187
q,
73188
k,
74189
v,
@@ -103,7 +218,7 @@ def npu_flash_attn_varlen_func(
103218

104219
if not causal:
105220
head_num = q.shape[1]
106-
output = npu_fusion_attention(
221+
output = torch_npu.npu_fusion_attention(
107222
q,
108223
k,
109224
v,
@@ -119,7 +234,7 @@ def npu_flash_attn_varlen_func(
119234
else:
120235
attn_mask_npu = get_attn_mask_npu(q.device)
121236
head_num = q.shape[1]
122-
output = npu_fusion_attention(
237+
output = torch_npu.npu_fusion_attention(
123238
q,
124239
k,
125240
v,

src/transformers/modeling_flash_attention_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,20 @@ def _lazy_imports(implementation: Optional[str]):
7777
"""
7878
is_fa2 = is_flash_attn_2_available()
7979
is_fa3 = is_flash_attn_3_available()
80-
if implementation == "flash_attention_2" or (implementation is None and is_fa2 and not is_fa3):
80+
81+
# Package `flash-attn` is unavailable on Ascend NPU, which will cause ImportError
82+
# Flash-Attention2 related apis for Ascend NPU must be imported from `.integrations.npu_flash_attention` module
83+
if implementation == "flash_attention_2" and is_torch_npu_available():
84+
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
85+
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
86+
from .integrations.npu_flash_attention import pad_input, unpad_input
87+
elif implementation == "flash_attention_2" or (implementation is None and is_fa2 and not is_fa3):
8188
from flash_attn import flash_attn_func, flash_attn_varlen_func
8289
from flash_attn.bert_padding import pad_input, unpad_input
8390
else:
8491
pad_input, unpad_input = _pad_input, _unpad_input
8592
if implementation == "flash_attention_3" or (implementation is None and is_fa3):
8693
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
87-
elif is_torch_npu_available():
88-
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
89-
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
9094
# Kernels fallback
9195
else:
9296
flash_attn_func = getattr(implementation, "flash_attn_func", None)

0 commit comments

Comments
 (0)