-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[3D-parallel] Reformat pipeline parallel #31786
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
outputs={'Out': [sync_var]}, | ||
attrs={ | ||
'ring_id': global_ring_id, | ||
'use_calc_stream': True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need sync calc stream
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add it in next pr.
origin_param = origin_block.vars[op_role_var[i]] | ||
if origin_param.is_distributed: | ||
continue | ||
if offset == idx: | ||
offset += 1 | ||
if not add_sync_calc_stream: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if c_allreduce_sum use calc_stream, this sync_op is unnecessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, i'll remove it in next pr.
@@ -123,7 +123,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): | |||
outputs={"Out": out_var}, | |||
attrs={ | |||
"in_dtype": in_var.dtype, | |||
"out_dtype": out_var.dtype | |||
"out_dtype": out_var.dtype, | |||
"op_device": op.attr("op_device") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
若cast为fp32->fp16,则可以考虑设置cast的"op_device"为prev_op的"op_device"属性。这样如果添加(send, recv) op,则会cast -- (send, recv) --> op这样插入,(send, recv)传输的则是fp16的输出。
当然若cast为fp16->fp32,则按照当前的设置就好。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按照评论设计可以降低通信开销。这个可以作为一个优化点后续考虑。
attrs={ | ||
"in_dtype": target_var.dtype, | ||
"out_dtype": cast_var.dtype, | ||
"op_device": op.attr("op_device") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
@@ -3937,6 +4030,11 @@ def _find_post_op(self, ops, cur_op, var_name): | |||
var_name as output. | |||
var_name (string): Variable name. | |||
""" | |||
# To skip the cast op added by amp which has no op_device set |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_device is already added in amp cast, is this also needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在大模型上验证如果没有用到这段代码的话,在下一个pr删除。
attrs={ | ||
self._op_device_key: prev_device, | ||
self._op_role_key: op_role, | ||
'use_calc_stream': True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可优化点:改成comm_stream,如果是forward,可以在反向某个点加一个sync。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix it in next pr。
'dtype': var.dtype, | ||
self._op_device_key: cur_device, | ||
self._op_role_key: op_role, | ||
'use_calc_stream': True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续可优化点:移到前面,overlap。
'shape': merged_param_grad_var.shape, | ||
'dtype': merged_param_grad_var.dtype, | ||
'value': float(0), | ||
# a trick to run this op once per mini-batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more detailed comment is required
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix it in next pr.
@@ -4444,36 +4647,68 @@ def _process_persistable_vars_in_multi_sections(self, main_program, | |||
'out_shape': read_block.var(var_name).shape, | |||
'dtype': read_block.var(var_name).dtype, | |||
self._op_device_key: read_device, | |||
'use_calc_stream': True, | |||
'use_calc_stream': False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后一个op是sync_comm,可以直接改成use_calc_stream=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix it in next pr.
place_list.append(core.CUDAPlace(local_rank)) | ||
for dev in device_list: | ||
dev_index = int(dev.split(":")[1]) | ||
place_list.append(core.CUDAPlace(dev_index % 8)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
% 8 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix it in next pr by using fixed value 0.
PR types
Others
PR changes
Others
Describe