From df10682b9a26f68fc036078b9024f72a2ab630f1 Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Thu, 24 Mar 2022 02:47:00 +0000 Subject: [PATCH] update --- python/paddle/distributed/fleet/base/fleet_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index daa436e001ab3d..e38b4839e5383b 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -46,7 +46,7 @@ class _RecomputeModelWrapper(paddle.nn.Layer): - def __init__(self, model, segments=1, preserve_rng_state=True): + def __init__(self, model, segments=2, preserve_rng_state=True): super(_RecomputeModelWrapper, self).__init__() assert isinstance(model, paddle.nn.Sequential), ( "The model passed to RecomputeModelWrapper must be of type " @@ -1024,7 +1024,7 @@ def forward(self, x): model = _RecomputeModelWrapper(model) if self._user_defined_strategy.heter_ccl_mode == True: - model = paddle.DataParallel( + distributed_model = paddle.DataParallel( model, comm_buffer_size=self._user_defined_strategy. fuse_grad_size_in_MB, @@ -1032,7 +1032,7 @@ def forward(self, x): last_comm_group_size_MB, find_unused_parameters=self._user_defined_strategy. find_unused_parameters) - return model + return distributed_model if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL: model = ShardingParallel(