Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ repos:

# | python/paddle/[e-i].+

# | python/paddle/j.+
| python/paddle/j.+

# | python/paddle/[k-n].+

Expand Down Expand Up @@ -143,7 +143,7 @@ repos:

| python/paddle/[e-i].+

| python/paddle/j.+
# | python/paddle/j.+

| python/paddle/[k-n].+

Expand Down
12 changes: 9 additions & 3 deletions python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,13 +756,17 @@ def convert_var_dtype(var, dtype):
'int32',
'int64',
'uint8',
], f"The dtype of var {var.name} is {src_dtype}, which is not supported in the cast op."
], (
f"The dtype of var {var.name} is {src_dtype}, which is not supported in the cast op."
)
assert dtype in [
'bool',
'int',
'float',
'complex',
], f"The casted target dtype is {dtype}, which is not supported in type casting."
], (
f"The casted target dtype is {dtype}, which is not supported in type casting."
)
cast_map = {
'bool': 'bool',
'int': 'int32',
Expand All @@ -776,7 +780,9 @@ def convert_var_dtype(var, dtype):
'int',
'float',
'complex',
], f"The casted target dtype is {dtype}, which is not supported in type casting."
], (
f"The casted target dtype is {dtype}, which is not supported in type casting."
)
return eval(dtype)(var)


Expand Down
12 changes: 6 additions & 6 deletions python/paddle/jit/dy2static/origin_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def create_and_update_origin_info_map(
static_node = attach_origin_info(static_node, static_func)

for t_node, s_node in ast_walk(transformed_node, static_node):
assert type(t_node) == type(
s_node
), f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}."
assert type(t_node) == type(s_node), (
f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}."
)
dygraph_info = getattr(t_node, ORIGIN_INFO, None)
static_info = getattr(s_node, ORIGIN_INFO, None)

Expand Down Expand Up @@ -232,9 +232,9 @@ def _as_list(x):
):
continue

assert type(t_node) == type(
s_node
), f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}."
assert type(t_node) == type(s_node), (
f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}."
)

yield t_node, s_node

Expand Down
42 changes: 21 additions & 21 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,15 @@ def __init__(
forward_range=None,
backward_range=None,
):
assert isinstance(
in_out_values, tuple
), "in_out_values must be tuple with len == 3"
assert (
len(in_out_values) == 3
), "in_out_values must be tuple with len == 3"
assert isinstance(
in_out_values[0], list
), "in_out_values must be tuple with len == 3"
assert isinstance(in_out_values, tuple), (
"in_out_values must be tuple with len == 3"
)
assert len(in_out_values) == 3, (
"in_out_values must be tuple with len == 3"
)
assert isinstance(in_out_values[0], list), (
"in_out_values must be tuple with len == 3"
)
self.program = program
self.x_names = self.convert_name(in_out_values[0])
self.param_names = self.convert_name(in_out_values[1])
Expand Down Expand Up @@ -310,9 +310,9 @@ def clone(self):
)

def split_forward_backward(self):
assert (
self.has_splited is False
), "Please ensure only split once! don't call split_forward_backward manually."
assert self.has_splited is False, (
"Please ensure only split once! don't call split_forward_backward manually."
)
self.has_splited = True
self.update_op_range()
(
Expand Down Expand Up @@ -406,9 +406,9 @@ def _forward_backward_program(self):

@cached_property # shouldn't changed when call this once.
def program_attr(self):
assert (
self.finish_pass is False
), "program_attr() is called by PartialProgramLayer, don't call it manually, use program_name_attr instead."
assert self.finish_pass is False, (
"program_attr() is called by PartialProgramLayer, don't call it manually, use program_name_attr instead."
)
# can't apply pass after call this function.
self.finish_pass = True
fwd_map = RunnableProgram._get_name_value_map_from_program(
Expand Down Expand Up @@ -445,9 +445,9 @@ def program_attr(self):
program_attr[f"{k}_names"] = ns

# Restore stop_gradient for output values
assert len(program_attr["fo_values"]) == len(
self.out_stop_gradients
), "Output values and stop gradients length mismatch"
assert len(program_attr["fo_values"]) == len(self.out_stop_gradients), (
"Output values and stop gradients length mismatch"
)
for v, stop_gradient in zip(
program_attr["fo_values"], self.out_stop_gradients
):
Expand All @@ -474,9 +474,9 @@ def unify_value_names(
# Get all values again because some values has been erased.
for value in RunnableProgram._get_program_all_values(program):
if value.has_name:
assert (
value._has_only_one_name()
), f"Expected all values in Program have only one name, but {value} has multiple names: {value._names}"
assert value._has_only_one_name(), (
f"Expected all values in Program have only one name, but {value} has multiple names: {value._names}"
)
return rename_mapping

@staticmethod
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,9 @@ def rollback(self) -> Callable[_InputT, _RetT]:
if self._patched_name is not None
else self._dygraph_function.__name__
)
assert (
fn_name in self.class_instance._original_funcs
), f"Not Found function '{fn_name}' in class '{self.class_instance.__class__}'."
assert fn_name in self.class_instance._original_funcs, (
f"Not Found function '{fn_name}' in class '{self.class_instance.__class__}'."
)
func = self.class_instance._original_funcs[fn_name]
setattr(self.class_instance, fn_name, func.__get__(self.class_instance))
return getattr(self.class_instance, fn_name)
Expand Down Expand Up @@ -1733,9 +1733,9 @@ def get_program(self, item):
return self._caches[item_id]

def last(self):
assert (
len(self._caches) >= 1
), "No valid cached program in ProgramCache."
assert len(self._caches) >= 1, (
"No valid cached program in ProgramCache."
)
assert self._recent_key is not None
return self._recent_key, self._caches[self._recent_key]

Expand Down
18 changes: 9 additions & 9 deletions python/paddle/jit/dy2static/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ class ForNodeVisitor:
"""

def __init__(self, for_node):
assert isinstance(
for_node, gast.For
), "Input node for the initialization of ForNodeVisitor is not gast.For node."
assert isinstance(for_node, gast.For), (
"Input node for the initialization of ForNodeVisitor is not gast.For node."
)
# 1. original for node
self.node = for_node

Expand Down Expand Up @@ -276,14 +276,14 @@ def is_for_enumerate_iter(self):
def _args_check(self):
if self.is_for_range_iter():
self.args_length = len(self.iter_args)
assert (
self.args_length >= 1 and self.args_length <= 3
), "range() function takes 1 to 3 arguments"
assert self.args_length >= 1 and self.args_length <= 3, (
"range() function takes 1 to 3 arguments"
)
elif self.is_for_enumerate_iter():
self.args_length = len(self.iter_args)
assert (
self.args_length >= 1 and self.args_length <= 2
), "enumerate() function takes 1 to 2 arguments"
assert self.args_length >= 1 and self.args_length <= 2, (
"enumerate() function takes 1 to 2 arguments"
)
else:
self.args_length = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class ForToWhileTransformer(BaseTransformer):
"""

def __init__(self, parent_node, loop_node, condition_node):
assert isinstance(
loop_node, gast.For
), "loop_node is not gast.For in ForToWhileTransformer"
assert isinstance(loop_node, gast.For), (
"loop_node is not gast.For in ForToWhileTransformer"
)
self.parent_node = parent_node
self.loop_node = loop_node
self.condition_node = condition_node
Expand All @@ -60,9 +60,9 @@ def transform(self):
)

def get_for_stmt_nodes(self, node):
assert isinstance(
node, gast.For
), "Input node is NOT gast.For in get_for_stmt_nodes"
assert isinstance(node, gast.For), (
"Input node is NOT gast.For in get_for_stmt_nodes"
)

# 1. parse current gast.For node
current_for_node_parser = ForNodeVisitor(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def transform(self):
self.visit(self.root)

def is_define_return_in_if(self, node):
assert isinstance(
node, gast.If
), f"Type of input node should be gast.If, but received {type(node)}."
assert isinstance(node, gast.If), (
f"Type of input node should be gast.If, but received {type(node)}."
)
for child in node.body:
if isinstance(child, gast.Return):
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def _create_bool_op_node(self, nodes, api_type):
according to the actual order. In `convert_logical_and(lambda:x>1, lambda:y<1)`, `lambda:y<1`
must be run after `lambda:x>1`, If `x>1` is False, `y<1` should NOT be run.
'''
assert (
len(nodes) > 1
), f"The length of BoolOp should be at least 2, but received {len(nodes)}."
assert len(nodes) > 1, (
f"The length of BoolOp should be at least 2, but received {len(nodes)}."
)
if len(nodes) > 2:
# Creates logic_and/logic_or node recursively.
pre_logic_node = self._create_bool_op_node(nodes[:2], api_type)
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/jit/dy2static/transformers/loop_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def __init__(self, root_node):
self.visit(root_node)

def get_loop_var_names(self, node):
assert isinstance(
node, (gast.While, gast.For)
), "Input node is not gast loop node"
assert isinstance(node, (gast.While, gast.For)), (
"Input node is not gast loop node"
)
loop_var_names = set()
create_var_names = set()
read_context = {type(gast.Load()), type(gast.AugLoad())}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ class AttributeJstTransformer(BaseTransformer):
"""

def __init__(self, node):
assert isinstance(
node, gast.AST
), "Input non-gast.AST node for the initialization of ToTensorTransformer."
assert isinstance(node, gast.AST), (
"Input non-gast.AST node for the initialization of ToTensorTransformer."
)
self.interested_name = {
'size',
}
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/jit/dy2static/transformers/return_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ class ReturnAnalysisVisitor(gast.NodeVisitor):

def __init__(self, root_node):
self.root = root_node
assert isinstance(
self.root, gast.FunctionDef
), "Input is not gast.FunctionDef node"
assert isinstance(self.root, gast.FunctionDef), (
"Input is not gast.FunctionDef node"
)

# the number of return statements
self.count_return = 0
Expand Down Expand Up @@ -151,9 +151,9 @@ class SingleReturnTransformer(BaseTransformer):

def __init__(self, root):
self.root = root
assert isinstance(
self.root, gast.FunctionDef
), "Input is not gast.FunctionDef node"
assert isinstance(self.root, gast.FunctionDef), (
"Input is not gast.FunctionDef node"
)

self.ancestor_nodes = []

Expand Down
24 changes: 12 additions & 12 deletions python/paddle/jit/dy2static/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,16 +268,16 @@ def create_node_for_name(name):


def get_attribute_full_name(node):
assert isinstance(
node, gast.Attribute
), "Input non-Attribute node to get attribute full name"
assert isinstance(node, gast.Attribute), (
"Input non-Attribute node to get attribute full name"
)
return ast_to_source_code(node).strip()


def is_api_in_module(node, module_prefix):
assert isinstance(
node, gast.Call
), "Input non-Call node for is_api_in_module"
assert isinstance(node, gast.Call), (
"Input non-Call node for is_api_in_module"
)

# Python can have gast.Call as function, for example: convert_call(func)(x)
# We only check the most outside function
Expand Down Expand Up @@ -385,9 +385,9 @@ def is_global_var(self, name):
it means global vars; otherwise, it means local vars.
Only valid after FunctionNameLivenessAnalysis visitor.
"""
assert self._is_simple_name(
name
), "is_global_var accept a simple name, but get `{name}`."
assert self._is_simple_name(name), (
"is_global_var accept a simple name, but get `{name}`."
)
ancestor = self
while ancestor is not None:
if name in ancestor.globals:
Expand Down Expand Up @@ -612,9 +612,9 @@ def _get_argument_names(self, node):
this node is local to the function and shouldn't
be created.
"""
assert isinstance(
node, gast.FunctionDef
), "Input node is not function define node"
assert isinstance(node, gast.FunctionDef), (
"Input node is not function define node"
)
names = list(node.args.args)
names.append(node.args.vararg)
names.append(node.args.kwarg)
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/jit/dy2static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,9 +790,9 @@ def get(self, names):
if vars is None:
return ()
for n in names:
assert (
n in self.name2id
), f"the name `{n}` not in name union set`{self.name2id.keys()}`."
assert n in self.name2id, (
f"the name `{n}` not in name union set`{self.name2id.keys()}`."
)
return tuple(vars[self.name2id[n]] for n in names)

def set(self, names, values):
Expand All @@ -804,9 +804,9 @@ def set(self, names, values):
if vars is None:
return
for n in names:
assert (
n in self.name2id
), f"the name `{n}` not in name union set`{self.name2id.keys()}`."
assert n in self.name2id, (
f"the name `{n}` not in name union set`{self.name2id.keys()}`."
)
vars = list(vars)
indices = [self.name2id[n] for n in names]
for i, v in zip(indices, values):
Expand Down
Loading
Loading