-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add no_sync in data parallel for dynamic graph #34740
Changes from all commits
85d88b8
139bf61
07834f1
e971051
c3ac5ec
4a5cc9a
60c4b42
10a0840
1047383
bca069b
a700d7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -527,6 +527,7 @@ void Reducer::TraverseBackwardGraph( | |
void Reducer::PrepareForBackward( | ||
const std::vector<std::shared_ptr<imperative::VarBase>> &outputs) { | ||
VLOG(3) << "after forward, then reset count for backward."; | ||
grad_need_hooks_ = true; | ||
next_group_ = 0; | ||
std::for_each(groups_.begin(), groups_.end(), [](Group &group) { | ||
group.pending_ = group.variable_indices_.size(); | ||
|
@@ -599,6 +600,11 @@ void Reducer::AddDistHook(size_t var_index) { | |
"than %d, but it is %d", | ||
variable_locators_.size(), var_index)); | ||
|
||
// gradient synchronization is not required when grad_need_hooks_ is false. | ||
if (!grad_need_hooks_) { | ||
return; | ||
} | ||
|
||
VLOG(3) << "Var[" << var_index << "] [" | ||
<< vars_[var_index]->GradVarBase()->Name() | ||
<< "] arrived and triggered disthook"; | ||
|
@@ -692,8 +698,8 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) { | |
auto var_base = vars_[var_index]->GradVarBase(); | ||
auto tensor = | ||
var_base->MutableVar()->GetMutable<framework::LoDTensor>(); | ||
TensorCopy(*tensor, place_, *dev_ctx, &group_tensor); | ||
group_tensor.Resize({static_cast<int64_t>(length)}); | ||
group_tensor.ShareDataWith(*tensor).Resize( | ||
{static_cast<int64_t>(length)}); | ||
} else { | ||
group_tensor.Resize({static_cast<int64_t>(length)}); | ||
operators::math::set_constant(*dev_ctx, &group_tensor, 0.0); | ||
|
@@ -907,6 +913,10 @@ void Reducer::ProcessUnusedDenseVars() { | |
|
||
// 3. create grad var base or get grad var base | ||
auto grad_var_base_tmp = dest_var_base->MutableGradVarBase(); | ||
// NOTE(haohongxiang): Calling SetIsEmpty here is to make sure that | ||
// gradient accumulation can continue normally after clear_gradients() | ||
// especiall in cases including complex control flow. | ||
grad_var_base_tmp->SharedVar()->SetIsEmpty(false); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explain the reason for this modification |
||
|
||
// 4. set grad tensor | ||
auto *dest_grad_tensor = | ||
|
@@ -942,6 +952,7 @@ bool Reducer::HasGrad(size_t var_index) { | |
|
||
void Reducer::FinalizeBackward() { | ||
groups_need_finalize_ = false; | ||
grad_need_hooks_ = false; | ||
#ifdef PADDLE_WITH_XPU_BKCL | ||
{ | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
from collections import OrderedDict | ||
import itertools | ||
import warnings | ||
from contextlib import contextmanager | ||
|
||
import paddle | ||
from paddle.fluid import core | ||
|
@@ -483,6 +484,7 @@ def __init__(self, | |
|
||
self._layers = layers | ||
self.find_unused_parameters = find_unused_parameters | ||
self.grad_need_sync = True | ||
|
||
# NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy. | ||
# It just stores some environment variables, which can be constructed by | ||
|
@@ -576,9 +578,55 @@ def _find_varbase(self, obj): | |
return itertools.chain(*map(self._find_varbase, obj.values())) | ||
return [] | ||
|
||
@contextmanager | ||
def no_sync(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Document description, api interface description, usage description |
||
""" | ||
A context manager to stop gradient synchronization. Within no_sync(), | ||
gradients of parameters will only be accumulated on model and not | ||
synchronized util the first forward-backward out of this context. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
# required: distributed | ||
import paddle | ||
import paddle.nn as nn | ||
import paddle.distributed as dist | ||
|
||
class SimpleNet(nn.Layer): | ||
def __init__(self): | ||
super(SimpleNet, self).__init__() | ||
self._linear = nn.Linear(10, 1) | ||
|
||
def forward(self, x): | ||
return self._linear(x) | ||
|
||
dist.init_parallel_env() | ||
model = SimpleNet() | ||
dp_model = paddle.DataParallel(model) | ||
|
||
inputs_1 = paddle.randn([10, 10], 'float32') | ||
inputs_2 = paddle.ones([10, 10], 'float32') | ||
|
||
with dp_model.no_sync(): | ||
# gradients will not be synchronized | ||
dp_model(inputs_1).backward() | ||
|
||
# synchronization happens here | ||
dp_model(inputs_2).backward() | ||
|
||
""" | ||
tmp_grad_need_sync = self.grad_need_sync | ||
self.grad_need_sync = False | ||
try: | ||
yield | ||
finally: | ||
self.grad_need_sync = tmp_grad_need_sync | ||
|
||
def forward(self, *inputs, **kwargs): | ||
outputs = self._layers(*inputs, **kwargs) | ||
if self._strategy.nranks > 1 and framework._dygraph_tracer()._has_grad: | ||
if self._strategy.nranks > 1 and framework._dygraph_tracer( | ||
)._has_grad and self.grad_need_sync: | ||
self._reducer.prepare_for_backward( | ||
list(self._find_varbase(outputs))) | ||
return outputs | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import os | ||
import contextlib | ||
import unittest | ||
import numpy as np | ||
import six | ||
import pickle | ||
import random | ||
|
||
import paddle | ||
import paddle.fluid as fluid | ||
import paddle.distributed as dist | ||
import paddle.fluid.dygraph as dygraph | ||
from paddle.fluid import core | ||
from paddle.fluid.dygraph.nn import Linear | ||
from test_dist_base import print_to_err, print_to_out, runtime_main, TestParallelDyGraphRunnerBase | ||
|
||
seed = 90 | ||
RUN_STEP = 20 | ||
batch_size = 4 | ||
batch_num = 1000 | ||
|
||
|
||
class SimpleNet(fluid.Layer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 单测中,也需要使用 paddle.nn.Layer, 非fluid下面的API。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的。 |
||
def __init__(self): | ||
super(SimpleNet, self).__init__() | ||
self.net_a = Linear(input_dim=10, output_dim=20) | ||
self.net_b = Linear(input_dim=20, output_dim=5) | ||
self.net_c = Linear(input_dim=5, output_dim=10) | ||
|
||
def forward(self, x): | ||
x = self.net_a(x) | ||
x = self.net_b(x) | ||
x = self.net_c(x) | ||
return x | ||
|
||
|
||
class TestNoSync(TestParallelDyGraphRunnerBase): | ||
def get_model(self): | ||
model = SimpleNet() | ||
train_reader = paddle.batch( | ||
fake_sample_reader(), batch_size=batch_size, drop_last=True) | ||
optimizer = paddle.optimizer.SGD(learning_rate=0.001, | ||
parameters=model.parameters()) | ||
return model, train_reader, optimizer | ||
|
||
def run_one_loop(self, model, optimizer, batch): | ||
x_data = np.array([x for x in batch]) | ||
x_data = x_data.reshape((-1, 10)) | ||
x = paddle.to_tensor(x_data) | ||
out = model(x) | ||
loss = out.sum() / len(batch) | ||
return loss | ||
|
||
def run_trainer(self, args): | ||
if fluid.core.is_compiled_with_cuda(): | ||
device_id = int(os.getenv("FLAGS_selected_gpus", "0")) | ||
place = fluid.CUDAPlace(device_id) | ||
else: | ||
assert ("Only support CUDAPlace for now.") | ||
|
||
with fluid.dygraph.guard(place): | ||
fluid.default_startup_program().random_seed = seed | ||
fluid.default_main_program().random_seed = seed | ||
np.random.seed(seed) | ||
random.seed(seed) | ||
model, train_reader, opt = self.get_model() | ||
|
||
if args.update_method == "nccl2": | ||
dist.init_parallel_env() | ||
print_to_err( | ||
type(self).__name__, | ||
"begin to prepare context in dygraph with nccl2") | ||
if not args.find_unused_parameters: | ||
model = paddle.DataParallel( | ||
model, find_unused_parameters=False) | ||
else: | ||
model = paddle.DataParallel( | ||
model, find_unused_parameters=True) | ||
print_to_err(type(self).__name__, "model built in dygraph") | ||
out_losses = [] | ||
print_to_err(type(self).__name__, "begin to run dygraph training") | ||
for step_id, data in enumerate(train_reader()): | ||
data = self._get_data(data, args) | ||
if step_id == RUN_STEP: | ||
break | ||
if step_id % 3 != 0: | ||
if args.update_method == "nccl2": | ||
with model.no_sync(): | ||
loss = self.run_one_loop(model, opt, data) | ||
loss.backward() | ||
else: | ||
loss = self.run_one_loop(model, opt, data) | ||
loss.backward() | ||
else: | ||
loss = self.run_one_loop(model, opt, data) | ||
loss.backward() | ||
opt.minimize(loss) | ||
print_to_err( | ||
type(self).__name__, | ||
"loss at step %d: %f" % (step_id, loss.numpy())) | ||
out_losses.append(loss.numpy()) | ||
|
||
if not args.accumulate_gradient: | ||
model.clear_gradients() | ||
print_to_out(out_losses) | ||
|
||
def run_trainer_with_spawn(self, args): | ||
fluid.default_startup_program().random_seed = seed | ||
fluid.default_main_program().random_seed = seed | ||
np.random.seed(seed) | ||
random.seed(seed) | ||
args.trainer_id = dist.get_rank() | ||
|
||
if args.update_method == "nccl2": | ||
dist.init_parallel_env() | ||
model, train_reader, opt = self.get_model() | ||
if args.update_method == "nccl2": | ||
if args.find_unused_parameters: | ||
model = paddle.DataParallel(model, find_unused_parameters=True) | ||
else: | ||
model = paddle.DataParallel(model, find_unused_parameters=False) | ||
|
||
out_losses = [] | ||
for step_id, data in enumerate(train_reader()): | ||
data = self._get_data(data, args) | ||
if step_id == RUN_STEP: | ||
break | ||
if step_id % 3 != 0: | ||
if args.update_method == "nccl2": | ||
with model.no_sync(): | ||
loss = self.run_one_loop(model, opt, data) | ||
loss.backward() | ||
else: | ||
loss = self.run_one_loop(model, opt, data) | ||
loss.backward() | ||
else: | ||
loss = self.run_one_loop(model, opt, data) | ||
loss.backward() | ||
opt.minimize(loss) | ||
print_to_err( | ||
type(self).__name__, | ||
"loss at step %d: %f" % (step_id, loss.numpy())) | ||
out_losses.append(loss.numpy()) | ||
model.clear_gradients() | ||
print_to_out(out_losses) | ||
return out_losses | ||
|
||
|
||
def fake_sample_reader(): | ||
def __reader__(): | ||
for i in range(batch_num): | ||
x_data = np.random.random_sample((10, )).astype('float32') | ||
yield x_data | ||
|
||
return __reader__ | ||
|
||
|
||
if __name__ == "__main__": | ||
runtime_main(TestNoSync) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a note to explain the role of this parameter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Already added notes in Line 212~215 of paddle/fluid/imperative/reducer.h