diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f1db341c82de9..e63b239d430782 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -87,7 +87,7 @@ repos: # | python/paddle/[e-i].+ - # | python/paddle/j.+ + | python/paddle/j.+ # | python/paddle/[k-n].+ @@ -143,7 +143,7 @@ repos: | python/paddle/[e-i].+ - | python/paddle/j.+ + # | python/paddle/j.+ | python/paddle/[k-n].+ diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 14c9998ae0d5dc..ed2fac98614836 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -756,13 +756,17 @@ def convert_var_dtype(var, dtype): 'int32', 'int64', 'uint8', - ], f"The dtype of var {var.name} is {src_dtype}, which is not supported in the cast op." + ], ( + f"The dtype of var {var.name} is {src_dtype}, which is not supported in the cast op." + ) assert dtype in [ 'bool', 'int', 'float', 'complex', - ], f"The casted target dtype is {dtype}, which is not supported in type casting." + ], ( + f"The casted target dtype is {dtype}, which is not supported in type casting." + ) cast_map = { 'bool': 'bool', 'int': 'int32', @@ -776,7 +780,9 @@ def convert_var_dtype(var, dtype): 'int', 'float', 'complex', - ], f"The casted target dtype is {dtype}, which is not supported in type casting." + ], ( + f"The casted target dtype is {dtype}, which is not supported in type casting." + ) return eval(dtype)(var) diff --git a/python/paddle/jit/dy2static/origin_info.py b/python/paddle/jit/dy2static/origin_info.py index ab125265c26460..58c6a5c6c3375e 100644 --- a/python/paddle/jit/dy2static/origin_info.py +++ b/python/paddle/jit/dy2static/origin_info.py @@ -155,9 +155,9 @@ def create_and_update_origin_info_map( static_node = attach_origin_info(static_node, static_func) for t_node, s_node in ast_walk(transformed_node, static_node): - assert type(t_node) == type( - s_node - ), f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}." + assert type(t_node) == type(s_node), ( + f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}." + ) dygraph_info = getattr(t_node, ORIGIN_INFO, None) static_info = getattr(s_node, ORIGIN_INFO, None) @@ -232,9 +232,9 @@ def _as_list(x): ): continue - assert type(t_node) == type( - s_node - ), f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}." + assert type(t_node) == type(s_node), ( + f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}." + ) yield t_node, s_node diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index 3e4b6f0dcfb1d1..0beb55f568e8b8 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -218,15 +218,15 @@ def __init__( forward_range=None, backward_range=None, ): - assert isinstance( - in_out_values, tuple - ), "in_out_values must be tuple with len == 3" - assert ( - len(in_out_values) == 3 - ), "in_out_values must be tuple with len == 3" - assert isinstance( - in_out_values[0], list - ), "in_out_values must be tuple with len == 3" + assert isinstance(in_out_values, tuple), ( + "in_out_values must be tuple with len == 3" + ) + assert len(in_out_values) == 3, ( + "in_out_values must be tuple with len == 3" + ) + assert isinstance(in_out_values[0], list), ( + "in_out_values must be tuple with len == 3" + ) self.program = program self.x_names = self.convert_name(in_out_values[0]) self.param_names = self.convert_name(in_out_values[1]) @@ -310,9 +310,9 @@ def clone(self): ) def split_forward_backward(self): - assert ( - self.has_splited is False - ), "Please ensure only split once! don't call split_forward_backward manually." + assert self.has_splited is False, ( + "Please ensure only split once! don't call split_forward_backward manually." + ) self.has_splited = True self.update_op_range() ( @@ -406,9 +406,9 @@ def _forward_backward_program(self): @cached_property # shouldn't changed when call this once. def program_attr(self): - assert ( - self.finish_pass is False - ), "program_attr() is called by PartialProgramLayer, don't call it manually, use program_name_attr instead." + assert self.finish_pass is False, ( + "program_attr() is called by PartialProgramLayer, don't call it manually, use program_name_attr instead." + ) # can't apply pass after call this function. self.finish_pass = True fwd_map = RunnableProgram._get_name_value_map_from_program( @@ -445,9 +445,9 @@ def program_attr(self): program_attr[f"{k}_names"] = ns # Restore stop_gradient for output values - assert len(program_attr["fo_values"]) == len( - self.out_stop_gradients - ), "Output values and stop gradients length mismatch" + assert len(program_attr["fo_values"]) == len(self.out_stop_gradients), ( + "Output values and stop gradients length mismatch" + ) for v, stop_gradient in zip( program_attr["fo_values"], self.out_stop_gradients ): @@ -474,9 +474,9 @@ def unify_value_names( # Get all values again because some values has been erased. for value in RunnableProgram._get_program_all_values(program): if value.has_name: - assert ( - value._has_only_one_name() - ), f"Expected all values in Program have only one name, but {value} has multiple names: {value._names}" + assert value._has_only_one_name(), ( + f"Expected all values in Program have only one name, but {value} has multiple names: {value._names}" + ) return rename_mapping @staticmethod diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index a4d7b16abd682f..1cc24931c44cea 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -672,9 +672,9 @@ def rollback(self) -> Callable[_InputT, _RetT]: if self._patched_name is not None else self._dygraph_function.__name__ ) - assert ( - fn_name in self.class_instance._original_funcs - ), f"Not Found function '{fn_name}' in class '{self.class_instance.__class__}'." + assert fn_name in self.class_instance._original_funcs, ( + f"Not Found function '{fn_name}' in class '{self.class_instance.__class__}'." + ) func = self.class_instance._original_funcs[fn_name] setattr(self.class_instance, fn_name, func.__get__(self.class_instance)) return getattr(self.class_instance, fn_name) @@ -1733,9 +1733,9 @@ def get_program(self, item): return self._caches[item_id] def last(self): - assert ( - len(self._caches) >= 1 - ), "No valid cached program in ProgramCache." + assert len(self._caches) >= 1, ( + "No valid cached program in ProgramCache." + ) assert self._recent_key is not None return self._recent_key, self._caches[self._recent_key] diff --git a/python/paddle/jit/dy2static/transformers/base.py b/python/paddle/jit/dy2static/transformers/base.py index f4fe487aa8a88a..6e640972a07645 100644 --- a/python/paddle/jit/dy2static/transformers/base.py +++ b/python/paddle/jit/dy2static/transformers/base.py @@ -184,9 +184,9 @@ class ForNodeVisitor: """ def __init__(self, for_node): - assert isinstance( - for_node, gast.For - ), "Input node for the initialization of ForNodeVisitor is not gast.For node." + assert isinstance(for_node, gast.For), ( + "Input node for the initialization of ForNodeVisitor is not gast.For node." + ) # 1. original for node self.node = for_node @@ -276,14 +276,14 @@ def is_for_enumerate_iter(self): def _args_check(self): if self.is_for_range_iter(): self.args_length = len(self.iter_args) - assert ( - self.args_length >= 1 and self.args_length <= 3 - ), "range() function takes 1 to 3 arguments" + assert self.args_length >= 1 and self.args_length <= 3, ( + "range() function takes 1 to 3 arguments" + ) elif self.is_for_enumerate_iter(): self.args_length = len(self.iter_args) - assert ( - self.args_length >= 1 and self.args_length <= 2 - ), "enumerate() function takes 1 to 2 arguments" + assert self.args_length >= 1 and self.args_length <= 2, ( + "enumerate() function takes 1 to 2 arguments" + ) else: self.args_length = None diff --git a/python/paddle/jit/dy2static/transformers/break_continue_transformer.py b/python/paddle/jit/dy2static/transformers/break_continue_transformer.py index 582e737aa53b30..b9c877da1a8995 100644 --- a/python/paddle/jit/dy2static/transformers/break_continue_transformer.py +++ b/python/paddle/jit/dy2static/transformers/break_continue_transformer.py @@ -31,9 +31,9 @@ class ForToWhileTransformer(BaseTransformer): """ def __init__(self, parent_node, loop_node, condition_node): - assert isinstance( - loop_node, gast.For - ), "loop_node is not gast.For in ForToWhileTransformer" + assert isinstance(loop_node, gast.For), ( + "loop_node is not gast.For in ForToWhileTransformer" + ) self.parent_node = parent_node self.loop_node = loop_node self.condition_node = condition_node @@ -60,9 +60,9 @@ def transform(self): ) def get_for_stmt_nodes(self, node): - assert isinstance( - node, gast.For - ), "Input node is NOT gast.For in get_for_stmt_nodes" + assert isinstance(node, gast.For), ( + "Input node is NOT gast.For in get_for_stmt_nodes" + ) # 1. parse current gast.For node current_for_node_parser = ForNodeVisitor(node) diff --git a/python/paddle/jit/dy2static/transformers/early_return_transformer.py b/python/paddle/jit/dy2static/transformers/early_return_transformer.py index ce8cf9e606878a..d438fe41d1f9bf 100644 --- a/python/paddle/jit/dy2static/transformers/early_return_transformer.py +++ b/python/paddle/jit/dy2static/transformers/early_return_transformer.py @@ -34,9 +34,9 @@ def transform(self): self.visit(self.root) def is_define_return_in_if(self, node): - assert isinstance( - node, gast.If - ), f"Type of input node should be gast.If, but received {type(node)}." + assert isinstance(node, gast.If), ( + f"Type of input node should be gast.If, but received {type(node)}." + ) for child in node.body: if isinstance(child, gast.Return): return True diff --git a/python/paddle/jit/dy2static/transformers/logical_transformer.py b/python/paddle/jit/dy2static/transformers/logical_transformer.py index 1f7cc50db6e6a3..0a49289c9af3f1 100644 --- a/python/paddle/jit/dy2static/transformers/logical_transformer.py +++ b/python/paddle/jit/dy2static/transformers/logical_transformer.py @@ -83,9 +83,9 @@ def _create_bool_op_node(self, nodes, api_type): according to the actual order. In `convert_logical_and(lambda:x>1, lambda:y<1)`, `lambda:y<1` must be run after `lambda:x>1`, If `x>1` is False, `y<1` should NOT be run. ''' - assert ( - len(nodes) > 1 - ), f"The length of BoolOp should be at least 2, but received {len(nodes)}." + assert len(nodes) > 1, ( + f"The length of BoolOp should be at least 2, but received {len(nodes)}." + ) if len(nodes) > 2: # Creates logic_and/logic_or node recursively. pre_logic_node = self._create_bool_op_node(nodes[:2], api_type) diff --git a/python/paddle/jit/dy2static/transformers/loop_transformer.py b/python/paddle/jit/dy2static/transformers/loop_transformer.py index 4f1f9161f0e358..175d199b5ce3fb 100644 --- a/python/paddle/jit/dy2static/transformers/loop_transformer.py +++ b/python/paddle/jit/dy2static/transformers/loop_transformer.py @@ -134,9 +134,9 @@ def __init__(self, root_node): self.visit(root_node) def get_loop_var_names(self, node): - assert isinstance( - node, (gast.While, gast.For) - ), "Input node is not gast loop node" + assert isinstance(node, (gast.While, gast.For)), ( + "Input node is not gast loop node" + ) loop_var_names = set() create_var_names = set() read_context = {type(gast.Load()), type(gast.AugLoad())} diff --git a/python/paddle/jit/dy2static/transformers/name_load_transformer.py b/python/paddle/jit/dy2static/transformers/name_load_transformer.py index 717b1da41ba60e..75f8f4d96c79a2 100644 --- a/python/paddle/jit/dy2static/transformers/name_load_transformer.py +++ b/python/paddle/jit/dy2static/transformers/name_load_transformer.py @@ -98,9 +98,9 @@ class AttributeJstTransformer(BaseTransformer): """ def __init__(self, node): - assert isinstance( - node, gast.AST - ), "Input non-gast.AST node for the initialization of ToTensorTransformer." + assert isinstance(node, gast.AST), ( + "Input non-gast.AST node for the initialization of ToTensorTransformer." + ) self.interested_name = { 'size', } diff --git a/python/paddle/jit/dy2static/transformers/return_transformer.py b/python/paddle/jit/dy2static/transformers/return_transformer.py index 7afbb8c1725b3a..2902c1df196e0f 100644 --- a/python/paddle/jit/dy2static/transformers/return_transformer.py +++ b/python/paddle/jit/dy2static/transformers/return_transformer.py @@ -85,9 +85,9 @@ class ReturnAnalysisVisitor(gast.NodeVisitor): def __init__(self, root_node): self.root = root_node - assert isinstance( - self.root, gast.FunctionDef - ), "Input is not gast.FunctionDef node" + assert isinstance(self.root, gast.FunctionDef), ( + "Input is not gast.FunctionDef node" + ) # the number of return statements self.count_return = 0 @@ -151,9 +151,9 @@ class SingleReturnTransformer(BaseTransformer): def __init__(self, root): self.root = root - assert isinstance( - self.root, gast.FunctionDef - ), "Input is not gast.FunctionDef node" + assert isinstance(self.root, gast.FunctionDef), ( + "Input is not gast.FunctionDef node" + ) self.ancestor_nodes = [] diff --git a/python/paddle/jit/dy2static/transformers/utils.py b/python/paddle/jit/dy2static/transformers/utils.py index f630f0deea5dc7..ff3dbc824e8406 100644 --- a/python/paddle/jit/dy2static/transformers/utils.py +++ b/python/paddle/jit/dy2static/transformers/utils.py @@ -268,16 +268,16 @@ def create_node_for_name(name): def get_attribute_full_name(node): - assert isinstance( - node, gast.Attribute - ), "Input non-Attribute node to get attribute full name" + assert isinstance(node, gast.Attribute), ( + "Input non-Attribute node to get attribute full name" + ) return ast_to_source_code(node).strip() def is_api_in_module(node, module_prefix): - assert isinstance( - node, gast.Call - ), "Input non-Call node for is_api_in_module" + assert isinstance(node, gast.Call), ( + "Input non-Call node for is_api_in_module" + ) # Python can have gast.Call as function, for example: convert_call(func)(x) # We only check the most outside function @@ -385,9 +385,9 @@ def is_global_var(self, name): it means global vars; otherwise, it means local vars. Only valid after FunctionNameLivenessAnalysis visitor. """ - assert self._is_simple_name( - name - ), "is_global_var accept a simple name, but get `{name}`." + assert self._is_simple_name(name), ( + "is_global_var accept a simple name, but get `{name}`." + ) ancestor = self while ancestor is not None: if name in ancestor.globals: @@ -612,9 +612,9 @@ def _get_argument_names(self, node): this node is local to the function and shouldn't be created. """ - assert isinstance( - node, gast.FunctionDef - ), "Input node is not function define node" + assert isinstance(node, gast.FunctionDef), ( + "Input node is not function define node" + ) names = list(node.args.args) names.append(node.args.vararg) names.append(node.args.kwarg) diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 5c7240d2a7e9d9..92776366876346 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -790,9 +790,9 @@ def get(self, names): if vars is None: return () for n in names: - assert ( - n in self.name2id - ), f"the name `{n}` not in name union set`{self.name2id.keys()}`." + assert n in self.name2id, ( + f"the name `{n}` not in name union set`{self.name2id.keys()}`." + ) return tuple(vars[self.name2id[n]] for n in names) def set(self, names, values): @@ -804,9 +804,9 @@ def set(self, names, values): if vars is None: return for n in names: - assert ( - n in self.name2id - ), f"the name `{n}` not in name union set`{self.name2id.keys()}`." + assert n in self.name2id, ( + f"the name `{n}` not in name union set`{self.name2id.keys()}`." + ) vars = list(vars) indices = [self.name2id[n] for n in names] for i, v in zip(indices, values): diff --git a/python/paddle/jit/sot/infer_meta.py b/python/paddle/jit/sot/infer_meta.py index 539e86e4f39a31..c448eef86473b1 100644 --- a/python/paddle/jit/sot/infer_meta.py +++ b/python/paddle/jit/sot/infer_meta.py @@ -63,9 +63,9 @@ def __init__(self, mesh=None, dims_mapping=None, local_shape=None): @staticmethod def from_tensor(tensor: paddle.Tensor) -> DistInfo: - assert ( - isinstance(tensor, paddle.Tensor) and tensor.is_dist() - ), f"Expect a Tensor, but got a {type(tensor)}." + assert isinstance(tensor, paddle.Tensor) and tensor.is_dist(), ( + f"Expect a Tensor, but got a {type(tensor)}." + ) mesh = tensor.process_mesh sharding_specs = get_shard_spec( @@ -77,9 +77,9 @@ def from_tensor(tensor: paddle.Tensor) -> DistInfo: @staticmethod def from_value(value: paddle.pir.Value) -> DistInfo: - assert ( - isinstance(value, paddle.pir.Value) and value.is_dist() - ), f"Expect a Value, but got a {type(value)}." + assert isinstance(value, paddle.pir.Value) and value.is_dist(), ( + f"Expect a Value, but got a {type(value)}." + ) return DistInfo( value.dist_attr().process_mesh, value.dist_attr().dims_mapping, @@ -149,13 +149,13 @@ def from_tensor( ) -> MetaInfoOrNull: if not tensor._is_dense_tensor_hold_allocation(): return MetaInfoOrNull.null() - assert isinstance( - tensor, paddle.Tensor - ), "Expect a Tensor, but got a Value." + assert isinstance(tensor, paddle.Tensor), ( + "Expect a Tensor, but got a Value." + ) - assert ( - -1 not in tensor.shape - ), "Tensor shape should not contain -1, maybe you pass a Value to from_tensor" + assert -1 not in tensor.shape, ( + "Tensor shape should not contain -1, maybe you pass a Value to from_tensor" + ) user_specified_dynamic_axes = extract_tensor_dynamic_dims(tensor) dynamic_axes = dynamic_axes or [] dynamic_axes = MetaInfoOrNull.mix_axes( @@ -265,9 +265,9 @@ def __init__( spec_name=None, dist_info=None, ): - assert ( - -1 not in shape - ), "NOTE: Shape should not contain -1, consider convert it to SymbolicInt." + assert -1 not in shape, ( + "NOTE: Shape should not contain -1, consider convert it to SymbolicInt." + ) self.name = name self.persistable = persistable self.type = type @@ -430,9 +430,9 @@ def create_var(self, meta_or_null: MetaInfoOrNull): placements = to_placements(meta.dist_info.dims_mapping, mesh) var = paddle._pir_ops.shard_tensor(var, mesh, placements) var.stop_gradient = meta.stop_gradient - assert not isinstance( - var, paddle.Tensor - ), "Expect a Variable, but got a Tensor." + assert not isinstance(var, paddle.Tensor), ( + "Expect a Variable, but got a Tensor." + ) return var def get_variable(self, meta: MetaInfoOrNull, without_cache=False): @@ -513,9 +513,9 @@ def infer_meta(func, *args, **kwargs): def infer_meta_for_layer(layer, *args, **kwargs): - assert isinstance( - layer, paddle.nn.Layer - ), f"Expect a Layer, but got {layer}." + assert isinstance(layer, paddle.nn.Layer), ( + f"Expect a Layer, but got {layer}." + ) layer = paddle.jit.to_static(layer, full_graph=True) args_, kwargs_ = convert_meta_to_input_spec((args, kwargs)) @@ -636,9 +636,9 @@ def value_fn(self, layer, *args, **kwargs): class ConstrainedInputSpec(InputSpec): def __init__(self, dynamic_axes: list[int], *args, **kwargs): - self.ranges: list[tuple[int, int | None, int | None]] = ( - [] - ) # (idx of dim, min, max) + self.ranges: list[ + tuple[int, int | None, int | None] + ] = [] # (idx of dim, min, max) super().__init__(*args, **kwargs) min_non_specialized_number = get_min_non_specialized_number() for i in dynamic_axes: diff --git a/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py index f3e2bb2385120b..10e11fef30ce1f 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py +++ b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py @@ -255,9 +255,9 @@ def lookup( ) if not enable_unsafe_cache_fastpath: # TODO(zrr1999): cache_index should be equal to index when enable_strict_guard. - assert ( - cache_index is None or index == cache_index - ), f"cache_index({cache_index}) is not equal to index({index})" + assert cache_index is None or index == cache_index, ( + f"cache_index({cache_index}) is not equal to index({index})" + ) if enable_unsafe_cache_fastpath: if index == 0: diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index 29c753815e85aa..c288b7b823d750 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -376,9 +376,9 @@ def guard_fn(self) -> Guard: guards = OrderedSet(guards) # type: ignore for guard in guards: - assert isinstance( - guard, StringifiedExpression - ), "guard must be StringifiedExpression." + assert isinstance(guard, StringifiedExpression), ( + "guard must be StringifiedExpression." + ) return make_guard(guards) diff --git a/python/paddle/jit/sot/opcode_translator/executor/guard.py b/python/paddle/jit/sot/opcode_translator/executor/guard.py index f93fa6c392ffb8..a8f4066985e258 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/guard.py +++ b/python/paddle/jit/sot/opcode_translator/executor/guard.py @@ -224,9 +224,9 @@ def check_guard( fn: Callable[[CheckGuardInputT], list[StringifiedExpression]], ) -> Callable[[CheckGuardInputT], list[StringifiedExpression]]: def wrapper(self: CheckGuardInputT) -> list[StringifiedExpression]: - assert ( - self.tracker.is_traceable() - ), "Cannot make guard from a non-tracable guard variable." + assert self.tracker.is_traceable(), ( + "Cannot make guard from a non-tracable guard variable." + ) def guard_log(): frame_value_tracer = self.tracker.trace_value_from_frame() @@ -246,9 +246,9 @@ def check_faster_guard( def wrapper( self: CheckGuardInputT, ) -> list[paddle.framework.core.GuardNodeBase]: - assert ( - self.tracker.is_traceable() - ), "Cannot make guard from a non-tracable guard variable." + assert self.tracker.is_traceable(), ( + "Cannot make guard from a non-tracable guard variable." + ) def guard_log(): frame_value_tracer = self.tracker.trace_value_from_frame() diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index e7976c1d3c1a57..b93928070833a3 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -567,9 +567,9 @@ def pop_call_stack_until_self(self): Pops the call stack until the current executor. """ - assert ( - self in OpcodeExecutorBase.call_stack - ), f"{self} not in call stack" + assert self in OpcodeExecutorBase.call_stack, ( + f"{self} not in call stack" + ) while OpcodeExecutorBase.call_stack.pop() is not self: pass @@ -812,9 +812,9 @@ def _rot_top_n(self, n: int): # a1 a2 a3 ... an <- TOS # the stack changes to # an a1 a2 a3 an-1 <- TOS - assert ( - len(self.stack) >= n - ), f"There are not enough elements on the stack. {n} is needed." + assert len(self.stack) >= n, ( + f"There are not enough elements on the stack. {n} is needed." + ) top = self.stack.pop() self.stack.insert(n - 1, top) @@ -1136,9 +1136,9 @@ def DELETE_SUBSCR(self, instr: Instruction): def BUILD_LIST(self, instr: Instruction): list_size = instr.arg - assert list_size <= len( - self.stack - ), f"OpExecutor want BUILD_LIST with size {list_size}, but current stack do not have enough elems." + assert list_size <= len(self.stack), ( + f"OpExecutor want BUILD_LIST with size {list_size}, but current stack do not have enough elems." + ) val_list = self.stack.pop_n(list_size) self.stack.push( ListVariable( @@ -1148,9 +1148,9 @@ def BUILD_LIST(self, instr: Instruction): def BUILD_TUPLE(self, instr: Instruction): tuple_size = instr.arg - assert tuple_size <= len( - self.stack - ), f"OpExecutor want BUILD_TUPLE with size {tuple_size}, but current stack do not have enough elems." + assert tuple_size <= len(self.stack), ( + f"OpExecutor want BUILD_TUPLE with size {tuple_size}, but current stack do not have enough elems." + ) val_tuple = self.stack.pop_n(tuple_size) self.stack.push( TupleVariable( @@ -1162,9 +1162,9 @@ def BUILD_TUPLE(self, instr: Instruction): def BUILD_STRING(self, instr: Instruction): count = instr.arg - assert count <= len( - self.stack - ), f"OpExecutor want BUILD_STRING with size {count}, but current stack do not have enough elems." + assert count <= len(self.stack), ( + f"OpExecutor want BUILD_STRING with size {count}, but current stack do not have enough elems." + ) str_list = self.stack.pop_n(count) new_str = '' for s in str_list: @@ -1209,9 +1209,9 @@ def build_map( def BUILD_MAP(self, instr: Instruction): map_size = instr.arg - assert map_size * 2 <= len( - self.stack - ), f"OpExecutor want BUILD_MAP with size {map_size} * 2, but current stack do not have enough elems." + assert map_size * 2 <= len(self.stack), ( + f"OpExecutor want BUILD_MAP with size {map_size} * 2, but current stack do not have enough elems." + ) val_for_dict = self.stack.pop_n(map_size * 2) keys = val_for_dict[::2] values = val_for_dict[1::2] @@ -1219,9 +1219,9 @@ def BUILD_MAP(self, instr: Instruction): def BUILD_CONST_KEY_MAP(self, instr: Instruction): map_size = instr.arg - assert map_size + 1 <= len( - self.stack - ), f"OpExecutor want BUILD_CONST_KEY_MAP with size {map_size} + 1, but current stack do not have enough elems." + assert map_size + 1 <= len(self.stack), ( + f"OpExecutor want BUILD_CONST_KEY_MAP with size {map_size} + 1, but current stack do not have enough elems." + ) keys = self.stack.pop().get_wrapped_items() keys = list(keys) if isinstance(keys, tuple) else keys assert len(keys) == map_size @@ -1399,9 +1399,9 @@ def CALL_FUNCTION_EX(self, instr: Instruction): args_variable = self.stack.pop() args_iter = args_variable.get_iter() - assert isinstance( - args_iter, IterVariable - ), f"args_iter should be IterVariable, but got {args_iter}" + assert isinstance(args_iter, IterVariable), ( + f"args_iter should be IterVariable, but got {args_iter}" + ) if not isinstance(args_iter, SequenceIterVariable): raise BreakGraphError( UnsupportedOperationBreak( @@ -1459,9 +1459,9 @@ def COMPARE_OP(self, instr: Instruction): def TO_BOOL(self, instr: Instruction): # we don't do anything in TO_BOOL, we simply check if the bytecode is legal next_instr = self._instructions[self.vframe.lasti] - assert ( - next_instr.opname in NEED_TO_BOOL - ), f"The bytecode is illegal! The opcode following TO_BOOL must be in ['POP_JUMP_IF_TRUE', 'POP_JUMP_IF_FALSE', 'UNARY_NOT'], the next instruction now is {next_instr.opname}" + assert next_instr.opname in NEED_TO_BOOL, ( + f"The bytecode is illegal! The opcode following TO_BOOL must be in ['POP_JUMP_IF_TRUE', 'POP_JUMP_IF_FALSE', 'UNARY_NOT'], the next instruction now is {next_instr.opname}" + ) @call_break_graph_decorator(push_n=1) def IS_OP(self, instr: Instruction): @@ -1556,7 +1556,9 @@ def SET_FUNCTION_ATTRIBUTE(self, instr: Instruction): assert isinstance( origin_func, (UserDefinedGeneratorFunctionVariable, UserDefinedFunctionVariable), - ), f"The object we manipulate must be a function object. But now got {type(origin_func)}" + ), ( + f"The object we manipulate must be a function object. But now got {type(origin_func)}" + ) origin_func_val = origin_func.get_py_value() related_list = [origin_func] closure, related_list, kw_defaults, default_args = ( @@ -1773,9 +1775,9 @@ def UNPACK_EX(self, instr: Instruction): # a, b, *c, d = e front_nums = instr.arg & 0xFF back_nums = instr.arg >> 8 - assert ( - len(sequence) >= front_nums + back_nums - ), f"Want unpack {sequence} to {front_nums + back_nums}, but {len(sequence)} is smaller than {front_nums + back_nums}." + assert len(sequence) >= front_nums + back_nums, ( + f"Want unpack {sequence} to {front_nums + back_nums}, but {len(sequence)} is smaller than {front_nums + back_nums}." + ) for i in range( len(sequence) - 1, len(sequence) - back_nums - 1, -1 @@ -1789,9 +1791,9 @@ def UNPACK_EX(self, instr: Instruction): ) else: # a, b, c, *d = e - assert ( - len(sequence) >= instr.arg - ), f"Want unpack {sequence} to {instr.arg}, but {len(sequence)} is smaller than {instr.arg}." + assert len(sequence) >= instr.arg, ( + f"Want unpack {sequence} to {instr.arg}, but {len(sequence)} is smaller than {instr.arg}." + ) slice_obj = slice(instr.arg, None) slice_var = SliceVariable( @@ -2183,9 +2185,9 @@ def FOR_ITER(self, instr): return Stop(state="BreakGraph") def RETURN_VALUE(self, instr: Instruction): - assert ( - len(self.stack) == 1 - ), f"Stack must have one element, but get {len(self.stack)} elements." + assert len(self.stack) == 1, ( + f"Stack must have one element, but get {len(self.stack)} elements." + ) ret_val = self.stack.pop() return self.compile_return(ret_val) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py index 870acb9e84c025..40b303a337630b 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py @@ -102,9 +102,9 @@ def inline_call(self) -> VariableBase: return self.return_value def RETURN_VALUE(self, instr: Instruction): - assert ( - len(self.stack) == 1 - ), f"Stack must have one element, but get {len(self.stack)} elements." + assert len(self.stack) == 1, ( + f"Stack must have one element, but get {len(self.stack)} elements." + ) self.return_value = self.stack.pop() return Stop(state="Return") @@ -217,9 +217,9 @@ def FOR_ITER(self, instr: Instruction): return inline_for_iter_impl(self, instr) def RETURN_VALUE(self, instr: Instruction): - assert ( - len(self.stack) == 1 - ), f"Stack must have one element, but get {len(self.stack)} elements." + assert len(self.stack) == 1, ( + f"Stack must have one element, but get {len(self.stack)} elements." + ) self.return_value = self.stack.pop() return Stop(state="Return") diff --git a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py index b1fd174e3e95ff..6c97bf0ff49f8b 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py +++ b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py @@ -1021,9 +1021,9 @@ def set_inputs( self, inputs: list[str], stack_size: int, null_indices: list[int] = [] ): stack_arg_str = self.name + '_stack_{}' - assert all( - idx < stack_size for idx in null_indices - ), "null index out of range" + assert all(idx < stack_size for idx in null_indices), ( + "null index out of range" + ) self.codegen._code_options['co_argcount'] = ( len(inputs) + stack_size - len(null_indices) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py index 00fc621c6d1e80..a0b18d3bd5d8ce 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py @@ -1200,7 +1200,9 @@ def tensor_mod_dispatcher( "TensorVariable", ), partial( - lambda reverse_magic_name, var, other: other.graph.call_tensor_method( + lambda reverse_magic_name, + var, + other: other.graph.call_tensor_method( reverse_magic_name, other, var ), magic_method.name, diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py b/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py index 88f74f8a88992a..bf00ab8f4967e3 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py @@ -84,20 +84,20 @@ def __getitem__( assert 0 < index <= len(self._data) return self._data[-index] if isinstance(index, slice): - assert ( - index.start is None and index.step is None - ), "slice which has start or step not supported" + assert index.start is None and index.step is None, ( + "slice which has start or step not supported" + ) assert 0 < index.stop <= len(self._data) return self._data[-index.stop :] raise NotImplementedError(f"index type {type(index)} not supported") def __setitem__(self, index: int, value: Any): - assert isinstance( - index, int - ), f"index type {type(index)} not supported" - assert ( - 0 < index <= len(self._data) - ), f"index should be in [1, {len(self._data)}], but get {index}" + assert isinstance(index, int), ( + f"index type {type(index)} not supported" + ) + assert 0 < index <= len(self._data), ( + f"index should be in [1, {len(self._data)}], but get {index}" + ) self.validate_value_func(value) self._data[-index] = value @@ -151,9 +151,9 @@ def insert(self, index: int, val: StackDataT): val: The variable to be inserted. """ - assert ( - 0 <= index <= len(self) - ), f"index should be in [0, {len(self)}], but get {index}" + assert 0 <= index <= len(self), ( + f"index should be in [0, {len(self)}], but get {index}" + ) self.validate_value_func(val) self._data.insert(len(self) - index, val) @@ -179,9 +179,9 @@ def pop_n(self, n: int) -> list[StackDataT]: A list of the popped values. """ - assert ( - len(self) >= n >= 0 - ), f"n should be in [0, {len(self)}], but get {n}" + assert len(self) >= n >= 0, ( + f"n should be in [0, {len(self)}], but get {n}" + ) if n == 0: return [] retval = self._data[-n:] diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py index 0a0d298e119dac..99dc58d7214e37 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -216,9 +216,9 @@ def bool(self): return ConstantVariable(bool(self), self.graph, DummyTracker([self])) def bool_not(self): - assert isinstance( - self.get_py_value(), bool - ), "Bool_not can only be applied to a bool variable." + assert isinstance(self.get_py_value(), bool), ( + "Bool_not can only be applied to a bool variable." + ) return ConstantVariable( not bool(self.get_py_value()), self.graph, DummyTracker([self]) ) @@ -287,9 +287,9 @@ def wrap_literal(value: Any, graph: FunctionGraph) -> ConstantVariable: """ if isinstance(value, ConstantVariable): return value - assert isinstance( - value, ConstTypes - ), f"value: {value},type: {type(value)}" + assert isinstance(value, ConstTypes), ( + f"value: {value},type: {type(value)}" + ) return ConstantVariable(value, graph, ConstTracker(value)) @@ -985,16 +985,16 @@ def __init__( super().__init__(graph, tracker) self.var_name = self.var_name_generator.next() if isinstance(value_or_meta, MetaInfoOrNull): - assert ( - not value_or_meta.is_null() - ), "MetaInfoOrNull should not be null" + assert not value_or_meta.is_null(), ( + "MetaInfoOrNull should not be null" + ) assert len(value_or_meta.unwrap_unsafe().shape) == 0 self.value = get_symbolic_from_meta(value_or_meta) self.meta = value_or_meta else: - assert isinstance( - value_or_meta, SymbolicInt - ), f"Unsupported type {type(value_or_meta)} for SymbolicVariable" + assert isinstance(value_or_meta, SymbolicInt), ( + f"Unsupported type {type(value_or_meta)} for SymbolicVariable" + ) self.value = value_or_meta self.meta = MetaInfo( [], paddle.int64, True, self.var_name, False, None, None @@ -1018,15 +1018,15 @@ def __init__( def add_constraint(self, constraint: SymbolicConstraint): constraint_node, constraint_extern_vars = constraint for extern_var in constraint_extern_vars.values(): - assert isinstance( - extern_var, SymbolicVariable - ), f"SymbolicVariable.add_constraint() got {extern_var}." - assert ( - extern_var.value.is_backed() - ), "Only backed symbol is supported." - assert ( - extern_var.tracker.is_traceable() - ), "Only traceable symbol is supported." + assert isinstance(extern_var, SymbolicVariable), ( + f"SymbolicVariable.add_constraint() got {extern_var}." + ) + assert extern_var.value.is_backed(), ( + "Only backed symbol is supported." + ) + assert extern_var.tracker.is_traceable(), ( + "Only traceable symbol is supported." + ) self.constraints.append(constraint) def to_constant(self): @@ -1082,9 +1082,9 @@ def get_py_value(self, allow_tensor: bool = False) -> bool | int | float: ) ) value = self.tracker.op(*input_values) - assert isinstance( - value, (bool, int, float) - ), f"SymbolicVariable.get_py_value() should return bool, int or float, but got {type(value)}" + assert isinstance(value, (bool, int, float)), ( + f"SymbolicVariable.get_py_value() should return bool, int or float, but got {type(value)}" + ) return value def get_example_value( @@ -1112,9 +1112,9 @@ def get_example_value( ) ) value = self.tracker.op(*input_values) - assert isinstance( - value, (bool, int, float) - ), f"SymbolicVariable.get_example_value() should return bool, int or float, but got {type(value)}" + assert isinstance(value, (bool, int, float)), ( + f"SymbolicVariable.get_example_value() should return bool, int or float, but got {type(value)}" + ) return value def create_constraint_tree( @@ -1127,9 +1127,9 @@ def create_constraint_tree( extern_vars = {} num_sym = 0 for input in tracker.inputs: - assert isinstance( - input, (ConstantVariable, SymbolicVariable) - ), f"SymbolicVariable.create_constraint_tree() got {input}." + assert isinstance(input, (ConstantVariable, SymbolicVariable)), ( + f"SymbolicVariable.create_constraint_tree() got {input}." + ) if isinstance(input, ConstantVariable): input_nodes.append(ConstantConstraintNode(input.get_py_value())) else: diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py index 4e92cf3ffad356..e57121cd8572d4 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py @@ -1171,9 +1171,9 @@ def call_function(self, /, *args, **kwargs): vframe, code_var, self.graph ) gen = inline_gen_executor.inline_call() - assert isinstance( - gen, GeneratorVariable - ), f"GeneratorFunction calling result should be GeneratorVariable, but got {type(gen)}" + assert isinstance(gen, GeneratorVariable), ( + f"GeneratorFunction calling result should be GeneratorVariable, but got {type(gen)}" + ) gen.tracker = DummyTracker([self, *args, *kwargs.values()]) return gen return GeneratorVariable( @@ -1266,9 +1266,9 @@ def call_function(self, /, *args, **kwargs): input_py_args = [var.get_py_value() for var in args] input_py_kwargs = {k: v.get_py_value() for k, v in kwargs.items()} new_layer = self.value(*input_py_args, **input_py_kwargs) - assert self.check_no_weight_and_buffers( - new_layer - ), "You have created a layer in to_static function which may have Potential bugs. please create it in __init__/main function." + assert self.check_no_weight_and_buffers(new_layer), ( + "You have created a layer in to_static function which may have Potential bugs. please create it in __init__/main function." + ) return VariableFactory.from_value( new_layer, self.graph, CreateLayerTracker(self, args, kwargs) ) @@ -1372,9 +1372,9 @@ def call_function(self, /, *args, **kwargs): parameters = fn_bind_inputs(self.value, self.graph, *args, **kwargs) fields = self.get_py_value()._fields - assert all( - field in parameters for field in fields - ), f"All fields of namedtuple should be in parameters, but got parameter {parameters} and fields {fields}" + assert all(field in parameters for field in fields), ( + f"All fields of namedtuple should be in parameters, but got parameter {parameters} and fields {fields}" + ) parameters_tuple = tuple(parameters[field] for field in fields) return NamedTupleVariable( diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/container.py b/python/paddle/jit/sot/opcode_translator/executor/variables/container.py index d073c4e1ce9ad0..d7fb89217e50b2 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/container.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/container.py @@ -418,9 +418,9 @@ def count(self, value: VariableBase): index_value, value ) eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) - assert isinstance( - eq_bool, ConstantVariable - ), "bool should return ConstantVariable" + assert isinstance(eq_bool, ConstantVariable), ( + "bool should return ConstantVariable" + ) if eq.get_py_value() is True: count += 1 continue @@ -442,9 +442,9 @@ def index(self, value: VariableBase): index_value, value ) eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) - assert isinstance( - eq_bool, ConstantVariable - ), "bool should return ConstantVariable" + assert isinstance(eq_bool, ConstantVariable), ( + "bool should return ConstantVariable" + ) if eq.get_py_value() is True: return ConstantVariable( res, self.graph, DummyTracker([self, value]) @@ -641,9 +641,9 @@ def count(self, value: VariableBase): index_value, value ) eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) - assert isinstance( - eq_bool, ConstantVariable - ), "bool should return ConstantVariable" + assert isinstance(eq_bool, ConstantVariable), ( + "bool should return ConstantVariable" + ) if eq.get_py_value() is True: count += 1 continue @@ -665,9 +665,9 @@ def index(self, value: VariableBase): index_value, value ) eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) - assert isinstance( - eq_bool, ConstantVariable - ), "bool should return ConstantVariable" + assert isinstance(eq_bool, ConstantVariable), ( + "bool should return ConstantVariable" + ) if eq.get_py_value() is True: return ConstantVariable( res, self.graph, DummyTracker([self, value]) diff --git a/python/paddle/jit/sot/opcode_translator/executor/virtual_frame.py b/python/paddle/jit/sot/opcode_translator/executor/virtual_frame.py index f0a91713678299..4fa4476056d91c 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/virtual_frame.py +++ b/python/paddle/jit/sot/opcode_translator/executor/virtual_frame.py @@ -51,9 +51,9 @@ def validate_value(value): - assert isinstance( - value, VariableBase - ), f"value: {value}, type should be VariableBase(or derived), but get {type(value)}" + assert isinstance(value, VariableBase), ( + f"value: {value}, type should be VariableBase(or derived), but get {type(value)}" + ) assert not isinstance(value.tracker, DanglingTracker) or isinstance( value, (NullVariable, CellVariable) ), f"dangling variable {value} should not be pushed into stack." diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py index dc6798db58a458..98cf9aa5bc359e 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -428,28 +428,28 @@ def modify_vars(instructions: list[Instruction], code_options): 'STORE_FAST', 'DELETE_FAST', ]: - assert ( - instrs.argval in co_varnames - ), f"`{instrs.argval}` not in {co_varnames}" + assert instrs.argval in co_varnames, ( + f"`{instrs.argval}` not in {co_varnames}" + ) instrs.arg = co_varnames.index(instrs.argval) elif instrs.opname == "LOAD_DEREF" or instrs.opname == "STORE_DEREF": if sys.version_info >= (3, 11): namemap = co_varnames + co_freevars - assert ( - instrs.argval in namemap - ), f"`{instrs.argval}` not in {namemap}" + assert instrs.argval in namemap, ( + f"`{instrs.argval}` not in {namemap}" + ) instrs.arg = namemap.index(instrs.argval) elif instrs.opname in [ 'LOAD_FAST_LOAD_FAST', 'STORE_FAST_STORE_FAST', 'STORE_FAST_LOAD_FAST', ]: - assert ( - instrs.argval[0] in co_varnames - ), f"`{instrs.argval[0]}` not in {co_varnames}" - assert ( - instrs.argval[1] in co_varnames - ), f"`{instrs.argval[1]}` not in {co_varnames}" + assert instrs.argval[0] in co_varnames, ( + f"`{instrs.argval[0]}` not in {co_varnames}" + ) + assert instrs.argval[1] in co_varnames, ( + f"`{instrs.argval[1]}` not in {co_varnames}" + ) instrs.arg = ( co_varnames.index(instrs.argval[0]) << 4 ) + co_varnames.index(instrs.argval[1]) diff --git a/python/paddle/jit/sot/symbolic/builder.py b/python/paddle/jit/sot/symbolic/builder.py index a951a1d3f3da09..6eb14604e420e7 100644 --- a/python/paddle/jit/sot/symbolic/builder.py +++ b/python/paddle/jit/sot/symbolic/builder.py @@ -91,12 +91,12 @@ def call_METHOD(self, method_name, inputs, outputs, stacks): """ Call a method of a api. The API here can be python or Paddle """ - assert isinstance( - method_name, str - ), "call_METHOD must method api name. string." - assert isinstance( - inputs[0][0], Symbol - ), "call_METHOD first argument must be Symbol Variable." + assert isinstance(method_name, str), ( + "call_METHOD must method api name. string." + ) + assert isinstance(inputs[0][0], Symbol), ( + "call_METHOD first argument must be Symbol Variable." + ) stmt = MethodStatement( method_name, inputs, diff --git a/python/paddle/jit/sot/symbolic/compile_cache.py b/python/paddle/jit/sot/symbolic/compile_cache.py index 4db0238ba2728f..ab3fa48a6c0fd2 100644 --- a/python/paddle/jit/sot/symbolic/compile_cache.py +++ b/python/paddle/jit/sot/symbolic/compile_cache.py @@ -205,9 +205,9 @@ def update_compile_time_info(self, SIR, partial_program_layer): assert code is not None, f"Cannot find code for SIR: {SIR}" OpcodeExecutorCache().compile_time_stats.setdefault(code, 0) - OpcodeExecutorCache().compile_time_stats[ - code - ] += partial_program_layer._compile_time_counter.get_total_time() + OpcodeExecutorCache().compile_time_stats[code] += ( + partial_program_layer._compile_time_counter.get_total_time() + ) @event_register( lambda self, *args, **kwargs: f"FallbackWrapper: {self.SIR.name}" diff --git a/python/paddle/jit/sot/translate.py b/python/paddle/jit/sot/translate.py index 2cf2ef3616ce74..bb3b539aa65cbd 100644 --- a/python/paddle/jit/sot/translate.py +++ b/python/paddle/jit/sot/translate.py @@ -101,9 +101,9 @@ def callback(frame): def impl(*args: P.args, **kwargs: P.kwargs) -> R: with StepInfoManager().step_guard(fn.__code__), SotStepProfilerGuard(): - assert hasattr( - fn, "__code__" - ), "Target function doesn't have code for simulating." + assert hasattr(fn, "__code__"), ( + "Target function doesn't have code for simulating." + ) InfoCollector().clear_step_info() paddle.framework.core.set_eval_frame(callback) try: diff --git a/python/paddle/jit/sot/utils/envs.py b/python/paddle/jit/sot/utils/envs.py index 5b003ef2723a7d..8c51184366007c 100644 --- a/python/paddle/jit/sot/utils/envs.py +++ b/python/paddle/jit/sot/utils/envs.py @@ -51,12 +51,12 @@ def parse_from_string(self) -> dict[str, list[str]]: def convert_to_string(self, value: dict[str, list[str]]) -> str: assert isinstance(value, dict), "The input must be a dict" - assert all( - isinstance(x, str) for x in value.keys() - ), "Keys must be a string" - assert all( - isinstance(x, list) for x in value.values() - ), "Values must be a list" + assert all(isinstance(x, str) for x in value.keys()), ( + "Keys must be a string" + ) + assert all(isinstance(x, list) for x in value.values()), ( + "Values must be a list" + ) env_list = [] for k, v in value.items():