1010# See the License for the specific language governing permissions and
1111# limitations under the License.
1212
13- import math
1413import os
1514
1615import torch
16+ import torch .nn .functional as F
1717
1818from ..utils .import_utils import is_torch_npu_available
1919
2020
2121if is_torch_npu_available ():
22- from torch_npu import npu_fusion_attention , npu_rotary_mul
22+ import math
23+
24+ import torch_npu
25+ from einops import rearrange , repeat
26+ from torch_npu import npu_rotary_mul
2327
2428
2529# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
@@ -48,6 +52,117 @@ def is_npu_fa2_top_left_aligned_causal_mask():
4852 return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available () else False
4953
5054
55+ # Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
56+ class IndexFirstAxis (torch .autograd .Function ):
57+ @staticmethod
58+ def forward (ctx , input , indices ):
59+ ctx .save_for_backward (indices )
60+ assert input .ndim >= 2
61+ ctx .first_axis_dim , other_shape = input .shape [0 ], input .shape [1 :]
62+ second_dim = other_shape .numel ()
63+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
64+ # return input[indices]
65+ return torch .gather (
66+ rearrange (input , "b ... -> b (...)" ), 0 , repeat (indices , "z -> z d" , d = second_dim )
67+ ).reshape (- 1 , * other_shape )
68+
69+ @staticmethod
70+ def backward (ctx , grad_output ):
71+ (indices ,) = ctx .saved_tensors
72+ assert grad_output .ndim >= 2
73+ other_shape = grad_output .shape [1 :]
74+ grad_output = rearrange (grad_output , "b ... -> b (...)" )
75+ grad_input = torch .zeros (
76+ [ctx .first_axis_dim , grad_output .shape [1 ]],
77+ device = grad_output .device ,
78+ dtype = grad_output .dtype ,
79+ )
80+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
81+ # grad_input[indices] = grad_output
82+ grad_input .scatter_ (0 , repeat (indices , "z -> z d" , d = grad_output .shape [1 ]), grad_output )
83+ return grad_input .reshape (ctx .first_axis_dim , * other_shape ), None
84+
85+
86+ index_first_axis = IndexFirstAxis .apply
87+
88+
89+ # Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
90+ class IndexPutFirstAxis (torch .autograd .Function ):
91+ @staticmethod
92+ def forward (ctx , values , indices , first_axis_dim ):
93+ ctx .save_for_backward (indices )
94+ assert indices .ndim == 1
95+ assert values .ndim >= 2
96+ output = torch .zeros (first_axis_dim , * values .shape [1 :], device = values .device , dtype = values .dtype )
97+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
98+ output [indices ] = values
99+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
100+ return output
101+
102+ @staticmethod
103+ def backward (ctx , grad_output ):
104+ (indices ,) = ctx .saved_tensors
105+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
106+ grad_values = grad_output [indices ]
107+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
108+ return grad_values , None , None
109+
110+
111+ index_put_first_axis = IndexPutFirstAxis .apply
112+
113+
114+ # Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
115+ def pad_input (hidden_states , indices , batch , seqlen ):
116+ """
117+ Arguments:
118+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
119+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
120+ batch: int, batch size for the padded sequence.
121+ seqlen: int, maximum sequence length for the padded sequence.
122+ Return:
123+ hidden_states: (batch, seqlen, ...)
124+ """
125+ # dim = hidden_states.shape[-1]
126+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
127+ # output[indices] = hidden_states
128+ output = index_put_first_axis (hidden_states , indices , batch * seqlen )
129+ return rearrange (output , "(b s) ... -> b s ..." , b = batch )
130+
131+
132+ # Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
133+ def unpad_input (hidden_states , attention_mask , unused_mask = None ):
134+ """
135+ Arguments:
136+ hidden_states: (batch, seqlen, ...)
137+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
138+ unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
139+ Return:
140+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
141+ indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
142+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
143+ max_seqlen_in_batch: int
144+ seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
145+ """
146+ all_masks = (attention_mask + unused_mask ) if unused_mask is not None else attention_mask
147+ seqlens_in_batch = all_masks .sum (dim = - 1 , dtype = torch .int32 )
148+ used_seqlens_in_batch = attention_mask .sum (dim = - 1 , dtype = torch .int32 )
149+ indices = torch .nonzero (all_masks .flatten (), as_tuple = False ).flatten ()
150+ max_seqlen_in_batch = seqlens_in_batch .max ().item ()
151+ cu_seqlens = F .pad (torch .cumsum (seqlens_in_batch , dim = 0 , dtype = torch .int32 ), (1 , 0 ))
152+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
153+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
154+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
155+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
156+ # so we write custom forward and backward to make it a bit faster.
157+ return (
158+ index_first_axis (rearrange (hidden_states , "b s ... -> (b s) ..." ), indices ),
159+ indices ,
160+ cu_seqlens ,
161+ max_seqlen_in_batch ,
162+ used_seqlens_in_batch ,
163+ )
164+
165+
51166def npu_flash_attn_func (
52167 q ,
53168 k ,
@@ -64,11 +179,11 @@ def npu_flash_attn_func(
64179
65180 if not causal :
66181 head_num = q .shape [2 ]
67- output = npu_fusion_attention (q , k , v , head_num , "BSND" , keep_prob = keep_prob , scale = softmax_scale )[0 ]
182+ output = torch_npu . npu_fusion_attention (q , k , v , head_num , "BSND" , keep_prob = keep_prob , scale = softmax_scale )[0 ]
68183 else :
69184 attn_mask_npu = get_attn_mask_npu (q .device )
70185 head_num = q .shape [2 ]
71- output = npu_fusion_attention (
186+ output = torch_npu . npu_fusion_attention (
72187 q ,
73188 k ,
74189 v ,
@@ -103,7 +218,7 @@ def npu_flash_attn_varlen_func(
103218
104219 if not causal :
105220 head_num = q .shape [1 ]
106- output = npu_fusion_attention (
221+ output = torch_npu . npu_fusion_attention (
107222 q ,
108223 k ,
109224 v ,
@@ -119,7 +234,7 @@ def npu_flash_attn_varlen_func(
119234 else :
120235 attn_mask_npu = get_attn_mask_npu (q .device )
121236 head_num = q .shape [1 ]
122- output = npu_fusion_attention (
237+ output = torch_npu . npu_fusion_attention (
123238 q ,
124239 k ,
125240 v ,
0 commit comments