-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add no_sync in data parallel for dynamic graph #34740
Conversation
Thanks for your contribution! |
@@ -576,9 +578,19 @@ def _find_varbase(self, obj): | |||
return itertools.chain(*map(self._find_varbase, obj.values())) | |||
return [] | |||
|
|||
@contextmanager | |||
def no_sync(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document description, api interface description, usage description
@@ -527,6 +527,7 @@ void Reducer::TraverseBackwardGraph( | |||
void Reducer::PrepareForBackward( | |||
const std::vector<std::shared_ptr<imperative::VarBase>> &outputs) { | |||
VLOG(3) << "after forward, then reset count for backward."; | |||
grad_need_hooks_ = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a note to explain the role of this parameter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Already added notes in Line 212~215 of paddle/fluid/imperative/reducer.h
@@ -907,6 +912,7 @@ void Reducer::ProcessUnusedDenseVars() { | |||
|
|||
// 3. create grad var base or get grad var base | |||
auto grad_var_base_tmp = dest_var_base->MutableGradVarBase(); | |||
grad_var_base_tmp->SharedVar()->SetIsEmpty(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain the reason for this modification
@@ -0,0 +1,175 @@ | |||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2018?->2021
@@ -0,0 +1,176 @@ | |||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -0,0 +1,179 @@ | |||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2021
@@ -0,0 +1,100 @@ | |||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2021
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for API
batch_num = 1000 | ||
|
||
|
||
class SimpleNet(fluid.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测中,也需要使用 paddle.nn.Layer, 非fluid下面的API。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的。
return x | ||
|
||
|
||
class TestNoSyncControlFlow(TestParallelDyGraphRunnerBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的单测是在 TestParallelDyGraphRunnerBase 里 check 加no_sync之后的梯度正确性嘛?下面代码没有看到怎么check的正确性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no_sync的单测是通过重写TestParallelDyGraphRunnerBase类中get_model、run_one_loop、run_trainer、run_trainer_with_spawn这四个函数来完成的,check加no_sync之后的梯度正确性的部分已由框架实现,是不需要自己实现的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe
Add no_sync in data parallel for dynamic graph
1、接口形式:
2、使用文档:
中文:
English Ver:
3、功能支持:
no_sync支持动态图数据并行中暂停梯度同步,支持accum_gradient;
在梯度累加循环中减少不必要的同步操作,不影响精度且一定程度上提升性能。
4、测试方案:
面对复杂的组网情况,实现no_sync后为每种case均提供单测,进行单卡与多卡、多卡与多卡的精度对比。
组网情况包含unused_params、复杂控制流等,参考自PR:#32826
自测结果:no_sync能够cover上述全部case,精度与单卡运行时无差。