Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ repos:

# | python/paddle/de.+

# | python/paddle/distributed/a.+
| python/paddle/distributed/a.+

# | python/paddle/distributed/[b-e].+

Expand Down Expand Up @@ -133,7 +133,7 @@ repos:

| python/paddle/de.+

| python/paddle/distributed/a.+
# | python/paddle/distributed/a.+

| python/paddle/distributed/[b-e].+

Expand Down
104 changes: 57 additions & 47 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,18 +298,18 @@ def shard_tensor(
stop_gradient = getattr(data, "stop_gradient", True)

if paddle.framework.in_pir_mode():
assert isinstance(
data, (type(None), pir.Value)
), "input tensor is not pir value."
assert (
data.is_dense_tensor_type()
), "shard_tensor() input data only supported dense tensor type right."
assert isinstance(data, (type(None), pir.Value)), (
"input tensor is not pir value."
)
assert data.is_dense_tensor_type(), (
"shard_tensor() input data only supported dense tensor type right."
)
tensor = data
else:
if isinstance(data, EagerParamBase) and not data._is_initialized():
assert (
data._init_func is not None
), "Get an uninitialized param with an unregistered init_func."
assert data._init_func is not None, (
"Get an uninitialized param with an unregistered init_func."
)
tensor = data
elif isinstance(data, paddle.Tensor) and dtype is None:
# if place is not equal, it is handled in paddle.Tensor()
Expand Down Expand Up @@ -620,7 +620,9 @@ def forward(
)
assert check_placements_equal(
global_placements, dist_tensor.placements
), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})."
), (
f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})."
)
local_shape = _cal_local_shape(
dist_tensor.shape, global_mesh, global_placements
)
Expand Down Expand Up @@ -890,9 +892,9 @@ def reshard(
elif in_pir_mode():
return paddle._C_ops.reshard(dist_tensor, mesh, placements)
else:
assert isinstance(
dist_tensor, Variable
), f"in dy2static mode, reshard's input should be Variable, but got [{dist_tensor}]"
assert isinstance(dist_tensor, Variable), (
f"in dy2static mode, reshard's input should be Variable, but got [{dist_tensor}]"
)
sharding_specs = get_shard_spec(mesh, placements, dist_tensor.ndim)
main_program = default_main_program()
default_dist_ctx = get_default_distributed_context()
Expand Down Expand Up @@ -1113,12 +1115,14 @@ def is_dist_tensor(tensor) -> bool:

class _ShardOptimizer(Optimizer):
def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
assert (
optimizer is not None
), "The argument `optimizer` cannot be empty."
assert optimizer is not None, (
"The argument `optimizer` cannot be empty."
)
assert isinstance(
optimizer, (paddle.optimizer.AdamW, paddle.optimizer.SGD)
), "`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now."
), (
"`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now."
)

# self.target_block = (
# paddle.base.framework.default_main_program().global_block()
Expand Down Expand Up @@ -1146,7 +1150,9 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
assert isinstance(
self._shard_fn,
(_ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3),
), "shard_fn must be an instance of one of: _ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3"
), (
"shard_fn must be an instance of one of: _ShardingStage0, ShardingStage1, ShardingStage2, ShardingStage3"
)

if isinstance(
self._shard_fn, (ShardingStage1, ShardingStage2, ShardingStage3)
Expand Down Expand Up @@ -1219,7 +1225,9 @@ def _set_and_check_sharding_prop_from_param(self):
else:
assert (
mesh.dim_size(self._sharding_axis) == self._sharding_degree
), "The sharding degree of all parameters must be equal currently."
), (
"The sharding degree of all parameters must be equal currently."
)

def _shard_accumulator(self, param):
# Note (luchang): Some models may have parameters whose first dimension is 1,
Expand Down Expand Up @@ -1988,9 +1996,9 @@ def shard_master_weight(
)
if isinstance(master_weight, pir.Value):
data_op = master_weight.get_defining_op()
assert (
data_op.name() == "pd_op.data"
), "The master weight must be a result of data op."
assert data_op.name() == "pd_op.data", (
"The master weight must be a result of data op."
)
dim_map, partial_status = to_dim_map(
placements, len(master_weight.shape)
)
Expand Down Expand Up @@ -3254,9 +3262,9 @@ def state_dict(
suffix = _get_suffix(param, fused_param)
if suffix is not None:
value = dist_state_dict[param]
assert (
value.is_dist()
), f"key {param} value:{value} is not a dist tensor."
assert value.is_dist(), (
f"key {param} value:{value} is not a dist tensor."
)
mesh = value.process_mesh
placements = value.placements
if "_pow_acc" in suffix:
Expand Down Expand Up @@ -3328,12 +3336,12 @@ def build_distributed_tensor(local_tensor, dist_attr):
)
if not isinstance(local_tensor, paddle.Tensor):
local_tensor = paddle.Tensor(local_tensor)
assert isinstance(
local_tensor, paddle.Tensor
), f"local tensor:{local_tensor} type {type(local_tensor)} is not paddle.Tensor."
assert len(local_tensor.shape) == len(
dist_attr["dims_mapping"]
), f"local tensor shape {local_tensor.shape} not equal to dims_mapping shape {dist_attr['dims_mapping']}."
assert isinstance(local_tensor, paddle.Tensor), (
f"local tensor:{local_tensor} type {type(local_tensor)} is not paddle.Tensor."
)
assert len(local_tensor.shape) == len(dist_attr["dims_mapping"]), (
f"local tensor shape {local_tensor.shape} not equal to dims_mapping shape {dist_attr['dims_mapping']}."
)
global_shape = local_tensor.shape
mesh = ProcessMesh(
np.array(dist_attr["process_group"]).reshape(
Expand All @@ -3343,18 +3351,18 @@ def build_distributed_tensor(local_tensor, dist_attr):
)
placements = to_placements(dist_attr["dims_mapping"], mesh)
dist_tensor = dtensor_from_local(local_tensor, mesh, placements)
assert (
dist_tensor._local_value().shape == local_tensor.shape
), f"local tensor shape {dist_tensor._local_value().shape} not equal to local_tensor.shape:{local_tensor.shape}"
assert dist_tensor._local_value().shape == local_tensor.shape, (
f"local tensor shape {dist_tensor._local_value().shape} not equal to local_tensor.shape:{local_tensor.shape}"
)
paddle.assign(local_tensor, dist_tensor._local_value())
return dist_tensor

global_state_dict = {}
with paddle.base.dygraph.guard():
for var_name, tensor in local_state_dict.items():
assert (
var_name in dist_attrs
), f"var {var_name} not in dist attrs:{dist_attrs}."
assert var_name in dist_attrs, (
f"var {var_name} not in dist attrs:{dist_attrs}."
)
global_state_dict[var_name] = build_distributed_tensor(
tensor, dist_attrs[var_name]
)
Expand Down Expand Up @@ -3386,7 +3394,9 @@ def set_state_dict(self, state_dict: dict[str, Tensor]) -> None:
k
].process_mesh or check_placements_equal(
v.placements, cur_v.placements
), f"process_mesh:{v.process_mesh} != {cur_v.process_mesh} or placements:{v.placements} != {cur_v.placements} not match"
), (
f"process_mesh:{v.process_mesh} != {cur_v.process_mesh} or placements:{v.placements} != {cur_v.placements} not match"
)
param_name = (
self._structured_to_parameter_name[k]
if k in self._structured_to_parameter_name
Expand Down Expand Up @@ -3472,9 +3482,9 @@ def _get_shard_stage1_optimizer(self):
):
optimizer = optimizer._optimizer

assert isinstance(
optimizer, ShardingOptimizerStage1
), "The optimizer should be ShardingOptimizerStage1 when stage1 tensor fusion is enabled."
assert isinstance(optimizer, ShardingOptimizerStage1), (
"The optimizer should be ShardingOptimizerStage1 when stage1 tensor fusion is enabled."
)

return optimizer

Expand All @@ -3485,9 +3495,9 @@ def _convert_state_dict_tensor_fusion(self, state_dict, optimizer_function):
else False
)

assert (
enable_tensor_fusion
), "Can only convert state_dict when tensor fusion is enabled."
assert enable_tensor_fusion, (
"Can only convert state_dict when tensor fusion is enabled."
)
optimizer = self._get_shard_stage1_optimizer()
assert optimizer is not None, "The optimizer should not be None."

Expand Down Expand Up @@ -3690,9 +3700,9 @@ def to_static(
# Deduce sharding degree for static
# Note: Because limitation of architecture, we need to ensure that
# all parameters are sharded by the same mesh axis
assert (
sharding_degree is not None
), "Sharding degree can not be None."
assert sharding_degree is not None, (
"Sharding degree can not be None."
)

if isinstance(shard_fn, ShardingStage1):
strategy.sharding.enable = True
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/distributed/auto_parallel/auto_dp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

def _fake_replicate_grad_to_partial(grad, partial_axis):
new_placements = grad.placements
assert (
new_placements[partial_axis] == dist.Replicate()
), "when reshard fake replicated grad to partial, the partial axis of grad should be Replicate"
assert new_placements[partial_axis] == dist.Replicate(), (
"when reshard fake replicated grad to partial, the partial axis of grad should be Replicate"
)

new_placements[partial_axis] = dist.Partial(dist.ReduceType.kRedSum)

Expand Down
40 changes: 21 additions & 19 deletions python/paddle/distributed/auto_parallel/high_level_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def __init__(self):

def cost_model(matched_programs, device_num, node_num):
# TODO(jeff41404): multi-node will be supported later
assert (
node_num == 1
), "we only support single node now, multi-node will be supported later"
assert node_num == 1, (
"we only support single node now, multi-node will be supported later"
)

# TODO(jeff41404): will evaluate the best combination of parallel strategies
# based on cost_model and return global_mesh, currently using pre-defined parallel strategy
Expand Down Expand Up @@ -224,7 +224,9 @@ def record_program_ops_post_hook(layer, inputs, outputs):
assert (
layer._op_recorder.start >= 0
and layer._op_recorder.is_valid is True
), f"{layer._full_name} has not recorded the start of the corresponding ops before"
), (
f"{layer._full_name} has not recorded the start of the corresponding ops before"
)
end = len(default_main_program().global_block().ops)
# some layers, such as rotary_embedding, will not add new ops to program
# assert end > layer._op_recorder.start, f"{layer._full_name} has not added new ops to the program"
Expand Down Expand Up @@ -754,19 +756,19 @@ def to_distributed(
for pattern_name, matched_patterns in results.items():
# process one pattern
pattern_ops_dist_infos = get_pattern(pattern_name).ops_dist_infos
assert (
pattern_ops_dist_infos is not None
), f"{pattern_name} does not contain ops_dist_infos, cannot reshard, please check"
assert pattern_ops_dist_infos is not None, (
f"{pattern_name} does not contain ops_dist_infos, cannot reshard, please check"
)
processed_patterns = []
for matched_pattern in matched_patterns:
# convert pattern_ops_dist_infos to program_ops_dist_infos
program_ops_dist_infos = {}
for pattern_ops_id, op_dist_info in pattern_ops_dist_infos.items():
program_ops_id = []
for pattern_op_id in pattern_ops_id:
assert (
pattern_op_id in matched_pattern.keys()
), f"please check ops_dist_infos of {pattern_name}, {pattern_op_id} not in matched_pattern: {matched_pattern.keys()}"
assert pattern_op_id in matched_pattern.keys(), (
f"please check ops_dist_infos of {pattern_name}, {pattern_op_id} not in matched_pattern: {matched_pattern.keys()}"
)
program_op_id = matched_pattern[pattern_op_id]
program_ops_id.append(program_op_id)
program_ops_dist_infos[tuple(program_ops_id)] = op_dist_info
Expand All @@ -789,9 +791,9 @@ def to_distributed(
if with_mp:
num_hidden_layers = len(matched_programs[DECODER_LAYER_NAME])
for pattern_name, processed_patterns in matched_programs.items():
assert (
len(processed_patterns) == num_hidden_layers
), "transformer patterns matched are incomplete"
assert len(processed_patterns) == num_hidden_layers, (
"transformer patterns matched are incomplete"
)
for idx, processed_pattern in enumerate(processed_patterns):
local_mesh = mesh
if with_pp:
Expand All @@ -801,9 +803,9 @@ def to_distributed(
local_mesh = mesh.get_mesh_with_dim("pp", pp_stage_id)

for program_ops_id, dist_infos in processed_pattern.items():
assert (
program_ops_id in ops_id_to_layer.keys()
), f"program_ops: {program_ops_id} is not corresponding to a dynamic layer"
assert program_ops_id in ops_id_to_layer.keys(), (
f"program_ops: {program_ops_id} is not corresponding to a dynamic layer"
)
dynamic_layer = ops_id_to_layer[program_ops_id]
mesh_num_dims = len(local_mesh.shape)
sharding_info = dist_infos.get_dist_info(mesh_num_dims)
Expand Down Expand Up @@ -832,9 +834,9 @@ def to_distributed(

if decoder_layers is not None:
num_decoder_blocks = len(decoder_layers)
assert (
num_decoder_blocks == num_hidden_layers
), f"decoder pattern layers matched are incomplete, num_decoder_blocks: {num_decoder_blocks} should be equal to num_hidden_layers: {num_hidden_layers}"
assert num_decoder_blocks == num_hidden_layers, (
f"decoder pattern layers matched are incomplete, num_decoder_blocks: {num_decoder_blocks} should be equal to num_hidden_layers: {num_hidden_layers}"
)

pp_degree = mesh.get_dim_size("pp")
num_blocks_per_stage = num_decoder_blocks // pp_degree
Expand Down
36 changes: 18 additions & 18 deletions python/paddle/distributed/auto_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
"""

if process_mesh is not None:
assert isinstance(
process_mesh, core.ProcessMesh
), f"Argument process_mesh {process_mesh} is not an instance of ProcessMesh"
assert isinstance(process_mesh, core.ProcessMesh), (
f"Argument process_mesh {process_mesh} is not an instance of ProcessMesh"
)
else:
process_mesh = get_current_process_mesh()
assert (
process_mesh is not None
), "Specify the process mesh argument or use ProcessMesh context manager first."
assert isinstance(
shard_spec, list
), f"Argument shard_spec {shard_spec} is not an instance of list"
assert process_mesh is not None, (
"Specify the process mesh argument or use ProcessMesh context manager first."
)
assert isinstance(shard_spec, list), (
f"Argument shard_spec {shard_spec} is not an instance of list"
)
if isinstance(x, str):
x = (
paddle.static.default_main_program()
Expand All @@ -100,9 +100,9 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
else:
tensor_shape = serial_tensor.shape
if shard_spec is not None:
assert verify_shard_spec(
shard_spec, tensor_shape, process_mesh
), f"For tensor {serial_tensor.name}, shard_spec {shard_spec} is invalid with tensor_shape {tensor_shape} and process_mesh {process_mesh}."
assert verify_shard_spec(shard_spec, tensor_shape, process_mesh), (
f"For tensor {serial_tensor.name}, shard_spec {shard_spec} is invalid with tensor_shape {tensor_shape} and process_mesh {process_mesh}."
)
dist_tensor.dist_attr.dims_mapping = convert_to_dims_mapping(
shard_spec, process_mesh
)
Expand Down Expand Up @@ -164,14 +164,14 @@ def shard_op(
"""

if process_mesh is not None:
assert isinstance(
process_mesh, ProcessMesh
), f"Argument process_mesh {process_mesh} is not an instance of ProcessMesh"
assert isinstance(process_mesh, ProcessMesh), (
f"Argument process_mesh {process_mesh} is not an instance of ProcessMesh"
)
else:
process_mesh = get_current_process_mesh()
assert (
process_mesh is not None
), "Specify the process mesh argument or use ProcessMesh context manager first."
assert process_mesh is not None, (
"Specify the process mesh argument or use ProcessMesh context manager first."
)
in_dims_mappings = []
if in_shard_specs is not None:
assert all(
Expand Down
Loading