-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
heter for collective #37613
heter for collective #37613
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看着都是动态图的,不能支持静态图吧?
@@ -176,6 +176,11 @@ void GLOOParallelContext::AllReduce(const framework::SelectedRows &src, | |||
} | |||
} | |||
|
|||
void GLOOParallelContext::BroadCast(framework::Variable *src, int ring_id) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broadcast?Broadcast是一个单词。另外,这个接口没有实现,为什么还要添加这个接口呢?
@@ -47,6 +47,8 @@ class GLOOParallelContext : public ParallelContext { | |||
framework::Variable* dst, int ring_id, | |||
bool use_calc_stream) override; | |||
|
|||
void BroadCast(framework::Variable* src, int ring_id) override; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 同上。
- gloo接口为什么需要传入ring_id?
@@ -158,6 +158,29 @@ void HCCLParallelContext::AllReduceByStream(const framework::Variable &src, | |||
} | |||
} | |||
|
|||
void HCCLParallelContext::BroadCast(framework::Variable *src, int ring_id) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BroadCast -> Broadcast?
@@ -127,6 +135,20 @@ void NCCLParallelContext::AllReduceByStream(const framework::Variable &src, | |||
AllReduce(src, dst, strategy_, ring_id, use_calc_stream); | |||
} | |||
|
|||
void NCCLParallelContext::BroadCast(framework::Variable *src, int ring_id) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BroadCast -> Broadcast?
@@ -60,6 +60,8 @@ class NCCLParallelContext : public ParallelContext { | |||
framework::Variable* dst, int ring_id, | |||
bool use_calc_stream) override; | |||
|
|||
void BroadCast(framework::Variable* src, int ring_id) override; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
@@ -56,6 +56,8 @@ class ParallelContext { | |||
framework::Variable* dst, int ring_id, | |||
bool use_calc_stream) = 0; | |||
|
|||
virtual void BroadCast(framework::Variable* src, int ring_id) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
@@ -41,6 +42,9 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { | |||
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) | |||
DivNRanks(tensor, nranks, context); | |||
#endif | |||
} else if (platform::is_npu_place(tensor->place())) { | |||
// TODO(kuizhiqing) | |||
VLOG(4) << "divnrank for npu not support yet"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Abort?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for const_cast
PR types
New features
PR changes
Others
Describe
Heterogenous mix training represents the model training with heterogenous hardwares. Dygraph mode is only supported now. GPU/NPU/XPU are targeting devices for this prototype work.
The basic idea is very similar as the use of hierarchical communication topology. The low layer reduce the data within each node, while the upper layer reduce across all global nodes.