Skip to content

Commit

Permalink
[AutoParallel] Auto Trans PP to VPP (PaddlePaddle#60467)
Browse files Browse the repository at this point in the history
* [AutoParallel] Auto Trans PP to VPP

* add comment
  • Loading branch information
zhaoyinglia authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent baaa9e3 commit 217db4c
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 64 deletions.
189 changes: 146 additions & 43 deletions python/paddle/distributed/auto_parallel/static/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,30 +1057,51 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id):
dist_op = self._dist_context.get_dist_op_for_program(op)
dist_op.dist_attr.chunk_id = chunk_id
for name in op.input_arg_names + op.output_arg_names:
var = block._find_var_recursive(name)
if "lod_tensor_blocking_queue" in name:
continue
if name not in var_to_chunk_id:
op_dist_attr = (
self._dist_context.get_op_dist_attr_for_program(op)
var = block._find_var_recursive(name)
dist_tensor = (
self._dist_context.get_dist_tensor_for_program(var)
)
tensor_dist_attr = (
self._dist_context.get_tensor_dist_attr_for_program(var)
if (
dist_op.dist_attr.process_mesh
== dist_tensor.dist_attr.process_mesh
):
dist_tensor.dist_attr.chunk_id = chunk_id
var_to_chunk_id[var.name] = chunk_id

def set_process_mesh(block, op, process_mesh, var_to_process_mesh):
dist_op = self._dist_context.get_dist_op_for_program(op)
for name in op.input_arg_names:
if name not in var_to_process_mesh:
var = block._find_var_recursive(name)
dist_tensor = (
self._dist_context.get_dist_tensor_for_program(var)
)
if (
op_dist_attr.process_mesh
== tensor_dist_attr.process_mesh
dist_op.dist_attr.process_mesh
== dist_tensor.dist_attr.process_mesh
):
tensor_dist_attr.chunk_id = op_dist_attr.chunk_id
var_to_chunk_id[var.name] = op_dist_attr.chunk_id
dist_tensor.dist_attr.process_mesh = process_mesh
var_to_process_mesh[var.name] = process_mesh
for name in op.output_arg_names:
if name not in var_to_process_mesh:
var = block._find_var_recursive(name)
dist_tensor = (
self._dist_context.get_dist_tensor_for_program(var)
)
dist_tensor.dist_attr.process_mesh = process_mesh
var_to_process_mesh[var.name] = process_mesh
dist_op.dist_attr.process_mesh = process_mesh

if (
not self._dist_context.strategy
or not self._dist_context.strategy.pipeline.enable
):
return

pp_degree = get_pp_degree(self._dist_context)
pp_degree, sub_process_meshes = get_pp_degree(self._dist_context)
vpp_degree = self._dist_context.strategy.pipeline.vpp_degree
seg_method = self._dist_context.strategy.pipeline.vpp_seg_method
schedule_mode = self._dist_context.strategy.pipeline.schedule_mode
Expand All @@ -1099,8 +1120,11 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id):
block = serial_main_program.global_block()
ops = block.ops

# 1. search seg_method in op's struct_name, and get all ops of segments
seg_op_deps = collections.OrderedDict()
# Step1: search seg_method in op's struct_name
# 1. get op_idx of each segment
# 2. get process_mesh or each segment
seg_op_deps = collections.OrderedDict() # struct_name -> [idx]
seg_op_mesh = collections.OrderedDict() # struct_name -> process_mesh
regex = re.compile(seg_method, re.IGNORECASE)
for i, op in enumerate(ops):
struct_name = op.struct_name
Expand All @@ -1109,59 +1133,93 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id):
continue

struct_name = struct_name[m.start(0) :].split("/")[0]
dist_op = self._dist_context.get_dist_op_for_program(op)
if struct_name not in seg_op_deps:
seg_op_deps[struct_name] = [i]
seg_op_mesh[struct_name] = dist_op.dist_attr.process_mesh
else:
assert (
seg_op_deps[struct_name][-1] + 1 == i
), "The segment's ops should be continuous."
pre_op = ops[seg_op_deps[struct_name][-1]]
pre_dist_op = self._dist_context.get_dist_op_for_program(pre_op)
dist_op = self._dist_context.get_dist_op_for_program(op)
pre_mesh = seg_op_mesh[struct_name]
assert (
pre_dist_op.dist_attr.process_mesh
== dist_op.dist_attr.process_mesh
pre_mesh == dist_op.dist_attr.process_mesh
), "The segment's ops should have same process_mesh."
seg_op_deps[struct_name].extend([i])

# the num of chunk is equal to vpp_degree
num_parts = pp_degree * vpp_degree
num_chunks = pp_degree * vpp_degree
assert (
len(seg_op_deps.keys()) % num_parts == 0
), "number of layers[{}] ({}) should be devided by part number ({}).".format(
seg_method, len(seg_op_deps.keys()), num_parts
len(seg_op_deps) % num_chunks == 0
), "The number of layers[{}] ({}) should be devided by part number ({}).".format(
seg_method, len(seg_op_deps), num_chunks
)

part_size = len(seg_op_deps.keys()) // vpp_degree
# Step2: analysis whether the pp_stage is non-decreasing among segments
# 1. if non_decreasing is True, the ops' process_mesh will be changed by vpp strategy
# 2. if non_decreasing is False, the ops's process_mesh will not be changed.
non_decreasing = True
seg_pp_stages = [-1]
for seg_pm in seg_op_mesh.values():
assert seg_pm in sub_process_meshes
pp_stage = sub_process_meshes.index(seg_pm)
if seg_pp_stages[-1] > pp_stage:
non_decreasing = False
break
seg_pp_stages.append(pp_stage)

# 2. get boundary index of each chunk
results = [0] * (vpp_degree + 1)
memory_counter = 0
result_idx = 1
for struct_name, idxs in seg_op_deps.items():
if not non_decreasing:
_logger.info("Cannot Use Auto VPP")
else:
_logger.info("Using Auto VPP")

# Step3: Get op index boundary, pp_stage, chunk_id, struct_names of each segment
seg_pp_stages = [i % pp_degree for i in range(num_chunks)]
seg_chunk_ids = [i // pp_degree for i in range(num_chunks)]
part_size = len(seg_op_deps) // num_chunks
segment_struct_names = []
segment_parts = [0] * (num_chunks + 1)
memory_counter, seg_idx = 0, 1
struct_name = []
for name, idxs in seg_op_deps.items():
struct_name.append(name)
memory_counter += 1
if memory_counter == part_size:
results[result_idx] = idxs[-1] + 1
result_idx += 1
memory_counter = 0
results[vpp_degree] = len(ops)
segment_parts[seg_idx] = idxs[-1] + 1
memory_counter, seg_idx = 0, seg_idx + 1
segment_struct_names.append(struct_name)
struct_name = []
segment_parts[num_chunks] = len(ops)

# 3. set right chunk_id for each op
# Step4: set right chunk_id and process_mesh for each op and var
var_to_chunk_id = {}
for chunk_id in range(len(results) - 1):
start_idx = results[chunk_id]
end_idx = results[chunk_id + 1]
var_to_process_mesh = {}
for seg_id in range(len(segment_parts) - 1):
start_idx = segment_parts[seg_id]
end_idx = segment_parts[seg_id + 1]
pp_stage = seg_pp_stages[seg_id]
chunk_id = seg_chunk_ids[seg_id]
process_mesh = sub_process_meshes[pp_stage]
struct_names = segment_struct_names[seg_id]
seg_op_idx = []
for name in struct_names:
seg_op_idx.extend(seg_op_deps[name])

_logger.info(
"[chunk_{}] start op: [{}]: [{}] [{}]".format(
"stage=[{}], chunk_id=[{}], layer_name=[{}]".format(
pp_stage,
chunk_id,
struct_names,
)
)
_logger.info(
"start op: [{}]: [{}] [{}]".format(
ops[start_idx].type,
ops[start_idx].input_arg_names,
ops[start_idx].output_arg_names,
)
)
_logger.info(
"[chunk_{}] end op: [{}]: [{}] [{}]".format(
chunk_id,
"end op: [{}]: [{}] [{}]".format(
ops[end_idx - 1].type,
ops[end_idx - 1].input_arg_names,
ops[end_idx - 1].output_arg_names,
Expand All @@ -1173,9 +1231,28 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id):
if op.has_attr("sub_block"):
block_id = op.attr('sub_block').id
sub_block = serial_main_program.blocks[block_id]
for op in sub_block.ops:
set_chunk_id(sub_block, op, chunk_id, var_to_chunk_id)
if non_decreasing and idx in seg_op_idx:
set_process_mesh(
block, op, process_mesh, var_to_process_mesh
)
set_chunk_id(block, op, chunk_id, var_to_chunk_id)

for sub_op in sub_block.ops:
if non_decreasing and idx in seg_op_idx:
set_process_mesh(
sub_block,
sub_op,
process_mesh,
var_to_process_mesh,
)
set_chunk_id(
sub_block, sub_op, chunk_id, var_to_chunk_id
)
else:
if non_decreasing and idx in seg_op_idx:
set_process_mesh(
block, op, process_mesh, var_to_process_mesh
)
set_chunk_id(block, op, chunk_id, var_to_chunk_id)

def _update_dist_attr_for_dp(self):
Expand Down Expand Up @@ -1915,8 +1992,34 @@ def infer_backward_op_partial_status(
grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping
)
grad_op_dist_attr.process_mesh = ref_fwd_process_mesh
grad_op_dist_attr.chunk_id = ref_fwd_chunk_id
# NOTE(zhaoyingli):
# The sum op is used to accmulate the grads' value of the same forward var,
# sum op's chunk_id is same with the last op which generate the grad.
ref_chunk_id = None
ref_process_mesh = None
for pre_idx in range(
idx - 1, first_backward_op_idx + 1, -1
):
pre_grad_op = ops[pre_idx]
inter_arg_name = list(
set(pre_grad_op.output_arg_names)
& set(grad_op.input_arg_names)
)
if len(inter_arg_name) > 0:
pre_op_dist_attr = (
self._dist_context.get_op_dist_attr_for_program(
pre_grad_op
)
)
ref_chunk_id = pre_op_dist_attr.chunk_id
ref_process_mesh = pre_op_dist_attr.process_mesh
break
assert (
ref_chunk_id is not None
and ref_process_mesh is not None
)
grad_op_dist_attr.process_mesh = ref_process_mesh
grad_op_dist_attr.chunk_id = ref_chunk_id
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr
)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2335,7 +2335,7 @@ def get_pp_degree(dist_context):
for idx in reversed(global_pm_idx):
process_meshes.pop(idx)

return len(process_meshes)
return len(process_meshes), process_meshes


def get_pp_stage(dist_context, rank):
Expand Down
Loading

0 comments on commit 217db4c

Please sign in to comment.