@@ -508,30 +508,30 @@ def flash_attention(
508508 fa_version = paddle .base .framework .get_flags (
509509 ["FLAGS_flash_attn_version" ]
510510 )["FLAGS_flash_attn_version" ]
511- assert (
512- in_dynamic_or_pir_mode () or fa_version == 2
513- ), "flash attention 3 only support dynamic or pir mode"
514- assert (
515- dropout == 0.0 or fa_version == 2
516- ), "flash attention 3 does not support dropout"
517- assert (
518- not return_softmax or fa_version == 2
519- ), "flash attention 3 does not support return softmax"
520- assert (
521- fixed_seed_offset is None or fa_version == 2
522- ), "flash attention 3 does not support return softmax"
523- assert (
524- rng_name == "" or fa_version == 2
525- ), "flash attention 3 does not support setting rng_name"
526- assert (
527- training or fa_version == 2
528- ), "flash attention 3 does not support setting training"
529- assert (
530- name is None or fa_version == 2
531- ), "flash attention 3 does not support setting name"
532- assert (
533- softmax_scale is None or fa_version == 3
534- ), "flash attention 2 does not support setting softmax_scale"
511+ assert in_dynamic_or_pir_mode () or fa_version == 2 , (
512+ "flash attention 3 only support dynamic or pir mode"
513+ )
514+ assert dropout == 0.0 or fa_version == 2 , (
515+ "flash attention 3 does not support dropout"
516+ )
517+ assert not return_softmax or fa_version == 2 , (
518+ "flash attention 3 does not support return softmax"
519+ )
520+ assert fixed_seed_offset is None or fa_version == 2 , (
521+ "flash attention 3 does not support return softmax"
522+ )
523+ assert rng_name == "" or fa_version == 2 , (
524+ "flash attention 3 does not support setting rng_name"
525+ )
526+ assert training or fa_version == 2 , (
527+ "flash attention 3 does not support setting training"
528+ )
529+ assert name is None or fa_version == 2 , (
530+ "flash attention 3 does not support setting name"
531+ )
532+ assert softmax_scale is None or fa_version == 3 , (
533+ "flash attention 2 does not support setting softmax_scale"
534+ )
535535 if in_dynamic_or_pir_mode ():
536536 if fa_version == 2 :
537537 (result_attention , result_softmax , _ , _ ) = _C_ops .flash_attn (
@@ -1142,9 +1142,9 @@ def flash_attn_varlen_func(
11421142 >>> output = paddle.nn.functional.flash_attention.flash_attention_v3_varlen(q, q, q, cu_seqlens_q, cu_seqlens_q, max_seqlen_q=max_seq_len_q, max_seqlen_k=max_seq_len_q, causal=True)
11431143 >>> # doctest: -SKIP
11441144 """
1145- assert (
1146- "xpu" not in paddle . get_device ()
1147- ), "flash_attn_varlen_func is not supported on xpu"
1145+ assert "xpu" not in paddle . get_device (), (
1146+ "flash_attn_varlen_func is not supported on xpu"
1147+ )
11481148
11491149 assert not paddle .get_flags (["FLAGS_cudnn_deterministic" ])[
11501150 "FLAGS_cudnn_deterministic"
@@ -1157,9 +1157,9 @@ def flash_attn_varlen_func(
11571157 == 3
11581158 ), "FLAGS_flash_attn_version is 2, conflicts with flash_attn_varlen_func"
11591159
1160- assert (
1161- in_dynamic_or_pir_mode ()
1162- ), "flash_attn_varlen_func only support dynamic or pir mode"
1160+ assert in_dynamic_or_pir_mode (), (
1161+ "flash_attn_varlen_func only support dynamic or pir mode"
1162+ )
11631163
11641164 assert qv is None , "flash_attn_varlen_func does not support setting qv"
11651165
@@ -2203,9 +2203,9 @@ def flashmask_attention(
22032203 window_size = (window_size , window_size )
22042204 sq = query .shape [1 ]
22052205 bsz = query .shape [0 ]
2206- assert (
2207- startend_row_indices is None
2208- ), "can't use window_size with startend_row_indices"
2206+ assert startend_row_indices is None , (
2207+ "can't use window_size with startend_row_indices"
2208+ )
22092209 if causal :
22102210 startend_row_indices = paddle .arange (
22112211 window_size [0 ] + 1 , sq + window_size [0 ] + 1 , dtype = "int32"
@@ -2246,24 +2246,26 @@ def flashmask_attention(
22462246 )
22472247
22482248 else :
2249- assert (
2250- startend_row_indices .dtype == paddle .int32
2251- ), f"startend_row_indices.dtype must be paddle.int32, but got { startend_row_indices . dtype } "
2252- assert (
2253- len ( startend_row_indices . shape ) == 4
2254- ), f"startend_row_indices rank must be 4,but got { startend_row_indices . shape } "
2255-
2256- assert (
2257- startend_row_indices .shape [0 ] == key .shape [0 ]
2258- ), f"startend_row_indices.shape[0] must be equal to batch_size, but got { startend_row_indices . shape [ 0 ] } and { key . shape [ 0 ] } "
2259-
2260- assert (
2261- startend_row_indices .shape [2 ] == key .shape [1 ]
2262- ), f"startend_row_indices.shape[2] must be equal to seqlen_k, but got { startend_row_indices . shape [ 2 ] } and { key . shape [ 2 ] } "
2249+ assert startend_row_indices . dtype == paddle . int32 , (
2250+ f" startend_row_indices.dtype must be paddle.int32, but got { startend_row_indices . dtype } "
2251+ )
2252+ assert len ( startend_row_indices . shape ) == 4 , (
2253+ f" startend_row_indices rank must be 4,but got { startend_row_indices . shape } "
2254+ )
2255+
2256+ assert startend_row_indices . shape [ 0 ] == key . shape [ 0 ], (
2257+ f" startend_row_indices.shape[0] must be equal to batch_size, but got { startend_row_indices . shape [ 0 ] } and { key .shape [0 ]} "
2258+ )
2259+
2260+ assert startend_row_indices . shape [ 2 ] == key . shape [ 1 ], (
2261+ f" startend_row_indices.shape[2] must be equal to seqlen_k, but got { startend_row_indices . shape [ 2 ] } and { key .shape [2 ] } "
2262+ )
22632263 assert startend_row_indices .shape [1 ] in [
22642264 1 ,
22652265 key .shape [2 ],
2266- ], "startend_row_indices head_num must be equal to 1(broadcast) or head_num_k."
2266+ ], (
2267+ "startend_row_indices head_num must be equal to 1(broadcast) or head_num_k."
2268+ )
22672269
22682270 if causal :
22692271 if startend_row_indices .shape [- 1 ] == 1 :
@@ -2383,9 +2385,9 @@ def calc_reduced_attention_scores(
23832385 >>> )
23842386 >>> # doctest: -SKIP
23852387 """
2386- assert (
2387- query . stop_gradient and key . stop_gradient
2388- ), 'calc_reduced_attention_scores() is for inference only.'
2388+ assert query . stop_gradient and key . stop_gradient , (
2389+ 'calc_reduced_attention_scores() is for inference only.'
2390+ )
23892391
23902392 if in_dynamic_or_pir_mode ():
23912393 reduced_scores = _C_ops .calc_reduced_attn_scores (
0 commit comments