Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ repos:

# | python/paddle/[a-c].+

# | python/paddle/de.+
| python/paddle/de.+

# | python/paddle/distributed/a.+

Expand Down Expand Up @@ -131,7 +131,7 @@ repos:

| python/paddle/[a-c].+

| python/paddle/de.+
# | python/paddle/de.+

| python/paddle/distributed/a.+

Expand Down
42 changes: 21 additions & 21 deletions python/paddle/decomposition/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,16 @@ def _check_op_results(
f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, '
f'but orig_out dtype={orig_dtype} and new_out dtype={new_dtype}'
)
assert (
-1 not in new_shape
), f'when replace origin op {op_name} with composite rule, composite out shape has -1.'
assert -1 not in new_shape, (
f'when replace origin op {op_name} with composite rule, composite out shape has -1.'
)
assert orig_shape == new_shape, (
f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out shape={orig_shape} and new_out shape={new_shape}'
)
assert not (orig_out is None) ^ (
new_out is None
), "orig_out and new_out should match."
assert not (orig_out is None) ^ (new_out is None), (
"orig_out and new_out should match."
)
return


Expand Down Expand Up @@ -261,9 +261,9 @@ def _check_op(

bwd_op_input_names = bwd_op.get_input_names()
bwd_inputs = [x.source() for x in bwd_op.operands()]
assert len(bwd_op_input_names) == len(
bwd_inputs
), "backward op names do not match backward op inputs"
assert len(bwd_op_input_names) == len(bwd_inputs), (
"backward op names do not match backward op inputs"
)
fwd_op_related_inputs_outputs = []
for idx, name in enumerate(bwd_op_input_names):
if "_grad" not in name:
Expand Down Expand Up @@ -417,14 +417,14 @@ def _prepare_grad_outputs(fwd_op, bwd_op):
# check forward outputs and backward inputs
fwd_outputs = fwd_op.results()
fwd_output_names = fwd_op.get_output_names()
assert len(fwd_output_names) == len(
fwd_outputs
), "forward op output names do not match forward op outputs"
assert len(fwd_output_names) == len(fwd_outputs), (
"forward op output names do not match forward op outputs"
)
bwd_inputs = [x.source() for x in bwd_op.operands()]
bwd_input_names = bwd_op.get_input_names()
assert len(bwd_input_names) == len(
bwd_inputs
), "backward op input names do not match backward op inputs"
assert len(bwd_input_names) == len(bwd_inputs), (
"backward op input names do not match backward op inputs"
)

# cut gradients from backward op's inputs
fwd_inputs = [x.source() for x in fwd_op.operands()]
Expand Down Expand Up @@ -541,9 +541,9 @@ def _decomp_bwd_with_vjp(
res.append(grad_input[0])
else:
res.append(pir.fake_value())
assert len(res) == len(
bwd_op.results()
), "results of original backward op do not match results of decomposed backward op"
assert len(res) == len(bwd_op.results()), (
"results of original backward op do not match results of decomposed backward op"
)

# step4: upgrade grad_var_to_var
_upgrade_grad_var_to_var(
Expand Down Expand Up @@ -735,9 +735,9 @@ def _set_prim_state():


def _reset_prim_state(state):
assert (
len(state) == 3
), "state should contain fwd_prim_state, bwd_prim_state and pir_api_state"
assert len(state) == 3, (
"state should contain fwd_prim_state, bwd_prim_state and pir_api_state"
)
core._set_prim_forward_enabled(state[0])
core._set_prim_backward_enabled(state[1])
paddle.framework.set_flags({"FLAGS_enable_pir_api": state[2]})
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/decomposition/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,13 @@ def _get_downstream_ops_recursively(cur):
return downstream_unrecomputable_ops

for op in self.ops:
self.upstream_unrecomputable_ops_map[
op
] |= _get_upstream_ops_recursively(op)
self.upstream_unrecomputable_ops_map[op] |= (
_get_upstream_ops_recursively(op)
)
for op in reversed(self.ops):
self.downstream_unrecomputable_ops_map[
op
] |= _get_downstream_ops_recursively(op)
self.downstream_unrecomputable_ops_map[op] |= (
_get_downstream_ops_recursively(op)
)

def _has_unfusible_op_on_any_path(self, op1, op2):
no_unfusible_op_on_path = (
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/decomposition/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def __init__(self, name):
def register(self, op_type, rule):
assert isinstance(op_type, str)
assert inspect.isfunction(rule)
assert (
op_type not in self.rules
), f'name "{op_type}" should not be registered before.'
assert op_type not in self.rules, (
f'name "{op_type}" should not be registered before.'
)
self.rules[op_type] = rule

def lookup(self, op_type):
Expand Down
18 changes: 9 additions & 9 deletions python/paddle/device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,18 +683,18 @@ def extract_device_id(device: _CustomPlaceLike, op_name: str) -> int:
"Please input appropriate device again!"
)

assert (
device_id >= 0
), f"The device id must be not less than 0, but got id = {device_id}."
assert device_id >= 0, (
f"The device id must be not less than 0, but got id = {device_id}."
)

if core.is_compiled_with_cuda():
assert (
device_id < device_count()
), f"The device id {device_id} exceeds gpu card number {device_count()}"
assert device_id < device_count(), (
f"The device id {device_id} exceeds gpu card number {device_count()}"
)
else:
assert device_id < core.get_custom_device_count(
device_type
), f"The device id {device_id} exceeds {device_type} device card number {core.get_custom_device_count(device_type)}"
assert device_id < core.get_custom_device_count(device_type), (
f"The device id {device_id} exceeds {device_type} device card number {core.get_custom_device_count(device_type)}"
)
return device_id


Expand Down
18 changes: 9 additions & 9 deletions python/paddle/device/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,18 @@ def extract_cuda_device_id(device: _CudaPlaceLike, op_name: str) -> int:
"Please input appropriate device again!"
)

assert (
device_id >= 0
), f"The device id must be not less than 0, but got id = {device_id}."
assert device_id >= 0, (
f"The device id must be not less than 0, but got id = {device_id}."
)

if core.is_compiled_with_cuda():
assert (
device_id < device_count()
), f"The device id {device_id} exceeds gpu card number {device_count()}"
assert device_id < device_count(), (
f"The device id {device_id} exceeds gpu card number {device_count()}"
)
else:
assert device_id < core.get_custom_device_count(
device_type
), f"The device id {device_id} exceeds {device_type} device card number {core.get_custom_device_count(device_type)}"
assert device_id < core.get_custom_device_count(device_type), (
f"The device id {device_id} exceeds {device_type} device card number {core.get_custom_device_count(device_type)}"
)
return device_id


Expand Down
66 changes: 33 additions & 33 deletions python/paddle/device/cuda/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def is_cuda_graph_supported():

class CUDAGraph:
def __init__(self, place=None, mode="thread_local", pool_id=None):
assert (
CoreCUDAGraph is not None
), "CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU."
assert CoreCUDAGraph is not None, (
"CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU."
)

self._graph = None
if place is None:
Expand Down Expand Up @@ -73,9 +73,9 @@ def print_to_dot_files(self, dirname, flags=None):
if not isinstance(dirname, (str, bytes)):
dirname = dirname.name
os.makedirs(name=dirname, exist_ok=True)
assert os.path.isdir(
dirname
), f"The dirname {dirname} should be a directory"
assert os.path.isdir(dirname), (
f"The dirname {dirname} should be a directory"
)
if flags is None:
flags = 2047 # only all information. It can be any integer inside [1, 2048)
self._graph.print_to_dot_files(dirname, flags)
Expand Down Expand Up @@ -238,16 +238,16 @@ def get_cuda_graph_sections(program):

for idx, op in enumerate(block.ops):
if op.type == 'conditional_block' or op.type == 'while':
assert (
op._cuda_graph_attr is None
), "Cuda graph not support conditional block op and while op."
assert op._cuda_graph_attr is None, (
"Cuda graph not support conditional block op and while op."
)
if op.has_attr('is_test') and op.attr('is_test'):
is_test = True
# find cuda graph sections
if op._cuda_graph_attr is not None:
assert isinstance(
op._cuda_graph_attr, str
), "cuda_graph_attr should be a str"
assert isinstance(op._cuda_graph_attr, str), (
"cuda_graph_attr should be a str"
)
cuda_graph_attrs = op._cuda_graph_attr.split(';')
assert len(cuda_graph_attrs) == 3, (
"cuda graph attr should have three fields: "
Expand All @@ -256,9 +256,9 @@ def get_cuda_graph_sections(program):
local_cuda_graph_id = int(cuda_graph_attrs[2])
if local_cuda_graph_id == current_cuda_graph_id:
if len(internal_section) > 0:
assert len(internal_section) == len(
internal_idx
), "len of internal section should be equal with len of internal idx"
assert len(internal_section) == len(internal_idx), (
"len of internal section should be equal with len of internal idx"
)
for internal_op in internal_section:
loss_related = (
int(internal_op.attr(op_role_attr_name))
Expand All @@ -283,9 +283,9 @@ def get_cuda_graph_sections(program):
internal_section = []
internal_idx = []
# Beside clear the internal section, a new cuda graph section should be recorded
assert len(current_section) == len(
current_idx
), "num of section's op is not equal with the idx"
assert len(current_section) == len(current_idx), (
"num of section's op is not equal with the idx"
)
if len(current_section) > 0:
# store previous section
cuda_graph_sections.append(current_section)
Expand All @@ -309,9 +309,9 @@ def get_cuda_graph_sections(program):
current_cuda_graph_id = (
local_cuda_graph_id # start record a new section
)
assert len(current_section) == len(
current_idx
), "num of section's op is not equal with num of idx"
assert len(current_section) == len(current_idx), (
"num of section's op is not equal with num of idx"
)
if len(current_section) > 0:
# store previous section
cuda_graph_sections.append(current_section)
Expand All @@ -324,9 +324,9 @@ def get_cuda_graph_sections(program):
internal_idx.append(idx)

# handle the last section
assert len(current_section) == len(
current_idx
), "num of section's op is not equal with num of idx"
assert len(current_section) == len(current_idx), (
"num of section's op is not equal with num of idx"
)
if len(current_section) > 0:
# store previous section
cuda_graph_sections.append(current_section)
Expand Down Expand Up @@ -377,9 +377,9 @@ def replace_cuda_graph_section(
memory_pool_id = int(attrs[1])
break

assert (
mode is not None and memory_pool_id is not None
), "mode and memory pool id should be specified in cuda graph attr"
assert mode is not None and memory_pool_id is not None, (
"mode and memory pool id should be specified in cuda graph attr"
)

cuda_graph_var = origin_block.create_var(
name="cuda_graph_" + str(order),
Expand Down Expand Up @@ -445,9 +445,9 @@ def cuda_graph_transform(program):
cuda_graph_sections, sections_idx, is_test = get_cuda_graph_sections(
program
)
assert len(cuda_graph_sections) == len(
sections_idx
), "num of cuda graph sections is not equal with num of idx sections"
assert len(cuda_graph_sections) == len(sections_idx), (
"num of cuda graph sections is not equal with num of idx sections"
)

# step 2: construct new program for each section and find inputs and outputs of each section.
# The inputs are variables generated outside the section but will be used by this section.
Expand All @@ -461,9 +461,9 @@ def cuda_graph_transform(program):
)
ins_and_outs.append(ins_outs)
section_programs.append(section_program)
assert len(section_programs) == len(
cuda_graph_sections
), "the num of cuda graph sections should be equal with the num of new program"
assert len(section_programs) == len(cuda_graph_sections), (
"the num of cuda graph sections should be equal with the num of new program"
)

# step 3: replace the ops in original program with run_program_op.
# Will remove all ops in the section from origin program, and use run_program_op to replace them.
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/device/xpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ def extract_xpu_device_id(device: _XPUPlaceLike, op_name: str) -> int:
"Please input appropriate device again!"
)

assert (
device_id >= 0
), f"The device id must be not less than 0, but got id = {device_id}."
assert (
device_id < device_count()
), f"The device id {device_id} exceeds xpu card number {device_count()}"
assert device_id >= 0, (
f"The device id must be not less than 0, but got id = {device_id}."
)
assert device_id < device_count(), (
f"The device id {device_id} exceeds xpu card number {device_count()}"
)
return device_id


Expand Down