Skip to content

Commit

Permalink
[HybridParallel]Fix precision problem of model parallel (#32897)
Browse files Browse the repository at this point in the history
* fix precision of mp

* fix bug of seed

* fix dp

* print group
  • Loading branch information
ForFishes authored May 17, 2021
1 parent 906db71 commit c809530
Show file tree
Hide file tree
Showing 14 changed files with 151 additions and 58 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ message PipelineConfig {

message TensorParallelConfig {
optional int32 tensor_parallel_degree = 1 [ default = 1 ];
optional int32 tensor_init_seed = 2 [ default = -1 ];
}

message DistributedStrategy {
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def get_group_rank(self, rank):
else:
return -1

def __repr__(self):
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
self.rank, self.nranks, self.id)
debug_str += ", ".join(map(str, self.ranks))
debug_str += ". "
return debug_str


_global_env = None

Expand Down
5 changes: 4 additions & 1 deletion python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,8 @@ def tensor_parallel_configs(self):
**Notes**:
**Detailed arguments for tensor_parallel_configs**
**tensor_parallel_degree**: degree of tensor parallel
**tensor_init_seed**: parameter initialization random seed
Examples:
Expand All @@ -957,7 +959,8 @@ def tensor_parallel_configs(self):
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4}
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4,
"tensor_init_seed": 123}
"""
return get_msg_dict(self.strategy.tensor_parallel_configs)
Expand Down
15 changes: 12 additions & 3 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import warnings
import paddle
import os
import numpy as np
from paddle.fluid.framework import dygraph_only
from paddle.fluid import compiler
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
Expand All @@ -28,7 +29,7 @@
from paddle.fluid.dygraph import parallel_helper
from . import topology as tp
from .topology import ParallelMode
from ..meta_parallel import ModelParallel
from ..meta_parallel import TensorParallel, model_parallel_random_seed
from ..meta_parallel import PipelineParallel
from ..meta_optimizers import HybridParallelOptimizer
from ..meta_optimizers import HybridParallelGradScaler
Expand Down Expand Up @@ -279,6 +280,14 @@ def _init_hybrid_parallel_env(self):

self._hcg = tp.HybridCommunicateGroup(self._topology)

if self.mp_degree > 1:
tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
if tensor_init_seed == -1:
model_parallel_random_seed()
else:
model_parallel_random_seed(tensor_init_seed)

def get_hybrid_communicate_group(self):
assert self._hcg is not None
return self._hcg
Expand Down Expand Up @@ -829,8 +838,8 @@ def forward(self, x):
last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy.
find_unused_parameters)
elif self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL:
distributed_model = ModelParallel(
elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
distributed_model = TensorParallel(
model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
distributed_model = PipelineParallel(
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class ParallelMode(object):
DATA_PARALLEL = 0
MODEL_PARALLEL = 1
TENSOR_PARALLEL = 1
PIPELINE_PARALLEL = 2


Expand Down Expand Up @@ -155,12 +155,12 @@ def __init__(self, topology):
_HYBRID_PARALLEL_GROUP = self

def get_parallel_mode(self):
# there are three modes : DataParallel / ModelParallel / PipelineParallel
# there are three modes : DataParallel / TensorParallel / PipelineParallel
if self._mp_degree == 1 and self._pp_degree == 1:
return ParallelMode.DATA_PARALLEL
elif self._mp_degree > 1 and self._pp_degree == 1:
# initialize the seed
return ParallelMode.MODEL_PARALLEL
return ParallelMode.TENSOR_PARALLEL
elif self._pp_degree > 1:
return ParallelMode.PIPELINE_PARALLEL

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, scaler, hcg):
self._scaler = scaler
self._hcg = hcg
self._is_mp = (
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL)
self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL)

def scale(self, var):
return self._scaler.scale(var)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def __init__(self, optimizer, hcg, strategy):
self._strategy = strategy
self._hcg = hcg
self._is_mp = (
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL)
self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL)
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)

if isinstance(self._inner_opt._grad_clip,
ClipGradByGlobalNorm) and self._is_mp:
logger.warning("using ClipGradByGlobalNorm in ModelParallel, the origin " \
logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
"optmizer'grad clip will be changed.")
self._inner_opt._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/fleet/meta_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .parallel_layers import RNGStatesTracker # noqa: F401
from .parallel_layers import model_parallel_random_seed # noqa: F401
from .parallel_layers import get_rng_state_tracker # noqa: F401
from .model_parallel import ModelParallel # noqa: F401
from .tensor_parallel import TensorParallel # noqa: F401
from .pipeline_parallel import PipelineParallel # noqa: F401

__all__ = []
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self,
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()

self.origin_num_embeddings = num_embeddings
self.is_mp = (self.world_size > 1)

per_part_size = (
num_embeddings + self.world_size - 1) // self.world_size
Expand All @@ -50,16 +51,36 @@ def __init__(self,
per_part_size += 1 # make the last row as the padding index
self.per_part_size = per_part_size

self.embedding = paddle.nn.Embedding(
per_part_size,
embedding_dim,
padding_idx=per_part_size - 1,
sparse=False,
weight_attr=weight_attr,
name=name)
self.embedding.weight.is_distributed = True
self._dtype = self._helper.get_default_dtype()
self._size = [per_part_size, embedding_dim]
self._weight_attr = weight_attr
self._name = name

if self.is_mp:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=self._size,
dtype=self._dtype,
is_bias=False)
self.weight[per_part_size - 1] = 0.0
self.weight.is_distributed = True
else:
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=[num_embeddings, embedding_dim],
dtype=self._dtype,
is_bias=False)

def forward(self, x):
if not self.is_mp:
return F.embedding(
x,
weight=self.weight,
padding_idx=None,
sparse=False,
name=self._name)

origin_input_shape = x.shape
if len(origin_input_shape) == 2:
x = paddle.unsqueeze(x, axis=-1)
Expand All @@ -72,13 +93,18 @@ def forward(self, x):
if len(origin_input_shape) == 2:
x_shard = paddle.squeeze(x_shard, axis=-1)

emb_out = self.embedding(x_shard)
if self.world_size > 1:
emb_out = paddle.distributed.collective._mp_allreduce(
emb_out,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
emb_out = F.embedding(
x_shard,
weight=self.weight,
padding_idx=self.per_part_size - 1,
sparse=False,
name=self._name)

emb_out = paddle.distributed.collective._mp_allreduce(
emb_out,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
return emb_out


Expand All @@ -96,8 +122,9 @@ def __init__(self,
)
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
)
self._name = name
self.is_mp = (self.world_size > 1)

self.name = name
self.gather_output = gather_output
assert out_features % self.world_size == 0, (
"Number of column of the weight for linear ({}) must be"
Expand All @@ -108,29 +135,45 @@ def __init__(self,
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()

self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype)
if self.is_mp:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)

self.weight.is_distributed = True

if has_bias:
# initialize bias to zero like Megatron
self.bias = self.create_parameter(
shape=[self.output_size_per_partition],
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype)
dtype=self._dtype,
is_bias=True)
self.bias.is_distributed = True
else:
self.bias = None

def forward(self, x):
# use inner api to process identity
input_parallel = paddle.distributed.collective._c_identity(
x, group=self.model_parallel_group)
if self.is_mp:
input_parallel = paddle.distributed.collective._c_identity(
x, group=self.model_parallel_group)
else:
input_parallel = x

output_parallel = F.linear(
input_parallel, self.weight, self.bias, name=self.name)
if self.gather_output:
input_parallel, self.weight, self.bias, name=self._name)

if self.gather_output and self.is_mp:
output = paddle.distributed.collective._c_concat(
output_parallel,
nranks=self.world_size,
Expand All @@ -155,37 +198,49 @@ def __init__(self,
self.input_is_parallel = input_is_parallel
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()
self.name = name
self._name = name

self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
)
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
)
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()

self.is_mp = (self.world_size > 1)
assert in_features % self.world_size == 0, (
"Number of row of the weight for linear ({}) must be"
" divisible by model parallel size ({})".format(in_features,
self.world_size))

self.input_size_per_partition = in_features // self.world_size

self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype)
if self.is_mp:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)

self.weight.is_distributed = True

if has_bias:
self.bias = self.create_parameter(
shape=[self.out_features],
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype)
dtype=self._dtype,
is_bias=True)
else:
self.bias = None

def forward(self, x):
if self.input_is_parallel:
if self.input_is_parallel or (not self.is_mp):
input_parallel = x
else:
# split last dim
Expand All @@ -195,12 +250,16 @@ def forward(self, x):
nranks=self.world_size,
group=self.model_parallel_group)

output_parallel = F.linear(input_parallel, self.weight, name=self.name)
output_ = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
output_parallel = F.linear(input_parallel, self.weight, name=self._name)

if self.is_mp:
output_ = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
else:
output_ = output_parallel

output = output_ + self.bias if self.bias is not None else output_
return output
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import paddle
import contextlib
import numpy as np

__all__ = []

Expand Down Expand Up @@ -65,14 +66,18 @@ def get_rng_state_tracker():
return RNG_STATE_TRACKER


def model_parallel_random_seed(seed=2048):
def model_parallel_random_seed(seed=None):
import paddle.distributed.fleet as fleet
hcg = fleet.get_hybrid_communicate_group()
rank = hcg.get_model_parallel_rank()

local_seed = seed + 1024 + rank
global_seed = seed
if seed:
global_seed = seed
local_seed = seed * 1024 + rank * 100
else:
global_seed = np.random.randint(0, 655350)
local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1)

RNG_STATE_TRACKER.reset()
paddle.seed(global_seed)
RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
paddle.seed(global_seed)
Loading

0 comments on commit c809530

Please sign in to comment.