Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bugs in mp_layers、pp_layers and HybridParallelClipGrad #36144

Merged
merged 8 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_list_dist = []
sum_square_list_not_dist = []

for p, g in params_grads:
if g is None:
continue
Expand All @@ -64,29 +65,38 @@ def _dygraph_clip(self, params_grads):
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square)

if p.is_distributed:
sum_square_list_dist.append(sum_square)
else:
sum_square_list_not_dist.append(sum_square)
not_shared_enable = (not hasattr(p, 'is_firstly_shared')) or (
hasattr(p, 'is_firstly_shared') and
getattr(p, 'is_firstly_shared', True))

# all parameters have been filterd out
if len(sum_square_list_dist) + len(sum_square_list_not_dist) == 0:
return params_grads
if not_shared_enable:
if p.is_distributed:
sum_square_list_dist.append(sum_square)
else:
sum_square_list_not_dist.append(sum_square)

global_norm_var_dist = layers.concat(sum_square_list_dist) if len(
sum_square_list_dist) != 0 else layers.concat(
[paddle.to_tensor([0.])])
global_norm_var_dist = layers.reduce_sum(global_norm_var_dist)

global_norm_var_not_dist = layers.concat(
sum_square_list_not_dist) if len(
sum_square_list_not_dist) != 0 else layers.concat(
[paddle.to_tensor([0.])])
global_norm_var_not_dist = layers.reduce_sum(global_norm_var_not_dist)

# add all reduce to get global norm of distributed params_and_grads in world size
# all reduce is not needed while getting global norm of non-distributed params_and_grads
paddle.distributed.all_reduce(
global_norm_var_dist, group=self._hcg.get_check_parallel_group())
# add all reduce to get global norm of distributed params_and_grads
if self._hcg.get_model_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_check_parallel_group())

# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
if self._hcg.get_pipe_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_pipe_parallel_group())

# In Sharding mode, param and grad is mapping different rank in optimizer.
# ClipGradByGlobalNorm need allreduce to get globol norm
Expand Down Expand Up @@ -143,8 +153,8 @@ def __init__(self, optimizer, hcg, strategy):

if isinstance(self._inner_opt._grad_clip,
ClipGradByGlobalNorm) and not self._use_dp_mode:
logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
"optmizer'grad clip will be changed.")
logger.warning("While using ClipGradByGlobalNorm in TensorParallel, PipelineParallel " \
"or Sharding, the grad clip of original optimizer will be changed.")

if self._sharding_enable:
# change sharding inner_optimizer's _grad_clip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self,
dtype=self._dtype,
is_bias=False)

self.weight.is_distributed = True
self.weight.is_distributed = True if self.is_mp else False

def forward(self, x):
if self.is_mp:
Expand Down Expand Up @@ -135,7 +135,7 @@ def __init__(self,
dtype=self._dtype,
is_bias=False)

self.weight.is_distributed = True
self.weight.is_distributed = True if self.is_mp else False

if has_bias:
# initialize bias to zero like Megatron
Expand All @@ -144,7 +144,7 @@ def __init__(self,
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype,
is_bias=True)
self.bias.is_distributed = True
self.bias.is_distributed = True if self.is_mp else False
else:
self.bias = None

Expand Down Expand Up @@ -212,7 +212,7 @@ def __init__(self,
dtype=self._dtype,
is_bias=False)

self.weight.is_distributed = True
self.weight.is_distributed = True if self.is_mp else False

if has_bias:
self.bias = self.create_parameter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ def _synchronize_shared_weights(self):
src=min(comm['ranks']),
group=comm['group'])

for param in comm['layer'].parameters():
if self.global_rank != min(comm['ranks']):
setattr(param, 'is_firstly_shared', False)

def allreduce_shared_weight_gradients(self):
for key, comm in self.shared_comm.items():
param = getattr(self.shared_layers[key], comm['weight_attr'])
Expand Down Expand Up @@ -316,6 +320,9 @@ def _build_layer(self):
self.shared_layers[layer.layer_name] = layer.build_layer()
self.shared_weight_attrs[
layer.layer_name] = layer.shared_weight_attr
for param in self.shared_layers[
layer.layer_name].parameters():
setattr(param, "is_firstly_shared", True)

if layer.forward_func is None:
self.run_function.append(self.shared_layers[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def setUp(self):
}
fleet.init(is_collective=True, strategy=strategy)

def build_optimizer(self, model):
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
parameters=model.parameters())
return scheduler, optimizer

def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
Expand All @@ -63,10 +70,7 @@ def test_pp_model(self):

#construct model a
model_a = AlexNet(10)
scheduler_a = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a,
parameters=model_a.parameters())
scheduler_a, optimizer_a = self.build_optimizer(model_a)

param_len = len(model_a.parameters())

Expand All @@ -76,10 +80,7 @@ def test_pp_model(self):

# construct model b
model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size)
scheduler_b = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b,
parameters=model_b.parameters())
scheduler_b, optimizer_b = self.build_optimizer(model_b)
model_b = fleet.distributed_model(model_b)
optimizer_b = fleet.distributed_optimizer(optimizer_b)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import print_function

import paddle
import unittest
from hybrid_parallel_pp_alexnet import TestDistPPTraning


class TestPPClipGrad(TestDistPPTraning):
def build_optimizer(self, model):
grad_clip = paddle.nn.ClipGradByGlobalNorm(0.5)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
grad_clip=grad_clip,
parameters=model.parameters())
return scheduler, optimizer


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def test_hybrid_parallel_save_load(self):
def test_hybrid_parallel_recompute(self):
self.run_mnist_2gpu('hybrid_parallel_pp_recompute.py')

def test_hybrid_parallel_pp_clip_grad(self):
self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py')


if __name__ == "__main__":
unittest.main()