diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index c22eee3df6f29..46b84697e5a61 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -230,8 +230,6 @@ from .framework import no_grad #DEFINE_ALIAS from .framework import save #DEFINE_ALIAS from .framework import load #DEFINE_ALIAS -from .framework import prepare_context #DEFINE_ALIAS -from .framework import ParallelEnv #DEFINE_ALIAS from .framework import DataParallel #DEFINE_ALIAS from .framework import NoamDecay #DEFINE_ALIAS diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 34dd605f901b4..d66577102c713 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -12,4 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import spawn +from .spawn import spawn + +from . import parallel +from .parallel import init_parallel_env +from .parallel import get_rank +from .parallel import get_world_size +from paddle.fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS +from paddle.fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS + +from . import collective from .collective import * + +# start multiprocess apis +__all__ = ["spawn"] + +# dygraph parallel apis +__all__ += [ + "init_parallel_env", "get_rank", "get_world_size", "prepare_context", + "ParallelEnv" +] + +# collective apis +__all__ += collective.__all__ diff --git a/python/paddle/distributed/launch.py b/python/paddle/distributed/launch.py index ecd1cf0ca7bef..e2ab321f9aebd 100644 --- a/python/paddle/distributed/launch.py +++ b/python/paddle/distributed/launch.py @@ -44,11 +44,9 @@ import six import copy from argparse import ArgumentParser, REMAINDER -import paddle -import paddle.fluid as fluid from paddle.distributed.utils import * -import paddle.distributed.cloud_utils as cloud_utils +from paddle.distributed import cloud_utils def _print_arguments(args): @@ -167,7 +165,8 @@ def get_cluster_from_args(args, selected_gpus): def get_gpus(selected_gpus): if selected_gpus is None: - gpus_num = fluid.core.get_cuda_device_count() + from paddle.fluid import core + gpus_num = core.get_cuda_device_count() selected_gpus = [str(x) for x in range(0, gpus_num)] else: cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") @@ -190,7 +189,7 @@ def get_gpus(selected_gpus): return selected_gpus -def launch(args): +def get_cluster_and_pod(args): # parse arguments, used for cloud-single-machine and local selected_gpus = get_gpus(args.selected_gpus) trainers_num = cloud_utils.get_trainers_num() @@ -209,6 +208,12 @@ def launch(args): cluster, pod = get_cluster_from_args(args, selected_gpus) logger.info("get cluster from args:{}".format(cluster)) + return cluster, pod + + +def launch(args): + cluster, pod = get_cluster_and_pod(args) + procs = start_local_trainers( cluster, pod, diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py new file mode 100644 index 0000000000000..0c806747217ad --- /dev/null +++ b/python/paddle/distributed/parallel.py @@ -0,0 +1,184 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except jin compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import six +import warnings + +from paddle import compat as cpt + +# deprecated module import +from paddle.fluid import core +from paddle.fluid.framework import _set_expected_place +from paddle.fluid.dygraph import parallel_helper +from paddle.fluid.dygraph.parallel import ParallelEnv + +__all__ = ["init_parallel_env"] + +ParallelStrategy = core.ParallelStrategy + + +def init_parallel_env(backend='nccl'): + """ + Initialize parallel training environments in dynamic mode. + + Args: + backend(str, optional): The backend to communication between multiple devices. + Now only support ``nccl`` . Default value is ``nccl`` . + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + import paddle.optimizer as opt + import paddle.distributed as dist + + class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear1 = nn.Linear(10, 10) + self._linear2 = nn.Linear(10, 1) + + def forward(self, x): + return self._linear2(self._linear1(x)) + + def train(): + # 1. enable dynamic mode + paddle.disable_static() + + # 2. initialize parallel environment + dist.init_parallel_env() + + # 3. create data parallel layer & optimizer + layer = LinearNet() + dp_layer = paddle.DataParallel(layer) + + loss_fn = nn.MSELoss() + adam = opt.Adam( + learning_rate=0.001, parameters=dp_layer.parameters()) + + # 4. run layer + inputs = paddle.randn([10, 10], 'float32') + outputs = dp_layer(inputs) + labels = paddle.randn([10, 1], 'float32') + loss = loss_fn(outputs, labels) + + loss = dp_layer.scale_loss(loss) + loss.backward() + dp_layer.apply_collective_grads() + + adam.step() + adam.clear_grad() + + if __name__ == '__main__': + dist.spawn(train) + """ + + # 1. input check + if not isinstance(backend, six.string_types): + raise TypeError("input `backend` type error, expected type is str, " + "but received type is %s." % type(backend)) + if cpt.to_text(backend) != 'nccl': + raise ValueError( + "backend `%s` is not supported, now only supports `nccl` backend." % + backend) + + # 2. check env + def _check_var_exists(var_name): + var = os.environ.get(var_name, None) + if var is None: + raise ValueError("paddle.distributed initialize error, " + "environment variable %s is needed, but not set." % + var_name) + + _check_var_exists("FLAGS_selected_gpus") + _check_var_exists("PADDLE_TRAINER_ID") + _check_var_exists("PADDLE_CURRENT_ENDPOINT") + _check_var_exists("PADDLE_TRAINERS_NUM") + _check_var_exists("PADDLE_TRAINER_ENDPOINTS") + + # 3. init ParallelStrategy + strategy = ParallelStrategy() + if cpt.to_text(backend) == 'nccl': + if parallel_helper._is_parallel_ctx_initialized(): + warnings.warn("The parallel environment has been initialized.") + strategy.nranks = ParallelEnv().world_size + strategy.local_rank = ParallelEnv().rank + strategy.trainer_endpoints = ParallelEnv().trainer_endpoints + strategy.current_endpoint = ParallelEnv().current_endpoint + if strategy.nranks < 2: + return + # NOTE(chenweihang): [ why config global place here? ] + # the dygraph mode will be set to default mode, + # users will not call `dygraph.guard` or `enable_dygraph` + # directly, if they want to switch default place, + # they need to call a function to change default place, + # here just set correctly place to users + place = core.CUDAPlace(ParallelEnv().device_id) + _set_expected_place(place) + + # init nccl context + parallel_helper._set_parallel_ctx( + core.NCCLParallelContext(strategy, place)) + parallel_helper._init_parallel_ctx() + + +def get_rank(): + """ + Returns the rank of current trainer. + + Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . + The default value is 0. + + Returns: + (int) The rank of current trainer. + + Examples: + .. code-block:: python + + import paddle + import paddle.distributed as dist + + # execute this command in terminal: export PADDLE_TRAINER_ID=0 + print("The rank is %d" % dist.get_rank()) + # The rank is 0 + """ + return ParallelEnv().rank + + +def get_world_size(): + """ + The number of trainers (number of processes participating in current job). + + Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . + The default value is 1. + + Returns: + (int) The number of trainers. + + Examples: + .. code-block:: python + + import paddle + import paddle.distributed as dist + + # execute this command in terminal: export PADDLE_TRAINERS_NUM=4 + print("The world_size is %d" % dist.get_world_size()) + # The world_size is 4 + """ + return ParallelEnv().world_size diff --git a/python/paddle/distributed/spawn.py b/python/paddle/distributed/spawn.py new file mode 100644 index 0000000000000..1ca2ebaa8d4bd --- /dev/null +++ b/python/paddle/distributed/spawn.py @@ -0,0 +1,415 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, division + +import multiprocessing +import os +import signal +import six +import sys +import warnings + +from paddle.distributed.launch import get_cluster_and_pod, _print_arguments +from paddle.distributed.utils import _prepare_trainer_env +from paddle.device import get_device + +# deprecated module import +from paddle.fluid import core +from paddle.fluid.framework import _cpu_num + + +# NOTE(chenweihang): The existence of this class leads to +# the maintenance of two arguments. When the launch.py arguments +# is updated, the arguments here also need to be updated, +# but I have not thought of a better way here +class ParallelEnvArgs(object): + def __init__(self): + # Paddle cluster nodes ips, such as 192.168.0.16,192.168.0.17.. + self.cluster_node_ips = None + + # The current node ip. + self.node_ip = None + + # whether to use paddlecloud platform to run your multi-process job. + # If false, no need to set this argument. + self.use_paddlecloud = None + + # The trainer's started port on a single node + self.started_port = None + + # Print the config or not + self.print_config = True + + # It's for gpu training and the training process will run + # on the selected_gpus, each process is bound to a single GPU. + # And if it's not set, this module will use all the gpu cards + # for training. + self.selected_gpus = None + + +def _py_supported_check(): + if not sys.version_info >= (3, 4): + raise RuntimeError( + "Use `paddle.distributed.spawn` to start parallel training " + "requires python version greater than 3.4, if your python " + "is lower than this version, please use " + "`paddle.distributed.launch` instead.") + + +def _get_subprocess_env_list(nprocs, options): + # contruct processes env list + processes_env_list = [] + + # get args from kwargs + args = ParallelEnvArgs() + + # set default `node_ip` and `cluster_node_ips` + args.cluster_node_ips = options.get('cluster_node_ips', None) + args.node_ip = options.get('node_ip', None) + if args.cluster_node_ips is not None and args.node_ip is None: + raise ValueError("please input current node ip, " + "cannot only give `cluster_node_ips`.") + default_node_ip = "127.0.0.1" + if args.node_ip is None: + args.node_ip = default_node_ip + if args.cluster_node_ips is None: + args.cluster_node_ips = default_node_ip + + # set default selected gpus + # e.g. if the nprocs is 4, the selected gpus is "0,1,2,3" + # NOTE(chenweihang): [ why not use FLAGS_selected_gpus directly? ] + # because the FLAGS_selected_gpus may be used in other place, + # if we set FLAGS_selected_gpus to be `0,1,2,3`, it may cause error + # when using `ParallelEnv` + # NOTE(chenweihang): use absolute gpu card id + args.selected_gpus = options.get('selected_gpus', None) + env_devices = os.getenv("CUDA_VISIBLE_DEVICES", None) + if env_devices is None or env_devices == "": + env_devices_list = [ + str(x) for x in six.moves.range(core.get_cuda_device_count()) + ] + else: + env_devices_list = env_devices.split(',') + if args.selected_gpus is None: + if len(env_devices_list) < nprocs: + raise RuntimeError( + "the number of visible devices(%d) is less than the number " + "of spawn processes(%d), please ensure that the correct " + "`nprocs` argument is passed or the environment variable " + "`CUDA_VISIBLE_DEVICES` is correctly configured." % + (len(env_devices_list), nprocs)) + args.selected_gpus = ",".join( + [str(env_devices_list[x]) for x in range(0, nprocs)]) + else: + for card_id in args.selected_gpus.split(','): + if card_id not in env_devices_list: + raise ValueError("The selected gpu card %s cannot found in " + "CUDA_VISIBLE_DEVICES (%s)." % + (card_id, ",".join(env_devices_list))) + + # set other arguments + args.started_port = options.get('started_port', None) + args.use_paddlecloud = options.get('use_paddlecloud', False) + args.print_config = options.get('print_config', False) + + # reuse code of launch.py + cluster, pod = get_cluster_and_pod(args) + + # prepare subprocess env list + for trainer in pod.trainers: + processes_env_list.append(_prepare_trainer_env(cluster, trainer)) + + # print config + if args.print_config: + _print_arguments(args) + + return processes_env_list + + +def _remove_risky_env(): + # remove useless env vars, same as launch.py + # no copy, each process will hold env vars itself + os.environ.pop("http_proxy", None) + os.environ.pop("https_proxy", None) + + +def _set_trainer_env(env_dict): + for var_name in env_dict: + os.environ[var_name] = env_dict[var_name] + + +def _func_wrapper(func, args, error_queue, return_queue, env_dict): + try: + # config subprocess environment variables + _remove_risky_env() + _set_trainer_env(env_dict) + # execute function + result = func(*args) + # record function return value + return_queue.put(result) + except KeyboardInterrupt: + pass + except Exception: + import traceback + error_queue.put(traceback.format_exc()) + sys.exit(1) + + +class MultiprocessContext(object): + def __init__(self, processes, error_queues, return_queues): + _py_supported_check() + self.error_queues = error_queues + # NOTE(chenweihang): The `spawn` method is mainly used + # to wrap the outermost execution function of the program for + # parallel execution. Generally, the return value is not concerned, + # but if the user needs to obtain the return value, users can get + # the return result of each process from context.return_queues + self.return_queues = return_queues + self.processes = processes + self.sentinels = { + process.sentinel: index + for index, process in enumerate(processes) + } + + def join(self, timeout=None): + if len(self.sentinels) == 0: + return True + + ready = multiprocessing.connection.wait( + self.sentinels.keys(), timeout=timeout) + + error_index = None + for sentinel in ready: + index = self.sentinels.pop(sentinel) + process = self.processes[index] + process.join() + if process.exitcode != 0: + error_index = index + break + + if error_index is None: + return len(self.sentinels) == 0 + + for process in self.processes: + if process.is_alive(): + process.terminate() + process.join() + + self._throw_exception(error_index) + + def _throw_exception(self, error_index): + if self.error_queues[error_index].empty(): + exitcode = self.processes[error_index].exitcode + if exitcode < 0: + name = signal.Signals(-exitcode).name + raise Exception("Process %d terminated with signal %s." % + (error_index, name)) + else: + raise Exception("Process %d terminated with exit code %d." & ( + error_index, exitcode)) + + original_trace = self.error_queues[error_index].get() + msg = "\n\n----------------------------------------------\n" \ + "Process %d terminated with the following error:\n" \ + "----------------------------------------------\n\n" % error_index + msg += original_trace + raise Exception(msg) + + +def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): + """ + Start multiple processes with ``spawn`` method for parallel training. + + Args: + func (function): The target function is called by spawned process. + This function need to be able to pickled, so it must be defined + at the top level of a module. + This function should be called as ``func(i, *args)``, ``i`` is + the process index and ``args`` contains other arguments as tuple. + args (tuple, optional): Arguments passed to ``func``. + nprocs (int, optional): Number of processed to start. Default: -1. + when nprocs is -1, the available device will be obtained from + the environment variable when the model is executed: If use GPU, + the currently available device ID is obtained from the environment + variable CUDA_VISIBLE_DEVICES; If use CPU, the currently available + CPU number is obtained from the environment variable CPU_NUM. + For example, export CPU_NUM=4, if the environment variable is not set, + the executor will add the variable to the environment variable and + set its value to 1. + join (bool, optional): Perform a blocking join on all spawned processes. + Default: True. + daemon (bool, optional): The spawned processes' daemon flag. Default: False. + **options(dict, optional): Other initial parallel execution environment + configuration options. The following options are currently supported: + (1) start_method (string): the way to start a process. + The start method can be ``spawn`` , ``fork`` , ``forkserver`` . + Because the CUDA runtime does not support the ``fork`` start method, + when use CUDA in subprocesses, we should start process by ``spawn`` + or ``forkserver`` method. Default: "spawn" ; + (2) cluster_node_ips (string): Paddle cluster nodes ips, such as + "192.168.0.16,192.168.0.17". Default: "127.0.0.1"; + (3) node_ip (string): The current node ip, such as "192.168.0.16". + Default: "127.0.0.1"; + (4) started_port (int): The trainer's started port on a single node, + such as 6170. Default: None; + (5) selected_gpus (string): The training process will run on the + selected_gpus, such as "0,1,2,3". Default: None; + (6) print_config: Print current parallel training config. Default: False; + (7) use_paddlecloud: Whether to use paddlecloud platform to run your + multi-process job. Default: False. + + Returns: + ``MultiprocessContext`` object, it hold the spawned processes. + + Examples: + .. code-block:: python + + from __future__ import print_function + + import paddle + import paddle.nn as nn + import paddle.optimizer as opt + import paddle.distributed as dist + + class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear1 = nn.Linear(10, 10) + self._linear2 = nn.Linear(10, 1) + + def forward(self, x): + return self._linear2(self._linear1(x)) + + def train(print_result=False): + # 1. enable dynamic mode + paddle.disable_static() + + # 2. initialize parallel environment + dist.init_parallel_env() + + # 3. create data parallel layer & optimizer + layer = LinearNet() + dp_layer = paddle.DataParallel(layer) + + loss_fn = nn.MSELoss() + adam = opt.Adam( + learning_rate=0.001, parameters=dp_layer.parameters()) + + # 4. run layer + inputs = paddle.randn([10, 10], 'float32') + outputs = dp_layer(inputs) + labels = paddle.randn([10, 1], 'float32') + loss = loss_fn(outputs, labels) + + if print_result is True: + print("loss:", loss.numpy()) + + loss = dp_layer.scale_loss(loss) + loss.backward() + dp_layer.apply_collective_grads() + + adam.step() + adam.clear_grad() + + # Usage 1: only pass function. + # If your training method no need any argument, and + # use all visible devices for parallel training. + if __name__ == '__main__': + dist.spawn(train) + + # Usage 2: pass function and arguments. + # If your training method need some arguments, and + # use all visible devices for parallel training. + if __name__ == '__main__': + dist.spawn(train, args=(True,)) + + # Usage 3: pass function, arguments and nprocs. + # If your training method need some arguments, and + # only use part of visible devices for parallel training. + # If your machine hold 8 cards {0,1,2,3,4,5,6,7}, + # this case will use cards {0,1}; If you set + # CUDA_VISIBLE_DEVICES=4,5,6,7, this case will use + # cards {4,5} + if __name__ == '__main__': + dist.spawn(train, args=(True,), nprocs=2) + + # Usage 4: pass function, arguments, nprocs and selected_gpus. + # If your training method need some arguments, and + # only use part of visible devices for parallel training, + # but you can't set your machine's environment varibale + # CUDA_VISIBLE_DEVICES, such as it is None or all cards + # {0,1,2,3,4,5,6,7}, you can pass `selelcted_gpus` to + # select the GPU cards you want to use. For example, + # this case will use cards {4,5} if your machine hold 8 cards. + if __name__ == '__main__': + dist.spawn(train, args=(True,), nprocs=2, selelcted_gpus='4,5') + """ + # NOTE(chenweihang): [ why only supports python3.4+ ? ] + # Python supported setting the child process startup method + # since 3.4. The previous version can only use the default startup + # method, while the default startup method of Unix is fork, which + # cannot support CUDA runtime multi-process + _py_supported_check() + + # get default nprocs + if nprocs == -1: + device = get_device() + if device == 'cpu': + # TODO: not supports cpu parallel now + nprocs = _cpu_num + else: + nprocs = core.get_cuda_device_count() + + # NOTE(chenweihang): [ why need get cluster info before run? ] + # when using `paddle.distributed.spawn` start parallel training, + # we should get cluster info before starting subprocess, and pass + # correct info to each subprocess + procs_env_list = _get_subprocess_env_list(nprocs, options) + + # start processes + # NOTE(chenweihang): [ why default start method is spawn? ] + # The CUDA runtime does not support the fork start method, + # either the spawn or forkserver start method are required + # to use CUDA in subprocesses. + start_method = options.get('start_method', None) + if start_method is None: + start_method = 'spawn' + mp = multiprocessing.get_context(start_method) + + error_queues = [] + return_queues = [] + processes = [] + for i in range(nprocs): + error_queue = mp.SimpleQueue() + return_queue = mp.SimpleQueue() + process = mp.Process( + target=_func_wrapper, + args=(func, args, error_queue, return_queue, procs_env_list[i])) + process.daemon = daemon + process.start() + error_queues.append(error_queue) + return_queues.append(return_queue) + processes.append(process) + + context = MultiprocessContext(processes, error_queues, return_queues) + if not join: + return context + + # loop until all process end + while not context.join(): + pass + + # finally return context + return context diff --git a/python/paddle/distributed/utils.py b/python/paddle/distributed/utils.py index 7c8fa257f778e..1fa307c4d1b89 100644 --- a/python/paddle/distributed/utils.py +++ b/python/paddle/distributed/utils.py @@ -327,6 +327,17 @@ def __free_port(): return None +def _prepare_trainer_env(cluster, trainer): + proc_env = { + "FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in trainer.gpus]), + "PADDLE_TRAINER_ID": "%d" % trainer.rank, + "PADDLE_CURRENT_ENDPOINT": "%s" % trainer.endpoint, + "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), + "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) + } + return proc_env + + class TrainerProc(object): def __init__(self): self.proc = None @@ -352,14 +363,7 @@ def start_local_trainers(cluster, procs = [] for idx, t in enumerate(pod.trainers): - proc_env = { - "FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in t.gpus]), - "PADDLE_TRAINER_ID": "%d" % t.rank, - "PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint, - "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), - "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) - } - + proc_env = _prepare_trainer_env(cluster, t) current_env.update(proc_env) logger.debug("trainer proc env:{}".format(current_env)) diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 54d2cda4ca685..bd578e6ba98a0 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -11,21 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import six import numpy as np +import warnings from collections import OrderedDict -from .. import core -from . import layers -from . import parallel_helper -from .. import framework -from . import to_variable, no_grad + +from paddle.fluid import core +from paddle.fluid import framework +from paddle.fluid.dygraph import layers +from paddle.fluid.dygraph import parallel_helper +from paddle.fluid.dygraph import to_variable, no_grad +from paddle.utils import deprecated __all__ = ["prepare_context", "ParallelEnv", "DataParallel"] ParallelStrategy = core.ParallelStrategy +@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env") def prepare_context(strategy=None): ''' :api_attr: imperative @@ -39,17 +44,18 @@ def prepare_context(strategy=None): if strategy.nranks < 2: return assert framework.in_dygraph_mode() is True, \ - "dygraph.prepare_context should be used with dygrahp mode." + "dygraph.prepare_context should be used with dygraph mode." place = framework._current_expected_place() assert place is not None, \ "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard." - if isinstance(place, core.CUDAPlace): - parallel_helper._set_parallel_ctx( - core.NCCLParallelContext(strategy, place)) - else: - # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation - assert ("Only support CUDAPlace for now.") - parallel_helper._init_parallel_ctx() + if not parallel_helper._is_parallel_ctx_initialized(): + if isinstance(place, core.CUDAPlace): + parallel_helper._set_parallel_ctx( + core.NCCLParallelContext(strategy, place)) + else: + # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation + assert ("Only support CUDAPlace for now.") + parallel_helper._init_parallel_ctx() return strategy @@ -112,84 +118,84 @@ class ParallelEnv(object): """ def __init__(self): - self._nranks = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) - self._local_rank = int(os.getenv("PADDLE_TRAINER_ID", "0")) - self._dev_id = int(os.getenv("FLAGS_selected_gpus", "0")) + self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0")) + self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) + self._device_id = int(os.getenv("FLAGS_selected_gpus", "0")) self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(",") self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "") @property - def nranks(self): + def rank(self): """ - The number of trainers, generally refers to the number of GPU cards used in training. + Rank of current trainer. - Its value is equal to the value of the environment variable PADDLE_TRAINERS_NUM. The default value is 1. + Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0. Examples: .. code-block:: python - # execute this command in terminal: export PADDLE_TRAINERS_NUM=4 - import paddle.fluid as fluid + # execute this command in terminal: export PADDLE_TRAINER_ID=0 + import paddle.distributed as dist - env = fluid.dygraph.ParallelEnv() - print("The nranks is %d" % env.nranks) - # The nranks is 4 + env = dist.ParallelEnv() + print("The rank is %d" % env.rank) + # The rank is 0 """ - return self._nranks + return self._rank @property - def local_rank(self): + def world_size(self): """ - The current trainer number. + The number of trainers (number of processes participating in current job). - Its value is equal to the value of the environment variable PADDLE_TRAINER_ID. The default value is 0. + Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1. Examples: .. code-block:: python - # execute this command in terminal: export PADDLE_TRAINER_ID=0 - import paddle.fluid as fluid + # execute this command in terminal: export PADDLE_TRAINERS_NUM=4 + import paddle.distributed as dist - env = fluid.dygraph.ParallelEnv() - print("The local rank is %d" % env.local_rank) - # The local rank is 0 + env = dist.ParallelEnv() + print("The world_size is %d" % env.world_size) + # The world_size is 4 """ - return self._local_rank + return self._world_size @property - def dev_id(self): + def device_id(self): """ The ID of selected GPU card for parallel training. - Its value is equal to the value of the environment variable FLAGS_selected_gpus. The default value is 0. + Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0. Examples: .. code-block:: python # execute this command in terminal: export FLAGS_selected_gpus=1 - import paddle.fluid as fluid + import paddle.distributed as dist - env = fluid.dygraph.ParallelEnv() - print("The device id are %d" % env.dev_id) + env = dist.ParallelEnv() + print("The device id are %d" % env.device_id) # The device id are 1 """ - return self._dev_id + return self._device_id @property def current_endpoint(self): """ The endpoint of current trainer, it is in the form of (node IP + port). - Its value is equal to the value of the environment variable PADDLE_CURRENT_ENDPOINT. The default value is "". + Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "". Examples: .. code-block:: python # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170 - import paddle.fluid as fluid + import paddle.distributed as dist - env = fluid.dygraph.ParallelEnv() + env = dist.ParallelEnv() print("The current endpoint are %s" % env.current_endpoint) # The current endpoint are 127.0.0.1:6170 """ @@ -201,20 +207,25 @@ def trainer_endpoints(self): The endpoints of all trainer nodes in the task, which are used to broadcast the NCCL ID when NCCL2 is initialized. - Its value is equal to the value of the environment variable PADDLE_TRAINER_ENDPOINTS. The default value is "". + Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "". Examples: .. code-block:: python # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171 - import paddle.fluid as fluid + import paddle.distributed as dist - env = fluid.dygraph.ParallelEnv() + env = dist.ParallelEnv() print("The trainer endpoints are %s" % env.trainer_endpoints) # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171'] """ return self._trainer_endpoints + # [aliases] Compatible with old method names + local_rank = rank + nranks = world_size + dev_id = device_id + # NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names # are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible @@ -227,61 +238,98 @@ class DataParallel(layers.Layer): Run the dygraph module with data parallelism. Currently, DataParallel class only supports to run the dynamic graph - with multi-process. The usage is: - `python -m paddle.distributed.launch --selected_gpus=0,1 dynamic_graph_test.py`. - And the content of `dynamic_graph_test.py` is the code of examples. + with multi-process. + + Now supports two ways to start training: + + 1. start by ``paddle.distributed.spawn`` method, for example: + + ``python demo.py`` (spawn need to be called in ``__main__`` method) + + 2. start by ``paddle.distributed.launch`` module, for example: + + ``python -m paddle.distributed.launch --selected_gpus=0,1 demo.py`` . + + And the content of `demo.py` is the code of examples. Args: layers(Layer): The module that should be executed by data parallel. - strategy(ParallelStrategy): The strategy of data parallelism, contains - environment configuration related to parallel execution. - + strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism, + contains environment configuration related to parallel execution. Default: None. + Returns: Layer: The data paralleled module. Examples: .. code-block:: python - import numpy as np - import paddle.fluid as fluid + import paddle + import paddle.nn as nn + import paddle.optimizer as opt + import paddle.distributed as dist - place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id) - with fluid.dygraph.guard(place): - - # prepare the data parallel context - strategy = fluid.dygraph.prepare_context() - - linear = fluid.dygraph.Linear(1, 10, act="softmax") - adam = fluid.optimizer.AdamOptimizer( - learning_rate=0.001, parameter_list=linear.parameters()) - - # make the module become the data parallelism module - linear = fluid.dygraph.DataParallel(linear, strategy) - - x_data = np.random.random(size=[10, 1]).astype(np.float32) - data = fluid.dygraph.to_variable(x_data) - - hidden = linear(data) - avg_loss = fluid.layers.mean(hidden) - - # scale the loss according to the number of trainers. - avg_loss = linear.scale_loss(avg_loss) - - avg_loss.backward() - - # collect the gradients of trainers. - linear.apply_collective_grads() - - adam.minimize(avg_loss) - linear.clear_gradients() + class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear1 = nn.Linear(10, 10) + self._linear2 = nn.Linear(10, 1) + + def forward(self, x): + return self._linear2(self._linear1(x)) + + def train(): + # 1. enable dynamic mode + paddle.disable_static() + + # 2. initialize parallel environment + dist.init_parallel_env() + + # 3. create data parallel layer & optimizer + layer = LinearNet() + dp_layer = paddle.DataParallel(layer) + + loss_fn = nn.MSELoss() + adam = opt.Adam( + learning_rate=0.001, parameters=dp_layer.parameters()) + + # 4. run layer + inputs = paddle.randn([10, 10], 'float32') + outputs = dp_layer(inputs) + labels = paddle.randn([10, 1], 'float32') + loss = loss_fn(outputs, labels) + + loss = dp_layer.scale_loss(loss) + loss.backward() + dp_layer.apply_collective_grads() + + adam.step() + adam.clear_grad() + + if __name__ == '__main__': + # 1. start by ``paddle.distributed.spawn`` (default) + dist.spawn(train, nprocs=2) + # 2. start by ``paddle.distributed.launch`` + # train() """ - def __init__(self, layers, strategy): + def __init__(self, layers, strategy=None): super(DataParallel, self).__init__(layers.full_name() + "_data_parallel") self._layers = layers - self._strategy = strategy + + # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy. + # It just stores some environment variables, which can be constructed by + # ParallelEnv. Here it is set as an optional argument. + # This parameter is not removed because of compatibility with 1.x writing. + if strategy is not None: + self._strategy = strategy + else: + self._strategy = ParallelStrategy() + self._strategy.nranks = ParallelEnv().nranks + self._strategy.local_rank = ParallelEnv().local_rank + self._strategy.trainer_endpoints = ParallelEnv().trainer_endpoints + self._strategy.current_endpoint = ParallelEnv().current_endpoint def forward(self, *inputs, **kwargs): return self._layers(*inputs, **kwargs) diff --git a/python/paddle/fluid/dygraph/parallel_helper.py b/python/paddle/fluid/dygraph/parallel_helper.py index f378211de2b8a..ff1675f0ae8a4 100644 --- a/python/paddle/fluid/dygraph/parallel_helper.py +++ b/python/paddle/fluid/dygraph/parallel_helper.py @@ -23,6 +23,11 @@ def _is_data_parallel_mode(): os.getenv("PADDLE_TRAINERS_NUM", "1")) > 1 +def _is_parallel_ctx_initialized(): + global __parallel_ctx__clz__ + return __parallel_ctx__clz__ is not None + + def _set_parallel_ctx(nccl_parallel_context): global __parallel_ctx__clz__ assert __parallel_ctx__clz__ is None, \ diff --git a/python/paddle/fluid/tests/unittests/spawn_runner_base.py b/python/paddle/fluid/tests/unittests/spawn_runner_base.py new file mode 100644 index 0000000000000..278d7b27c5288 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/spawn_runner_base.py @@ -0,0 +1,81 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, division + +import numpy as np +import unittest + +import paddle + +# used by model.run_trainer in test_dist_base +from test_dist_base import RUN_STEP + + +# NOTE: compatible TestParallelDyGraphRunnerBase args +class SpawnAssistTestArgs(object): + update_method = "local" + trainer_id = 0 + + +class TestDistSpawnRunner(unittest.TestCase): + def setUp(self): + # NOTE(chenweihang): keep consistent with + # TestDistBase.check_with_place + self.nprocs = 2 + + def _run(self, model, args): + args.update_method = "local" + return model.run_trainer_with_spawn(args) + + def _run_parallel(self, model, args): + args.update_method = "nccl2" + context = paddle.distributed.spawn( + func=model.run_trainer_with_spawn, + args=(args, ), + nprocs=self.nprocs, + join=True) + result_list = [] + for res_queue in context.return_queues: + result_list.append(res_queue.get()) + return result_list + + def check_dist_result_with_spawn(self, test_class, delta=1e-3): + # 0. prepare model and args + model = test_class() + args = SpawnAssistTestArgs() + + # 1. calc signal card loss + losses = self._run(model, args) + + # 2. calc multi card loss (nccl mode) + dist_losses_list = self._run_parallel(model, args) + + # 3. compare losses + for step_id in range(RUN_STEP): + loss = losses[step_id] + dist_loss_sum = None + for dist_losses in dist_losses_list: + if dist_loss_sum is None: + dist_loss_sum = np.array(dist_losses[step_id]) + else: + dist_loss_sum += np.array(dist_losses[step_id]) + dist_loss = dist_loss_sum / self.nprocs + self.assertAlmostEqual( + loss, + dist_loss, + delta=delta, + msg="The results of single-card execution and multi-card execution are inconsistent." + "signal-card loss is:\n{}\nmulti-card average loss is:\n{}\n". + format(loss, dist_loss)) diff --git a/python/paddle/fluid/tests/unittests/test_directory_migration.py b/python/paddle/fluid/tests/unittests/test_directory_migration.py index 74cc87bd9dbd6..2919ec5e9ca97 100644 --- a/python/paddle/fluid/tests/unittests/test_directory_migration.py +++ b/python/paddle/fluid/tests/unittests/test_directory_migration.py @@ -38,9 +38,10 @@ def test_new_directory(self): 'paddle.enable_static', 'paddle.disable_static', 'paddle.in_dynamic_mode', 'paddle.to_variable', 'paddle.grad', 'paddle.no_grad', 'paddle.save', 'paddle.load', - 'paddle.static.save', 'paddle.static.load', 'paddle.ParallelEnv', - 'paddle.prepare_context', 'paddle.DataParallel', 'paddle.jit', - 'paddle.jit.TracedLayer', 'paddle.jit.to_static', + 'paddle.static.save', 'paddle.static.load', + 'paddle.distributed.ParallelEnv', + 'paddle.distributed.prepare_context', 'paddle.DataParallel', + 'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static', 'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer', 'paddle.jit.save', 'paddle.jit.load', 'paddle.jit.SaveLoadConfig', 'paddle.NoamDecay', 'paddle.PiecewiseDecay', diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index ba292f2d87c37..faff81fa84fb5 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -23,8 +23,11 @@ import six import argparse import pickle +import random import numpy as np import time + +import paddle import paddle.fluid as fluid from paddle.fluid import compiler import paddle.fluid.dygraph as dygraph @@ -382,22 +385,22 @@ def run_one_loop(self, model, opt, data): raise NotImplementedError( "train_one_loop should be implemented by the child classes.") + def _get_data(self, batch, args): + if args.update_method != "local": + new_batch = [] + for offset, item in enumerate(batch): + if offset % 2 == args.trainer_id: + new_batch.append(item) + return new_batch + else: + return batch + def run_trainer(self, args): seed = 90 device_id = int(os.getenv("FLAGS_selected_gpus", "0")) place = fluid.CUDAPlace(device_id) - def _get_data(batch): - if args.update_method != "local": - new_batch = [] - for offset, item in enumerate(batch): - if offset % 2 == args.trainer_id: - new_batch.append(item) - return new_batch - else: - return batch - with fluid.dygraph.guard(place): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed @@ -422,7 +425,7 @@ def _get_data(batch): out_losses = [] print_to_err(type(self).__name__, "begin to run dygraph training") for step_id, data in enumerate(train_reader()): - data = _get_data(data) + data = self._get_data(data, args) if step_id == RUN_STEP: break loss = self.run_one_loop(model, opt, data) @@ -444,6 +447,47 @@ def _get_data(batch): model.clear_gradients() print_to_out(out_losses) + def run_trainer_with_spawn(self, args): + # 1. enable dygraph + paddle.disable_static() + + # 2. init seed + seed = 90 + paddle.static.default_startup_program().random_seed = seed + paddle.static.default_main_program().random_seed = seed + np.random.seed(seed) + random.seed = seed + # get trainer id + args.trainer_id = paddle.distributed.get_rank() + + # 3. init parallel env + if args.update_method == "nccl2": + paddle.distributed.init_parallel_env() + + # 4. train model + model, train_reader, opt = self.get_model() + if args.update_method == "nccl2": + model = paddle.DataParallel(model) + + out_losses = [] + for step_id, data in enumerate(train_reader()): + data = self._get_data(data, args) + if step_id == RUN_STEP: + break + loss = self.run_one_loop(model, opt, data) + out_losses.append(loss.numpy()) + + if args.update_method == "nccl2": + loss = model.scale_loss(loss) + + loss.backward() + if args.update_method == "nccl2": + model.apply_collective_grads() + + opt.minimize(loss) + model.clear_gradients() + return out_losses + def runtime_main(test_class): parser = argparse.ArgumentParser(description='Run dist test.') diff --git a/python/paddle/fluid/tests/unittests/test_imperative_data_parallel.py b/python/paddle/fluid/tests/unittests/test_imperative_data_parallel.py index d3f488d92ac45..428f97c0af818 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_data_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_data_parallel.py @@ -43,7 +43,7 @@ def forward(self, inputs): class TestDataParallelStateDict(unittest.TestCase): def test_data_parallel_state_dict(self): with fluid.dygraph.guard(): - strategy = paddle.prepare_context() + strategy = paddle.distributed.prepare_context() mlp = MLP() parallel_mlp = dygraph.parallel.DataParallel(mlp, strategy) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py index 5677157fde8d7..bac196b1ab52b 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py @@ -13,11 +13,16 @@ # limitations under the License. from __future__ import print_function + +import os +import sys import unittest -from test_dist_base import TestDistBase + import paddle.fluid as fluid +from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner +from parallel_dygraph_mnist import TestMnist -import os flag_name = os.path.splitext(__file__)[0] @@ -36,5 +41,11 @@ def test_mnist(self): log_name=flag_name) +class TestParallelDygraphMnistSpawn(TestDistSpawnRunner): + def test_mnist_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn(test_class=TestMnist, delta=1e-5) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_se_resnext.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_se_resnext.py index 8c5cdf8321a4b..cf89dc484c488 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_se_resnext.py @@ -13,11 +13,16 @@ # limitations under the License. from __future__ import print_function + +import os +import sys import unittest -from test_dist_base import TestDistBase + import paddle.fluid as fluid +from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner +from parallel_dygraph_se_resnext import TestSeResNeXt -import os flag_name = os.path.splitext(__file__)[0] @@ -36,5 +41,12 @@ def test_se_resnext(self): log_name=flag_name) +class TestParallelDygraphSeResNeXtSpawn(TestDistSpawnRunner): + def test_se_resnext_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn( + test_class=TestSeResNeXt, delta=0.01) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py index 40b5833053d29..7f051f1005c7b 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py @@ -15,10 +15,13 @@ from __future__ import print_function import os +import sys import unittest -import paddle.fluid as fluid +import paddle.fluid as fluid from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner +from parallel_dygraph_sparse_embedding import TestSparseEmbedding flag_name = os.path.splitext(__file__)[0] @@ -38,5 +41,12 @@ def test_sparse_embedding(self): log_name=flag_name) +class TestParallelDygraphSparseEmdeddingSpawn(TestDistSpawnRunner): + def test_sparse_embedding_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn( + test_class=TestSparseEmbedding, delta=1e-5) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py index 385c4d892a650..c8d47eab2c519 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py @@ -15,10 +15,13 @@ from __future__ import print_function import os +import sys import unittest -import paddle.fluid as fluid +import paddle.fluid as fluid from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner +from parallel_dygraph_transformer import TestTransformer flag_name = os.path.splitext(__file__)[0] @@ -38,5 +41,12 @@ def test_transformer(self): log_name=flag_name) +class TestParallelDygraphTransformerSpawn(TestDistSpawnRunner): + def test_transformer_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn( + test_class=TestTransformer, delta=1e-5) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py b/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py new file mode 100644 index 0000000000000..ca92bc75245ce --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py @@ -0,0 +1,87 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import numpy as np +import unittest + +import paddle +import paddle.distributed as dist +from paddle.distributed.spawn import _get_subprocess_env_list + +from paddle.fluid import core +from paddle.fluid.dygraph import parallel_helper + +# NOTE(chenweihang): Coverage CI is currently not able to count python3 +# unittest, so the unittests here covers some cases that will only be +# executed in the python3 sub-process. + + +class TestInitParallelEnv(unittest.TestCase): + def test_beckend_type_error(self): + with self.assertRaises(TypeError): + dist.init_parallel_env(backend=1) + + def test_backend_value_error(self): + with self.assertRaises(ValueError): + dist.init_parallel_env(backend="mpi") + + def test_check_env_failed(self): + os.environ['FLAGS_selected_gpus'] = '0' + os.environ['PADDLE_TRAINER_ID'] = '0' + os.environ['PADDLE_CURRENT_ENDPOINT'] = '127.0.0.1:6170' + os.environ['PADDLE_TRAINERS_NUM'] = '1' + with self.assertRaises(ValueError): + dist.init_parallel_env() + + def test_init_parallel_env_break(self): + os.environ['FLAGS_selected_gpus'] = '0' + os.environ['PADDLE_TRAINER_ID'] = '0' + os.environ['PADDLE_CURRENT_ENDPOINT'] = '127.0.0.1:6170' + os.environ['PADDLE_TRAINERS_NUM'] = '1' + os.environ['PADDLE_TRAINER_ENDPOINTS'] = '127.0.0.1:6170' + # coverage success branch + dist.init_parallel_env() + self.assertFalse(parallel_helper._is_parallel_ctx_initialized()) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSpawnAssistMethod(unittest.TestCase): + def test_only_cluster_node_ips_error(self): + with self.assertRaises(ValueError): + options = dict() + options['cluster_node_ips'] = "127.0.0.1,127.0.0.2" + _get_subprocess_env_list(nprocs=1, options=options) + + def test_nprocs_greater_than_device_num_error(self): + with self.assertRaises(RuntimeError): + _get_subprocess_env_list(nprocs=100, options=dict()) + + def test_selected_gpus_error(self): + with self.assertRaises(ValueError): + options = dict() + options['selected_gpus'] = "100,101" + _get_subprocess_env_list(nprocs=2, options=options) + + def test_get_correct_env(self): + env_dict = _get_subprocess_env_list(nprocs=1, options=dict())[0] + self.assertEqual(env_dict['PADDLE_TRAINER_ID'], '0') + self.assertEqual(env_dict['PADDLE_TRAINERS_NUM'], '1') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index 95a0cb5204679..b2975283fbef0 100644 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -50,8 +50,6 @@ from ..fluid.dygraph.base import grad #DEFINE_ALIAS from ..fluid.dygraph.checkpoint import load_dygraph as load #DEFINE_ALIAS from ..fluid.dygraph.checkpoint import save_dygraph as save #DEFINE_ALIAS -from ..fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS -from ..fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS from ..fluid.dygraph.parallel import DataParallel #DEFINE_ALIAS from ..fluid.dygraph.learning_rate_scheduler import NoamDecay #DEFINE_ALIAS