1818import math
1919from dataclasses import dataclass
2020from enum import Enum
21- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Literal , Optional , Tuple , Union
21+ from typing import (
22+ TYPE_CHECKING ,
23+ Any ,
24+ Callable ,
25+ Dict ,
26+ List ,
27+ Literal ,
28+ Optional ,
29+ Tuple ,
30+ Union ,
31+ )
2232
2333import torch
2434
6878
6979if _CAN_USE_FLASH_ATTN :
7080 from flash_attn import flash_attn_func , flash_attn_varlen_func
71- from flash_attn .flash_attn_interface import _wrapped_flash_attn_backward , _wrapped_flash_attn_forward
81+ from flash_attn .flash_attn_interface import (
82+ _wrapped_flash_attn_backward ,
83+ _wrapped_flash_attn_forward ,
84+ )
7285else :
7386 flash_attn_func = None
7487 flash_attn_varlen_func = None
7790
7891
7992if _CAN_USE_FLASH_ATTN_3 :
93+ from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward
8094 from flash_attn_interface import flash_attn_func as flash_attn_3_func
8195 from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
82- from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward
8396else :
8497 flash_attn_3_func = None
8598 flash_attn_3_varlen_func = None
122135
123136
124137if _CAN_USE_XLA_ATTN :
125- from torch_xla .experimental .custom_kernel import flash_attention as xla_flash_attention
138+ from torch_xla .experimental .custom_kernel import (
139+ flash_attention as xla_flash_attention ,
140+ )
126141else :
127142 xla_flash_attention = None
128143
@@ -265,13 +280,17 @@ class _HubKernelConfig:
265280_HUB_KERNELS_REGISTRY : Dict ["AttentionBackendName" , _HubKernelConfig ] = {
266281 # TODO: temporary revision for now. Remove when merged upstream into `main`.
267282 AttentionBackendName ._FLASH_3_HUB : _HubKernelConfig (
268- repo_id = "kernels-community/flash-attn3" , function_attr = "flash_attn_func" , revision = "fake-ops-return-probs"
283+ repo_id = "kernels-community/flash-attn3" ,
284+ function_attr = "flash_attn_func" ,
285+ revision = "fake-ops-return-probs" ,
269286 )
270287}
271288
272289
273290@contextlib .contextmanager
274- def attention_backend (backend : Union [str , AttentionBackendName ] = AttentionBackendName .NATIVE ):
291+ def attention_backend (
292+ backend : Union [str , AttentionBackendName ] = AttentionBackendName .NATIVE ,
293+ ):
275294 """
276295 Context manager to set the active attention backend.
277296 """
@@ -416,7 +435,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
416435 f"Flash Attention backend '{ backend .value } ' is not usable because of missing package or the version is too old. Please install `flash-attn>={ _REQUIRED_FLASH_VERSION } `."
417436 )
418437
419- elif backend in [AttentionBackendName ._FLASH_3 , AttentionBackendName ._FLASH_VARLEN_3 ]:
438+ elif backend in [
439+ AttentionBackendName ._FLASH_3 ,
440+ AttentionBackendName ._FLASH_VARLEN_3 ,
441+ ]:
420442 if not _CAN_USE_FLASH_ATTN_3 :
421443 raise RuntimeError (
422444 f"Flash Attention 3 backend '{ backend .value } ' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
@@ -488,7 +510,11 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
488510 cu_seqlens_k [1 :] = torch .cumsum (seqlens_k , dim = 0 )
489511 max_seqlen_q = seqlens_q .max ().item ()
490512 max_seqlen_k = seqlens_k .max ().item ()
491- return (seqlens_q , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k )
513+ return (
514+ (seqlens_q , seqlens_k ),
515+ (cu_seqlens_q , cu_seqlens_k ),
516+ (max_seqlen_q , max_seqlen_k ),
517+ )
492518
493519
494520def _prepare_for_flash_attn_or_sage_varlen_with_mask (
@@ -505,7 +531,11 @@ def _prepare_for_flash_attn_or_sage_varlen_with_mask(
505531 cu_seqlens_k [1 :] = torch .cumsum (seqlens_k , dim = 0 )
506532 max_seqlen_q = seqlens_q .max ().item ()
507533 max_seqlen_k = seqlens_k .max ().item ()
508- return (seqlens_q , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k )
534+ return (
535+ (seqlens_q , seqlens_k ),
536+ (cu_seqlens_q , cu_seqlens_k ),
537+ (max_seqlen_q , max_seqlen_k ),
538+ )
509539
510540
511541def _prepare_for_flash_attn_or_sage_varlen (
@@ -625,7 +655,7 @@ def _wrapped_flash_attn_3(
625655 window_size = (- 1 , - 1 )
626656 max_seqlen_q = q .shape [2 ]
627657 max_seqlen_k = k .shape [2 ]
628-
658+
629659 out , lse , * _ = flash_attn_3_forward (
630660 q = q ,
631661 k = k ,
@@ -764,7 +794,10 @@ def _native_attention_backward_op(
764794
765795 grad_out_t = grad_out .permute (0 , 2 , 1 , 3 )
766796 grad_query_t , grad_key_t , grad_value_t = torch .autograd .grad (
767- outputs = out , inputs = [query_t , key_t , value_t ], grad_outputs = grad_out_t , retain_graph = False
797+ outputs = out ,
798+ inputs = [query_t , key_t , value_t ],
799+ grad_outputs = grad_out_t ,
800+ retain_graph = False ,
768801 )
769802
770803 grad_query = grad_query_t .permute (0 , 2 , 1 , 3 )
@@ -803,18 +836,26 @@ def _cudnn_attention_forward_op(
803836 value = value .transpose (1 , 2 ).contiguous ()
804837 tensors_to_save += (query , key , value )
805838
806- out , lse , cum_seq_q , cum_seq_k , max_q , max_k , philox_seed , philox_offset , debug_attn_mask = (
807- torch .ops .aten ._scaled_dot_product_cudnn_attention (
808- query = query ,
809- key = key ,
810- value = value ,
811- attn_bias = attn_mask ,
812- compute_log_sumexp = return_lse ,
813- dropout_p = dropout_p ,
814- is_causal = is_causal ,
815- return_debug_mask = False ,
816- scale = scale ,
817- )
839+ (
840+ out ,
841+ lse ,
842+ cum_seq_q ,
843+ cum_seq_k ,
844+ max_q ,
845+ max_k ,
846+ philox_seed ,
847+ philox_offset ,
848+ debug_attn_mask ,
849+ ) = torch .ops .aten ._scaled_dot_product_cudnn_attention (
850+ query = query ,
851+ key = key ,
852+ value = value ,
853+ attn_bias = attn_mask ,
854+ compute_log_sumexp = return_lse ,
855+ dropout_p = dropout_p ,
856+ is_causal = is_causal ,
857+ return_debug_mask = False ,
858+ scale = scale ,
818859 )
819860
820861 tensors_to_save += (out , lse , cum_seq_q , cum_seq_k , philox_seed , philox_offset )
@@ -941,7 +982,11 @@ def _flash_attention_backward_op(
941982 ** kwargs ,
942983):
943984 query , key , value , out , lse , rng_state = ctx .saved_tensors
944- grad_query , grad_key , grad_value = torch .empty_like (query ), torch .empty_like (key ), torch .empty_like (value )
985+ grad_query , grad_key , grad_value = (
986+ torch .empty_like (query ),
987+ torch .empty_like (key ),
988+ torch .empty_like (value ),
989+ )
945990
946991 lse_d = _wrapped_flash_attn_backward ( # noqa: F841
947992 grad_out ,
@@ -1165,7 +1210,19 @@ def backward(
11651210
11661211 grad_query , grad_key , grad_value = (x .to (grad_out .dtype ) for x in (grad_query , grad_key , grad_value ))
11671212
1168- return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None
1213+ return (
1214+ grad_query ,
1215+ grad_key ,
1216+ grad_value ,
1217+ None ,
1218+ None ,
1219+ None ,
1220+ None ,
1221+ None ,
1222+ None ,
1223+ None ,
1224+ None ,
1225+ )
11691226
11701227
11711228class TemplatedUlyssesAttention (torch .autograd .Function ):
@@ -1260,7 +1317,19 @@ def backward(
12601317 x .flatten (0 , 1 ).permute (1 , 2 , 0 , 3 ).contiguous () for x in (grad_query , grad_key , grad_value )
12611318 )
12621319
1263- return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None
1320+ return (
1321+ grad_query ,
1322+ grad_key ,
1323+ grad_value ,
1324+ None ,
1325+ None ,
1326+ None ,
1327+ None ,
1328+ None ,
1329+ None ,
1330+ None ,
1331+ None ,
1332+ )
12641333
12651334
12661335def _templated_context_parallel_attention (
@@ -1608,7 +1677,12 @@ def _native_flex_attention(
16081677 block_mask = attn_mask
16091678 elif is_causal :
16101679 block_mask = flex_attention .create_block_mask (
1611- _flex_attention_causal_mask_mod , batch_size , num_heads , seq_len_q , seq_len_kv , query .device
1680+ _flex_attention_causal_mask_mod ,
1681+ batch_size ,
1682+ num_heads ,
1683+ seq_len_q ,
1684+ seq_len_kv ,
1685+ query .device ,
16121686 )
16131687 elif torch .is_tensor (attn_mask ):
16141688 if attn_mask .ndim == 2 :
@@ -1628,6 +1702,7 @@ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
16281702
16291703 def score_mod (score , batch_idx , head_idx , q_idx , kv_idx ):
16301704 return score + attn_mask [batch_idx , head_idx , q_idx , kv_idx ]
1705+
16311706 else :
16321707 raise ValueError ("Attention mask must be either None, a BlockMask, or a 2D/4D tensor." )
16331708
0 commit comments