diff --git a/mindnlp/accelerate/__init__.py b/mindnlp/accelerate/__init__.py index a3966932e..573430bcb 100644 --- a/mindnlp/accelerate/__init__.py +++ b/mindnlp/accelerate/__init__.py @@ -5,8 +5,9 @@ # DDPCommunicationHookType, # DeepSpeedPlugin, # DistributedDataParallelKwargs, - # DistributedType, # FullyShardedDataParallelPlugin, + accelerate_distributed_type, + DistributedType, # GradScalerKwargs, # InitProcessGroupKwargs, # ProfileKwargs, diff --git a/mindnlp/accelerate/accelerator.py b/mindnlp/accelerate/accelerator.py index 059fe96d0..3e039b00a 100644 --- a/mindnlp/accelerate/accelerator.py +++ b/mindnlp/accelerate/accelerator.py @@ -1,19 +1,20 @@ """accelerate""" import os +import mindspore +import numpy + from contextlib import contextmanager from typing import Optional - -import mindspore from mindspore import nn from mindspore.communication import init from .state import AcceleratorState from .utils import ( - DistributedType, MindFormersPlugin, is_mindformers_available, wait_for_everyone ) +from .utils import DistributedType,accelerate_distributed_type from ..utils import logging if is_mindformers_available(): @@ -45,7 +46,7 @@ def __init__( # init mindformers_plugin from env variables if mindformers_plugin is None: mindformers_plugin = ( - MindFormersPlugin() if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true" else None + MindFormersPlugin() if accelerate_distributed_type == DistributedType.MINDFORMERS else None ) else: os.environ["ACCELERATE_USE_MINDFORMERS"] = "true" @@ -104,12 +105,20 @@ def prepare(self, *args): """ result = [] - # Only support mindsormers now + # Only support mindsormers and MULTI_NPU_DP now if self.distributed_type == DistributedType.MINDFORMERS: result = self._prepare_mindformers(*args) - + elif self.distributed_type == DistributedType.MULTI_NPU_DP: + result = self._prepare_data_parallel_native_minspore(*args) return result + def _prepare_data_parallel_native_minspore(self, *args): + # initialize data parallel for native mindspore + mindspore.set_context(mode=mindspore.GRAPH_MODE) + mindspore.set_auto_parallel_context(parallel_mode=mindspore.ParallelMode.DATA_PARALLEL, gradients_mean=True) + mindspore.communication.init() + mindspore.set_seed(numpy.random.seed()) + def _prepare_mindformers(self, *args): mindformers_plugin = self.state.mindformers_plugin diff --git a/mindnlp/accelerate/state.py b/mindnlp/accelerate/state.py index 059085714..f82679fc4 100644 --- a/mindnlp/accelerate/state.py +++ b/mindnlp/accelerate/state.py @@ -4,14 +4,16 @@ from contextlib import contextmanager from typing import Callable, Any from mindspore import communication + try: from mindspore.communication.comm_func import barrier except: barrier = None from .utils import ( - DistributedType, is_mindformers_available + is_mindformers_available ) +from ..accelerate.utils import accelerate_distributed_type, DistributedType SharedDict = dict @@ -341,11 +343,14 @@ def print(self, *args, **kwargs): print(*args, **kwargs) def _prepare_backend(self): - # now mindformers only - if is_mindformers_available(): + # now mindformers and mindspore data parallel only + if accelerate_distributed_type == DistributedType.MINDFORMERS and is_mindformers_available(): self.backend = "hccl" self.distributed_type = DistributedType.MINDFORMERS - + elif accelerate_distributed_type == DistributedType.MULTI_NPU_DP: + self.backend = "hccl" + self.distributed_type = DistributedType.MULTI_NPU_DP + @num_processes.setter def num_processes(self, value): self._num_processes = value @@ -366,10 +371,14 @@ def __init__(self, mindformers_plugin=None, **kwargs): if PartialState._shared_state: PartialState(**kwargs) self.__dict__.update(PartialState._shared_state) - - if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true": + # set distributed_type + if accelerate_distributed_type == DistributedType.MULTI_NPU_DP: + self.distributed_type = DistributedType.MULTI_NPU_DP + elif accelerate_distributed_type == DistributedType.MINDFORMERS: self.distributed_type = DistributedType.MINDFORMERS self.mindformers_plugin = mindformers_plugin + else: + self.distributed_type = DistributedType.NO PartialState._shared_state["distributed_type"] = self.distributed_type diff --git a/mindnlp/accelerate/utils/__init__.py b/mindnlp/accelerate/utils/__init__.py index fac142e81..98bdbb08c 100644 --- a/mindnlp/accelerate/utils/__init__.py +++ b/mindnlp/accelerate/utils/__init__.py @@ -1,4 +1,5 @@ """accelerate utils""" +from .constants import accelerate_distributed_type from .dataclasses import ( DistributedType, MindFormersPlugin diff --git a/mindnlp/accelerate/utils/constants.py b/mindnlp/accelerate/utils/constants.py new file mode 100644 index 000000000..dfdc16739 --- /dev/null +++ b/mindnlp/accelerate/utils/constants.py @@ -0,0 +1,21 @@ +"""constants""" +import os +from .dataclasses import DistributedType + +def detect_accelerate_distributed_type(): + """ + detect distributed_type + + Returns: + _type_: According to the factors such as the available parallel software and hardware environment of the current system and the user-specified parallel scheme, + the optimal parallel strategy is comprehensively decided in different situations. + """ + if os.environ.get("MULTI_NPU_DP", None) == "true": + return DistributedType.MULTI_NPU_DP + if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true": + return DistributedType.MINDFORMERS + else: + return DistributedType.NO + +accelerate_distributed_type = detect_accelerate_distributed_type() + \ No newline at end of file diff --git a/mindnlp/accelerate/utils/dataclasses.py b/mindnlp/accelerate/utils/dataclasses.py index fa3a0078b..20a9004b0 100644 --- a/mindnlp/accelerate/utils/dataclasses.py +++ b/mindnlp/accelerate/utils/dataclasses.py @@ -17,8 +17,11 @@ class DistributedType(str, enum.Enum): Values: - **MINDFORMERS** -- Using mindformers + - **NO** -- Not a distributed environment, just a single process. + - **MULTI_NPU_DP** -- Distributed data parallel on multiple NPUs. """ + MULTI_NPU_DP = "MULTI_NPU_DP" MINDFORMERS = "MINDFORMERS" NO = "NO" diff --git a/mindnlp/dataset/load.py b/mindnlp/dataset/load.py index a1e0eb942..f23325382 100644 --- a/mindnlp/dataset/load.py +++ b/mindnlp/dataset/load.py @@ -23,6 +23,10 @@ from datasets import Dataset, IterableDataset, Split, Features, \ DownloadConfig, DownloadMode, VerificationMode, Version from mindnlp.configs import DEFAULT_ROOT +from mindspore.communication import get_rank, get_group_size +from ..accelerate import DistributedType +from ..accelerate.utils import accelerate_distributed_type + class TransferIterableDataset(): """TransferDataset for Huggingface Dataset.""" @@ -331,12 +335,19 @@ def load_dataset( column_names = list(raw_ds.features.keys()) source = TransferDataset(raw_ds, column_names) if isinstance(raw_ds, Dataset) \ else TransferIterableDataset(raw_ds, column_names) - ms_ds = GeneratorDataset( - source=source, + if accelerate_distributed_type == DistributedType.MULTI_NPU_DP: + ms_ds = GeneratorDataset(source=source, + column_names=column_names, + shuffle=shuffle, + num_parallel_workers=num_proc if num_proc else 1, + num_shards=get_group_size(), shard_id=get_rank()) + datasets_dict[key] = ms_ds + else: + ms_ds = GeneratorDataset(source=source, column_names=column_names, shuffle=shuffle, num_parallel_workers=num_proc if num_proc else 1) - datasets_dict[key] = ms_ds + datasets_dict[key] = ms_ds if len(datasets_dict) == 1: return datasets_dict.popitem()[1] diff --git a/mindnlp/engine/trainer/base.py b/mindnlp/engine/trainer/base.py index 940bd38a4..343fa05d8 100644 --- a/mindnlp/engine/trainer/base.py +++ b/mindnlp/engine/trainer/base.py @@ -45,6 +45,8 @@ WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME from ...dataset import BaseMapFunction from ...utils import logging, find_labels, can_return_loss +from ...accelerate.utils import DistributedType +from ...accelerate.utils import accelerate_distributed_type from ...utils.import_utils import is_safetensors_available from ...transformers.modeling_utils import PreTrainedModel from ...transformers.configuration_utils import PretrainedConfig @@ -88,6 +90,7 @@ TrainerControl, TrainerState, ) +from ..utils import _get_learning_rate logger = logging.get_logger(__name__) @@ -124,7 +127,6 @@ class Trainer: """ Trainer is a simple but feature-complete training and eval loop for MindSpore, optimized for 🤗 Transformers. """ - from ..utils import _get_learning_rate def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, @@ -284,6 +286,7 @@ def __init__( # Internal variables to help with automatic batch size reduction self._train_batch_size = args.train_batch_size self._created_lr_scheduler = False + self.actual_distributed_type = accelerate_distributed_type def _activate_neftune(self, model): r""" @@ -1373,6 +1376,14 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[mindspore.Tens inputs = self._prepare_inputs(inputs) def forward(inputs): + if accelerate_distributed_type == DistributedType.MULTI_NPU_DP: + from mindspore.communication import get_group_size + import mindspore.ops as msops + rank_size = get_group_size() + for parameter in model.parameters(): + all_reduce_sum = msops.AllReduce(msops.ReduceOp.SUM) + new_grads_mean = all_reduce_sum(parameter.grad) / rank_size + parameter.grad = new_grads_mean return self.compute_loss(model, inputs) if getattr(self, 'grad_fn', None) is None or self.model_reload: diff --git a/mindnlp/engine/trainer/default_func.py b/mindnlp/engine/trainer/default_func.py index 03493efbf..d7b4839a5 100644 --- a/mindnlp/engine/trainer/default_func.py +++ b/mindnlp/engine/trainer/default_func.py @@ -15,10 +15,12 @@ """ utils for trainer. """ -from mindspore import ops, value_and_grad +from mindspore import nn, ops, value_and_grad from mindspore.amp import all_finite from mindnlp.utils import ModelOutput +from ...accelerate.utils import DistributedType +from ...accelerate.utils import accelerate_distributed_type def get_default_forward_fn_with_loss_fn(network, loss_fn, loss_scaler): """get default forward function with loss function""" diff --git a/tests/accelerate/grad_Reduce_ut/test_grad_Reduce.py b/tests/accelerate/grad_Reduce_ut/test_grad_Reduce.py new file mode 100644 index 000000000..405b1b444 --- /dev/null +++ b/tests/accelerate/grad_Reduce_ut/test_grad_Reduce.py @@ -0,0 +1,35 @@ + + +def test_AllReduce_mean(): + import numpy as np + from mindspore.communication import init, get_rank, get_group_size + import mindspore as ms + import mindspore.nn as nn + import mindspore.ops as ops + + init() + rank_size = get_group_size() + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.all_reduce_sum = ops.AllReduce(ops.ReduceOp.SUM) + + def construct(self, x): + new_grads_mean = self.all_reduce_sum(x) / rank_size + new_grad = new_grads_mean + return new_grad + + rank_id_value = get_rank() # Current NPU number 0,...,7 + print('rank_id_value=',rank_id_value) + input_x = ms.Tensor(np.array([[rank_id_value]]).astype(np.float32)) + print('input_x=',input_x) + net = Net() + output = net(input_x) + print("mean:",output) # sum(0, rank_size) / rank_size + + + + +if __name__ == '__main__': + test_AllReduce_mean() + diff --git a/tests/accelerate/grad_Reduce_ut/test_grad_Reduce.sh b/tests/accelerate/grad_Reduce_ut/test_grad_Reduce.sh new file mode 100644 index 000000000..7508de39e --- /dev/null +++ b/tests/accelerate/grad_Reduce_ut/test_grad_Reduce.sh @@ -0,0 +1,3 @@ +# mpirun -n 8 -H 127.0.0.1:8 --output-filename bak/log_output_mpirun_single/log_ python test_grad_Reduce.py +msrun --worker_num=4 --local_worker_num=4 --master_port=8123 --log_dir=bak/msrun_log --join=True --cluster_time_out=100 test_grad_Reduce.py +# msrun --worker_num=8 --local_worker_num=8 --master_port=8123 --log_dir=bak/msrun_log --join=True --cluster_time_out=100 test_grad_Reduce.py \ No newline at end of file