Skip to content

Commit 9977cf1

Browse files
authored
[Flash Attention] Fix flash attention integration (#40002)
* fix flash attention * i got a stroke reading that comment * change dropout kwarg back to before * rename _fa3... as it's used for multiple variants and should work as fallback instead * simplify imports and support kwargs for fa * style * fix comments order * small fix * skip kernels test (causes cuda illegal memories w/o cleanup), fix fa test in general esp for models like bart * style * allow fullgraph by preloading on init * make globals "private" * ci pls be happy * change skip conditions based on backend flag (indicating missing mask interface) * move globals support to a function to prepare kwargs * style * generalize supported kwargs * small change to doc * fix * add comments * style * revert prep during generate * style * revert weird style changes * add fa kwarg prep during generate with fixes back * how did this even happen * how * add comment
1 parent b6ba595 commit 9977cf1

File tree

5 files changed

+526
-424
lines changed

5 files changed

+526
-424
lines changed

src/transformers/generation/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,10 @@ def prepare_inputs_for_generation(
678678
if encoder_attention_mask is not None:
679679
model_inputs["attention_mask"] = encoder_attention_mask
680680

681+
# 7. Prepare kwargs for flash attention to avoid recomputations
681682
if "flash" in self.config._attn_implementation and self._supports_attention_backend:
682-
cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k = prepare_fa_kwargs_from_position_ids(
683-
position_ids, is_packed_sequence=False
683+
(cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(
684+
model_inputs["position_ids"], is_packed_sequence=False
684685
)
685686
model_inputs.update(
686687
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
@@ -689,12 +690,12 @@ def prepare_inputs_for_generation(
689690
max_length_k=max_length_k,
690691
)
691692

692-
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
693+
# 8. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
693694
for key, value in kwargs.items():
694695
if key not in model_inputs:
695696
model_inputs[key] = value
696697

697-
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
698+
# 9. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
698699
model_inputs.pop("labels", None)
699700
return model_inputs
700701

src/transformers/integrations/npu_flash_attention.py

Lines changed: 6 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,16 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13+
import math
1314
import os
1415

1516
import torch
16-
import torch.nn.functional as F
1717

1818
from ..utils.import_utils import is_torch_npu_available
1919

2020

2121
if is_torch_npu_available():
22-
import math
23-
24-
import torch_npu
25-
from einops import rearrange, repeat
26-
from torch_npu import npu_rotary_mul
22+
from torch_npu import npu_fusion_attention, npu_rotary_mul
2723

2824

2925
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
@@ -52,117 +48,6 @@ def is_npu_fa2_top_left_aligned_causal_mask():
5248
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
5349

5450

55-
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
56-
class IndexFirstAxis(torch.autograd.Function):
57-
@staticmethod
58-
def forward(ctx, input, indices):
59-
ctx.save_for_backward(indices)
60-
assert input.ndim >= 2
61-
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
62-
second_dim = other_shape.numel()
63-
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
64-
# return input[indices]
65-
return torch.gather(
66-
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
67-
).reshape(-1, *other_shape)
68-
69-
@staticmethod
70-
def backward(ctx, grad_output):
71-
(indices,) = ctx.saved_tensors
72-
assert grad_output.ndim >= 2
73-
other_shape = grad_output.shape[1:]
74-
grad_output = rearrange(grad_output, "b ... -> b (...)")
75-
grad_input = torch.zeros(
76-
[ctx.first_axis_dim, grad_output.shape[1]],
77-
device=grad_output.device,
78-
dtype=grad_output.dtype,
79-
)
80-
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
81-
# grad_input[indices] = grad_output
82-
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
83-
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
84-
85-
86-
index_first_axis = IndexFirstAxis.apply
87-
88-
89-
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
90-
class IndexPutFirstAxis(torch.autograd.Function):
91-
@staticmethod
92-
def forward(ctx, values, indices, first_axis_dim):
93-
ctx.save_for_backward(indices)
94-
assert indices.ndim == 1
95-
assert values.ndim >= 2
96-
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
97-
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
98-
output[indices] = values
99-
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
100-
return output
101-
102-
@staticmethod
103-
def backward(ctx, grad_output):
104-
(indices,) = ctx.saved_tensors
105-
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
106-
grad_values = grad_output[indices]
107-
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
108-
return grad_values, None, None
109-
110-
111-
index_put_first_axis = IndexPutFirstAxis.apply
112-
113-
114-
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
115-
def pad_input(hidden_states, indices, batch, seqlen):
116-
"""
117-
Arguments:
118-
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
119-
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
120-
batch: int, batch size for the padded sequence.
121-
seqlen: int, maximum sequence length for the padded sequence.
122-
Return:
123-
hidden_states: (batch, seqlen, ...)
124-
"""
125-
# dim = hidden_states.shape[-1]
126-
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
127-
# output[indices] = hidden_states
128-
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
129-
return rearrange(output, "(b s) ... -> b s ...", b=batch)
130-
131-
132-
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
133-
def unpad_input(hidden_states, attention_mask, unused_mask=None):
134-
"""
135-
Arguments:
136-
hidden_states: (batch, seqlen, ...)
137-
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
138-
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
139-
Return:
140-
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
141-
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
142-
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
143-
max_seqlen_in_batch: int
144-
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
145-
"""
146-
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
147-
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
148-
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
149-
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
150-
max_seqlen_in_batch = seqlens_in_batch.max().item()
151-
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
152-
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
153-
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
154-
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
155-
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
156-
# so we write custom forward and backward to make it a bit faster.
157-
return (
158-
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
159-
indices,
160-
cu_seqlens,
161-
max_seqlen_in_batch,
162-
used_seqlens_in_batch,
163-
)
164-
165-
16651
def npu_flash_attn_func(
16752
q,
16853
k,
@@ -179,11 +64,11 @@ def npu_flash_attn_func(
17964

18065
if not causal:
18166
head_num = q.shape[2]
182-
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
67+
output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
18368
else:
18469
attn_mask_npu = get_attn_mask_npu(q.device)
18570
head_num = q.shape[2]
186-
output = torch_npu.npu_fusion_attention(
71+
output = npu_fusion_attention(
18772
q,
18873
k,
18974
v,
@@ -218,7 +103,7 @@ def npu_flash_attn_varlen_func(
218103

219104
if not causal:
220105
head_num = q.shape[1]
221-
output = torch_npu.npu_fusion_attention(
106+
output = npu_fusion_attention(
222107
q,
223108
k,
224109
v,
@@ -234,7 +119,7 @@ def npu_flash_attn_varlen_func(
234119
else:
235120
attn_mask_npu = get_attn_mask_npu(q.device)
236121
head_num = q.shape[1]
237-
output = torch_npu.npu_fusion_attention(
122+
output = npu_fusion_attention(
238123
q,
239124
k,
240125
v,
@@ -267,8 +152,3 @@ def npu_apply_rotary_emb(x, cos, sin, **kwargs):
267152
sin = sin.unsqueeze(0).unsqueeze(2)
268153

269154
return npu_rotary_mul(x, cos, sin)
270-
271-
272-
def get_npu_flash_attn_funcs():
273-
# return flash attention related functions used for Ascend NPU in order
274-
return npu_flash_attn_func, npu_flash_attn_varlen_func, pad_input, unpad_input, False

0 commit comments

Comments
 (0)