Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeStyle][Ruff][BUAA][D-[7-13]] Fix ruff RUF015 diagnostic for 6 files in python/paddle/ #67359

Merged
merged 11 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/paddle/distributed/transpiler/distribute_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,11 +1826,11 @@ def _update_dist_lookup_table_vars(
for grad in grad_list
if grad.name != grad_var_name(self.table_name)
]
self.table_param_grad = [
self.table_param_grad = next(
param_grad
for param_grad in params_grads
if param_grad[0].name == self.table_name
][0]
)
table_grad_var = self.table_param_grad[1]
if self.sync_mode:
self.trainer_side_table_grad_list = [
Expand Down Expand Up @@ -2132,12 +2132,12 @@ def _create_table_optimize_block(
table_opt_block = pserver_program._create_block(pre_block_idx)
# create table param and grad var in pserver program
# create table optimize block in pserver program
table_opt_op = [
table_opt_op = next(
op
for op in self.optimize_ops
if 'Param' in op.input_names
and op.input("Param")[0] == self.table_name
][0]
)

origin_param_var = self.origin_program.global_block().vars[
self.table_name
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def to(self, *args, **kwargs):
if len(invalid_keys) != 0:
raise TypeError(
"to() got an unexpected keyword argument "
+ list(invalid_keys)[0]
+ next(iter(invalid_keys))
)

def dtype_first_sig(dtype, blocking=None): ...
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/static/amp/function_overload.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def get(self, *args, **kwargs):
satisfied_function_keys.remove(func_key)
break
if len(satisfied_function_keys) == 1:
key = list(satisfied_function_keys)[0]
key = next(iter(satisfied_function_keys))
elif len(args) >= 3 and isinstance(args[2], float):
key = FunctionType.FP16_ONLY
else:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,9 +1023,9 @@ def get_paddle_extra_install_requirements():
output = subprocess.check_output(['nvcc', '--version']).decode(
'utf-8'
)
version_line = [
version_line = next(
line for line in output.split('\n') if 'release' in line
][0]
)
version = version_line.split(' ')[-1].split(',')[0]
cuda_major_version = version.split('.')[0]
except Exception as e:
Expand Down
6 changes: 3 additions & 3 deletions test/sot/test_model_switch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def setUp(self):

def check_mode(self, is_train):
self.assertEqual(len(self.compile_cache.cache), 1)
mode = list(self.compile_cache.cache.values())[
0
].partial_program.training
mode = next(
iter(self.compile_cache.cache.values())
).partial_program.training
self.assertEqual(mode, is_train)

def get_dygraph_out(self, input):
Expand Down
4 changes: 2 additions & 2 deletions test/xpu/test_tril_triu_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_errors1(self):
errmsg = {
"diagonal: TypeError": f"diagonal in {op_type} must be a python Int",
}
expected = list(errmsg.keys())[0]
expected = next(iter(errmsg.keys()))
with self.assertRaisesRegex(
eval(expected.split(':')[-1]), errmsg[expected]
):
Expand All @@ -155,7 +155,7 @@ def test_errors2(self):
errmsg = {
"input: ValueError": f"x shape in {op_type} must be at least 2-D",
}
expected = list(errmsg.keys())[0]
expected = next(iter(errmsg.keys()))
with self.assertRaisesRegex(
eval(expected.split(':')[-1]), errmsg[expected]
):
Expand Down