Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0ccc67d
fix flash attention
vasqu Aug 7, 2025
0f97be0
i got a stroke reading that comment
vasqu Aug 7, 2025
0a45416
change dropout kwarg back to before
vasqu Aug 7, 2025
92e2075
rename _fa3... as it's used for multiple variants and should work as …
vasqu Aug 7, 2025
22574dc
Merge branch 'main' into fix-fa-integration
vasqu Aug 7, 2025
3581bd6
simplify imports and support kwargs for fa
vasqu Aug 7, 2025
2f607ba
style
vasqu Aug 7, 2025
49ce7ae
fix comments order
vasqu Aug 7, 2025
5f7d937
small fix
vasqu Aug 8, 2025
d21095c
skip kernels test (causes cuda illegal memories w/o cleanup), fix fa …
vasqu Aug 8, 2025
36bfffb
style
vasqu Aug 8, 2025
9ec8c45
Merge branch 'main' into fix-fa-integration
vasqu Aug 8, 2025
1612b56
allow fullgraph by preloading on init
vasqu Aug 8, 2025
ba8fd00
make globals "private"
vasqu Aug 8, 2025
5240985
Merge branch 'main' into fix-fa-integration
vasqu Aug 8, 2025
3dbf11a
ci pls be happy
vasqu Aug 8, 2025
d9d8ff7
change skip conditions based on backend flag (indicating missing mask…
vasqu Aug 11, 2025
ec0fbf3
move globals support to a function to prepare kwargs
vasqu Aug 11, 2025
a6996f5
style
vasqu Aug 11, 2025
ce0e586
Merge branch 'main' into fix-fa-integration
vasqu Aug 11, 2025
86c9e81
generalize supported kwargs
vasqu Aug 11, 2025
3ae14ec
small change to doc
vasqu Aug 11, 2025
89b8f95
fix
vasqu Aug 11, 2025
24512e4
Merge branch 'main' into fix-fa-integration
vasqu Aug 11, 2025
ec41ac5
Merge branch 'main' into fix-fa-integration
vasqu Aug 11, 2025
74a8987
Merge branch 'main' into fix-fa-integration
vasqu Aug 12, 2025
8f91219
add comments
vasqu Aug 12, 2025
6a0a3a1
style
vasqu Aug 12, 2025
54ed29e
Merge branch 'main' into fix-fa-integration
vasqu Aug 12, 2025
7016548
revert prep during generate
vasqu Aug 12, 2025
d9da331
Merge branch 'main' into fix-fa-integration
vasqu Aug 12, 2025
a996bd5
style
vasqu Aug 12, 2025
07fafe1
revert weird style changes
vasqu Aug 12, 2025
a98fac4
add fa kwarg prep during generate with fixes back
vasqu Aug 12, 2025
9971d75
how did this even happen
vasqu Aug 12, 2025
5e2d35f
how
vasqu Aug 12, 2025
85cb1b1
Merge branch 'main' into fix-fa-integration
vasqu Aug 12, 2025
4ad364c
add comment
vasqu Aug 12, 2025
cb89dbe
Merge branch 'main' into fix-fa-integration
vasqu Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,10 @@ def prepare_inputs_for_generation(
if encoder_attention_mask is not None:
model_inputs["attention_mask"] = encoder_attention_mask

# 7. Prepare kwargs for flash attention to avoid recomputations
if "flash" in self.config._attn_implementation and self._supports_attention_backend:
cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k = prepare_fa_kwargs_from_position_ids(
position_ids, is_packed_sequence=False
(cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(
model_inputs["position_ids"], is_packed_sequence=False
)
model_inputs.update(
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
Expand All @@ -689,12 +690,12 @@ def prepare_inputs_for_generation(
max_length_k=max_length_k,
)

# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
# 8. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value

# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
# 9. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
model_inputs.pop("labels", None)
return model_inputs

Expand Down
132 changes: 6 additions & 126 deletions src/transformers/integrations/npu_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os

import torch
import torch.nn.functional as F

from ..utils.import_utils import is_torch_npu_available


if is_torch_npu_available():
import math

import torch_npu
from einops import rearrange, repeat
from torch_npu import npu_rotary_mul
from torch_npu import npu_fusion_attention, npu_rotary_mul


# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
Expand Down Expand Up @@ -52,117 +48,6 @@ def is_npu_fa2_top_left_aligned_causal_mask():
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False


# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
).reshape(-1, *other_shape)

@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None


index_first_axis = IndexFirstAxis.apply


# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output

@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None


index_put_first_axis = IndexPutFirstAxis.apply


# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
# dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)


# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
def unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
"""
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)


def npu_flash_attn_func(
q,
k,
Expand All @@ -179,11 +64,11 @@ def npu_flash_attn_func(

if not causal:
head_num = q.shape[2]
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
else:
attn_mask_npu = get_attn_mask_npu(q.device)
head_num = q.shape[2]
output = torch_npu.npu_fusion_attention(
output = npu_fusion_attention(
q,
k,
v,
Expand Down Expand Up @@ -218,7 +103,7 @@ def npu_flash_attn_varlen_func(

if not causal:
head_num = q.shape[1]
output = torch_npu.npu_fusion_attention(
output = npu_fusion_attention(
q,
k,
v,
Expand All @@ -234,7 +119,7 @@ def npu_flash_attn_varlen_func(
else:
attn_mask_npu = get_attn_mask_npu(q.device)
head_num = q.shape[1]
output = torch_npu.npu_fusion_attention(
output = npu_fusion_attention(
q,
k,
v,
Expand Down Expand Up @@ -267,8 +152,3 @@ def npu_apply_rotary_emb(x, cos, sin, **kwargs):
sin = sin.unsqueeze(0).unsqueeze(2)

return npu_rotary_mul(x, cos, sin)


def get_npu_flash_attn_funcs():
# return flash attention related functions used for Ascend NPU in order
return npu_flash_attn_func, npu_flash_attn_varlen_func, pad_input, unpad_input, False
Loading