Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-Pick][HybridParallel]Fix precision problem of model parallel #33087

Merged
merged 1 commit into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ message PipelineConfig {

message TensorParallelConfig {
optional int32 tensor_parallel_degree = 1 [ default = 1 ];
optional int32 tensor_init_seed = 2 [ default = -1 ];
}

message DistributedStrategy {
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ def get_group_rank(self, rank):
else:
return -1

def __repr__(self):
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
self.rank, self.nranks, self.id)
debug_str += ", ".join(map(str, self.ranks))
debug_str += ". "
return debug_str


_global_env = None

Expand Down
5 changes: 4 additions & 1 deletion python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,8 @@ def tensor_parallel_configs(self):
**Notes**:
**Detailed arguments for tensor_parallel_configs**
**tensor_parallel_degree**: degree of tensor parallel
**tensor_init_seed**: parameter initialization random seed


Examples:

Expand All @@ -931,7 +933,8 @@ def tensor_parallel_configs(self):
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4}
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4,
"tensor_init_seed": 123}

"""
return get_msg_dict(self.strategy.tensor_parallel_configs)
Expand Down
15 changes: 12 additions & 3 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import warnings
import paddle
import os
import numpy as np
from paddle.fluid.framework import dygraph_only
from paddle.fluid import compiler
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
Expand All @@ -28,7 +29,7 @@
from paddle.fluid.dygraph import parallel_helper
from . import topology as tp
from .topology import ParallelMode
from ..meta_parallel import ModelParallel
from ..meta_parallel import TensorParallel, model_parallel_random_seed
from ..meta_parallel import PipelineParallel
from ..meta_optimizers import HybridParallelOptimizer
from ..meta_optimizers import HybridParallelGradScaler
Expand Down Expand Up @@ -279,6 +280,14 @@ def _init_hybrid_parallel_env(self):

self._hcg = tp.HybridCommunicateGroup(self._topology)

if self.mp_degree > 1:
tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
if tensor_init_seed == -1:
model_parallel_random_seed()
else:
model_parallel_random_seed(tensor_init_seed)

def get_hybrid_communicate_group(self):
assert self._hcg is not None
return self._hcg
Expand Down Expand Up @@ -780,8 +789,8 @@ def forward(self, x):
last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy.
find_unused_parameters)
elif self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL:
distributed_model = ModelParallel(
elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
distributed_model = TensorParallel(
model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
distributed_model = PipelineParallel(
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class ParallelMode(object):
DATA_PARALLEL = 0
MODEL_PARALLEL = 1
TENSOR_PARALLEL = 1
PIPELINE_PARALLEL = 2


Expand Down Expand Up @@ -155,12 +155,12 @@ def __init__(self, topology):
_HYBRID_PARALLEL_GROUP = self

def get_parallel_mode(self):
# there are three modes : DataParallel / ModelParallel / PipelineParallel
# there are three modes : DataParallel / TensorParallel / PipelineParallel
if self._mp_degree == 1 and self._pp_degree == 1:
return ParallelMode.DATA_PARALLEL
elif self._mp_degree > 1 and self._pp_degree == 1:
# initialize the seed
return ParallelMode.MODEL_PARALLEL
return ParallelMode.TENSOR_PARALLEL
elif self._pp_degree > 1:
return ParallelMode.PIPELINE_PARALLEL

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, scaler, hcg):
self._scaler = scaler
self._hcg = hcg
self._is_mp = (
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL)
self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL)

def scale(self, var):
return self._scaler.scale(var)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def __init__(self, optimizer, hcg, strategy):
self._strategy = strategy
self._hcg = hcg
self._is_mp = (
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL)
self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL)
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)

if isinstance(self._inner_opt._grad_clip,
ClipGradByGlobalNorm) and self._is_mp:
logger.warning("using ClipGradByGlobalNorm in ModelParallel, the origin " \
logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
"optmizer'grad clip will be changed.")
self._inner_opt._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/fleet/meta_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .parallel_layers import RNGStatesTracker # noqa: F401
from .parallel_layers import model_parallel_random_seed # noqa: F401
from .parallel_layers import get_rng_state_tracker # noqa: F401
from .model_parallel import ModelParallel # noqa: F401
from .tensor_parallel import TensorParallel # noqa: F401
from .pipeline_parallel import PipelineParallel # noqa: F401

__all__ = []
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self,
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()

self.origin_num_embeddings = num_embeddings
self.is_mp = (self.world_size > 1)

per_part_size = (
num_embeddings + self.world_size - 1) // self.world_size
Expand All @@ -50,16 +51,36 @@ def __init__(self,
per_part_size += 1 # make the last row as the padding index
self.per_part_size = per_part_size

self.embedding = paddle.nn.Embedding(
per_part_size,
embedding_dim,
padding_idx=per_part_size - 1,
sparse=False,
weight_attr=weight_attr,
name=name)
self.embedding.weight.is_distributed = True
self._dtype = self._helper.get_default_dtype()
self._size = [per_part_size, embedding_dim]
self._weight_attr = weight_attr
self._name = name

if self.is_mp:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=self._size,
dtype=self._dtype,
is_bias=False)
self.weight[per_part_size - 1] = 0.0
self.weight.is_distributed = True
else:
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=[num_embeddings, embedding_dim],
dtype=self._dtype,
is_bias=False)

def forward(self, x):
if not self.is_mp:
return F.embedding(
x,
weight=self.weight,
padding_idx=None,
sparse=False,
name=self._name)

origin_input_shape = x.shape
if len(origin_input_shape) == 2:
x = paddle.unsqueeze(x, axis=-1)
Expand All @@ -72,13 +93,18 @@ def forward(self, x):
if len(origin_input_shape) == 2:
x_shard = paddle.squeeze(x_shard, axis=-1)

emb_out = self.embedding(x_shard)
if self.world_size > 1:
emb_out = paddle.distributed.collective._mp_allreduce(
emb_out,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
emb_out = F.embedding(
x_shard,
weight=self.weight,
padding_idx=self.per_part_size - 1,
sparse=False,
name=self._name)

emb_out = paddle.distributed.collective._mp_allreduce(
emb_out,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
return emb_out


Expand All @@ -96,8 +122,9 @@ def __init__(self,
)
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
)
self._name = name
self.is_mp = (self.world_size > 1)

self.name = name
self.gather_output = gather_output
assert out_features % self.world_size == 0, (
"Number of column of the weight for linear ({}) must be"
Expand All @@ -108,29 +135,45 @@ def __init__(self,
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()

self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype)
if self.is_mp:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)

self.weight.is_distributed = True

if has_bias:
# initialize bias to zero like Megatron
self.bias = self.create_parameter(
shape=[self.output_size_per_partition],
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype)
dtype=self._dtype,
is_bias=True)
self.bias.is_distributed = True
else:
self.bias = None

def forward(self, x):
# use inner api to process identity
input_parallel = paddle.distributed.collective._c_identity(
x, group=self.model_parallel_group)
if self.is_mp:
input_parallel = paddle.distributed.collective._c_identity(
x, group=self.model_parallel_group)
else:
input_parallel = x

output_parallel = F.linear(
input_parallel, self.weight, self.bias, name=self.name)
if self.gather_output:
input_parallel, self.weight, self.bias, name=self._name)

if self.gather_output and self.is_mp:
output = paddle.distributed.collective._c_concat(
output_parallel,
nranks=self.world_size,
Expand All @@ -155,37 +198,49 @@ def __init__(self,
self.input_is_parallel = input_is_parallel
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()
self.name = name
self._name = name

self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
)
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
)
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()

self.is_mp = (self.world_size > 1)
assert in_features % self.world_size == 0, (
"Number of row of the weight for linear ({}) must be"
" divisible by model parallel size ({})".format(in_features,
self.world_size))

self.input_size_per_partition = in_features // self.world_size

self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype)
if self.is_mp:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)

self.weight.is_distributed = True

if has_bias:
self.bias = self.create_parameter(
shape=[self.out_features],
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype)
dtype=self._dtype,
is_bias=True)
else:
self.bias = None

def forward(self, x):
if self.input_is_parallel:
if self.input_is_parallel or (not self.is_mp):
input_parallel = x
else:
# split last dim
Expand All @@ -195,12 +250,16 @@ def forward(self, x):
nranks=self.world_size,
group=self.model_parallel_group)

output_parallel = F.linear(input_parallel, self.weight, name=self.name)
output_ = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
output_parallel = F.linear(input_parallel, self.weight, name=self._name)

if self.is_mp:
output_ = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
else:
output_ = output_parallel

output = output_ + self.bias if self.bias is not None else output_
return output
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import paddle
import contextlib
import numpy as np

__all__ = []

Expand Down Expand Up @@ -65,14 +66,18 @@ def get_rng_state_tracker():
return RNG_STATE_TRACKER


def model_parallel_random_seed(seed=2048):
def model_parallel_random_seed(seed=None):
import paddle.distributed.fleet as fleet
hcg = fleet.get_hybrid_communicate_group()
rank = hcg.get_model_parallel_rank()

local_seed = seed + 1024 + rank
global_seed = seed
if seed:
global_seed = seed
local_seed = seed * 1024 + rank * 100
else:
global_seed = np.random.randint(0, 655350)
local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1)

RNG_STATE_TRACKER.reset()
paddle.seed(global_seed)
RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
paddle.seed(global_seed)
Loading