@@ -298,18 +298,18 @@ def shard_tensor(
298298 stop_gradient = getattr (data , "stop_gradient" , True )
299299
300300 if paddle .framework .in_pir_mode ():
301- assert isinstance (
302- data , ( type ( None ), pir . Value )
303- ), "input tensor is not pir value."
304- assert (
305- data . is_dense_tensor_type ()
306- ), "shard_tensor() input data only supported dense tensor type right."
301+ assert isinstance (data , ( type ( None ), pir . Value )), (
302+ "input tensor is not pir value."
303+ )
304+ assert data . is_dense_tensor_type (), (
305+ "shard_tensor() input data only supported dense tensor type right."
306+ )
307307 tensor = data
308308 else :
309309 if isinstance (data , EagerParamBase ) and not data ._is_initialized ():
310- assert (
311- data . _init_func is not None
312- ), "Get an uninitialized param with an unregistered init_func."
310+ assert data . _init_func is not None , (
311+ "Get an uninitialized param with an unregistered init_func."
312+ )
313313 tensor = data
314314 elif isinstance (data , paddle .Tensor ) and dtype is None :
315315 # if place is not equal, it is handled in paddle.Tensor()
@@ -620,7 +620,9 @@ def forward(
620620 )
621621 assert check_placements_equal (
622622 global_placements , dist_tensor .placements
623- ), f"the global_placements ({ global_placements } ) is not equal to dist_tensor's placements ({ dist_tensor .placements } )."
623+ ), (
624+ f"the global_placements ({ global_placements } ) is not equal to dist_tensor's placements ({ dist_tensor .placements } )."
625+ )
624626 local_shape = _cal_local_shape (
625627 dist_tensor .shape , global_mesh , global_placements
626628 )
@@ -890,9 +892,9 @@ def reshard(
890892 elif in_pir_mode ():
891893 return paddle ._C_ops .reshard (dist_tensor , mesh , placements )
892894 else :
893- assert isinstance (
894- dist_tensor , Variable
895- ), f"in dy2static mode, reshard's input should be Variable, but got [ { dist_tensor } ]"
895+ assert isinstance (dist_tensor , Variable ), (
896+ f"in dy2static mode, reshard's input should be Variable, but got [ { dist_tensor } ]"
897+ )
896898 sharding_specs = get_shard_spec (mesh , placements , dist_tensor .ndim )
897899 main_program = default_main_program ()
898900 default_dist_ctx = get_default_distributed_context ()
@@ -1113,12 +1115,14 @@ def is_dist_tensor(tensor) -> bool:
11131115
11141116class _ShardOptimizer (Optimizer ):
11151117 def __init__ (self , optimizer , shard_fn = None , gradient_accumulation_steps = 1 ):
1116- assert (
1117- optimizer is not None
1118- ), "The argument `optimizer` cannot be empty."
1118+ assert optimizer is not None , (
1119+ "The argument ` optimizer` cannot be empty."
1120+ )
11191121 assert isinstance (
11201122 optimizer , (paddle .optimizer .AdamW , paddle .optimizer .SGD )
1121- ), "`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now."
1123+ ), (
1124+ "`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now."
1125+ )
11221126
11231127 # self.target_block = (
11241128 # paddle.base.framework.default_main_program().global_block()
@@ -1146,7 +1150,9 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
11461150 assert isinstance (
11471151 self ._shard_fn ,
11481152 (_ShardingStage0 , ShardingStage1 , ShardingStage2 , ShardingStage3 ),
1149- ), "shard_fn must be an instance of one of: _ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3"
1153+ ), (
1154+ "shard_fn must be an instance of one of: _ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3"
1155+ )
11501156
11511157 if isinstance (
11521158 self ._shard_fn , (ShardingStage1 , ShardingStage2 , ShardingStage3 )
@@ -1219,7 +1225,9 @@ def _set_and_check_sharding_prop_from_param(self):
12191225 else :
12201226 assert (
12211227 mesh .dim_size (self ._sharding_axis ) == self ._sharding_degree
1222- ), "The sharding degree of all parameters must be equal currently."
1228+ ), (
1229+ "The sharding degree of all parameters must be equal currently."
1230+ )
12231231
12241232 def _shard_accumulator (self , param ):
12251233 # Note (luchang): Some models may have parameters whose first dimension is 1,
@@ -1988,9 +1996,9 @@ def shard_master_weight(
19881996 )
19891997 if isinstance (master_weight , pir .Value ):
19901998 data_op = master_weight .get_defining_op ()
1991- assert (
1992- data_op . name () == "pd_op. data"
1993- ), "The master weight must be a result of data op."
1999+ assert data_op . name () == "pd_op.data" , (
2000+ "The master weight must be a result of data op. "
2001+ )
19942002 dim_map , partial_status = to_dim_map (
19952003 placements , len (master_weight .shape )
19962004 )
@@ -3254,9 +3262,9 @@ def state_dict(
32543262 suffix = _get_suffix (param , fused_param )
32553263 if suffix is not None :
32563264 value = dist_state_dict [param ]
3257- assert (
3258- value . is_dist ()
3259- ), f"key { param } value: { value } is not a dist tensor."
3265+ assert value . is_dist (), (
3266+ f"key { param } value: { value } is not a dist tensor."
3267+ )
32603268 mesh = value .process_mesh
32613269 placements = value .placements
32623270 if "_pow_acc" in suffix :
@@ -3328,12 +3336,12 @@ def build_distributed_tensor(local_tensor, dist_attr):
33283336 )
33293337 if not isinstance (local_tensor , paddle .Tensor ):
33303338 local_tensor = paddle .Tensor (local_tensor )
3331- assert isinstance (
3332- local_tensor , paddle .Tensor
3333- ), f"local tensor: { local_tensor } type { type ( local_tensor ) } is not paddle.Tensor."
3334- assert len (local_tensor .shape ) == len (
3335- dist_attr [" dims_mapping" ]
3336- ), f"local tensor shape { local_tensor . shape } not equal to dims_mapping shape { dist_attr [ 'dims_mapping' ] } ."
3339+ assert isinstance (local_tensor , paddle . Tensor ), (
3340+ f"local tensor: { local_tensor } type { type ( local_tensor ) } is not paddle.Tensor."
3341+ )
3342+ assert len (local_tensor .shape ) == len (dist_attr [ "dims_mapping" ]), (
3343+ f"local tensor shape { local_tensor . shape } not equal to dims_mapping shape { dist_attr [' dims_mapping' ] } ."
3344+ )
33373345 global_shape = local_tensor .shape
33383346 mesh = ProcessMesh (
33393347 np .array (dist_attr ["process_group" ]).reshape (
@@ -3343,18 +3351,18 @@ def build_distributed_tensor(local_tensor, dist_attr):
33433351 )
33443352 placements = to_placements (dist_attr ["dims_mapping" ], mesh )
33453353 dist_tensor = dtensor_from_local (local_tensor , mesh , placements )
3346- assert (
3347- dist_tensor ._local_value ().shape == local_tensor .shape
3348- ), f"local tensor shape { dist_tensor . _local_value (). shape } not equal to local_tensor.shape: { local_tensor . shape } "
3354+ assert dist_tensor . _local_value (). shape == local_tensor . shape , (
3355+ f"local tensor shape { dist_tensor ._local_value ().shape } not equal to local_tensor.shape: { local_tensor . shape } "
3356+ )
33493357 paddle .assign (local_tensor , dist_tensor ._local_value ())
33503358 return dist_tensor
33513359
33523360 global_state_dict = {}
33533361 with paddle .base .dygraph .guard ():
33543362 for var_name , tensor in local_state_dict .items ():
3355- assert (
3356- var_name in dist_attrs
3357- ), f"var { var_name } not in dist attrs: { dist_attrs } ."
3363+ assert var_name in dist_attrs , (
3364+ f"var { var_name } not in dist attrs: { dist_attrs } ."
3365+ )
33583366 global_state_dict [var_name ] = build_distributed_tensor (
33593367 tensor , dist_attrs [var_name ]
33603368 )
@@ -3386,7 +3394,9 @@ def set_state_dict(self, state_dict: dict[str, Tensor]) -> None:
33863394 k
33873395 ].process_mesh or check_placements_equal (
33883396 v .placements , cur_v .placements
3389- ), f"process_mesh:{ v .process_mesh } != { cur_v .process_mesh } or placements:{ v .placements } != { cur_v .placements } not match"
3397+ ), (
3398+ f"process_mesh:{ v .process_mesh } != { cur_v .process_mesh } or placements:{ v .placements } != { cur_v .placements } not match"
3399+ )
33903400 param_name = (
33913401 self ._structured_to_parameter_name [k ]
33923402 if k in self ._structured_to_parameter_name
@@ -3472,9 +3482,9 @@ def _get_shard_stage1_optimizer(self):
34723482 ):
34733483 optimizer = optimizer ._optimizer
34743484
3475- assert isinstance (
3476- optimizer , ShardingOptimizerStage1
3477- ), "The optimizer should be ShardingOptimizerStage1 when stage1 tensor fusion is enabled."
3485+ assert isinstance (optimizer , ShardingOptimizerStage1 ), (
3486+ "The optimizer should be ShardingOptimizerStage1 when stage1 tensor fusion is enabled."
3487+ )
34783488
34793489 return optimizer
34803490
@@ -3485,9 +3495,9 @@ def _convert_state_dict_tensor_fusion(self, state_dict, optimizer_function):
34853495 else False
34863496 )
34873497
3488- assert (
3489- enable_tensor_fusion
3490- ), "Can only convert state_dict when tensor fusion is enabled."
3498+ assert enable_tensor_fusion , (
3499+ "Can only convert state_dict when tensor fusion is enabled."
3500+ )
34913501 optimizer = self ._get_shard_stage1_optimizer ()
34923502 assert optimizer is not None , "The optimizer should not be None."
34933503
@@ -3690,9 +3700,9 @@ def to_static(
36903700 # Deduce sharding degree for static
36913701 # Note: Because limitation of architecture, we need to ensure that
36923702 # all parameters are sharded by the same mesh axis
3693- assert (
3694- sharding_degree is not None
3695- ), "Sharding degree can not be None."
3703+ assert sharding_degree is not None , (
3704+ "Sharding degree can not be None."
3705+ )
36963706
36973707 if isinstance (shard_fn , ShardingStage1 ):
36983708 strategy .sharding .enable = True
0 commit comments