Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap dist api for dygraph mode #40408

Merged
merged 28 commits into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9ba08b1
rename TensorBase interface data_type() to dtype()
zyfncg Nov 16, 2021
3c1afc0
rename type to dtype of TensorMeta
zyfncg Nov 17, 2021
288f086
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Nov 17, 2021
701a0bd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Nov 17, 2021
7bc3cbb
merge the code
zyfncg Nov 17, 2021
7b79b03
merge the code
zyfncg Nov 17, 2021
471a1bf
fix the problem when merge conflict
zyfncg Nov 18, 2021
d39a1d9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Nov 19, 2021
835e415
fix bug of ci caused by type of tensor_meta
zyfncg Nov 19, 2021
ab60a6d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 19, 2021
471741f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 20, 2021
691056a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jan 20, 2022
9484f12
unify the support for dp/amp/recompute/groupsharded
Jan 21, 2022
8b501f6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Feb 17, 2022
4baa29f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Feb 17, 2022
ea480f4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Feb 18, 2022
464a640
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Mar 10, 2022
986211d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Mar 14, 2022
4009041
update
Mar 15, 2022
70259a0
update
Mar 21, 2022
151d1f4
update
Mar 21, 2022
a5802b9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Mar 21, 2022
9510fa6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Mar 21, 2022
4a2ce73
update
Mar 22, 2022
a9ea543
update
Mar 22, 2022
51ca6b2
update
Mar 23, 2022
c02328f
update
Mar 23, 2022
df10682
update
Mar 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 78 additions & 6 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,45 @@
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable
from paddle.distributed.fleet.utils.recompute import RecomputeFunction
from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar

__all__ = []

_grad_scalar = None


class _RecomputeModelWrapper(paddle.nn.Layer):
def __init__(self, model, segments=2, preserve_rng_state=True):
super(_RecomputeModelWrapper, self).__init__()
assert isinstance(model, paddle.nn.Sequential), (
"The model passed to RecomputeModelWrapper must be of type "
"paddle.nn.Sequential.")
self._model = model
self._segments = segments
self._preserve_rng_state = preserve_rng_state
self._layers = list(model.children())
self._segment_size = len(self._layers) // segments

def _run_func(self, begin, end):
def do_run(input):
for i in range(begin, end):
input = self._layers[i](input)
return input

return do_run

def _checkpoint(self, func, *args, **kwargs):
return RecomputeFunction.apply(func, self._preserve_rng_state, *args)

def forward(self, input):
end = 0
for begin in range(0, self._segment_size * (self._segments - 1),
self._segment_size):
end = begin + self._segment_size
input = self._checkpoint(self._run_func(begin, end), input)
return self._run_func(end, len(self._layers))(input)


def apply_ir_passes(main_program, startup_program, config):
build_strategy = config._user_defined_strategy.build_strategy._copy()
Expand Down Expand Up @@ -952,6 +988,41 @@ def forward(self, x):
if self.worker_num() <= 1:
return model

amp_enable = False
recompute_enable = False
strategy = self._user_defined_strategy
if strategy.amp == True:
amp_enable = True
amp_level = "O2" if strategy.amp_configs['use_pure_fp16'] else "O1"
if amp_level.upper() == "O2":
model = paddle.amp.decorate(
models=model,
optimizers=None,
level="O2",
master_weight=None,
save_dtype=None)
init_loss_scaling = strategy.amp_configs['init_loss_scaling']
incr_ratio = strategy.amp_configs['incr_ratio']
decr_ratio = strategy.amp_configs['decr_ratio']
incr_every_n_steps = strategy.amp_configs['incr_every_n_steps']
decr_every_n_nan_or_inf = strategy.amp_configs[
'decr_every_n_nan_or_inf']
use_dynamic_loss_scaling = strategy.amp_configs[
'use_dynamic_loss_scaling']

global _grad_scalar
_grad_scalar = paddle.amp.GradScaler(
init_loss_scaling=init_loss_scaling,
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
incr_every_n_steps=incr_every_n_steps,
decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)

if strategy.recompute == True:
recompute_enable = True
model = _RecomputeModelWrapper(model)

if self._user_defined_strategy.heter_ccl_mode == True:
distributed_model = paddle.DataParallel(
model,
Expand All @@ -964,7 +1035,7 @@ def forward(self, x):
return distributed_model

if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
distributed_model = ShardingParallel(
model = ShardingParallel(
model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:

Expand All @@ -975,22 +1046,23 @@ def forward(self, x):
assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size(
)
broadcast_sharding_parameters(model, self._hcg)
distributed_model = paddle.DataParallel(
model = paddle.DataParallel(
model,
comm_buffer_size=self._user_defined_strategy.
fuse_grad_size_in_MB,
last_comm_buffer_size=self._user_defined_strategy.
last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy.
find_unused_parameters)
find_unused_parameters,
static_graph=True if recompute_enable else False)
elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
distributed_model = TensorParallel(
model = TensorParallel(
model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
distributed_model = PipelineParallel(
model = PipelineParallel(
model, self._hcg, strategy=self._user_defined_strategy)

return distributed_model
return model

@dygraph_only
def state_dict(self):
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import paddle.utils.deprecated as deprecated
from paddle import _C_ops

_grad_scalar = None


class TensorHookRemoveHelper(object):
"""
Expand Down Expand Up @@ -261,6 +263,9 @@ def backward(self, grad_tensor=None, retain_graph=False):
grad_tensor = []
else:
grad_tensor = [grad_tensor]
if _grad_scalar:
# When using amp with Fleet DistributedStrategy, we do loss scaling implicitly.
self = _grad_scalar.scale(self)
if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_npu():
# TODO(liuyuhui): Currently only for xpu. Will be removed in the future.
scaled_loss = scale_loss(self)
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@ if (WITH_DISTRIBUTE)
set_tests_properties(test_dist_fleet_infer PROPERTIES TIMEOUT 200)
set_tests_properties(test_dist_fleet_raw_program_optimizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_dist_fleet_raw_program_optimizer_fuse_allreduce PROPERTIES TIMEOUT 60)
set_tests_properties(test_dist_dygraph_apis PROPERTIES TIMEOUT 120)
endif()

if (WITH_DISTRIBUTE AND NOT APPLE)
Expand Down
60 changes: 60 additions & 0 deletions python/paddle/fluid/tests/unittests/dygraph_fleet_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import random
import numpy as np
import os
import shutil

import paddle
import paddle.nn as nn
from paddle.fluid import core
import datetime
from datetime import timedelta
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv


class TestDygraphFleetAPI(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
random.seed(2022)
np.random.seed(2022)
self.config()

def config(self):
self.dtype = "float32"
self.shape = (2, 10, 5)

def test_dygraph_fleet_api(self):
import paddle.distributed.fleet as fleet
import paddle.distributed as dist
strategy = fleet.DistributedStrategy()
strategy.amp = True
strategy.recompute = True
fleet.init(is_collective=True, strategy=strategy)
net = paddle.nn.Sequential(
paddle.nn.Linear(10, 1), paddle.nn.Linear(1, 2))
net = dist.fleet.distributed_model(net)
data = np.random.uniform(-1, 1, [30, 10]).astype('float32')
data = paddle.to_tensor(data)
net(data)


if __name__ == "__main__":
unittest.main()
27 changes: 27 additions & 0 deletions python/paddle/fluid/tests/unittests/test_dist_dygraph_apis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
from test_parallel_dygraph_dataparallel import TestMultipleGpus


class TestDygraphFleetApi(TestMultipleGpus):
def test_dygraph_fleet_api(self):
self.run_mnist_2gpu('dygraph_fleet_api.py')


if __name__ == "__main__":
unittest.main()