Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interface to launch parallel dygraph by multiprocessing #26044

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
97b8bdc
add dygraph parallel run interface
chenwhql Aug 7, 2020
00b56d5
polish implement & unified env property name
chenwhql Aug 7, 2020
17f7fe9
add print config arg
chenwhql Aug 10, 2020
07c86aa
refactor init_parallel_env function
chenwhql Aug 11, 2020
4c955a1
Compatible with multiprocessing and launch modes
chenwhql Aug 13, 2020
523e007
set default trainer start port
chenwhql Aug 14, 2020
8101b03
support run in python 2
chenwhql Aug 15, 2020
d3b9a06
polish python2 support code
chenwhql Aug 17, 2020
48c46ff
remove python2 support
chenwhql Aug 17, 2020
b06d400
refine launch import
chenwhql Aug 19, 2020
e1df353
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
chenwhql Aug 19, 2020
2c7b3fd
polish dome design details
chenwhql Aug 19, 2020
39fddff
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
chenwhql Aug 19, 2020
d26f495
refactor api implemention & path
chenwhql Aug 20, 2020
bf985cc
use new method _set_expected_place
chenwhql Aug 20, 2020
7939384
add spawn unittest framework & mnist test
chenwhql Aug 24, 2020
95c0367
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
chenwhql Aug 24, 2020
04580d8
add more unittests & doc
chenwhql Aug 24, 2020
131afd4
fix unittest failed
chenwhql Aug 25, 2020
e170f10
polish english doc
chenwhql Aug 25, 2020
0ef215d
self review and polish details
chenwhql Aug 25, 2020
b27cfee
refactor code by reviewer's comments
chenwhql Aug 25, 2020
f50f343
fix unittest failed
chenwhql Aug 26, 2020
11221a8
fix parallel_env unittest
chenwhql Aug 26, 2020
0980c23
fix several typos
chenwhql Aug 26, 2020
af50518
fix error introduced when fixing typos
chenwhql Aug 27, 2020
a378140
add unpublic note for start_processes
chenwhql Aug 27, 2020
cca82b6
polish details by xiaoguang's comment
chenwhql Aug 27, 2020
82223a6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
chenwhql Aug 27, 2020
d39331c
verify correctly when spawn nprocs=-1
chenwhql Aug 27, 2020
10df04c
resolve collective api conflict
chenwhql Aug 27, 2020
3a2d7e8
refactor spawn & init_parallel_env design
chenwhql Aug 27, 2020
0582c4b
polish doc details
chenwhql Aug 27, 2020
9ceaeff
open spawn unittests
chenwhql Aug 27, 2020
4b7d810
try to fix doc compile error
chenwhql Aug 27, 2020
4261e22
try to fix unknown doc format error
chenwhql Aug 27, 2020
cad6872
add skip unittest when not gpu
chenwhql Aug 28, 2020
377c919
resolve develop conflict
chenwhql Aug 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,6 @@
from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS
from .framework import prepare_context #DEFINE_ALIAS
from .framework import ParallelEnv #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS

from .framework import NoamDecay #DEFINE_ALIAS
Expand Down
23 changes: 23 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import spawn
from .spawn import spawn

from . import parallel
from .parallel import init_parallel_env
from .parallel import get_rank
from .parallel import get_world_size
from paddle.fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS
from paddle.fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS

from . import collective
from .collective import *

# start multiprocess apis
__all__ = ["spawn"]

# dygraph parallel apis
__all__ += [
"init_parallel_env", "get_rank", "get_world_size", "prepare_context",
"ParallelEnv"
]

# collective apis
__all__ += collective.__all__
15 changes: 10 additions & 5 deletions python/paddle/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@
import six
import copy
from argparse import ArgumentParser, REMAINDER
import paddle
import paddle.fluid as fluid

from paddle.distributed.utils import *
import paddle.distributed.cloud_utils as cloud_utils
from paddle.distributed import cloud_utils


def _print_arguments(args):
Expand Down Expand Up @@ -167,7 +165,8 @@ def get_cluster_from_args(args, selected_gpus):

def get_gpus(selected_gpus):
if selected_gpus is None:
gpus_num = fluid.core.get_cuda_device_count()
from paddle.fluid import core
gpus_num = core.get_cuda_device_count()
selected_gpus = [str(x) for x in range(0, gpus_num)]
else:
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
Expand All @@ -190,7 +189,7 @@ def get_gpus(selected_gpus):
return selected_gpus


def launch(args):
def get_cluster_and_pod(args):
# parse arguments, used for cloud-single-machine and local
selected_gpus = get_gpus(args.selected_gpus)
trainers_num = cloud_utils.get_trainers_num()
Expand All @@ -209,6 +208,12 @@ def launch(args):
cluster, pod = get_cluster_from_args(args, selected_gpus)
logger.info("get cluster from args:{}".format(cluster))

return cluster, pod


def launch(args):
cluster, pod = get_cluster_and_pod(args)

procs = start_local_trainers(
cluster,
pod,
Expand Down
184 changes: 184 additions & 0 deletions python/paddle/distributed/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import six
import warnings

from paddle import compat as cpt

# deprecated module import
from paddle.fluid import core
from paddle.fluid.framework import _set_expected_place
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph.parallel import ParallelEnv

__all__ = ["init_parallel_env"]

ParallelStrategy = core.ParallelStrategy


def init_parallel_env(backend='nccl'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL is an underlying communication library, I don't think it's necessary to let users know we have different backends here. If we want to support operating system such as windows that doesn't support NCCL, it's better to detect the operating system inside the init function to use other communication library, such as gloo. I highly recommend to remove backend argument currently for simplicity of usage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx, I think it is okay to remove it, we can discuss removing this argument by cherry-pick

"""
Initialize parallel training environments in dynamic mode.

Args:
backend(str, optional): The backend to communication between multiple devices.
Now only support ``nccl`` . Default value is ``nccl`` .

Returns:
None

Examples:
.. code-block:: python

import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist

class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)

def forward(self, x):
return self._linear2(self._linear1(x))

def train():
# 1. enable dynamic mode
paddle.disable_static()

# 2. initialize parallel environment
dist.init_parallel_env()

# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)

loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())

# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)

loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()

adam.step()
adam.clear_grad()

if __name__ == '__main__':
dist.spawn(train)
"""

# 1. input check
if not isinstance(backend, six.string_types):
raise TypeError("input `backend` type error, expected type is str, "
"but received type is %s." % type(backend))
if cpt.to_text(backend) != 'nccl':
raise ValueError(
"backend `%s` is not supported, now only supports `nccl` backend." %
backend)

# 2. check env
def _check_var_exists(var_name):
var = os.environ.get(var_name, None)
if var is None:
raise ValueError("paddle.distributed initialize error, "
"environment variable %s is needed, but not set." %
var_name)

_check_var_exists("FLAGS_selected_gpus")
_check_var_exists("PADDLE_TRAINER_ID")
_check_var_exists("PADDLE_CURRENT_ENDPOINT")
_check_var_exists("PADDLE_TRAINERS_NUM")
_check_var_exists("PADDLE_TRAINER_ENDPOINTS")

# 3. init ParallelStrategy
strategy = ParallelStrategy()
if cpt.to_text(backend) == 'nccl':
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn("The parallel environment has been initialized.")
strategy.nranks = ParallelEnv().world_size
strategy.local_rank = ParallelEnv().rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
# NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode,
# users will not call `dygraph.guard` or `enable_dygraph`
# directly, if they want to switch default place,
# they need to call a function to change default place,
# here just set correctly place to users
place = core.CUDAPlace(ParallelEnv().device_id)
_set_expected_place(place)
willthefrog marked this conversation as resolved.
Show resolved Hide resolved

# init nccl context
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
parallel_helper._init_parallel_ctx()


def get_rank():
"""
Returns the rank of current trainer.

Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` .
The default value is 0.

Returns:
(int) The rank of current trainer.

Examples:
.. code-block:: python

import paddle
import paddle.distributed as dist

# execute this command in terminal: export PADDLE_TRAINER_ID=0
print("The rank is %d" % dist.get_rank())
# The rank is 0
"""
return ParallelEnv().rank


def get_world_size():
"""
The number of trainers (number of processes participating in current job).

Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` .
The default value is 1.

Returns:
(int) The number of trainers.

Examples:
.. code-block:: python

import paddle
import paddle.distributed as dist

# execute this command in terminal: export PADDLE_TRAINERS_NUM=4
print("The world_size is %d" % dist.get_world_size())
# The world_size is 4
"""
return ParallelEnv().world_size
Loading