Skip to content

Commit

Permalink
cherry-pick from pr906, fix conflicts. (#946)
Browse files Browse the repository at this point in the history
Co-authored-by: Liujie0926 <44688141+Liujie0926@users.noreply.github.com>
  • Loading branch information
carryyu and Liujie0926 authored Nov 28, 2022
1 parent 6ffc799 commit c29cc2b
Show file tree
Hide file tree
Showing 7 changed files with 853 additions and 48 deletions.
47 changes: 47 additions & 0 deletions ppfleetx/configs/nlp/gpt/auto/generation_gpt_345M_single_card.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
_base_: ./pretrain_gpt_base.yaml


Engine:
mix_precision:
level: "o2"
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div", "where"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
use_fp16_guard: False


Generation:
top_k: 0
top_p: 0.9
use_topp_sampling: True
inference: True
temperature: 1.0
min_dec_len: 8
max_dec_len: 8
num_return_sequences: 1
decode_strategy: "sampling"


Model:
module: GPTGenerationModuleAuto
vocab_size: 50304
hidden_size: 1024
num_layers: 24
num_attention_heads: 16
ffn_hidden_size: 4096
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 1024
type_vocab_size: 16
initializer_range: 0.02
use_recompute: False
fuse_attn_qkv: True


Distributed:
dp_degree: 1
mp_degree: 1
pp_degree: 1
sharding:
sharding_degree: 1
sharding_stage: 1
102 changes: 78 additions & 24 deletions ppfleetx/models/language_model/gpt/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,8 @@ def __init__(self, gpt, configs):
self.temperature = self.configs.get('temperature', 1.0)
self.top_k = self.configs.get('top_k', 0)
self.top_p = self.configs.get('top_p', 1.0)
self.use_topp_sampling = self.configs.get('use_topp_sampling', False)
self.inference = self.configs.get('inference', False)
self.repetition_penalty = self.configs.get('repetition_penalty', 1.0)
self.num_beams = self.configs.get('num_beams', 1)
self.num_beam_groups = self.configs.get('num_beam_groups', 1)
Expand Down Expand Up @@ -852,10 +854,6 @@ def prepare_inputs_for_generation(self,
if "int" in paddle.common_ops_import.convert_dtype(
attention_mask.dtype):
attention_mask = (1.0 - attention_mask) * -1e4
if cache is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
if position_ids is not None:
position_ids = position_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"position_ids": position_ids,
Expand All @@ -864,6 +862,7 @@ def prepare_inputs_for_generation(self,
}

def update_model_kwargs_for_generation(self,
next_tokens,
outputs,
model_kwargs,
is_encoder_decoder=False):
Expand All @@ -888,8 +887,7 @@ def update_model_kwargs_for_generation(self,
if "position_ids" in model_kwargs and model_kwargs[
"position_ids"] is not None:
position_ids = model_kwargs["position_ids"]
model_kwargs["position_ids"] = paddle.concat(
[position_ids, position_ids[:, -1:] + 1], axis=-1)
model_kwargs["position_ids"] = position_ids[:, -1:] + 1

# update attention_mask
if not is_encoder_decoder and "attention_mask" in model_kwargs:
Expand Down Expand Up @@ -926,6 +924,9 @@ def update_model_kwargs_for_generation(self,
model_kwargs["role_ids"] = paddle.concat(
[role_ids, role_ids[:, -1:]], axis=-1)

model_kwargs['res'] = paddle.concat(
[model_kwargs['res'], next_tokens], axis=1)

return model_kwargs

def sample(self,
Expand Down Expand Up @@ -977,11 +978,20 @@ def TopPProcess(probs, top_p, min_tokens_to_keep):
return probs

batch_size, cur_len = paddle.shape(input_ids)
# used for compute on gpu, avoid memcpy D2H
cur_len_gpu = paddle.full([1], cur_len, dtype='int64')

origin_len = paddle.shape(input_ids)[1]
# used for compute on gpu, avoid memcpy D2H
origin_len_gpu = paddle.full([1], origin_len, dtype='int64')

unfinished_flag = paddle.full([batch_size, 1], True, dtype='bool')
scores = paddle.full(
[batch_size, 1], 0.0, dtype=paddle.get_default_dtype())

res = paddle.assign(input_ids)
model_kwargs['res'] = res

# use_cache is immutable, we split it off other mutable kwargs.
assert 'use_cache' in model_kwargs
immutable = {'use_cache': model_kwargs['use_cache']}
Expand Down Expand Up @@ -1021,15 +1031,33 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores,

# sample
origin_probs = F.softmax(logits)
origin_probs = paddle.log(origin_probs)
if temperature is not None and temperature != 1.0:
if temperature is None or temperature == 1.0:
probs = paddle.assign(origin_probs)
origin_probs = paddle.log(origin_probs)
else:
origin_probs = paddle.log(origin_probs)
logits = logits / temperature
probs = F.softmax(logits)
probs = F.softmax(logits)
if top_k is not None and top_k != 0:
probs = TopKProcess(probs, top_k, min_tokens_to_keep)
if top_p is not None and top_p < 1.0:
probs = TopPProcess(probs, top_p, min_tokens_to_keep)
next_tokens = paddle.multinomial(probs)
if self.use_topp_sampling:
try:
from ppfleetx_ops import topp_sampling
except ImportError:
raise ImportError(
"please install ppfleetx_ops by 'cd ppfleetx/ops && python setup_cuda.py install'!"
)
top_ps_tensor = paddle.full(
shape=[paddle.shape(probs)[0]],
fill_value=top_p,
dtype=probs.dtype)
next_tokens = topp_sampling(probs, top_ps_tensor)
else:
probs = TopPProcess(probs, top_p, min_tokens_to_keep)

if not self.use_topp_sampling:
next_tokens = paddle.multinomial(probs)

next_scores = paddle.index_sample(origin_probs, next_tokens)

Expand All @@ -1041,13 +1069,14 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores,
scores = self.update_scores_for_generation(
scores, next_scores, cur_len - origin_len, unfinished_flag)

input_ids = paddle.concat([input_ids, next_tokens], axis=1)
input_ids = next_tokens

if eos_token_id is not None:
unfinished_flag = paddle.logical_and(
unfinished_flag, next_tokens != eos_token_id)

model_kwargs = self.update_model_kwargs_for_generation(
next_tokens,
outputs,
model_kwargs,
is_encoder_decoder=self.is_encoder_decoder)
Expand All @@ -1059,9 +1088,14 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores,
outputs = _forward_(**model_kwargs)

input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
outputs, input_ids, cur_len, origin_len, scores, unfinished_flag,
model_kwargs)
cur_len += 1
outputs, input_ids, cur_len_gpu, origin_len_gpu, scores,
unfinished_flag, model_kwargs)
if not self.inference:
cur_len += 1
else:
# Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static
paddle.increment(cur_len)
paddle.increment(cur_len_gpu)

attn_mask = model_kwargs['attention_mask']
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
Expand All @@ -1075,14 +1109,19 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores,
# and change it to pass directly to _post_process_ to avoid
# closed-loop problem of dynamic-to-static model
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
_forward_(**model_kwargs), input_ids, cur_len, origin_len,
scores, unfinished_flag, model_kwargs)
cur_len += 1
_forward_(**model_kwargs), input_ids, cur_len_gpu,
origin_len_gpu, scores, unfinished_flag, model_kwargs)
if not self.inference:
cur_len += 1
else:
# Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static
paddle.increment(cur_len)
paddle.increment(cur_len_gpu)

if not paddle.any(unfinished_flag):
break

return input_ids[:, origin_len:], scores
return model_kwargs['res'][:, origin_len:], scores

def forward(self, input_ids=None, **model_kwargs):

Expand Down Expand Up @@ -1136,16 +1175,31 @@ def forward(self, input_ids=None, **model_kwargs):
model_kwargs[
"attention_mask"] = self.prepare_attention_mask_for_generation(
input_ids, pad_token_id, eos_token_id)

if model_kwargs.get("position_ids", None) is None:
model_kwargs['position_ids'] = paddle.arange(
0,
paddle.shape(model_kwargs['attention_mask'])[-1],
dtype=input_ids.dtype).unsqueeze(0)

self.is_encoder_decoder = False

model_kwargs["use_cache"] = use_cache

max_length += paddle.shape(input_ids)[-1]
min_length += paddle.shape(input_ids)[-1]
if self.inference:
# Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static
min_len = input_ids.shape[-1]
max_len = input_ids.shape[-1]
paddle.increment(min_len, min_length)
paddle.increment(max_len, max_length)
else:
input_len = input_ids.shape[-1]
max_len = max_length + input_len
min_len = min_length + input_len

logits_processors = self.get_logits_processor(
min_length=min_length,
max_length=max_length,
min_length=min_len,
max_length=max_len,
eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
Expand All @@ -1161,7 +1215,7 @@ def forward(self, input_ids=None, **model_kwargs):
expand_size=num_return_sequences,
**model_kwargs)

ret = self.sample(input_ids, logits_processors, max_length,
ret = self.sample(input_ids, logits_processors, max_len,
pad_token_id, eos_token_id, top_k, top_p,
temperature, **model_kwargs)
else:
Expand Down
Loading

0 comments on commit c29cc2b

Please sign in to comment.