Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ repos:

# | python/paddle/distributed/f.+

# | python/paddle/distributed/[g-z].+
| python/paddle/distributed/[g-z].+

# | python/paddle/[e-i].+

Expand Down Expand Up @@ -139,7 +139,7 @@ repos:

| python/paddle/distributed/f.+

| python/paddle/distributed/[g-z].+
# | python/paddle/distributed/[g-z].+

| python/paddle/[e-i].+

Expand Down
12 changes: 6 additions & 6 deletions python/paddle/distributed/launch/controllers/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def __init__(self, ctx):
self.join_server = None

def deploy_pod(self):
assert (
len(self.pod.containers) + len(self.pod.init_containers) > 0
), "No container in the pod"
assert len(self.pod.containers) + len(self.pod.init_containers) > 0, (
"No container in the pod"
)

self.ctx.logger.info(f"Run {self.pod}")
if len(self.pod.init_containers) > 0:
Expand Down Expand Up @@ -309,9 +309,9 @@ def save_pod_log(self, info):
self.ctx.logger.error(f"save log failed because {e}")

def save_pod_env(self):
assert (
len(self.pod.containers) + len(self.pod.init_containers) > 0
), "No container in the pod"
assert len(self.pod.containers) + len(self.pod.init_containers) > 0, (
"No container in the pod"
)

if not self.ctx.args.log_dir:
return
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/distributed/launch/controllers/ipu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,19 @@ def replace_training_script(self):

num_ipus = int(self.ctx.args.devices)
# The number of replicas for data parallel
assert (
num_ipus % poprun_args.ipus_per_replica
) == 0, f"The number of IPUs:{num_ipus} mod the number of IPUs per replica:{poprun_args.ipus_per_replica} must == 0"
assert (num_ipus % poprun_args.ipus_per_replica) == 0, (
f"The number of IPUs:{num_ipus} mod the number of IPUs per replica:{poprun_args.ipus_per_replica} must == 0"
)
num_replicas = num_ipus // poprun_args.ipus_per_replica
self.ctx.logger.info(f"The number of total replicas is {num_replicas}.")

# The number of processes
num_nodes = len(poprun_args.hosts.split(','))
num_procs = num_nodes * poprun_args.nproc_per_host
self.ctx.logger.info(f"The number of total processes is {num_procs}.")
assert (
num_replicas % num_procs
) == 0, f"The number of replicas:{num_replicas} mod the number of processes:{num_procs} must == 0"
assert (num_replicas % num_procs) == 0, (
f"The number of replicas:{num_replicas} mod the number of processes:{num_procs} must == 0"
)

# hosts and endpoints
hosts = poprun_args.hosts.replace(' ', '').split(',')
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/distributed/launch/controllers/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def enable(cls, ctx):
return False

def build_pod(self):
assert (
self.ctx.args.master is not None
), "Master is None, Please set master address!"
assert self.ctx.args.master is not None, (
"Master is None, Please set master address!"
)
self._build_pod_with_master()

def _build_pod_with_master(self):
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/distributed/launch/job/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def update_env(self, env={}, **kwargs):

def _validate_env(self):
for k, v in self._env.items():
assert isinstance(k, str) and isinstance(
v, str
), f'env {k}:{v} must be str'
assert isinstance(k, str) and isinstance(v, str), (
f'env {k}:{v} must be str'
)

def _get_fd(self, pth):
if not pth:
Expand Down
18 changes: 9 additions & 9 deletions python/paddle/distributed/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,9 @@ def __init__(
) -> None:
super().__init__(layers.full_name() + "_data_parallel")

assert (
in_dynamic_mode()
), "It's not supported to construct DataParallel in static graph mode."
assert in_dynamic_mode(), (
"It's not supported to construct DataParallel in static graph mode."
)

self._layers = layers
self.find_unused_parameters = find_unused_parameters
Expand Down Expand Up @@ -756,12 +756,12 @@ def __init__(self):
).split(",")
self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
assert (
self._nrings > 0
), "nccl_nrings must be an integer greater than 0."
assert (
self._nrings < 9
), "nccl_nrings should be less than 9, which is enough in most scenarios."
assert self._nrings > 0, (
"nccl_nrings must be an integer greater than 0."
)
assert self._nrings < 9, (
"nccl_nrings should be less than 9, which is enough in most scenarios."
)

@property
def rank(self) -> int:
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/distributed/parallel_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ def _is_parallel_ctx_initialized():

def _set_parallel_ctx(ccl_parallel_context):
global __parallel_ctx__clz__
assert (
__parallel_ctx__clz__ is None
), "ParallelContext can only be initialized once."
assert __parallel_ctx__clz__ is None, (
"ParallelContext can only be initialized once."
)
__parallel_ctx__clz__ = ccl_parallel_context


def _init_parallel_ctx():
global __parallel_ctx__clz__
assert (
__parallel_ctx__clz__ is not None
), "ParallelContext should be initialized."
assert __parallel_ctx__clz__ is not None, (
"ParallelContext should be initialized."
)
__parallel_ctx__clz__.init()


Expand Down
6 changes: 3 additions & 3 deletions python/paddle/distributed/parallel_with_gloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def gloo_init_parallel_env(
... test_gloo_init_with_multiprocess(2)
"""

assert (
rank_num < 2
) is False, "rank_num should greater than or equal to 2 for parallel environment initialization."
assert (rank_num < 2) is False, (
"rank_num should greater than or equal to 2 for parallel environment initialization."
)

# init gloo context
manager = Manager()
Expand Down
24 changes: 12 additions & 12 deletions python/paddle/distributed/passes/auto_parallel_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,9 @@ def _cast_block(self, block):
out_var = block.var(out_var_name)
in_var = block._find_var_recursive(in_var_name)
for in_var_name in op.input_arg_names:
assert (
in_var.dtype == block.var(in_var_name).dtype
), f"{in_var}, {block.var(in_var_name)}, {op}"
assert in_var.dtype == block.var(in_var_name).dtype, (
f"{in_var}, {block.var(in_var_name)}, {op}"
)
out_var.desc.set_dtype(in_var.dtype)
elif int(op.attr('op_role')) == 257:
pass
Expand Down Expand Up @@ -545,9 +545,9 @@ def _keep_fp32_output(op, out_name):
cast_name, in_var_dist_attr
)
else:
assert (
in_var.dtype == dst_dtype
), f"op [{op.type}] expect input [{in_name}] to be dtype [{dst_dtype}] BUT got [{in_var.dtype}]. {op}"
assert in_var.dtype == dst_dtype, (
f"op [{op.type}] expect input [{in_name}] to be dtype [{dst_dtype}] BUT got [{in_var.dtype}]. {op}"
)

for out_name in op.output_names:
if src_dtype == paddle.float32 and _keep_fp32_output(op, out_name):
Expand Down Expand Up @@ -1158,13 +1158,13 @@ def _update_loss_scaling(self, grads, found_inf):
e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling'
)
if e.dtype == paddle.float16:
assert (
self._loss_scaling.dtype == paddle.float32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
assert self._loss_scaling.dtype == paddle.float32, (
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
)
else:
assert (
self._loss_scaling.dtype == e.dtype
), "The dtype of prev_loss_scaling should be equal to the dtype of x."
assert self._loss_scaling.dtype == e.dtype, (
"The dtype of prev_loss_scaling should be equal to the dtype of x."
)

inputs = {
'X': grads,
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/distributed/passes/auto_parallel_c_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def _update_before_dims_mapping(self, new_op):
results.append(dist_attr_new)
sub_name = op.name().split('.')[1]
if op.num_operands() > 0:
assert (
sub_name != "cast"
), "Need to add support for {sub_name}."
assert sub_name != "cast", (
"Need to add support for {sub_name}."
)
operands.append(dist_attr_new)
next_op = op.operand(0).source().get_defining_op()
stack.append(next_op)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ def _analyze_program(self):
grad_name = op.output_arg_names[0]
if grad_name in self._grad_name_to_group_map:
continue
assert op.has_attr(
"ring_id"
), f"Unexpected: comm op [{op}] has NOT ring id."
assert op.has_attr("ring_id"), (
f"Unexpected: comm op [{op}] has NOT ring id."
)
group = ring_id_to_process_group(op.attr("ring_id"))

assert (
group is not None
), f"Unexpected: data parallel group of [{grad_name}] from op [{op}] is None"
assert group is not None, (
f"Unexpected: data parallel group of [{grad_name}] from op [{op}] is None"
)

self._grad_name_to_group_map[grad_name] = group

Expand All @@ -182,9 +182,9 @@ def _analyze_program(self):
for grad_name in scaled_grads:
if grad_name not in self._grad_name_to_group_map:
not_synchronized_grads.append(grad_name)
assert (
len(not_synchronized_grads) == 0
), f"Unexpected: gradients [{not_synchronized_grads}] is scaled BUT NOT synchronized."
assert len(not_synchronized_grads) == 0, (
f"Unexpected: gradients [{not_synchronized_grads}] is scaled BUT NOT synchronized."
)

def is_data_parallel_applied(self):
return len(self._group_to_grad_name_map) > 0
Expand Down Expand Up @@ -239,12 +239,12 @@ def _update_opt_rescale_grad(self):
is_optimize_op(op)
and op.type in __rescale_grad_supported_opts__
):
assert op.has_attr(
'rescale_grad'
), f"Unexpected: op [{op}] is supported to have [rescale_grad] attribute."
assert (
len(op.input("Grad")) == 1
), f"Unexpected: op [{op}] is supported to have only one input grad var."
assert op.has_attr('rescale_grad'), (
f"Unexpected: op [{op}] is supported to have [rescale_grad] attribute."
)
assert len(op.input("Grad")) == 1, (
f"Unexpected: op [{op}] is supported to have only one input grad var."
)

grad_name = op.input("Grad")[0]
dp_degree = len(
Expand All @@ -255,9 +255,9 @@ def _update_opt_rescale_grad(self):
rescale_grad = float(op.attr('rescale_grad')) / dp_degree
op._set_attr('rescale_grad', rescale_grad)

assert scaled_grads == set(
self._grad_name_to_group_map.keys()
), f"Unexpected: gradients [{set(self._grad_name_to_group_map.keys()) - scaled_grads}] are unscaled."
assert scaled_grads == set(self._grad_name_to_group_map.keys()), (
f"Unexpected: gradients [{set(self._grad_name_to_group_map.keys()) - scaled_grads}] are unscaled."
)

def _could_be_overlap(self):
# NOTE current different nccl comm will use different cuda stream
Expand Down Expand Up @@ -478,9 +478,9 @@ def _update_program(self, grad_groups):
# update allreduce & scale op
if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx]
assert (
scale_op.type == 'scale'
), f"should found scale op but found {scale_op}"
assert scale_op.type == 'scale', (
f"should found scale op but found {scale_op}"
)
scale_op._rename_input(
scale_op.input_arg_names[0], group.coalesce_var.name
)
Expand Down Expand Up @@ -524,9 +524,9 @@ def _update_program(self, grad_groups):
+ group.remove_scale_op_indices
)
for idx in sorted(remove_op_indices, reverse=True):
assert (
block.ops[idx].type in remove_op_types
), f"Unexpected: try to remove op {block.ops[idx]}"
assert block.ops[idx].type in remove_op_types, (
f"Unexpected: try to remove op {block.ops[idx]}"
)
block._remove_op(idx, False)

# insert coalesce op
Expand Down Expand Up @@ -753,9 +753,9 @@ def add(self, grad_var, ring_id, i):
grad_op_idx -= 1

grad_op = self.ops[grad_op_idx]
assert (
grad_var.name in grad_op.output_arg_names
), f"grad [{grad_var.name}] should be output of {grad_op}"
assert grad_var.name in grad_op.output_arg_names, (
f"grad [{grad_var.name}] should be output of {grad_op}"
)
self.coalesce_op_idx = grad_op_idx

def finalize(self):
Expand Down
Loading
Loading