Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3D-parallel] Reformat pipeline parallel #31786

Merged
merged 9 commits into from
Mar 26, 2021

Conversation

sandyhouse
Copy link

@sandyhouse sandyhouse commented Mar 22, 2021

PR types

Others

PR changes

Others

Describe

  1. Reformat the implementation of pipeline parallelism
  2. Unify the implementation of group initialization for creating multiple communication groups.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@CLAassistant
Copy link

CLAassistant commented Mar 22, 2021

CLA assistant check
All committers have signed the CLA.

Copy link
Contributor

@wangxicoding wangxicoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

outputs={'Out': [sync_var]},
attrs={
'ring_id': global_ring_id,
'use_calc_stream': True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need sync calc stream

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add it in next pr.

origin_param = origin_block.vars[op_role_var[i]]
if origin_param.is_distributed:
continue
if offset == idx:
offset += 1
if not add_sync_calc_stream:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if c_allreduce_sum use calc_stream, this sync_op is unnecessary

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, i'll remove it in next pr.

@@ -123,7 +123,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype
"out_dtype": out_var.dtype,
"op_device": op.attr("op_device")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

若cast为fp32->fp16,则可以考虑设置cast的"op_device"为prev_op的"op_device"属性。这样如果添加(send, recv) op,则会cast -- (send, recv) --> op这样插入,(send, recv)传输的则是fp16的输出。
当然若cast为fp16->fp32,则按照当前的设置就好。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照评论设计可以降低通信开销。这个可以作为一个优化点后续考虑。

attrs={
"in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype,
"op_device": op.attr("op_device")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

@@ -3937,6 +4030,11 @@ def _find_post_op(self, ops, cur_op, var_name):
var_name as output.
var_name (string): Variable name.
"""
# To skip the cast op added by amp which has no op_device set
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_device is already added in amp cast, is this also needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在大模型上验证如果没有用到这段代码的话,在下一个pr删除。

attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可优化点:改成comm_stream,如果是forward,可以在反向某个点加一个sync。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it in next pr。

'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续可优化点:移到前面,overlap。

'shape': merged_param_grad_var.shape,
'dtype': merged_param_grad_var.dtype,
'value': float(0),
# a trick to run this op once per mini-batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more detailed comment is required

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it in next pr.

@@ -4444,36 +4647,68 @@ def _process_persistable_vars_in_multi_sections(self, main_program,
'out_shape': read_block.var(var_name).shape,
'dtype': read_block.var(var_name).dtype,
self._op_device_key: read_device,
'use_calc_stream': True,
'use_calc_stream': False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后一个op是sync_comm,可以直接改成use_calc_stream=True

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it in next pr.

place_list.append(core.CUDAPlace(local_rank))
for dev in device_list:
dev_index = int(dev.split(":")[1])
place_list.append(core.CUDAPlace(dev_index % 8))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

% 8 ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it in next pr by using fixed value 0.

@sandyhouse sandyhouse merged commit c3974d0 into PaddlePaddle:develop Mar 26, 2021
@sandyhouse sandyhouse deleted the reformat-pp branch March 8, 2022 09:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants