1010# See the License for the specific language governing permissions and
1111# limitations under the License.
1212
13+ import math
1314import os
1415
1516import torch
16- import torch .nn .functional as F
1717
1818from ..utils .import_utils import is_torch_npu_available
1919
2020
2121if is_torch_npu_available ():
22- import math
23-
24- import torch_npu
25- from einops import rearrange , repeat
26- from torch_npu import npu_rotary_mul
22+ from torch_npu import npu_fusion_attention , npu_rotary_mul
2723
2824
2925# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
@@ -52,117 +48,6 @@ def is_npu_fa2_top_left_aligned_causal_mask():
5248 return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available () else False
5349
5450
55- # Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
56- class IndexFirstAxis (torch .autograd .Function ):
57- @staticmethod
58- def forward (ctx , input , indices ):
59- ctx .save_for_backward (indices )
60- assert input .ndim >= 2
61- ctx .first_axis_dim , other_shape = input .shape [0 ], input .shape [1 :]
62- second_dim = other_shape .numel ()
63- # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
64- # return input[indices]
65- return torch .gather (
66- rearrange (input , "b ... -> b (...)" ), 0 , repeat (indices , "z -> z d" , d = second_dim )
67- ).reshape (- 1 , * other_shape )
68-
69- @staticmethod
70- def backward (ctx , grad_output ):
71- (indices ,) = ctx .saved_tensors
72- assert grad_output .ndim >= 2
73- other_shape = grad_output .shape [1 :]
74- grad_output = rearrange (grad_output , "b ... -> b (...)" )
75- grad_input = torch .zeros (
76- [ctx .first_axis_dim , grad_output .shape [1 ]],
77- device = grad_output .device ,
78- dtype = grad_output .dtype ,
79- )
80- # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
81- # grad_input[indices] = grad_output
82- grad_input .scatter_ (0 , repeat (indices , "z -> z d" , d = grad_output .shape [1 ]), grad_output )
83- return grad_input .reshape (ctx .first_axis_dim , * other_shape ), None
84-
85-
86- index_first_axis = IndexFirstAxis .apply
87-
88-
89- # Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
90- class IndexPutFirstAxis (torch .autograd .Function ):
91- @staticmethod
92- def forward (ctx , values , indices , first_axis_dim ):
93- ctx .save_for_backward (indices )
94- assert indices .ndim == 1
95- assert values .ndim >= 2
96- output = torch .zeros (first_axis_dim , * values .shape [1 :], device = values .device , dtype = values .dtype )
97- # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
98- output [indices ] = values
99- # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
100- return output
101-
102- @staticmethod
103- def backward (ctx , grad_output ):
104- (indices ,) = ctx .saved_tensors
105- # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
106- grad_values = grad_output [indices ]
107- # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
108- return grad_values , None , None
109-
110-
111- index_put_first_axis = IndexPutFirstAxis .apply
112-
113-
114- # Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
115- def pad_input (hidden_states , indices , batch , seqlen ):
116- """
117- Arguments:
118- hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
119- indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
120- batch: int, batch size for the padded sequence.
121- seqlen: int, maximum sequence length for the padded sequence.
122- Return:
123- hidden_states: (batch, seqlen, ...)
124- """
125- # dim = hidden_states.shape[-1]
126- # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
127- # output[indices] = hidden_states
128- output = index_put_first_axis (hidden_states , indices , batch * seqlen )
129- return rearrange (output , "(b s) ... -> b s ..." , b = batch )
130-
131-
132- # Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
133- def unpad_input (hidden_states , attention_mask , unused_mask = None ):
134- """
135- Arguments:
136- hidden_states: (batch, seqlen, ...)
137- attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
138- unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
139- Return:
140- hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
141- indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
142- cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
143- max_seqlen_in_batch: int
144- seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
145- """
146- all_masks = (attention_mask + unused_mask ) if unused_mask is not None else attention_mask
147- seqlens_in_batch = all_masks .sum (dim = - 1 , dtype = torch .int32 )
148- used_seqlens_in_batch = attention_mask .sum (dim = - 1 , dtype = torch .int32 )
149- indices = torch .nonzero (all_masks .flatten (), as_tuple = False ).flatten ()
150- max_seqlen_in_batch = seqlens_in_batch .max ().item ()
151- cu_seqlens = F .pad (torch .cumsum (seqlens_in_batch , dim = 0 , dtype = torch .int32 ), (1 , 0 ))
152- # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
153- # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
154- # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
155- # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
156- # so we write custom forward and backward to make it a bit faster.
157- return (
158- index_first_axis (rearrange (hidden_states , "b s ... -> (b s) ..." ), indices ),
159- indices ,
160- cu_seqlens ,
161- max_seqlen_in_batch ,
162- used_seqlens_in_batch ,
163- )
164-
165-
16651def npu_flash_attn_func (
16752 q ,
16853 k ,
@@ -179,11 +64,11 @@ def npu_flash_attn_func(
17964
18065 if not causal :
18166 head_num = q .shape [2 ]
182- output = torch_npu . npu_fusion_attention (q , k , v , head_num , "BSND" , keep_prob = keep_prob , scale = softmax_scale )[0 ]
67+ output = npu_fusion_attention (q , k , v , head_num , "BSND" , keep_prob = keep_prob , scale = softmax_scale )[0 ]
18368 else :
18469 attn_mask_npu = get_attn_mask_npu (q .device )
18570 head_num = q .shape [2 ]
186- output = torch_npu . npu_fusion_attention (
71+ output = npu_fusion_attention (
18772 q ,
18873 k ,
18974 v ,
@@ -218,7 +103,7 @@ def npu_flash_attn_varlen_func(
218103
219104 if not causal :
220105 head_num = q .shape [1 ]
221- output = torch_npu . npu_fusion_attention (
106+ output = npu_fusion_attention (
222107 q ,
223108 k ,
224109 v ,
@@ -234,7 +119,7 @@ def npu_flash_attn_varlen_func(
234119 else :
235120 attn_mask_npu = get_attn_mask_npu (q .device )
236121 head_num = q .shape [1 ]
237- output = torch_npu . npu_fusion_attention (
122+ output = npu_fusion_attention (
238123 q ,
239124 k ,
240125 v ,
@@ -267,8 +152,3 @@ def npu_apply_rotary_emb(x, cos, sin, **kwargs):
267152 sin = sin .unsqueeze (0 ).unsqueeze (2 )
268153
269154 return npu_rotary_mul (x , cos , sin )
270-
271-
272- def get_npu_flash_attn_funcs ():
273- # return flash attention related functions used for Ascend NPU in order
274- return npu_flash_attn_func , npu_flash_attn_varlen_func , pad_input , unpad_input , False
0 commit comments