Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No recompute layers + Opt tensor fuse for sharding #752

Merged
merged 10 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_1.3B_dp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Model:
initializer_range: 0.02
use_recompute: True
recompute_granularity:
no_recompute_layers:


Distributed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Model:
initializer_range: 0.02
use_recompute: True
recompute_granularity:
no_recompute_layers:


Distributed:
Expand Down
1 change: 1 addition & 0 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_175B_mp8_pp16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Model:
initializer_range: 0.02
use_recompute: True
recompute_granularity:
no_recompute_layers:
virtual_pp_degree: 1


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Model:
initializer_range: 0.02
use_recompute: False
recompute_granularity:
no_recompute_layers:


Distributed:
Expand Down
1 change: 1 addition & 0 deletions ppfleetx/configs/nlp/gpt/pretrain_gpt_6.7B_sharding16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Model:
initializer_range: 0.02
use_recompute: True
recompute_granularity:
no_recompute_layers:


Distributed:
Expand Down
52 changes: 38 additions & 14 deletions ppfleetx/models/language_model/gpt/dygraph/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def __init__(self,
fused_linear=False,
use_recompute=False,
recompute_granularity="full",
sequence_parallel=False):
sequence_parallel=False,
do_recompute=True):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
Expand All @@ -102,6 +103,7 @@ def __init__(self,
self.fuse = fuse
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
self.do_recompute = do_recompute
self.sequence_parallel = sequence_parallel

if sequence_parallel:
Expand Down Expand Up @@ -329,7 +331,7 @@ def forward(self,
q, k, v, cache = self._prepare_qkv(query, key, value,
use_cache, cache)

if self.use_recompute and self.recompute_granularity == "core_attn":
if self.use_recompute and self.recompute_granularity == "core_attn" and self.do_recompute:
out, weights = recompute(self.core_attn, q, k, v, attn_mask)
else:
out, weights = self.core_attn(q, k, v, attn_mask=attn_mask)
Expand Down Expand Up @@ -359,9 +361,14 @@ def __init__(self,
hidden_size=None,
use_recompute=False,
recompute_granularity="full",
sequence_parallel=False):
sequence_parallel=False,
no_recompute_layers=None):
super(TransformerDecoder, self).__init__()

if no_recompute_layers is None:
no_recompute_layers = []
self.no_recompute_layers = no_recompute_layers

self.num_layers = num_layers
self.layers = decoder_layers
self.norm = norm
Expand Down Expand Up @@ -403,7 +410,7 @@ def forward(self,
cache=cache)
new_caches.append(new_cache)
else:
if self.use_recompute and self.recompute_granularity == "full":
if self.use_recompute and self.recompute_granularity == "full" and i not in self.no_recompute_layers:
output = recompute(mod, output, memory, tgt_mask,
use_cache, cache)
else:
Expand Down Expand Up @@ -459,7 +466,8 @@ def __init__(self,
recompute_attn=False,
use_recompute=False,
recompute_granularity="full",
sequence_parallel=False):
sequence_parallel=False,
do_recompute=True):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
Expand All @@ -471,6 +479,7 @@ def __init__(self,
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
self.sequence_parallel = sequence_parallel
self.do_recompute = do_recompute

if sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
Expand All @@ -492,7 +501,8 @@ def __init__(self,
fused_linear=fused_linear,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity,
sequence_parallel=sequence_parallel)
sequence_parallel=sequence_parallel,
do_recompute=do_recompute)

self.linear1 = ColumnParallelLinear(
d_model,
Expand Down Expand Up @@ -534,7 +544,7 @@ def forward(self,
tgt = self.norm1(tgt)

if use_cache is False:
if self.use_recompute and self.recompute_granularity == "full_attn":
if self.use_recompute and self.recompute_granularity == "full_attn" and self.do_recompute:
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask,
use_cache, cache)
else:
Expand Down Expand Up @@ -638,10 +648,13 @@ def __init__(self,
use_recompute=False,
fused_linear=False,
recompute_granularity="full",
sequence_parallel=False):
sequence_parallel=False,
no_recompute_layers=None):

super(GPTModelHybrid, self).__init__()

if no_recompute_layers is None:
no_recompute_layers = []
self.initializer_range = initializer_range
self.hidden_size = hidden_size
self.vocab_size = vocab_size
Expand Down Expand Up @@ -676,7 +689,8 @@ def __init__(self,
fused_linear=fused_linear,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity,
sequence_parallel=sequence_parallel))
sequence_parallel=sequence_parallel,
do_recompute=i not in no_recompute_layers))

self.decoder = TransformerDecoder(
decoder_layers,
Expand All @@ -685,7 +699,8 @@ def __init__(self,
hidden_size=hidden_size,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity,
sequence_parallel=sequence_parallel)
sequence_parallel=sequence_parallel,
no_recompute_layers=no_recompute_layers)

def forward(self,
input_ids,
Expand Down Expand Up @@ -884,11 +899,19 @@ def __init__(self,
fused_linear=False,
recompute_granularity="full",
virtual_pp_degree=1,
sequence_parallel=False):
sequence_parallel=False,
no_recompute_layers=None):

# forward desc
self.descs = []

if no_recompute_layers is None:
no_recompute_layers = []
else:
if recompute_granularity == 'full':
assert len(no_recompute_layers) == 0, \
"for pp with full recompute, no_recompute_layers is not support"

assert sequence_parallel is False, "Sequence parallel strategy \
is not supported in GPTForPretrainingPipe model now."

Expand All @@ -904,7 +927,7 @@ def __init__(self,
type_vocab_size=type_vocab_size,
initializer_range=0.02))

for _ in range(num_layers):
for i in range(num_layers):
self.descs.append(
LayerDesc(
TransformerDecoderLayer,
Expand All @@ -922,7 +945,8 @@ def __init__(self,
num_partitions=num_partitions,
fused_linear=fused_linear,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity))
recompute_granularity=recompute_granularity,
do_recompute=i not in no_recompute_layers))

self.descs.append(
LayerDesc(
Expand Down Expand Up @@ -957,7 +981,7 @@ def _logits_helper(embedding, output):
recompute_ctx={
"mp_group": fleet.fleet._hcg.get_model_parallel_group(),
"offload": False,
"partition": False
"partition": False,
},
num_virtual_pipeline_stages=virtual_pp_degree)

Expand Down
36 changes: 26 additions & 10 deletions ppfleetx/models/language_model/gpt/dygraph/single_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def __init__(self,
fuse=True,
fused_linear=False,
use_recompute=False,
recompute_granularity="full"):
recompute_granularity="full",
do_recompute=True):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
Expand All @@ -74,6 +75,7 @@ def __init__(self,
self.fuse = fuse
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
self.do_recompute = do_recompute

self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
Expand Down Expand Up @@ -240,7 +242,7 @@ def forward(self,
q, k, v, cache = self._prepare_qkv(query, key, value,
use_cache, cache)

if self.use_recompute and self.recompute_granularity == "core_attn":
if self.use_recompute and self.recompute_granularity == "core_attn" and self.do_recompute:
out, weights = recompute(self.core_attn, q, k, v, attn_mask)
else:
out, weights = self.core_attn(q, k, v, attn_mask=attn_mask)
Expand All @@ -267,9 +269,14 @@ def __init__(self,
norm=None,
hidden_size=None,
use_recompute=False,
recompute_granularity="full"):
recompute_granularity="full",
no_recompute_layers=None):
super(TransformerDecoder, self).__init__()

if no_recompute_layers is None:
no_recompute_layers = []
self.no_recompute_layers = no_recompute_layers

self.num_layers = num_layers
self.layers = decoder_layers
self.norm = norm
Expand Down Expand Up @@ -305,7 +312,7 @@ def forward(self,
cache=cache)
new_caches.append(new_cache)
else:
if self.use_recompute and self.recompute_granularity == "full":
if self.use_recompute and self.recompute_granularity == "full" and i not in self.no_recompute_layers:
output = recompute(mod, output, memory, tgt_mask,
use_cache, cache)
else:
Expand Down Expand Up @@ -357,7 +364,8 @@ def __init__(self,
bias_attr=None,
fused_linear=False,
use_recompute=False,
recompute_granularity="full"):
recompute_granularity="full",
do_recompute=True):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
Expand All @@ -368,6 +376,7 @@ def __init__(self,
self.normalize_before = normalize_before
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
self.do_recompute = do_recompute

weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
Expand All @@ -382,7 +391,8 @@ def __init__(self,
bias_attr=bias_attrs[0],
fused_linear=fused_linear,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity)
recompute_granularity=recompute_granularity,
do_recompute=do_recompute)
self.linear1 = Linear(
d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2])
self.linear2 = Linear(
Expand All @@ -401,7 +411,7 @@ def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
tgt = self.norm1(tgt)

if use_cache is False:
if self.use_recompute and self.recompute_granularity == "full_attn":
if self.use_recompute and self.recompute_granularity == "full_attn" and self.do_recompute:
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask,
use_cache, cache)
else:
Expand Down Expand Up @@ -487,9 +497,13 @@ def __init__(self,
initializer_range=0.02,
fused_linear=False,
recompute_granularity="full",
sequence_parallel=False):
sequence_parallel=False,
no_recompute_layers=None):

super(GPTModel, self).__init__()

if no_recompute_layers is None:
no_recompute_layers = []
self.initializer_range = initializer_range
self.hidden_size = hidden_size
self.vocab_size = vocab_size
Expand All @@ -515,15 +529,17 @@ def __init__(self,
bias_attr=None,
fused_linear=fused_linear,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity))
recompute_granularity=recompute_granularity,
do_recompute=i not in no_recompute_layers))

self.decoder = TransformerDecoder(
decoder_layers,
num_layers,
norm="LayerNorm",
hidden_size=hidden_size,
use_recompute=use_recompute,
recompute_granularity=recompute_granularity)
recompute_granularity=recompute_granularity,
no_recompute_layers=no_recompute_layers)

def forward(self,
input_ids,
Expand Down
15 changes: 14 additions & 1 deletion ppfleetx/models/language_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ def process_model_configs(config):
if configs['use_recompute']:
if not configs['recompute_granularity']:
configs['recompute_granularity'] = 'full'
if not configs['no_recompute_layers']:
configs['no_recompute_layers'] = []
else:
assert isinstance(configs['no_recompute_layers'], list), "no_recompute_layers should be a list"
for i in configs['no_recompute_layers']:
assert isinstance(i, int), "all values in no_recompute_layers should be an integer"
assert min(configs['no_recompute_layers']) >= 0, \
"the min value in no_recompute_layers should >= 0"
assert max(configs['no_recompute_layers']) < configs['num_layers'], \
"the max value in no_recompute_layers should < num_layers"
configs['no_recompute_layers'] = sorted(list(set(configs['no_recompute_layers'])))

if configs['fused_linear'] and not is_fused_matmul_bias_supported():
configs['fused_linear'] = False
Expand Down Expand Up @@ -110,8 +121,10 @@ def process_optim_configs(config):

nranks = dist.get_world_size()
dp_degree = config['Distributed']['dp_degree']
sharding_degree = config['Distributed']['sharding']['sharding_degree']
if config['Optimizer']['tensor_fusion']:
assert nranks == dp_degree, "tensor_fusion only support single card train or data parallel train"
assert nranks == dp_degree * sharding_degree, \
"tensor_fusion only support single card train or data/sharding parallel train"


def process_data_configs(config):
Expand Down
7 changes: 6 additions & 1 deletion ppfleetx/optims/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import sys
import paddle
import paddle.distributed.fleet as fleet

from ppfleetx.utils.tensor_fusion_helper import fused_parameters
from paddle.optimizer import Adam, AdamW, Momentum
Expand All @@ -30,9 +31,13 @@ class FusedAdamW(paddle.optimizer.AdamW):
def __init__(self, learning_rate, parameters, grad_clip, **config):
tensor_fusion = config.pop("tensor_fusion", False)

if paddle.distributed.get_world_size() > 1:
hcg = fleet.get_hybrid_communicate_group()
sharding_size = hcg.get_sharding_parallel_world_size()

if tensor_fusion:
self.decay_fused_tensors, self.all_fused_tensors = fused_parameters(
parameters)
parameters, sharding_size > 1)
decay_params = [p.name for p in self.decay_fused_tensors]
else:
decay_params = [
Expand Down
Loading