Skip to content

Commit

Permalink
support ClipGradByGlobalNorm in sharding (#36012)
Browse files Browse the repository at this point in the history
* support ClipGradByGlobalNorm in sharding

* support ClipGradByGlobalNorm in sharding

* test=allcase
  • Loading branch information
zhaoyinglia authored Oct 9, 2021
1 parent 2fd8dee commit 623df42
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
from .hybrid_parallel_optimizer import HybridParallelOptimizer
from .hybrid_parallel_gradscaler import HybridParallelGradScaler
from .dygraph_sharding_optimizer import DygraphShardingOptimizer

__all__ = []
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ def _dygraph_clip(self, params_grads):
paddle.distributed.all_reduce(
global_norm_var_dist, group=self._hcg.get_check_parallel_group())

# In Sharding mode, param and grad is mapping different rank in optimizer.
# ClipGradByGlobalNorm need allreduce to get globol norm
if self._hcg.get_sharding_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group())

global_norm_var = layers.sqrt(global_norm_var_dist +
global_norm_var_not_dist)

Expand Down Expand Up @@ -139,8 +146,13 @@ def __init__(self, optimizer, hcg, strategy):
logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
"optmizer'grad clip will be changed.")

self._inner_opt._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)
if self._sharding_enable:
# change sharding inner_optimizer's _grad_clip
self._inner_opt._inner_optimizer._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)
else:
self._inner_opt._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)

@imperative_base.no_grad
@framework.dygraph_only
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,32 +183,37 @@ def build_optimizer(self,
strategy=None,
is_sharding=True,
Optimizer="adam"):

clip = paddle.nn.ClipGradByGlobalNorm(0.5)
if Optimizer == "adam":
if is_sharding:
optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(),
user_defined_strategy=strategy,
params=model.parameters(),
inner_optimizer_class=paddle.optimizer.Adam,
inner_optimizer_class=paddle.optimizer.AdamW,
learning_rate=0.001,
weight_decay=0.00001, )
weight_decay=0.00001,
grad_clip=clip)
else:
optimizer = paddle.optimizer.Adam(
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=0.001,
weight_decay=0.00001, )
weight_decay=0.00001,
grad_clip=clip)
else:
if is_sharding:
optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(),
user_defined_strategy=strategy,
params=model.parameters(),
inner_optimizer_class=paddle.optimizer.Momentum,
learning_rate=0.001, )
learning_rate=0.001,
grad_clip=clip)
else:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001, parameters=model.parameters())
learning_rate=0.001,
parameters=model.parameters(),
grad_clip=clip)
return optimizer

def build_model_optimizer(self, Optimizer="adam"):
Expand Down

0 comments on commit 623df42

Please sign in to comment.