Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Openmind to master #1

Merged
merged 2 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mindnlp/accelerate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
# DDPCommunicationHookType,
# DeepSpeedPlugin,
# DistributedDataParallelKwargs,
# DistributedType,
# FullyShardedDataParallelPlugin,
accelerate_distributed_type,
DistributedType,
# GradScalerKwargs,
# InitProcessGroupKwargs,
# ProfileKwargs,
Expand Down
21 changes: 15 additions & 6 deletions mindnlp/accelerate/accelerator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
"""accelerate"""
import os
import mindspore
import numpy

from contextlib import contextmanager
from typing import Optional

import mindspore
from mindspore import nn
from mindspore.communication import init

from .state import AcceleratorState
from .utils import (
DistributedType,
MindFormersPlugin,
is_mindformers_available,
wait_for_everyone
)
from .utils import DistributedType,accelerate_distributed_type
from ..utils import logging

if is_mindformers_available():
Expand Down Expand Up @@ -45,7 +46,7 @@ def __init__(
# init mindformers_plugin from env variables
if mindformers_plugin is None:
mindformers_plugin = (
MindFormersPlugin() if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true" else None
MindFormersPlugin() if accelerate_distributed_type == DistributedType.MINDFORMERS else None
)
else:
os.environ["ACCELERATE_USE_MINDFORMERS"] = "true"
Expand Down Expand Up @@ -104,12 +105,20 @@ def prepare(self, *args):
"""
result = []

# Only support mindsormers now
# Only support mindsormers and MULTI_NPU_DP now
if self.distributed_type == DistributedType.MINDFORMERS:
result = self._prepare_mindformers(*args)

elif self.distributed_type == DistributedType.MULTI_NPU_DP:
result = self._prepare_data_parallel_native_minspore(*args)
return result

def _prepare_data_parallel_native_minspore(self, *args):
# initialize data parallel for native mindspore
mindspore.set_context(mode=mindspore.GRAPH_MODE)
mindspore.set_auto_parallel_context(parallel_mode=mindspore.ParallelMode.DATA_PARALLEL, gradients_mean=True)
mindspore.communication.init()
mindspore.set_seed(numpy.random.seed())

def _prepare_mindformers(self, *args):
mindformers_plugin = self.state.mindformers_plugin

Expand Down
21 changes: 15 additions & 6 deletions mindnlp/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from contextlib import contextmanager
from typing import Callable, Any
from mindspore import communication

try:
from mindspore.communication.comm_func import barrier
except:
barrier = None

from .utils import (
DistributedType, is_mindformers_available
is_mindformers_available
)
from ..accelerate.utils import accelerate_distributed_type, DistributedType

SharedDict = dict

Expand Down Expand Up @@ -341,11 +343,14 @@ def print(self, *args, **kwargs):
print(*args, **kwargs)

def _prepare_backend(self):
# now mindformers only
if is_mindformers_available():
# now mindformers and mindspore data parallel only
if accelerate_distributed_type == DistributedType.MINDFORMERS and is_mindformers_available():
self.backend = "hccl"
self.distributed_type = DistributedType.MINDFORMERS

elif accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
self.backend = "hccl"
self.distributed_type = DistributedType.MULTI_NPU_DP

@num_processes.setter
def num_processes(self, value):
self._num_processes = value
Expand All @@ -366,10 +371,14 @@ def __init__(self, mindformers_plugin=None, **kwargs):
if PartialState._shared_state:
PartialState(**kwargs)
self.__dict__.update(PartialState._shared_state)

if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true":
# set distributed_type
if accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
self.distributed_type = DistributedType.MULTI_NPU_DP
elif accelerate_distributed_type == DistributedType.MINDFORMERS:
self.distributed_type = DistributedType.MINDFORMERS
self.mindformers_plugin = mindformers_plugin
else:
self.distributed_type = DistributedType.NO

PartialState._shared_state["distributed_type"] = self.distributed_type

Expand Down
1 change: 1 addition & 0 deletions mindnlp/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""accelerate utils"""
from .constants import accelerate_distributed_type
from .dataclasses import (
DistributedType,
MindFormersPlugin
Expand Down
21 changes: 21 additions & 0 deletions mindnlp/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""constants"""
import os
from .dataclasses import DistributedType

def detect_accelerate_distributed_type():
"""
detect distributed_type

Returns:
_type_: According to the factors such as the available parallel software and hardware environment of the current system and the user-specified parallel scheme,
the optimal parallel strategy is comprehensively decided in different situations.
"""
if os.environ.get("MULTI_NPU_DP", None) == "true":
return DistributedType.MULTI_NPU_DP
if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true":
return DistributedType.MINDFORMERS
else:
return DistributedType.NO

accelerate_distributed_type = detect_accelerate_distributed_type()

3 changes: 3 additions & 0 deletions mindnlp/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ class DistributedType(str, enum.Enum):

Values:
- **MINDFORMERS** -- Using mindformers
- **NO** -- Not a distributed environment, just a single process.
- **MULTI_NPU_DP** -- Distributed data parallel on multiple NPUs.
"""

MULTI_NPU_DP = "MULTI_NPU_DP"
MINDFORMERS = "MINDFORMERS"
NO = "NO"

Expand Down
17 changes: 14 additions & 3 deletions mindnlp/dataset/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from datasets import Dataset, IterableDataset, Split, Features, \
DownloadConfig, DownloadMode, VerificationMode, Version
from mindnlp.configs import DEFAULT_ROOT
from mindspore.communication import get_rank, get_group_size
from ..accelerate import DistributedType
from ..accelerate.utils import accelerate_distributed_type


class TransferIterableDataset():
"""TransferDataset for Huggingface Dataset."""
Expand Down Expand Up @@ -331,12 +335,19 @@ def load_dataset(
column_names = list(raw_ds.features.keys())
source = TransferDataset(raw_ds, column_names) if isinstance(raw_ds, Dataset) \
else TransferIterableDataset(raw_ds, column_names)
ms_ds = GeneratorDataset(
source=source,
if accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
ms_ds = GeneratorDataset(source=source,
column_names=column_names,
shuffle=shuffle,
num_parallel_workers=num_proc if num_proc else 1,
num_shards=get_group_size(), shard_id=get_rank())
datasets_dict[key] = ms_ds
else:
ms_ds = GeneratorDataset(source=source,
column_names=column_names,
shuffle=shuffle,
num_parallel_workers=num_proc if num_proc else 1)
datasets_dict[key] = ms_ds
datasets_dict[key] = ms_ds

if len(datasets_dict) == 1:
return datasets_dict.popitem()[1]
Expand Down
13 changes: 12 additions & 1 deletion mindnlp/engine/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from ...dataset import BaseMapFunction
from ...utils import logging, find_labels, can_return_loss
from ...accelerate.utils import DistributedType
from ...accelerate.utils import accelerate_distributed_type
from ...utils.import_utils import is_safetensors_available
from ...transformers.modeling_utils import PreTrainedModel
from ...transformers.configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -88,6 +90,7 @@
TrainerControl,
TrainerState,
)
from ..utils import _get_learning_rate


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -124,7 +127,6 @@ class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for MindSpore, optimized for 🤗 Transformers.
"""
from ..utils import _get_learning_rate
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
Expand Down Expand Up @@ -284,6 +286,7 @@ def __init__(
# Internal variables to help with automatic batch size reduction
self._train_batch_size = args.train_batch_size
self._created_lr_scheduler = False
self.actual_distributed_type = accelerate_distributed_type

def _activate_neftune(self, model):
r"""
Expand Down Expand Up @@ -1373,6 +1376,14 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[mindspore.Tens
inputs = self._prepare_inputs(inputs)

def forward(inputs):
if accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
from mindspore.communication import get_group_size
import mindspore.ops as msops
rank_size = get_group_size()
for parameter in model.parameters():
all_reduce_sum = msops.AllReduce(msops.ReduceOp.SUM)
new_grads_mean = all_reduce_sum(parameter.grad) / rank_size
parameter.grad = new_grads_mean
return self.compute_loss(model, inputs)

if getattr(self, 'grad_fn', None) is None or self.model_reload:
Expand Down
4 changes: 3 additions & 1 deletion mindnlp/engine/trainer/default_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"""
utils for trainer.
"""
from mindspore import ops, value_and_grad
from mindspore import nn, ops, value_and_grad
from mindspore.amp import all_finite

from mindnlp.utils import ModelOutput
from ...accelerate.utils import DistributedType
from ...accelerate.utils import accelerate_distributed_type

def get_default_forward_fn_with_loss_fn(network, loss_fn, loss_scaler):
"""get default forward function with loss function"""
Expand Down
35 changes: 35 additions & 0 deletions tests/accelerate/grad_Reduce_ut/test_grad_Reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@


def test_AllReduce_mean():
import numpy as np
from mindspore.communication import init, get_rank, get_group_size
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops

init()
rank_size = get_group_size()
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.all_reduce_sum = ops.AllReduce(ops.ReduceOp.SUM)

def construct(self, x):
new_grads_mean = self.all_reduce_sum(x) / rank_size
new_grad = new_grads_mean
return new_grad

rank_id_value = get_rank() # Current NPU number 0,...,7
print('rank_id_value=',rank_id_value)
input_x = ms.Tensor(np.array([[rank_id_value]]).astype(np.float32))
print('input_x=',input_x)
net = Net()
output = net(input_x)
print("mean:",output) # sum(0, rank_size) / rank_size




if __name__ == '__main__':
test_AllReduce_mean()

3 changes: 3 additions & 0 deletions tests/accelerate/grad_Reduce_ut/test_grad_Reduce.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# mpirun -n 8 -H 127.0.0.1:8 --output-filename bak/log_output_mpirun_single/log_ python test_grad_Reduce.py
msrun --worker_num=4 --local_worker_num=4 --master_port=8123 --log_dir=bak/msrun_log --join=True --cluster_time_out=100 test_grad_Reduce.py
# msrun --worker_num=8 --local_worker_num=8 --master_port=8123 --log_dir=bak/msrun_log --join=True --cluster_time_out=100 test_grad_Reduce.py
Loading