Skip to content

Commit 7cc1980

Browse files
committed
[CodeStyle] black -> ruff format migration - part 26
1 parent a5d987e commit 7cc1980

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+1240
-1210
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ repos:
7777
7878
# | python/paddle/de.+
7979
80-
# | python/paddle/distributed/a.+
80+
| python/paddle/distributed/a.+
8181
8282
# | python/paddle/distributed/[b-e].+
8383
@@ -133,7 +133,7 @@ repos:
133133
134134
| python/paddle/de.+
135135
136-
| python/paddle/distributed/a.+
136+
# | python/paddle/distributed/a.+
137137
138138
| python/paddle/distributed/[b-e].+
139139

python/paddle/_paddle_docs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def _parse_function_signature(
7474
if func_def.args.defaults and len(func_def.args.defaults) > (
7575
len(func_def.args.args) - len(func_def.args.defaults)
7676
):
77-
7877
idx = count - (
7978
len(func_def.args.args) - len(func_def.args.defaults)
8079
)

python/paddle/distributed/auto_parallel/api.py

Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -298,18 +298,18 @@ def shard_tensor(
298298
stop_gradient = getattr(data, "stop_gradient", True)
299299

300300
if paddle.framework.in_pir_mode():
301-
assert isinstance(
302-
data, (type(None), pir.Value)
303-
), "input tensor is not pir value."
304-
assert (
305-
data.is_dense_tensor_type()
306-
), "shard_tensor() input data only supported dense tensor type right."
301+
assert isinstance(data, (type(None), pir.Value)), (
302+
"input tensor is not pir value."
303+
)
304+
assert data.is_dense_tensor_type(), (
305+
"shard_tensor() input data only supported dense tensor type right."
306+
)
307307
tensor = data
308308
else:
309309
if isinstance(data, EagerParamBase) and not data._is_initialized():
310-
assert (
311-
data._init_func is not None
312-
), "Get an uninitialized param with an unregistered init_func."
310+
assert data._init_func is not None, (
311+
"Get an uninitialized param with an unregistered init_func."
312+
)
313313
tensor = data
314314
elif isinstance(data, paddle.Tensor) and dtype is None:
315315
# if place is not equal, it is handled in paddle.Tensor()
@@ -620,7 +620,9 @@ def forward(
620620
)
621621
assert check_placements_equal(
622622
global_placements, dist_tensor.placements
623-
), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})."
623+
), (
624+
f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})."
625+
)
624626
local_shape = _cal_local_shape(
625627
dist_tensor.shape, global_mesh, global_placements
626628
)
@@ -890,9 +892,9 @@ def reshard(
890892
elif in_pir_mode():
891893
return paddle._C_ops.reshard(dist_tensor, mesh, placements)
892894
else:
893-
assert isinstance(
894-
dist_tensor, Variable
895-
), f"in dy2static mode, reshard's input should be Variable, but got [{dist_tensor}]"
895+
assert isinstance(dist_tensor, Variable), (
896+
f"in dy2static mode, reshard's input should be Variable, but got [{dist_tensor}]"
897+
)
896898
sharding_specs = get_shard_spec(mesh, placements, dist_tensor.ndim)
897899
main_program = default_main_program()
898900
default_dist_ctx = get_default_distributed_context()
@@ -1113,12 +1115,14 @@ def is_dist_tensor(tensor) -> bool:
11131115

11141116
class _ShardOptimizer(Optimizer):
11151117
def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
1116-
assert (
1117-
optimizer is not None
1118-
), "The argument `optimizer` cannot be empty."
1118+
assert optimizer is not None, (
1119+
"The argument `optimizer` cannot be empty."
1120+
)
11191121
assert isinstance(
11201122
optimizer, (paddle.optimizer.AdamW, paddle.optimizer.SGD)
1121-
), "`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now."
1123+
), (
1124+
"`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now."
1125+
)
11221126

11231127
# self.target_block = (
11241128
# paddle.base.framework.default_main_program().global_block()
@@ -1146,7 +1150,9 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
11461150
assert isinstance(
11471151
self._shard_fn,
11481152
(_ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3),
1149-
), "shard_fn must be an instance of one of: _ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3"
1153+
), (
1154+
"shard_fn must be an instance of one of: _ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3"
1155+
)
11501156

11511157
if isinstance(
11521158
self._shard_fn, (ShardingStage1, ShardingStage2, ShardingStage3)
@@ -1219,7 +1225,9 @@ def _set_and_check_sharding_prop_from_param(self):
12191225
else:
12201226
assert (
12211227
mesh.dim_size(self._sharding_axis) == self._sharding_degree
1222-
), "The sharding degree of all parameters must be equal currently."
1228+
), (
1229+
"The sharding degree of all parameters must be equal currently."
1230+
)
12231231

12241232
def _shard_accumulator(self, param):
12251233
# Note (luchang): Some models may have parameters whose first dimension is 1,
@@ -1988,9 +1996,9 @@ def shard_master_weight(
19881996
)
19891997
if isinstance(master_weight, pir.Value):
19901998
data_op = master_weight.get_defining_op()
1991-
assert (
1992-
data_op.name() == "pd_op.data"
1993-
), "The master weight must be a result of data op."
1999+
assert data_op.name() == "pd_op.data", (
2000+
"The master weight must be a result of data op."
2001+
)
19942002
dim_map, partial_status = to_dim_map(
19952003
placements, len(master_weight.shape)
19962004
)
@@ -3254,9 +3262,9 @@ def state_dict(
32543262
suffix = _get_suffix(param, fused_param)
32553263
if suffix is not None:
32563264
value = dist_state_dict[param]
3257-
assert (
3258-
value.is_dist()
3259-
), f"key {param} value:{value} is not a dist tensor."
3265+
assert value.is_dist(), (
3266+
f"key {param} value:{value} is not a dist tensor."
3267+
)
32603268
mesh = value.process_mesh
32613269
placements = value.placements
32623270
if "_pow_acc" in suffix:
@@ -3328,12 +3336,12 @@ def build_distributed_tensor(local_tensor, dist_attr):
33283336
)
33293337
if not isinstance(local_tensor, paddle.Tensor):
33303338
local_tensor = paddle.Tensor(local_tensor)
3331-
assert isinstance(
3332-
local_tensor, paddle.Tensor
3333-
), f"local tensor:{local_tensor} type {type(local_tensor)} is not paddle.Tensor."
3334-
assert len(local_tensor.shape) == len(
3335-
dist_attr["dims_mapping"]
3336-
), f"local tensor shape {local_tensor.shape} not equal to dims_mapping shape {dist_attr['dims_mapping']}."
3339+
assert isinstance(local_tensor, paddle.Tensor), (
3340+
f"local tensor:{local_tensor} type {type(local_tensor)} is not paddle.Tensor."
3341+
)
3342+
assert len(local_tensor.shape) == len(dist_attr["dims_mapping"]), (
3343+
f"local tensor shape {local_tensor.shape} not equal to dims_mapping shape {dist_attr['dims_mapping']}."
3344+
)
33373345
global_shape = local_tensor.shape
33383346
mesh = ProcessMesh(
33393347
np.array(dist_attr["process_group"]).reshape(
@@ -3343,18 +3351,18 @@ def build_distributed_tensor(local_tensor, dist_attr):
33433351
)
33443352
placements = to_placements(dist_attr["dims_mapping"], mesh)
33453353
dist_tensor = dtensor_from_local(local_tensor, mesh, placements)
3346-
assert (
3347-
dist_tensor._local_value().shape == local_tensor.shape
3348-
), f"local tensor shape {dist_tensor._local_value().shape} not equal to local_tensor.shape:{local_tensor.shape}"
3354+
assert dist_tensor._local_value().shape == local_tensor.shape, (
3355+
f"local tensor shape {dist_tensor._local_value().shape} not equal to local_tensor.shape:{local_tensor.shape}"
3356+
)
33493357
paddle.assign(local_tensor, dist_tensor._local_value())
33503358
return dist_tensor
33513359

33523360
global_state_dict = {}
33533361
with paddle.base.dygraph.guard():
33543362
for var_name, tensor in local_state_dict.items():
3355-
assert (
3356-
var_name in dist_attrs
3357-
), f"var {var_name} not in dist attrs:{dist_attrs}."
3363+
assert var_name in dist_attrs, (
3364+
f"var {var_name} not in dist attrs:{dist_attrs}."
3365+
)
33583366
global_state_dict[var_name] = build_distributed_tensor(
33593367
tensor, dist_attrs[var_name]
33603368
)
@@ -3386,7 +3394,9 @@ def set_state_dict(self, state_dict: dict[str, Tensor]) -> None:
33863394
k
33873395
].process_mesh or check_placements_equal(
33883396
v.placements, cur_v.placements
3389-
), f"process_mesh:{v.process_mesh} != {cur_v.process_mesh} or placements:{v.placements} != {cur_v.placements} not match"
3397+
), (
3398+
f"process_mesh:{v.process_mesh} != {cur_v.process_mesh} or placements:{v.placements} != {cur_v.placements} not match"
3399+
)
33903400
param_name = (
33913401
self._structured_to_parameter_name[k]
33923402
if k in self._structured_to_parameter_name
@@ -3472,9 +3482,9 @@ def _get_shard_stage1_optimizer(self):
34723482
):
34733483
optimizer = optimizer._optimizer
34743484

3475-
assert isinstance(
3476-
optimizer, ShardingOptimizerStage1
3477-
), "The optimizer should be ShardingOptimizerStage1 when stage1 tensor fusion is enabled."
3485+
assert isinstance(optimizer, ShardingOptimizerStage1), (
3486+
"The optimizer should be ShardingOptimizerStage1 when stage1 tensor fusion is enabled."
3487+
)
34783488

34793489
return optimizer
34803490

@@ -3485,9 +3495,9 @@ def _convert_state_dict_tensor_fusion(self, state_dict, optimizer_function):
34853495
else False
34863496
)
34873497

3488-
assert (
3489-
enable_tensor_fusion
3490-
), "Can only convert state_dict when tensor fusion is enabled."
3498+
assert enable_tensor_fusion, (
3499+
"Can only convert state_dict when tensor fusion is enabled."
3500+
)
34913501
optimizer = self._get_shard_stage1_optimizer()
34923502
assert optimizer is not None, "The optimizer should not be None."
34933503

@@ -3690,9 +3700,9 @@ def to_static(
36903700
# Deduce sharding degree for static
36913701
# Note: Because limitation of architecture, we need to ensure that
36923702
# all parameters are sharded by the same mesh axis
3693-
assert (
3694-
sharding_degree is not None
3695-
), "Sharding degree can not be None."
3703+
assert sharding_degree is not None, (
3704+
"Sharding degree can not be None."
3705+
)
36963706

36973707
if isinstance(shard_fn, ShardingStage1):
36983708
strategy.sharding.enable = True

python/paddle/distributed/auto_parallel/auto_dp_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
def _fake_replicate_grad_to_partial(grad, partial_axis):
2323
new_placements = grad.placements
24-
assert (
25-
new_placements[partial_axis] == dist.Replicate()
26-
), "when reshard fake replicated grad to partial, the partial axis of grad should be Replicate"
24+
assert new_placements[partial_axis] == dist.Replicate(), (
25+
"when reshard fake replicated grad to partial, the partial axis of grad should be Replicate"
26+
)
2727

2828
new_placements[partial_axis] = dist.Partial(dist.ReduceType.kRedSum)
2929

python/paddle/distributed/auto_parallel/high_level_api.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def __init__(self):
3434

3535
def cost_model(matched_programs, device_num, node_num):
3636
# TODO(jeff41404): multi-node will be supported later
37-
assert (
38-
node_num == 1
39-
), "we only support single node now, multi-node will be supported later"
37+
assert node_num == 1, (
38+
"we only support single node now, multi-node will be supported later"
39+
)
4040

4141
# TODO(jeff41404): will evaluate the best combination of parallel strategies
4242
# based on cost_model and return global_mesh, currently using pre-defined parallel strategy
@@ -224,7 +224,9 @@ def record_program_ops_post_hook(layer, inputs, outputs):
224224
assert (
225225
layer._op_recorder.start >= 0
226226
and layer._op_recorder.is_valid is True
227-
), f"{layer._full_name} has not recorded the start of the corresponding ops before"
227+
), (
228+
f"{layer._full_name} has not recorded the start of the corresponding ops before"
229+
)
228230
end = len(default_main_program().global_block().ops)
229231
# some layers, such as rotary_embedding, will not add new ops to program
230232
# assert end > layer._op_recorder.start, f"{layer._full_name} has not added new ops to the program"
@@ -754,19 +756,19 @@ def to_distributed(
754756
for pattern_name, matched_patterns in results.items():
755757
# process one pattern
756758
pattern_ops_dist_infos = get_pattern(pattern_name).ops_dist_infos
757-
assert (
758-
pattern_ops_dist_infos is not None
759-
), f"{pattern_name} does not contain ops_dist_infos, cannot reshard, please check"
759+
assert pattern_ops_dist_infos is not None, (
760+
f"{pattern_name} does not contain ops_dist_infos, cannot reshard, please check"
761+
)
760762
processed_patterns = []
761763
for matched_pattern in matched_patterns:
762764
# convert pattern_ops_dist_infos to program_ops_dist_infos
763765
program_ops_dist_infos = {}
764766
for pattern_ops_id, op_dist_info in pattern_ops_dist_infos.items():
765767
program_ops_id = []
766768
for pattern_op_id in pattern_ops_id:
767-
assert (
768-
pattern_op_id in matched_pattern.keys()
769-
), f"please check ops_dist_infos of {pattern_name}, {pattern_op_id} not in matched_pattern: {matched_pattern.keys()}"
769+
assert pattern_op_id in matched_pattern.keys(), (
770+
f"please check ops_dist_infos of {pattern_name}, {pattern_op_id} not in matched_pattern: {matched_pattern.keys()}"
771+
)
770772
program_op_id = matched_pattern[pattern_op_id]
771773
program_ops_id.append(program_op_id)
772774
program_ops_dist_infos[tuple(program_ops_id)] = op_dist_info
@@ -789,9 +791,9 @@ def to_distributed(
789791
if with_mp:
790792
num_hidden_layers = len(matched_programs[DECODER_LAYER_NAME])
791793
for pattern_name, processed_patterns in matched_programs.items():
792-
assert (
793-
len(processed_patterns) == num_hidden_layers
794-
), "transformer patterns matched are incomplete"
794+
assert len(processed_patterns) == num_hidden_layers, (
795+
"transformer patterns matched are incomplete"
796+
)
795797
for idx, processed_pattern in enumerate(processed_patterns):
796798
local_mesh = mesh
797799
if with_pp:
@@ -801,9 +803,9 @@ def to_distributed(
801803
local_mesh = mesh.get_mesh_with_dim("pp", pp_stage_id)
802804

803805
for program_ops_id, dist_infos in processed_pattern.items():
804-
assert (
805-
program_ops_id in ops_id_to_layer.keys()
806-
), f"program_ops: {program_ops_id} is not corresponding to a dynamic layer"
806+
assert program_ops_id in ops_id_to_layer.keys(), (
807+
f"program_ops: {program_ops_id} is not corresponding to a dynamic layer"
808+
)
807809
dynamic_layer = ops_id_to_layer[program_ops_id]
808810
mesh_num_dims = len(local_mesh.shape)
809811
sharding_info = dist_infos.get_dist_info(mesh_num_dims)
@@ -832,9 +834,9 @@ def to_distributed(
832834

833835
if decoder_layers is not None:
834836
num_decoder_blocks = len(decoder_layers)
835-
assert (
836-
num_decoder_blocks == num_hidden_layers
837-
), f"decoder pattern layers matched are incomplete, num_decoder_blocks: {num_decoder_blocks} should be equal to num_hidden_layers: {num_hidden_layers}"
837+
assert num_decoder_blocks == num_hidden_layers, (
838+
f"decoder pattern layers matched are incomplete, num_decoder_blocks: {num_decoder_blocks} should be equal to num_hidden_layers: {num_hidden_layers}"
839+
)
838840

839841
pp_degree = mesh.get_dim_size("pp")
840842
num_blocks_per_stage = num_decoder_blocks // pp_degree

0 commit comments

Comments
 (0)