Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ repos:

# | python/paddle/j.+

# | python/paddle/[k-n].+
| python/paddle/[k-n].+

# | python/paddle/[o-t].+

Expand Down Expand Up @@ -145,7 +145,7 @@ repos:

| python/paddle/j.+

| python/paddle/[k-n].+
# | python/paddle/[k-n].+

| python/paddle/[o-t].+

Expand Down
24 changes: 12 additions & 12 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,9 @@ def prelu(
[-1.25000000, 6. , 7. , -2. ],
[ 6. , 7. , 8. , 9. ]]]])
"""
assert (
len(weight.shape) == 0 or len(weight.shape) == 1
), "The dim count of weight shape should be 0 or 1 in prelu()."
assert len(weight.shape) == 0 or len(weight.shape) == 1, (
"The dim count of weight shape should be 0 or 1 in prelu()."
)

mode = 'all'
if len(weight.shape) == 1 and weight.shape[0] > 1:
Expand All @@ -626,19 +626,19 @@ def prelu(

data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'

assert (
len(x.shape) > 1
), "The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]."
assert len(x.shape) > 1, (
"The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]."
)

# NOTE(GuoxiaWang): support NHWC data format
if data_format == 'NHWC':
assert (
weight.shape[0] == x.shape[-1]
), "The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
assert weight.shape[0] == x.shape[-1], (
"The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
)
else:
assert (
weight.shape[0] == x.shape[1]
), "The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
assert weight.shape[0] == x.shape[1], (
"The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
)
mode = 'channel'

if in_dynamic_or_pir_mode():
Expand Down
46 changes: 24 additions & 22 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,9 @@ def _is_list_or_tuple_(data):
if isinstance(dim_size, (Variable, paddle.pir.Value)):
contain_var = True
continue
assert (
dim_size > 0
), "Each dimension size given in out_shape must be greater than 0."
assert dim_size > 0, (
"Each dimension size given in out_shape must be greater than 0."
)

if contain_var:
new_size_tensor = []
Expand Down Expand Up @@ -2068,7 +2068,9 @@ def pad(
'replicate',
'constant',
'circular',
], f"mode should be one of constant, reflect, replicate, circular, but got {mode}."
], (
f"mode should be one of constant, reflect, replicate, circular, but got {mode}."
)

x_dim = len(x.shape)
if in_dynamic_mode():
Expand Down Expand Up @@ -2162,9 +2164,9 @@ def pad(
4: ["NCHW", "NHWC"],
5: ["NCDHW", "NDHWC"],
}
assert (
data_format in supported_format_map[x_dim]
), f"input tensor dimension is {x_dim}, it's data format should be in {supported_format_map[x_dim]} but got {data_format}"
assert data_format in supported_format_map[x_dim], (
f"input tensor dimension is {x_dim}, it's data format should be in {supported_format_map[x_dim]} but got {data_format}"
)

unsqueezed_dim = []

Expand Down Expand Up @@ -2831,40 +2833,40 @@ def fold(
)

assert len(x.shape) == 3, "input should be the format of [N, C, L]"
assert (
math.prod(x.shape) >= 0
), "The number of elements must greater or equal than zero."
assert math.prod(x.shape) >= 0, (
"The number of elements must greater or equal than zero."
)

def _is_list_or_tuple_(data):
return isinstance(data, (list, tuple))

if isinstance(output_sizes, int):
output_sizes = [output_sizes, output_sizes]
else:
assert _is_list_or_tuple_(output_sizes) and (
len(output_sizes) == 2
), "output_sizes should either be an integer or a list/tuple of two integers"
assert _is_list_or_tuple_(output_sizes) and (len(output_sizes) == 2), (
"output_sizes should either be an integer or a list/tuple of two integers"
)

if isinstance(kernel_sizes, int):
kernel_sizes = [kernel_sizes, kernel_sizes]
else:
assert _is_list_or_tuple_(kernel_sizes) and (
len(kernel_sizes) == 2
), "kernel_sizes should either be an integer or a list/tuple of two integers"
assert _is_list_or_tuple_(kernel_sizes) and (len(kernel_sizes) == 2), (
"kernel_sizes should either be an integer or a list/tuple of two integers"
)

if isinstance(strides, int):
strides = [strides, strides]
else:
assert _is_list_or_tuple_(strides) and (
len(strides) == 2
), "strides should either be an integer or a list/tuple of two integers"
assert _is_list_or_tuple_(strides) and (len(strides) == 2), (
"strides should either be an integer or a list/tuple of two integers"
)

if isinstance(dilations, int):
dilations = [dilations, dilations]
else:
assert _is_list_or_tuple_(dilations) and (
len(dilations) == 2
), "dilations should either be an integer or a list/tuple of two integers"
assert _is_list_or_tuple_(dilations) and (len(dilations) == 2), (
"dilations should either be an integer or a list/tuple of two integers"
)

if isinstance(paddings, int):
paddings = [paddings] * 4
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/nn/functional/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,9 @@ def _conv_nd(
attrs={'axis': -1},
)
else:
assert len(x_shape) > len(
y_shape
), 'The length of pre_bias must greater than the length of bias'
assert len(x_shape) > len(y_shape), (
'The length of pre_bias must greater than the length of bias'
)
padding = len(x_shape) - len(y_shape) - channel_dim
bias = reshape(
bias, [1] * channel_dim + y_shape + [1] * padding
Expand Down Expand Up @@ -1336,9 +1336,9 @@ def conv2d_transpose(
attrs={'axis': -1},
)
else:
assert len(x_shape) > len(
y_shape
), 'The length of pre_bias must greater than the length of bias'
assert len(x_shape) > len(y_shape), (
'The length of pre_bias must greater than the length of bias'
)
padding = len(x_shape) - len(y_shape) - channel_dim
bias = reshape(
bias, [1] * channel_dim + y_shape + [1] * padding
Expand Down
104 changes: 53 additions & 51 deletions python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,30 +508,30 @@ def flash_attention(
fa_version = paddle.base.framework.get_flags(
["FLAGS_flash_attn_version"]
)["FLAGS_flash_attn_version"]
assert (
in_dynamic_or_pir_mode() or fa_version == 2
), "flash attention 3 only support dynamic or pir mode"
assert (
dropout == 0.0 or fa_version == 2
), "flash attention 3 does not support dropout"
assert (
not return_softmax or fa_version == 2
), "flash attention 3 does not support return softmax"
assert (
fixed_seed_offset is None or fa_version == 2
), "flash attention 3 does not support return softmax"
assert (
rng_name == "" or fa_version == 2
), "flash attention 3 does not support setting rng_name"
assert (
training or fa_version == 2
), "flash attention 3 does not support setting training"
assert (
name is None or fa_version == 2
), "flash attention 3 does not support setting name"
assert (
softmax_scale is None or fa_version == 3
), "flash attention 2 does not support setting softmax_scale"
assert in_dynamic_or_pir_mode() or fa_version == 2, (
"flash attention 3 only support dynamic or pir mode"
)
assert dropout == 0.0 or fa_version == 2, (
"flash attention 3 does not support dropout"
)
assert not return_softmax or fa_version == 2, (
"flash attention 3 does not support return softmax"
)
assert fixed_seed_offset is None or fa_version == 2, (
"flash attention 3 does not support return softmax"
)
assert rng_name == "" or fa_version == 2, (
"flash attention 3 does not support setting rng_name"
)
assert training or fa_version == 2, (
"flash attention 3 does not support setting training"
)
assert name is None or fa_version == 2, (
"flash attention 3 does not support setting name"
)
assert softmax_scale is None or fa_version == 3, (
"flash attention 2 does not support setting softmax_scale"
)
if in_dynamic_or_pir_mode():
if fa_version == 2:
(result_attention, result_softmax, _, _) = _C_ops.flash_attn(
Expand Down Expand Up @@ -1142,9 +1142,9 @@ def flash_attn_varlen_func(
>>> output = paddle.nn.functional.flash_attention.flash_attention_v3_varlen(q, q, q, cu_seqlens_q, cu_seqlens_q, max_seqlen_q=max_seq_len_q, max_seqlen_k=max_seq_len_q, causal=True)
>>> # doctest: -SKIP
"""
assert (
"xpu" not in paddle.get_device()
), "flash_attn_varlen_func is not supported on xpu"
assert "xpu" not in paddle.get_device(), (
"flash_attn_varlen_func is not supported on xpu"
)

assert not paddle.get_flags(["FLAGS_cudnn_deterministic"])[
"FLAGS_cudnn_deterministic"
Expand All @@ -1157,9 +1157,9 @@ def flash_attn_varlen_func(
== 3
), "FLAGS_flash_attn_version is 2, conflicts with flash_attn_varlen_func"

assert (
in_dynamic_or_pir_mode()
), "flash_attn_varlen_func only support dynamic or pir mode"
assert in_dynamic_or_pir_mode(), (
"flash_attn_varlen_func only support dynamic or pir mode"
)

assert qv is None, "flash_attn_varlen_func does not support setting qv"

Expand Down Expand Up @@ -2203,9 +2203,9 @@ def flashmask_attention(
window_size = (window_size, window_size)
sq = query.shape[1]
bsz = query.shape[0]
assert (
startend_row_indices is None
), "can't use window_size with startend_row_indices"
assert startend_row_indices is None, (
"can't use window_size with startend_row_indices"
)
if causal:
startend_row_indices = paddle.arange(
window_size[0] + 1, sq + window_size[0] + 1, dtype="int32"
Expand Down Expand Up @@ -2246,24 +2246,26 @@ def flashmask_attention(
)

else:
assert (
startend_row_indices.dtype == paddle.int32
), f"startend_row_indices.dtype must be paddle.int32, but got {startend_row_indices.dtype}"
assert (
len(startend_row_indices.shape) == 4
), f"startend_row_indices rank must be 4,but got {startend_row_indices.shape}"

assert (
startend_row_indices.shape[0] == key.shape[0]
), f"startend_row_indices.shape[0] must be equal to batch_size, but got {startend_row_indices.shape[0]} and {key.shape[0]}"

assert (
startend_row_indices.shape[2] == key.shape[1]
), f"startend_row_indices.shape[2] must be equal to seqlen_k, but got {startend_row_indices.shape[2]} and {key.shape[2]}"
assert startend_row_indices.dtype == paddle.int32, (
f"startend_row_indices.dtype must be paddle.int32, but got {startend_row_indices.dtype}"
)
assert len(startend_row_indices.shape) == 4, (
f"startend_row_indices rank must be 4,but got {startend_row_indices.shape}"
)

assert startend_row_indices.shape[0] == key.shape[0], (
f"startend_row_indices.shape[0] must be equal to batch_size, but got {startend_row_indices.shape[0]} and {key.shape[0]}"
)

assert startend_row_indices.shape[2] == key.shape[1], (
f"startend_row_indices.shape[2] must be equal to seqlen_k, but got {startend_row_indices.shape[2]} and {key.shape[2]}"
)
assert startend_row_indices.shape[1] in [
1,
key.shape[2],
], "startend_row_indices head_num must be equal to 1(broadcast) or head_num_k."
], (
"startend_row_indices head_num must be equal to 1(broadcast) or head_num_k."
)

if causal:
if startend_row_indices.shape[-1] == 1:
Expand Down Expand Up @@ -2383,9 +2385,9 @@ def calc_reduced_attention_scores(
>>> )
>>> # doctest: -SKIP
"""
assert (
query.stop_gradient and key.stop_gradient
), 'calc_reduced_attention_scores() is for inference only.'
assert query.stop_gradient and key.stop_gradient, (
'calc_reduced_attention_scores() is for inference only.'
)

if in_dynamic_or_pir_mode():
reduced_scores = _C_ops.calc_reduced_attn_scores(
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def dice_loss(
"""
assert input.dtype in (paddle.float32, paddle.float64)
assert label.dtype in (paddle.int32, paddle.int64)
assert (
len(input.shape) >= 2
), "The rank of input should be greater than or equal to 2."
assert len(input.shape) >= 2, (
"The rank of input should be greater than or equal to 2."
)
assert len(input.shape) == len(label.shape), (
"The rank of input and label should be equal, "
f"but received input: {len(input.shape)}, label: {len(label.shape)}."
Expand All @@ -105,9 +105,9 @@ def dice_loss(
"The last dimension of label should be 1, "
f"but received {label.shape[-1]}."
)
assert (
input.shape[:-1] == label.shape[:-1]
), "All dimensions should be equal except the last one."
assert input.shape[:-1] == label.shape[:-1], (
"All dimensions should be equal except the last one."
)

label = paddle.squeeze(label, [-1])
label = paddle.nn.functional.one_hot(label, input.shape[-1])
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/nn/functional/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,9 +704,9 @@ def max_pool1d(


def _unpool_output_size(x, kernel_size, stride, padding, output_size):
assert output_size is None or isinstance(
output_size, (list, tuple)
), f"Required output_size is None|list|tuple, but received {output_size}"
assert output_size is None or isinstance(output_size, (list, tuple)), (
f"Required output_size is None|list|tuple, but received {output_size}"
)
input_size = x.shape
default_size = []
for d in range(len(kernel_size)):
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/nn/initializer/bilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def forward(
"""
assert not (
isinstance(var, framework.EagerParamBase) and var.is_dist()
), "Currently, Bilinear initializer not support lazy init for dist param."
), (
"Currently, Bilinear initializer not support lazy init for dist param."
)
block = self._check_block(block)

if not isinstance(var, (framework.Variable, pir.core.ParameterMeta)):
Expand Down
Loading