@@ -94,14 +94,14 @@ def append_backward_new(
9494 from paddle .incubate .autograd .primx import Transform , orig2prim
9595
9696 program = default_main_program ()
97- assert (
98- program . num_blocks == 1
99- ), "The append_backward_new interface is designed to process only one block."
97+ assert program . num_blocks == 1 , (
98+ "The append_backward_new interface is designed to process only one block."
99+ )
100100 block = program .current_block ()
101101 for el in loss_list :
102- assert (
103- el . block == block
104- ), 'variable in loss_list should be in current block of main program'
102+ assert el . block == block , (
103+ 'variable in loss_list should be in current block of main program'
104+ )
105105
106106 orig2prim (block )
107107 ad = Transform (block )
@@ -280,9 +280,9 @@ def __init__(
280280 if self ._parameter_list :
281281 if isinstance (self ._parameter_list [0 ], dict ):
282282 for param_group in self ._parameter_list :
283- assert (
284- 'params' in param_group
285- ), 'params should be set in parameters if parameter groups are optimized in different options'
283+ assert 'params' in param_group , (
284+ 'params should be set in parameters if parameter groups are optimized in different options'
285+ )
286286 self ._dtype = self ._parameter_list [0 ]['params' ][0 ].dtype
287287 else :
288288 self ._dtype = self ._parameter_list [0 ].dtype
@@ -477,9 +477,9 @@ def set_state_dict(self, state_dict: dict[str, Tensor]) -> None:
477477 if isinstance (self ._learning_rate , LRScheduler ):
478478 lr_state_dict = state_dict .get ("LR_Scheduler" , None )
479479 if not isinstance (self ._learning_rate , LambdaDecay ):
480- assert (
481- lr_state_dict is not None
482- ), "LR_Scheduler state must be included in the state dict except LambdaDecay"
480+ assert lr_state_dict is not None , (
481+ "LR_Scheduler state must be included in the state dict except LambdaDecay"
482+ )
483483 if lr_state_dict :
484484 self ._learning_rate .set_state_dict (lr_state_dict )
485485
@@ -495,9 +495,9 @@ def set_state_dict(self, state_dict: dict[str, Tensor]) -> None:
495495 self ._accumulators_holder = state_dict
496496 for k , v in self ._accumulators .items ():
497497 for para_name , var_tmp in v .items ():
498- assert (
499- var_tmp .name in state_dict
500- ), f"optimizer Tensor { var_tmp . name } not found"
498+ assert var_tmp . name in state_dict , (
499+ f"optimizer Tensor { var_tmp .name } not found"
500+ )
501501
502502 var = var_tmp .value ()
503503 tensor = var .get_tensor ()
@@ -1112,9 +1112,9 @@ def _add_accumulator(
11121112
11131113 if framework .in_dygraph_mode ():
11141114 if len (self ._accumulators_holder ) > 0 :
1115- assert (
1116- var_name in self . _accumulators_holder
1117- ), f"Optimizer set error, { var_name } should in state dict"
1115+ assert var_name in self . _accumulators_holder , (
1116+ f"Optimizer set error, { var_name } should in state dict"
1117+ )
11181118 var .set_value (self ._accumulators_holder .pop (var_name ))
11191119
11201120 # load scale value for xpu
@@ -1231,9 +1231,9 @@ def _create_optimization_pass(
12311231 target_block = global_block
12321232 current_block = framework .default_main_program ().current_block ()
12331233 if current_block .idx != global_block .idx :
1234- assert (
1235- current_block . backward_block_idx != - 1
1236- ), "current block is not global_block, but it doesn't have backward block."
1234+ assert current_block . backward_block_idx != - 1 , (
1235+ "current block is not global_block, but it doesn't have backward block."
1236+ )
12371237 target_block = framework .default_main_program ().blocks [
12381238 current_block .backward_block_idx
12391239 ]
@@ -1669,9 +1669,7 @@ def _apply_optimize(
16691669 paddle .static .default_main_program (),
16701670 paddle .static .default_startup_program (),
16711671 ):
1672- auto_dp = (
1673- paddle .distributed .auto_parallel .auto_dp_utils .in_auto_dp_mode ()
1674- )
1672+ auto_dp = paddle .distributed .auto_parallel .auto_dp_utils .in_auto_dp_mode ()
16751673 if auto_dp :
16761674 paddle .distributed .auto_parallel .auto_dp_utils ._convert_fake_replicate_grad_to_partial (
16771675 params_grads
@@ -1943,9 +1941,9 @@ def minimize(
19431941 >>> adam.clear_grad()
19441942
19451943 """
1946- assert isinstance (
1947- loss , ( Variable , paddle . pir . Value )
1948- ), "The loss should be an Tensor."
1944+ assert isinstance (loss , ( Variable , paddle . pir . Value )), (
1945+ "The loss should be an Tensor."
1946+ )
19491947
19501948 parameter_list = parameters if parameters else self ._parameter_list
19511949
@@ -1969,9 +1967,9 @@ def _declarative_step(self):
19691967 params = (
19701968 paddle .static .default_main_program ().global_block ().all_parameters ()
19711969 )
1972- assert not isinstance (
1973- self . _parameter_list [ 0 ], dict
1974- ), "Only list of parameters is supported while using optimizer in @paddle.jit.static."
1970+ assert not isinstance (self . _parameter_list [ 0 ], dict ), (
1971+ "Only list of parameters is supported while using optimizer in @paddle.jit.static."
1972+ )
19751973 selected_params = {param .name for param in self ._parameter_list }
19761974 parameters = [param for param in params if param .trainable ]
19771975 parameters = list (
@@ -2141,9 +2139,9 @@ def _is_dtype_fp16_or_bf16(self, dtype):
21412139 :param dtype: instance of core.VarDesc.VarType
21422140 :return: True if dtype is one of fp16 or bf16, False otherwise
21432141 """
2144- assert isinstance (
2145- dtype , ( core .VarDesc .VarType , core .DataType )
2146- ), "The dtype should be an instance of core.VarDesc.VarType or core.DataType."
2142+ assert isinstance (dtype , ( core . VarDesc . VarType , core . DataType )), (
2143+ "The dtype should be an instance of core.VarDesc.VarType or core.DataType."
2144+ )
21472145 if isinstance (dtype , core .VarDesc .VarType ):
21482146 return (
21492147 dtype == core .VarDesc .VarType .FP16
0 commit comments