Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ repos:

# | test/[m-z].+

# | tools/.+
| tools/.+
)$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.0
Expand Down Expand Up @@ -163,7 +163,7 @@ repos:

| test/[m-z].+

| tools/.+
# | tools/.+
)$
# For C++ files
- repo: local
Expand Down
6 changes: 3 additions & 3 deletions tools/check_op_benchmark_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def compare_benchmark_result(
develop_speed = develop_result.get("speed")
pr_speed = pr_result.get("speed")

assert type(develop_speed) == type(
pr_speed
), "The types of comparison results need to be consistent."
assert type(develop_speed) == type(pr_speed), (
"The types of comparison results need to be consistent."
)

if isinstance(develop_speed, dict) and isinstance(pr_speed, dict):
if check_speed_result(case_name, develop_speed, pr_speed, pr_result):
Expand Down
18 changes: 9 additions & 9 deletions tools/check_op_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,17 +300,17 @@ def compare_op_desc(origin_op_desc, new_op_desc):
desc_error_message.setdefault(op_type, {})[ATTRS] = attrs_diff

if ins_version_errors:
version_error_message.setdefault(op_type, {})[
INPUTS
] = ins_version_errors
version_error_message.setdefault(op_type, {})[INPUTS] = (
ins_version_errors
)
if outs_version_errors:
version_error_message.setdefault(op_type, {})[
OUTPUTS
] = outs_version_errors
version_error_message.setdefault(op_type, {})[OUTPUTS] = (
outs_version_errors
)
if attrs_version_errors:
version_error_message.setdefault(op_type, {})[
ATTRS
] = attrs_version_errors
version_error_message.setdefault(op_type, {})[ATTRS] = (
attrs_version_errors
)

return desc_error_message, version_error_message

Expand Down
36 changes: 18 additions & 18 deletions tools/gen_pybind11_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ def parse_input_and_attr(
inputs = {'names': [], 'input_info': {}}
attrs = {'names': [], 'attr_info': {}}
args_str = args_config.strip()
assert args_str.startswith('(') and args_str.endswith(
')'
), f"Args declaration should start with '(' and end with ')', please check the args of {api_name} in yaml."
assert args_str.startswith('(') and args_str.endswith(')'), (
f"Args declaration should start with '(' and end with ')', please check the args of {api_name} in yaml."
)
args_str = args_str[1:-1]
pattern = re.compile(r',(?![^{]*\})') # support int[] a={1,3}
args_list = re.split(pattern, args_str.strip())
Expand All @@ -541,12 +541,12 @@ def parse_input_and_attr(
for in_type_symbol, in_type in INPUT_TYPES_MAP.items():
if type_and_name[0] == in_type_symbol:
input_name = type_and_name[1].strip()
assert (
len(input_name) > 0
), f"The input tensor name should not be empty. Please check the args of {api_name} in yaml."
assert (
len(attrs['names']) == 0
), f"The input Tensor should appear before attributes. please check the position of {api_name}:input({input_name}) in yaml"
assert len(input_name) > 0, (
f"The input tensor name should not be empty. Please check the args of {api_name} in yaml."
)
assert len(attrs['names']) == 0, (
f"The input Tensor should appear before attributes. please check the position of {api_name}:input({input_name}) in yaml"
)

if input_name in optional_vars:
in_type = OPTIONAL_TYPES_TRANS[in_type_symbol]
Expand All @@ -562,9 +562,9 @@ def parse_input_and_attr(
for attr_type_symbol, attr_type in ATTR_TYPES_MAP.items():
if type_and_name[0] == attr_type_symbol:
attr_name = item[len(attr_type_symbol) :].strip()
assert (
len(attr_name) > 0
), f"The attribute name should not be empty. Please check the args of {api_name} in yaml."
assert len(attr_name) > 0, (
f"The attribute name should not be empty. Please check the args of {api_name} in yaml."
)
default_value = None
if '=' in attr_name:
attr_infos = attr_name.split('=')
Expand All @@ -589,14 +589,14 @@ def parse_output_item(output_item):
r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?",
output_item,
)
assert (
result is not None
), f"{api_name} : the output config parse error."
assert result is not None, (
f"{api_name} : the output config parse error."
)
out_type = result.group('out_type')
assert (
out_type in OUTPUT_TYPE_MAP
), f"{api_name} : Output type error: the output type only support Tensor and Tensor[], \
assert out_type in OUTPUT_TYPE_MAP, (
f"{api_name} : Output type error: the output type only support Tensor and Tensor[], \
but now is {out_type}."
)

out_name = (
'out'
Expand Down
62 changes: 33 additions & 29 deletions tools/gen_ut_cmakelists.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,9 @@ def _process_archs(arch):
for a in arch.split(";"):
if '' == a:
continue
assert a in [
"GPU",
"ROCM",
"XPU",
], f"""Supported arch options are "GPU", "ROCM", and "XPU", but the options is {a}"""
assert a in ["GPU", "ROCM", "XPU"], (
f"""Supported arch options are "GPU", "ROCM", and "XPU", but the options is {a}"""
)
archs += "WITH_" + a.upper() + " OR "
arch = "(" + archs[:-4] + ")"
else:
Expand All @@ -127,11 +125,9 @@ def _process_os(os_):
if len(os_) > 0:
os_ = os_.upper()
for p in os_.split(';'):
assert p in [
"WIN32",
"APPLE",
"LINUX",
], f"""Supported os options are 'WIN32', 'APPLE' and 'LINUX', but the options is {p}"""
assert p in ["WIN32", "APPLE", "LINUX"], (
f"""Supported os options are 'WIN32', 'APPLE' and 'LINUX', but the options is {p}"""
)
os_ = os_.replace(";", " OR ")
os_ = "(" + os_ + ")"
else:
Expand All @@ -146,7 +142,9 @@ def _process_run_serial(run_serial):
"1",
"0",
"",
], f"""the value of run_serial must be one of 0, 1 or empty. But this value is {rs}"""
], (
f"""the value of run_serial must be one of 0, 1 or empty. But this value is {rs}"""
)
if rs == "":
return ""
return rs
Expand Down Expand Up @@ -175,9 +173,9 @@ def _process_name(name, curdir):
)
filepath_prefix = os.path.join(curdir, name)
suffix = [".py", ".sh"]
assert _file_with_extension(
filepath_prefix, suffix
), f""" Please ensure the test file with the prefix '{filepath_prefix}' and one of the suffix {suffix} exists, because you specified a unittest named '{name}'"""
assert _file_with_extension(filepath_prefix, suffix), (
f""" Please ensure the test file with the prefix '{filepath_prefix}' and one of the suffix {suffix} exists, because you specified a unittest named '{name}'"""
)

return name

Expand Down Expand Up @@ -238,7 +236,9 @@ def process_dist_port_num(self, port_num):
re.compile("^[0-9]+$").search(port_num)
and int(port_num) > 0
or port_num.strip() == ""
), f"""port_num must be format as a positive integer or empty, but this port_num is '{port_num}'"""
), (
f"""port_num must be format as a positive integer or empty, but this port_num is '{port_num}'"""
)
port_num = port_num.strip()
if len(port_num) == 0:
return 0
Expand Down Expand Up @@ -272,7 +272,9 @@ def _init_dist_ut_ports_from_cmakefile(self, cmake_file_name):

# match right tests name format, the name must start with 'test_' followed by at least one char of
# '0-9'. 'a-z'. 'A-Z' or '_'
assert re.compile("^test_[0-9a-zA-Z_]+").search(
assert re.compile(
"^test_[0-9a-zA-Z_]+"
).search(
name
), f'''we found a test for initial the latest dist_port but the test name '{name}' seems to be wrong
at line {k - 1}, in file {cmake_file_name}
Expand Down Expand Up @@ -349,9 +351,9 @@ def parse_assigned_dist_ut_ports(self, current_work_dir, depth=0):
if name == self.last_test_name:
found = True
break
assert (
found
), f"no such test named '{self.last_test_name}' in file '{self.last_test_cmake_file}'"
assert found, (
f"no such test named '{self.last_test_name}' in file '{self.last_test_cmake_file}'"
)
if launcher[-2:] == ".sh":
self.process_dist_port_num(num_port)

Expand Down Expand Up @@ -485,9 +487,9 @@ def _parse_line(self, line, curdir):
try:
run_type = _process_run_type(run_type)
except Exception as e:
assert (
run_type.strip() == ""
), f"{e}\nIf use test_runner.py, the run_type can be ''"
assert run_type.strip() == "", (
f"{e}\nIf use test_runner.py, the run_type can be ''"
)
cmd += f'''if({archs} AND {os_})
py_test_modules(
{name}
Expand Down Expand Up @@ -580,7 +582,9 @@ def _gen_cmakelists(self, current_work_dir, depth=0):
assert (
f"{current_work_dir}/CMakeLists.txt"
not in self.modified_or_created_files
), f"the file {current_work_dir}/CMakeLists.txt are modified twice, which may cause some error"
), (
f"the file {current_work_dir}/CMakeLists.txt are modified twice, which may cause some error"
)
self.modified_or_created_files.append(
f"{current_work_dir}/CMakeLists.txt"
)
Expand Down Expand Up @@ -630,15 +634,15 @@ def _gen_cmakelists(self, current_work_dir, depth=0):
)
args = parser.parse_args()

assert not (
len(args.files) == 0 and len(args.dirpaths) == 0
), "You must provide at least one file or dirpath"
assert not (len(args.files) == 0 and len(args.dirpaths) == 0), (
"You must provide at least one file or dirpath"
)
current_work_dirs = []
if len(args.files) >= 1:
for p in args.files:
assert (
os.path.basename(p) == "testslist.csv"
), "you must input file named testslist.csv"
assert os.path.basename(p) == "testslist.csv", (
"you must input file named testslist.csv"
)
current_work_dirs = current_work_dirs + [
os.path.dirname(file) for file in args.files
]
Expand Down
4 changes: 1 addition & 3 deletions tools/test_check_pr_approval.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def setUp(self):
"author_association": "CONTRIBUTOR"
}
]
""".encode(
self.codeset
)
""".encode(self.codeset)

def test_ids(self):
cmd = [sys.executable, 'check_pr_approval.py', '1', '26408901']
Expand Down