Skip to content

Commit 8fff414

Browse files
gouzilSigureMo
andauthored
[CodeStyle] black -> ruff format migration - part 15 (#74669)
--------- Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com>
1 parent ee70af1 commit 8fff414

File tree

30 files changed

+103
-81
lines changed

30 files changed

+103
-81
lines changed

python/paddle/base/device_worker.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Definition of device workers."""
15+
1516
import sys
1617

1718
__all__ = []
@@ -626,9 +627,10 @@ def _gen_worker_desc(self, trainer_desc):
626627
# then runs Backward phase for all microbatches.
627628
# 1F1B scheduler, which runs forward phase and backward phase alternatively
628629
# after startup phase.
629-
assert schedule_mode_str in ["F-then-B", "1F1B"], (
630-
"The schedule mode " "for pipeline must be one of F-then-B or 1F1B"
631-
)
630+
assert schedule_mode_str in [
631+
"F-then-B",
632+
"1F1B",
633+
], "The schedule mode for pipeline must be one of F-then-B or 1F1B"
632634
schedule_mode = 0 if schedule_mode_str == "F-then-B" else 1
633635
section_param.schedule_mode = schedule_mode
634636
cfg = section_param.section_config

python/paddle/cinn/runtime/cinn_jit.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,16 @@ def _make_launcher(self):
5050
jit_input_args = ', '.join(arg_name for arg_name in self.arg_names)
5151
lazy_compile = f"""
5252
import paddle.cinn as cinn
53-
def {self.fn.__name__}({jit_input_args}, target=cinn.common.DefaultHostTarget()):
53+
def {self.fn.__name__}({
54+
jit_input_args
55+
}, target=cinn.common.DefaultHostTarget()):
5456
from paddle.cinn.compiler import compile
5557
jit_inputs = {', '.join([f'{arg}' for arg in self.arg_names])}
5658
jit_inputs_signature = {{ i: self._convert_arg_type(arg) \
5759
for i, arg in enumerate(jit_inputs)}}
5860
module = compile(self, jit_inputs_signature=jit_inputs_signature, arg_names={
59-
self.arg_names}, target=target)
61+
self.arg_names
62+
}, target=target)
6063
module({jit_input_args})
6164
6265
return module

python/paddle/dataset/conll05.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,17 @@ def reader():
190190
pred_idx = [predicate_dict.get(predicate)] * sen_len
191191
label_idx = [label_dict.get(w) for w in labels]
192192

193-
yield word_idx, ctx_n2_idx, ctx_n1_idx, ctx_0_idx, ctx_p1_idx, ctx_p2_idx, pred_idx, mark, label_idx
193+
yield (
194+
word_idx,
195+
ctx_n2_idx,
196+
ctx_n1_idx,
197+
ctx_0_idx,
198+
ctx_p1_idx,
199+
ctx_p2_idx,
200+
pred_idx,
201+
mark,
202+
label_idx,
203+
)
194204

195205
return reader
196206

python/paddle/dataset/imdb.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,14 @@ def tokenize(pattern):
4949
while tf is not None:
5050
if bool(pattern.match(tf.name)):
5151
# newline and punctuations removal and ad-hoc tokenization.
52-
yield tarf.extractfile(tf).read().rstrip(b'\n\r').translate(
53-
None, string.punctuation.encode('latin-1')
54-
).lower().split()
52+
yield (
53+
tarf.extractfile(tf)
54+
.read()
55+
.rstrip(b'\n\r')
56+
.translate(None, string.punctuation.encode('latin-1'))
57+
.lower()
58+
.split()
59+
)
5560
tf = tarf.next()
5661

5762

python/paddle/dataset/wmt14.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
__all__ = []
2929

3030
URL_DEV_TEST = (
31-
'http://www-lium.univ-lemans.fr/~schwenk/'
32-
'cslm_joint_paper/data/dev+test.tgz'
31+
'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
3332
)
3433
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
3534
# this is a small set of data for test. The original data is too large and

python/paddle/distributed/auto_parallel/pipelining/_backward.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def extract_tensors_with_grads(
120120
# Deactivate auto mixed precision context in the backward phase
121121
with paddle.amp.auto_cast(enable=False):
122122
paddle.autograd.backward(
123-
stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type]
123+
stage_output_tensors,
124+
grad_tensors=output_grad_tensors,
124125
)
125126

126127
# Extract gradients wrt the input values

python/paddle/distributed/auto_parallel/pipelining/schedules.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,9 @@ def _step_microbatches(
470470
for work in works.values():
471471
work.wait()
472472

473-
output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
473+
output = self._stage.forward_one_chunk(
474+
i, arg_mbs[i], kwarg_mbs[i]
475+
)
474476

475477
ops = self._stage.get_fwd_send_ops(i)
476478
works = _sorted_batch_p2p(ops, desc="fwd_send")
@@ -577,7 +579,9 @@ def _step_microbatches(
577579
recv_work.wait()
578580

579581
# Compute
580-
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
582+
output = self._stage.forward_one_chunk(
583+
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
584+
)
581585

582586
# Clear previous chunk's forward sends (hopefully they have well
583587
# finished, otherwise, we are heavily communication bound, in which
@@ -639,7 +643,9 @@ def _step_microbatches(
639643
fuse_work.wait()
640644

641645
# Now do the fwd
642-
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
646+
output = self._stage.forward_one_chunk(
647+
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
648+
)
643649

644650
# Compute loss
645651
self._maybe_compute_loss(

python/paddle/distributed/auto_parallel/ring_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,9 @@ def update_out_and_lse(
161161
old_lse[:, old_lse.shape[1] // 2 :, :, :] = second_chunk_lse
162162
return old_out, old_lse
163163
else:
164-
block_out, block_lse = paddle.cast(block_out, "float32"), paddle.cast(
165-
block_lse, "float32"
164+
block_out, block_lse = (
165+
paddle.cast(block_out, "float32"),
166+
paddle.cast(block_lse, "float32"),
166167
)
167168
with paddle.amp.auto_cast(enable=False):
168169
return old_out - (old_out - block_out) * F.sigmoid(

python/paddle/distributed/auto_parallel/static/converter.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ def _check_tensor_dict(self, tensors_dict):
6161
def _check_pre_strategy(self, pre_strategy):
6262
if not pre_strategy:
6363
raise ValueError(
64-
"'pre_strategy' is None, "
65-
"there are not tensors in pre process."
64+
"'pre_strategy' is None, there are not tensors in pre process."
6665
)
6766
if not isinstance(pre_strategy, dict):
6867
raise TypeError(
@@ -74,8 +73,7 @@ def _check_pre_strategy(self, pre_strategy):
7473
def _check_cur_strategy(self, cur_strategy):
7574
if not cur_strategy:
7675
warnings.warn(
77-
"'cur_strategy' is None, "
78-
"there are not tensors in cur process"
76+
"'cur_strategy' is None, there are not tensors in cur process"
7977
)
8078
if not isinstance(cur_strategy, dict):
8179
raise TypeError(

python/paddle/distributed/auto_parallel/static/cost/op_runtime_cost.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,10 @@ def _filter_vars_with_zero_in_degree_and_ignore_feed_fetch_vars():
9191
# ignore communication op from graph, because sometimes we want to profile a sub-graph
9292
# and these dangling operators will not work (no graph to communicate to/from)
9393
continue
94-
input_var_names, output_var_names = _collect_op_input_var_names(
95-
op
96-
), _collect_op_output_var_names(op)
94+
input_var_names, output_var_names = (
95+
_collect_op_input_var_names(op),
96+
_collect_op_output_var_names(op),
97+
)
9798
for var_name in input_var_names + output_var_names:
9899
if var_name not in var_in_degree:
99100
var_in_degree[var_name] = 0
@@ -280,10 +281,9 @@ def measure_program_real_op_cost(
280281
isinstance(place, supported_place)
281282
for supported_place in supported_places
282283
), f'Current place ({place}) does not support runtime profiling. "place" should be one of the following: {supported_places}.'
283-
assert isinstance(run_iters, int) and run_iters >= 1, (
284-
'Invalid parameter run_iters set. run_iters '
285-
'should be an integer >= 1.'
286-
)
284+
assert (
285+
isinstance(run_iters, int) and run_iters >= 1
286+
), 'Invalid parameter run_iters set. run_iters should be an integer >= 1.'
287287
if run_iters == 1:
288288
warnings.warn(
289289
'run_iters was set to 1, profiling results might be inaccurate due to outliers.'

0 commit comments

Comments
 (0)