Skip to content

Commit 216c655

Browse files
authored
[CodeStyle] black -> ruff format migration - part 32 (#74746)
1 parent b6baf35 commit 216c655

24 files changed

+222
-210
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ repos:
8989
9090
# | python/paddle/j.+
9191
92-
# | python/paddle/[k-n].+
92+
| python/paddle/[k-n].+
9393
9494
# | python/paddle/[o-t].+
9595
@@ -145,7 +145,7 @@ repos:
145145
146146
| python/paddle/j.+
147147
148-
| python/paddle/[k-n].+
148+
# | python/paddle/[k-n].+
149149
150150
| python/paddle/[o-t].+
151151

python/paddle/nn/functional/activation.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,9 @@ def prelu(
603603
[-1.25000000, 6. , 7. , -2. ],
604604
[ 6. , 7. , 8. , 9. ]]]])
605605
"""
606-
assert (
607-
len(weight.shape) == 0 or len(weight.shape) == 1
608-
), "The dim count of weight shape should be 0 or 1 in prelu()."
606+
assert len(weight.shape) == 0 or len(weight.shape) == 1, (
607+
"The dim count of weight shape should be 0 or 1 in prelu()."
608+
)
609609

610610
mode = 'all'
611611
if len(weight.shape) == 1 and weight.shape[0] > 1:
@@ -626,19 +626,19 @@ def prelu(
626626

627627
data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'
628628

629-
assert (
630-
len(x.shape) > 1
631-
), "The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]."
629+
assert len(x.shape) > 1, (
630+
"The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]."
631+
)
632632

633633
# NOTE(GuoxiaWang): support NHWC data format
634634
if data_format == 'NHWC':
635-
assert (
636-
weight.shape[0] == x.shape[-1]
637-
), "The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
635+
assert weight.shape[0] == x.shape[-1], (
636+
"The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
637+
)
638638
else:
639-
assert (
640-
weight.shape[0] == x.shape[1]
641-
), "The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
639+
assert weight.shape[0] == x.shape[1], (
640+
"The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
641+
)
642642
mode = 'channel'
643643

644644
if in_dynamic_or_pir_mode():

python/paddle/nn/functional/common.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -601,9 +601,9 @@ def _is_list_or_tuple_(data):
601601
if isinstance(dim_size, (Variable, paddle.pir.Value)):
602602
contain_var = True
603603
continue
604-
assert (
605-
dim_size > 0
606-
), "Each dimension size given in out_shape must be greater than 0."
604+
assert dim_size > 0, (
605+
"Each dimension size given in out_shape must be greater than 0."
606+
)
607607

608608
if contain_var:
609609
new_size_tensor = []
@@ -2068,7 +2068,9 @@ def pad(
20682068
'replicate',
20692069
'constant',
20702070
'circular',
2071-
], f"mode should be one of constant, reflect, replicate, circular, but got {mode}."
2071+
], (
2072+
f"mode should be one of constant, reflect, replicate, circular, but got {mode}."
2073+
)
20722074

20732075
x_dim = len(x.shape)
20742076
if in_dynamic_mode():
@@ -2162,9 +2164,9 @@ def pad(
21622164
4: ["NCHW", "NHWC"],
21632165
5: ["NCDHW", "NDHWC"],
21642166
}
2165-
assert (
2166-
data_format in supported_format_map[x_dim]
2167-
), f"input tensor dimension is {x_dim}, it's data format should be in {supported_format_map[x_dim]} but got {data_format}"
2167+
assert data_format in supported_format_map[x_dim], (
2168+
f"input tensor dimension is {x_dim}, it's data format should be in {supported_format_map[x_dim]} but got {data_format}"
2169+
)
21682170

21692171
unsqueezed_dim = []
21702172

@@ -2831,40 +2833,40 @@ def fold(
28312833
)
28322834

28332835
assert len(x.shape) == 3, "input should be the format of [N, C, L]"
2834-
assert (
2835-
math.prod(x.shape) >= 0
2836-
), "The number of elements must greater or equal than zero."
2836+
assert math.prod(x.shape) >= 0, (
2837+
"The number of elements must greater or equal than zero."
2838+
)
28372839

28382840
def _is_list_or_tuple_(data):
28392841
return isinstance(data, (list, tuple))
28402842

28412843
if isinstance(output_sizes, int):
28422844
output_sizes = [output_sizes, output_sizes]
28432845
else:
2844-
assert _is_list_or_tuple_(output_sizes) and (
2845-
len(output_sizes) == 2
2846-
), "output_sizes should either be an integer or a list/tuple of two integers"
2846+
assert _is_list_or_tuple_(output_sizes) and (len(output_sizes) == 2), (
2847+
"output_sizes should either be an integer or a list/tuple of two integers"
2848+
)
28472849

28482850
if isinstance(kernel_sizes, int):
28492851
kernel_sizes = [kernel_sizes, kernel_sizes]
28502852
else:
2851-
assert _is_list_or_tuple_(kernel_sizes) and (
2852-
len(kernel_sizes) == 2
2853-
), "kernel_sizes should either be an integer or a list/tuple of two integers"
2853+
assert _is_list_or_tuple_(kernel_sizes) and (len(kernel_sizes) == 2), (
2854+
"kernel_sizes should either be an integer or a list/tuple of two integers"
2855+
)
28542856

28552857
if isinstance(strides, int):
28562858
strides = [strides, strides]
28572859
else:
2858-
assert _is_list_or_tuple_(strides) and (
2859-
len(strides) == 2
2860-
), "strides should either be an integer or a list/tuple of two integers"
2860+
assert _is_list_or_tuple_(strides) and (len(strides) == 2), (
2861+
"strides should either be an integer or a list/tuple of two integers"
2862+
)
28612863

28622864
if isinstance(dilations, int):
28632865
dilations = [dilations, dilations]
28642866
else:
2865-
assert _is_list_or_tuple_(dilations) and (
2866-
len(dilations) == 2
2867-
), "dilations should either be an integer or a list/tuple of two integers"
2867+
assert _is_list_or_tuple_(dilations) and (len(dilations) == 2), (
2868+
"dilations should either be an integer or a list/tuple of two integers"
2869+
)
28682870

28692871
if isinstance(paddings, int):
28702872
paddings = [paddings] * 4

python/paddle/nn/functional/conv.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,9 @@ def _conv_nd(
272272
attrs={'axis': -1},
273273
)
274274
else:
275-
assert len(x_shape) > len(
276-
y_shape
277-
), 'The length of pre_bias must greater than the length of bias'
275+
assert len(x_shape) > len(y_shape), (
276+
'The length of pre_bias must greater than the length of bias'
277+
)
278278
padding = len(x_shape) - len(y_shape) - channel_dim
279279
bias = reshape(
280280
bias, [1] * channel_dim + y_shape + [1] * padding
@@ -1336,9 +1336,9 @@ def conv2d_transpose(
13361336
attrs={'axis': -1},
13371337
)
13381338
else:
1339-
assert len(x_shape) > len(
1340-
y_shape
1341-
), 'The length of pre_bias must greater than the length of bias'
1339+
assert len(x_shape) > len(y_shape), (
1340+
'The length of pre_bias must greater than the length of bias'
1341+
)
13421342
padding = len(x_shape) - len(y_shape) - channel_dim
13431343
bias = reshape(
13441344
bias, [1] * channel_dim + y_shape + [1] * padding

python/paddle/nn/functional/flash_attention.py

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -508,30 +508,30 @@ def flash_attention(
508508
fa_version = paddle.base.framework.get_flags(
509509
["FLAGS_flash_attn_version"]
510510
)["FLAGS_flash_attn_version"]
511-
assert (
512-
in_dynamic_or_pir_mode() or fa_version == 2
513-
), "flash attention 3 only support dynamic or pir mode"
514-
assert (
515-
dropout == 0.0 or fa_version == 2
516-
), "flash attention 3 does not support dropout"
517-
assert (
518-
not return_softmax or fa_version == 2
519-
), "flash attention 3 does not support return softmax"
520-
assert (
521-
fixed_seed_offset is None or fa_version == 2
522-
), "flash attention 3 does not support return softmax"
523-
assert (
524-
rng_name == "" or fa_version == 2
525-
), "flash attention 3 does not support setting rng_name"
526-
assert (
527-
training or fa_version == 2
528-
), "flash attention 3 does not support setting training"
529-
assert (
530-
name is None or fa_version == 2
531-
), "flash attention 3 does not support setting name"
532-
assert (
533-
softmax_scale is None or fa_version == 3
534-
), "flash attention 2 does not support setting softmax_scale"
511+
assert in_dynamic_or_pir_mode() or fa_version == 2, (
512+
"flash attention 3 only support dynamic or pir mode"
513+
)
514+
assert dropout == 0.0 or fa_version == 2, (
515+
"flash attention 3 does not support dropout"
516+
)
517+
assert not return_softmax or fa_version == 2, (
518+
"flash attention 3 does not support return softmax"
519+
)
520+
assert fixed_seed_offset is None or fa_version == 2, (
521+
"flash attention 3 does not support return softmax"
522+
)
523+
assert rng_name == "" or fa_version == 2, (
524+
"flash attention 3 does not support setting rng_name"
525+
)
526+
assert training or fa_version == 2, (
527+
"flash attention 3 does not support setting training"
528+
)
529+
assert name is None or fa_version == 2, (
530+
"flash attention 3 does not support setting name"
531+
)
532+
assert softmax_scale is None or fa_version == 3, (
533+
"flash attention 2 does not support setting softmax_scale"
534+
)
535535
if in_dynamic_or_pir_mode():
536536
if fa_version == 2:
537537
(result_attention, result_softmax, _, _) = _C_ops.flash_attn(
@@ -1142,9 +1142,9 @@ def flash_attn_varlen_func(
11421142
>>> output = paddle.nn.functional.flash_attention.flash_attention_v3_varlen(q, q, q, cu_seqlens_q, cu_seqlens_q, max_seqlen_q=max_seq_len_q, max_seqlen_k=max_seq_len_q, causal=True)
11431143
>>> # doctest: -SKIP
11441144
"""
1145-
assert (
1146-
"xpu" not in paddle.get_device()
1147-
), "flash_attn_varlen_func is not supported on xpu"
1145+
assert "xpu" not in paddle.get_device(), (
1146+
"flash_attn_varlen_func is not supported on xpu"
1147+
)
11481148

11491149
assert not paddle.get_flags(["FLAGS_cudnn_deterministic"])[
11501150
"FLAGS_cudnn_deterministic"
@@ -1157,9 +1157,9 @@ def flash_attn_varlen_func(
11571157
== 3
11581158
), "FLAGS_flash_attn_version is 2, conflicts with flash_attn_varlen_func"
11591159

1160-
assert (
1161-
in_dynamic_or_pir_mode()
1162-
), "flash_attn_varlen_func only support dynamic or pir mode"
1160+
assert in_dynamic_or_pir_mode(), (
1161+
"flash_attn_varlen_func only support dynamic or pir mode"
1162+
)
11631163

11641164
assert qv is None, "flash_attn_varlen_func does not support setting qv"
11651165

@@ -2203,9 +2203,9 @@ def flashmask_attention(
22032203
window_size = (window_size, window_size)
22042204
sq = query.shape[1]
22052205
bsz = query.shape[0]
2206-
assert (
2207-
startend_row_indices is None
2208-
), "can't use window_size with startend_row_indices"
2206+
assert startend_row_indices is None, (
2207+
"can't use window_size with startend_row_indices"
2208+
)
22092209
if causal:
22102210
startend_row_indices = paddle.arange(
22112211
window_size[0] + 1, sq + window_size[0] + 1, dtype="int32"
@@ -2246,24 +2246,26 @@ def flashmask_attention(
22462246
)
22472247

22482248
else:
2249-
assert (
2250-
startend_row_indices.dtype == paddle.int32
2251-
), f"startend_row_indices.dtype must be paddle.int32, but got {startend_row_indices.dtype}"
2252-
assert (
2253-
len(startend_row_indices.shape) == 4
2254-
), f"startend_row_indices rank must be 4,but got {startend_row_indices.shape}"
2255-
2256-
assert (
2257-
startend_row_indices.shape[0] == key.shape[0]
2258-
), f"startend_row_indices.shape[0] must be equal to batch_size, but got {startend_row_indices.shape[0]} and {key.shape[0]}"
2259-
2260-
assert (
2261-
startend_row_indices.shape[2] == key.shape[1]
2262-
), f"startend_row_indices.shape[2] must be equal to seqlen_k, but got {startend_row_indices.shape[2]} and {key.shape[2]}"
2249+
assert startend_row_indices.dtype == paddle.int32, (
2250+
f"startend_row_indices.dtype must be paddle.int32, but got {startend_row_indices.dtype}"
2251+
)
2252+
assert len(startend_row_indices.shape) == 4, (
2253+
f"startend_row_indices rank must be 4,but got {startend_row_indices.shape}"
2254+
)
2255+
2256+
assert startend_row_indices.shape[0] == key.shape[0], (
2257+
f"startend_row_indices.shape[0] must be equal to batch_size, but got {startend_row_indices.shape[0]} and {key.shape[0]}"
2258+
)
2259+
2260+
assert startend_row_indices.shape[2] == key.shape[1], (
2261+
f"startend_row_indices.shape[2] must be equal to seqlen_k, but got {startend_row_indices.shape[2]} and {key.shape[2]}"
2262+
)
22632263
assert startend_row_indices.shape[1] in [
22642264
1,
22652265
key.shape[2],
2266-
], "startend_row_indices head_num must be equal to 1(broadcast) or head_num_k."
2266+
], (
2267+
"startend_row_indices head_num must be equal to 1(broadcast) or head_num_k."
2268+
)
22672269

22682270
if causal:
22692271
if startend_row_indices.shape[-1] == 1:
@@ -2383,9 +2385,9 @@ def calc_reduced_attention_scores(
23832385
>>> )
23842386
>>> # doctest: -SKIP
23852387
"""
2386-
assert (
2387-
query.stop_gradient and key.stop_gradient
2388-
), 'calc_reduced_attention_scores() is for inference only.'
2388+
assert query.stop_gradient and key.stop_gradient, (
2389+
'calc_reduced_attention_scores() is for inference only.'
2390+
)
23892391

23902392
if in_dynamic_or_pir_mode():
23912393
reduced_scores = _C_ops.calc_reduced_attn_scores(

python/paddle/nn/functional/loss.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ def dice_loss(
9494
"""
9595
assert input.dtype in (paddle.float32, paddle.float64)
9696
assert label.dtype in (paddle.int32, paddle.int64)
97-
assert (
98-
len(input.shape) >= 2
99-
), "The rank of input should be greater than or equal to 2."
97+
assert len(input.shape) >= 2, (
98+
"The rank of input should be greater than or equal to 2."
99+
)
100100
assert len(input.shape) == len(label.shape), (
101101
"The rank of input and label should be equal, "
102102
f"but received input: {len(input.shape)}, label: {len(label.shape)}."
@@ -105,9 +105,9 @@ def dice_loss(
105105
"The last dimension of label should be 1, "
106106
f"but received {label.shape[-1]}."
107107
)
108-
assert (
109-
input.shape[:-1] == label.shape[:-1]
110-
), "All dimensions should be equal except the last one."
108+
assert input.shape[:-1] == label.shape[:-1], (
109+
"All dimensions should be equal except the last one."
110+
)
111111

112112
label = paddle.squeeze(label, [-1])
113113
label = paddle.nn.functional.one_hot(label, input.shape[-1])

python/paddle/nn/functional/pooling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -704,9 +704,9 @@ def max_pool1d(
704704

705705

706706
def _unpool_output_size(x, kernel_size, stride, padding, output_size):
707-
assert output_size is None or isinstance(
708-
output_size, (list, tuple)
709-
), f"Required output_size is None|list|tuple, but received {output_size}"
707+
assert output_size is None or isinstance(output_size, (list, tuple)), (
708+
f"Required output_size is None|list|tuple, but received {output_size}"
709+
)
710710
input_size = x.shape
711711
default_size = []
712712
for d in range(len(kernel_size)):

python/paddle/nn/initializer/bilinear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ def forward(
9696
"""
9797
assert not (
9898
isinstance(var, framework.EagerParamBase) and var.is_dist()
99-
), "Currently, Bilinear initializer not support lazy init for dist param."
99+
), (
100+
"Currently, Bilinear initializer not support lazy init for dist param."
101+
)
100102
block = self._check_block(block)
101103

102104
if not isinstance(var, (framework.Variable, pir.core.ParameterMeta)):

0 commit comments

Comments
 (0)