diff --git a/ppfleetx/configs/nlp/gpt/auto/generation_gpt_345M_single_card.yaml b/ppfleetx/configs/nlp/gpt/auto/generation_gpt_345M_single_card.yaml new file mode 100644 index 000000000..a3728aa8b --- /dev/null +++ b/ppfleetx/configs/nlp/gpt/auto/generation_gpt_345M_single_card.yaml @@ -0,0 +1,47 @@ +_base_: ./pretrain_gpt_base.yaml + + +Engine: + mix_precision: + level: "o2" + scale_loss: 32768.0 + custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div", "where"] + custom_white_list: ["lookup_table", "lookup_table_v2"] + use_fp16_guard: False + + +Generation: + top_k: 0 + top_p: 0.9 + use_topp_sampling: True + inference: True + temperature: 1.0 + min_dec_len: 8 + max_dec_len: 8 + num_return_sequences: 1 + decode_strategy: "sampling" + + +Model: + module: GPTGenerationModuleAuto + vocab_size: 50304 + hidden_size: 1024 + num_layers: 24 + num_attention_heads: 16 + ffn_hidden_size: 4096 + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + max_position_embeddings: 1024 + type_vocab_size: 16 + initializer_range: 0.02 + use_recompute: False + fuse_attn_qkv: True + + +Distributed: + dp_degree: 1 + mp_degree: 1 + pp_degree: 1 + sharding: + sharding_degree: 1 + sharding_stage: 1 diff --git a/ppfleetx/models/language_model/gpt/auto/auto_model.py b/ppfleetx/models/language_model/gpt/auto/auto_model.py index f8ce8d0a4..004ec9640 100644 --- a/ppfleetx/models/language_model/gpt/auto/auto_model.py +++ b/ppfleetx/models/language_model/gpt/auto/auto_model.py @@ -709,6 +709,8 @@ def __init__(self, gpt, configs): self.temperature = self.configs.get('temperature', 1.0) self.top_k = self.configs.get('top_k', 0) self.top_p = self.configs.get('top_p', 1.0) + self.use_topp_sampling = self.configs.get('use_topp_sampling', False) + self.inference = self.configs.get('inference', False) self.repetition_penalty = self.configs.get('repetition_penalty', 1.0) self.num_beams = self.configs.get('num_beams', 1) self.num_beam_groups = self.configs.get('num_beam_groups', 1) @@ -852,10 +854,6 @@ def prepare_inputs_for_generation(self, if "int" in paddle.common_ops_import.convert_dtype( attention_mask.dtype): attention_mask = (1.0 - attention_mask) * -1e4 - if cache is not None: - input_ids = input_ids[:, -1].unsqueeze(-1) - if position_ids is not None: - position_ids = position_ids[:, -1].unsqueeze(-1) return { "input_ids": input_ids, "position_ids": position_ids, @@ -864,6 +862,7 @@ def prepare_inputs_for_generation(self, } def update_model_kwargs_for_generation(self, + next_tokens, outputs, model_kwargs, is_encoder_decoder=False): @@ -888,8 +887,7 @@ def update_model_kwargs_for_generation(self, if "position_ids" in model_kwargs and model_kwargs[ "position_ids"] is not None: position_ids = model_kwargs["position_ids"] - model_kwargs["position_ids"] = paddle.concat( - [position_ids, position_ids[:, -1:] + 1], axis=-1) + model_kwargs["position_ids"] = position_ids[:, -1:] + 1 # update attention_mask if not is_encoder_decoder and "attention_mask" in model_kwargs: @@ -926,6 +924,9 @@ def update_model_kwargs_for_generation(self, model_kwargs["role_ids"] = paddle.concat( [role_ids, role_ids[:, -1:]], axis=-1) + model_kwargs['res'] = paddle.concat( + [model_kwargs['res'], next_tokens], axis=1) + return model_kwargs def sample(self, @@ -977,11 +978,20 @@ def TopPProcess(probs, top_p, min_tokens_to_keep): return probs batch_size, cur_len = paddle.shape(input_ids) + # used for compute on gpu, avoid memcpy D2H + cur_len_gpu = paddle.full([1], cur_len, dtype='int64') + origin_len = paddle.shape(input_ids)[1] + # used for compute on gpu, avoid memcpy D2H + origin_len_gpu = paddle.full([1], origin_len, dtype='int64') + unfinished_flag = paddle.full([batch_size, 1], True, dtype='bool') scores = paddle.full( [batch_size, 1], 0.0, dtype=paddle.get_default_dtype()) + res = paddle.assign(input_ids) + model_kwargs['res'] = res + # use_cache is immutable, we split it off other mutable kwargs. assert 'use_cache' in model_kwargs immutable = {'use_cache': model_kwargs['use_cache']} @@ -1021,15 +1031,33 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, # sample origin_probs = F.softmax(logits) - origin_probs = paddle.log(origin_probs) - if temperature is not None and temperature != 1.0: + if temperature is None or temperature == 1.0: + probs = paddle.assign(origin_probs) + origin_probs = paddle.log(origin_probs) + else: + origin_probs = paddle.log(origin_probs) logits = logits / temperature - probs = F.softmax(logits) + probs = F.softmax(logits) if top_k is not None and top_k != 0: probs = TopKProcess(probs, top_k, min_tokens_to_keep) if top_p is not None and top_p < 1.0: - probs = TopPProcess(probs, top_p, min_tokens_to_keep) - next_tokens = paddle.multinomial(probs) + if self.use_topp_sampling: + try: + from ppfleetx_ops import topp_sampling + except ImportError: + raise ImportError( + "please install ppfleetx_ops by 'cd ppfleetx/ops && python setup_cuda.py install'!" + ) + top_ps_tensor = paddle.full( + shape=[paddle.shape(probs)[0]], + fill_value=top_p, + dtype=probs.dtype) + next_tokens = topp_sampling(probs, top_ps_tensor) + else: + probs = TopPProcess(probs, top_p, min_tokens_to_keep) + + if not self.use_topp_sampling: + next_tokens = paddle.multinomial(probs) next_scores = paddle.index_sample(origin_probs, next_tokens) @@ -1041,13 +1069,14 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, scores = self.update_scores_for_generation( scores, next_scores, cur_len - origin_len, unfinished_flag) - input_ids = paddle.concat([input_ids, next_tokens], axis=1) + input_ids = next_tokens if eos_token_id is not None: unfinished_flag = paddle.logical_and( unfinished_flag, next_tokens != eos_token_id) model_kwargs = self.update_model_kwargs_for_generation( + next_tokens, outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder) @@ -1059,9 +1088,14 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, outputs = _forward_(**model_kwargs) input_ids, scores, unfinished_flag, model_kwargs = _post_process_( - outputs, input_ids, cur_len, origin_len, scores, unfinished_flag, - model_kwargs) - cur_len += 1 + outputs, input_ids, cur_len_gpu, origin_len_gpu, scores, + unfinished_flag, model_kwargs) + if not self.inference: + cur_len += 1 + else: + # Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static + paddle.increment(cur_len) + paddle.increment(cur_len_gpu) attn_mask = model_kwargs['attention_mask'] # make the shape of attention_mask = (-1, -1, -1, -1) in dy2static. @@ -1075,14 +1109,19 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, # and change it to pass directly to _post_process_ to avoid # closed-loop problem of dynamic-to-static model input_ids, scores, unfinished_flag, model_kwargs = _post_process_( - _forward_(**model_kwargs), input_ids, cur_len, origin_len, - scores, unfinished_flag, model_kwargs) - cur_len += 1 + _forward_(**model_kwargs), input_ids, cur_len_gpu, + origin_len_gpu, scores, unfinished_flag, model_kwargs) + if not self.inference: + cur_len += 1 + else: + # Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static + paddle.increment(cur_len) + paddle.increment(cur_len_gpu) if not paddle.any(unfinished_flag): break - return input_ids[:, origin_len:], scores + return model_kwargs['res'][:, origin_len:], scores def forward(self, input_ids=None, **model_kwargs): @@ -1136,16 +1175,31 @@ def forward(self, input_ids=None, **model_kwargs): model_kwargs[ "attention_mask"] = self.prepare_attention_mask_for_generation( input_ids, pad_token_id, eos_token_id) + + if model_kwargs.get("position_ids", None) is None: + model_kwargs['position_ids'] = paddle.arange( + 0, + paddle.shape(model_kwargs['attention_mask'])[-1], + dtype=input_ids.dtype).unsqueeze(0) + self.is_encoder_decoder = False model_kwargs["use_cache"] = use_cache - max_length += paddle.shape(input_ids)[-1] - min_length += paddle.shape(input_ids)[-1] + if self.inference: + # Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static + min_len = input_ids.shape[-1] + max_len = input_ids.shape[-1] + paddle.increment(min_len, min_length) + paddle.increment(max_len, max_length) + else: + input_len = input_ids.shape[-1] + max_len = max_length + input_len + min_len = min_length + input_len logits_processors = self.get_logits_processor( - min_length=min_length, - max_length=max_length, + min_length=min_len, + max_length=max_len, eos_token_id=eos_token_id, forced_bos_token_id=forced_bos_token_id, forced_eos_token_id=forced_eos_token_id, @@ -1161,7 +1215,7 @@ def forward(self, input_ids=None, **model_kwargs): expand_size=num_return_sequences, **model_kwargs) - ret = self.sample(input_ids, logits_processors, max_length, + ret = self.sample(input_ids, logits_processors, max_len, pad_token_id, eos_token_id, top_k, top_p, temperature, **model_kwargs) else: diff --git a/ppfleetx/models/language_model/gpt/dygraph/single_model.py b/ppfleetx/models/language_model/gpt/dygraph/single_model.py index 16e2ad810..84ccab544 100644 --- a/ppfleetx/models/language_model/gpt/dygraph/single_model.py +++ b/ppfleetx/models/language_model/gpt/dygraph/single_model.py @@ -799,6 +799,8 @@ def __init__(self, gpt, configs): self.temperature = self.configs.get('temperature', 1.0) self.top_k = self.configs.get('top_k', 0) self.top_p = self.configs.get('top_p', 1.0) + self.use_topp_sampling = self.configs.get('use_topp_sampling', False) + self.inference = self.configs.get('inference', False) self.repetition_penalty = self.configs.get('repetition_penalty', 1.0) self.num_beams = self.configs.get('num_beams', 1) self.num_beam_groups = self.configs.get('num_beam_groups', 1) @@ -942,10 +944,6 @@ def prepare_inputs_for_generation(self, if "int" in paddle.common_ops_import.convert_dtype( attention_mask.dtype): attention_mask = (1.0 - attention_mask) * -1e4 - if cache is not None: - input_ids = input_ids[:, -1].unsqueeze(-1) - if position_ids is not None: - position_ids = position_ids[:, -1].unsqueeze(-1) return { "input_ids": input_ids, "position_ids": position_ids, @@ -954,6 +952,7 @@ def prepare_inputs_for_generation(self, } def update_model_kwargs_for_generation(self, + next_tokens, outputs, model_kwargs, is_encoder_decoder=False): @@ -978,8 +977,7 @@ def update_model_kwargs_for_generation(self, if "position_ids" in model_kwargs and model_kwargs[ "position_ids"] is not None: position_ids = model_kwargs["position_ids"] - model_kwargs["position_ids"] = paddle.concat( - [position_ids, position_ids[:, -1:] + 1], axis=-1) + model_kwargs["position_ids"] = position_ids[:, -1:] + 1 # update attention_mask if not is_encoder_decoder and "attention_mask" in model_kwargs: @@ -1016,6 +1014,9 @@ def update_model_kwargs_for_generation(self, model_kwargs["role_ids"] = paddle.concat( [role_ids, role_ids[:, -1:]], axis=-1) + model_kwargs['res'] = paddle.concat( + [model_kwargs['res'], next_tokens], axis=1) + return model_kwargs def sample(self, @@ -1067,11 +1068,20 @@ def TopPProcess(probs, top_p, min_tokens_to_keep): return probs batch_size, cur_len = input_ids.shape + # used for compute on gpu, avoid memcpy D2H + cur_len_gpu = paddle.full([1], cur_len, dtype='int64') + origin_len = input_ids.shape[1] + # used for compute on gpu, avoid memcpy D2H + origin_len_gpu = paddle.full([1], origin_len, dtype='int64') + unfinished_flag = paddle.full([batch_size, 1], True, dtype='bool') scores = paddle.full( [batch_size, 1], 0.0, dtype=paddle.get_default_dtype()) + res = paddle.assign(input_ids) + model_kwargs['res'] = res + # use_cache is immutable, we split it off other mutable kwargs. assert 'use_cache' in model_kwargs immutable = {'use_cache': model_kwargs['use_cache']} @@ -1100,15 +1110,33 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, # sample origin_probs = F.softmax(logits) - origin_probs = paddle.log(origin_probs) - if temperature is not None and temperature != 1.0: + if temperature is None or temperature == 1.0: + probs = paddle.assign(origin_probs) + origin_probs = paddle.log(origin_probs) + else: + origin_probs = paddle.log(origin_probs) logits = logits / temperature - probs = F.softmax(logits) + probs = F.softmax(logits) if top_k is not None and top_k != 0: probs = TopKProcess(probs, top_k, min_tokens_to_keep) if top_p is not None and top_p < 1.0: - probs = TopPProcess(probs, top_p, min_tokens_to_keep) - next_tokens = paddle.multinomial(probs) + if self.use_topp_sampling: + try: + from ppfleetx_ops import topp_sampling + except ImportError: + raise ImportError( + "please install ppfleetx_ops by 'cd ppfleetx/ops && python setup_cuda.py install'!" + ) + top_ps_tensor = paddle.full( + shape=[paddle.shape(probs)[0]], + fill_value=top_p, + dtype=probs.dtype) + next_tokens = topp_sampling(probs, top_ps_tensor) + else: + probs = TopPProcess(probs, top_p, min_tokens_to_keep) + + if not self.use_topp_sampling: + next_tokens = paddle.multinomial(probs) next_scores = paddle.index_sample(origin_probs, next_tokens) @@ -1120,13 +1148,14 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, scores = self.update_scores_for_generation( scores, next_scores, cur_len - origin_len, unfinished_flag) - input_ids = paddle.concat([input_ids, next_tokens], axis=1) + input_ids = next_tokens if eos_token_id is not None: unfinished_flag = paddle.logical_and( unfinished_flag, next_tokens != eos_token_id) model_kwargs = self.update_model_kwargs_for_generation( + next_tokens, outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder) @@ -1138,9 +1167,14 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, outputs = _forward_(**model_kwargs) input_ids, scores, unfinished_flag, model_kwargs = _post_process_( - outputs, input_ids, cur_len, origin_len, scores, unfinished_flag, - model_kwargs) - cur_len += 1 + outputs, input_ids, cur_len_gpu, origin_len_gpu, scores, + unfinished_flag, model_kwargs) + if not self.inference: + cur_len += 1 + else: + # Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static + paddle.increment(cur_len) + paddle.increment(cur_len_gpu) attn_mask = model_kwargs['attention_mask'] # make the shape of attention_mask = (-1, -1, -1, -1) in dy2static. @@ -1153,14 +1187,19 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, # and change it to pass directly to _post_process_ to avoid # closed-loop problem of dynamic-to-static model input_ids, scores, unfinished_flag, model_kwargs = _post_process_( - _forward_(**model_kwargs), input_ids, cur_len, origin_len, - scores, unfinished_flag, model_kwargs) - cur_len += 1 + _forward_(**model_kwargs), input_ids, cur_len_gpu, + origin_len_gpu, scores, unfinished_flag, model_kwargs) + if not self.inference: + cur_len += 1 + else: + # Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static + paddle.increment(cur_len) + paddle.increment(cur_len_gpu) if not paddle.any(unfinished_flag): break - return input_ids[:, origin_len:], scores + return model_kwargs['res'][:, origin_len:], scores def forward(self, input_ids=None, **model_kwargs): @@ -1214,16 +1253,31 @@ def forward(self, input_ids=None, **model_kwargs): model_kwargs[ "attention_mask"] = self.prepare_attention_mask_for_generation( input_ids, pad_token_id, eos_token_id) + + if model_kwargs.get("position_ids", None) is None: + model_kwargs['position_ids'] = paddle.arange( + 0, + paddle.shape(model_kwargs['attention_mask'])[-1], + dtype=input_ids.dtype).unsqueeze(0) + self.is_encoder_decoder = False model_kwargs["use_cache"] = use_cache - max_length += input_ids.shape[-1] - min_length += input_ids.shape[-1] + if self.inference: + # Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static + min_len = input_ids.shape[-1] + max_len = input_ids.shape[-1] + paddle.increment(min_len, min_length) + paddle.increment(max_len, max_length) + else: + input_len = input_ids.shape[-1] + max_len = max_length + input_len + min_len = min_length + input_len logits_processors = self.get_logits_processor( - min_length=min_length, - max_length=max_length, + min_length=min_len, + max_length=max_len, eos_token_id=eos_token_id, forced_bos_token_id=forced_bos_token_id, forced_eos_token_id=forced_eos_token_id, @@ -1239,7 +1293,7 @@ def forward(self, input_ids=None, **model_kwargs): expand_size=num_return_sequences, **model_kwargs) - ret = self.sample(input_ids, logits_processors, max_length, + ret = self.sample(input_ids, logits_processors, max_len, pad_token_id, eos_token_id, top_k, top_p, temperature, **model_kwargs) else: diff --git a/ppfleetx/ops/setup_cuda.py b/ppfleetx/ops/setup_cuda.py new file mode 100644 index 000000000..9f1da47b6 --- /dev/null +++ b/ppfleetx/ops/setup_cuda.py @@ -0,0 +1,19 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.utils.cpp_extension import CUDAExtension, setup + +setup( + name='ppfleetx_ops', + ext_modules=CUDAExtension(sources=['topp_sampling.cu'])) diff --git a/ppfleetx/ops/test_topp_sampling.py b/ppfleetx/ops/test_topp_sampling.py new file mode 100644 index 000000000..80525fba2 --- /dev/null +++ b/ppfleetx/ops/test_topp_sampling.py @@ -0,0 +1,25 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +from ppfleetx.ops import topp_sampling + +paddle.seed(2022) + +x = paddle.randn([1, 51200], dtype="float16") +x = paddle.nn.functional.softmax(x) +top_ps = paddle.to_tensor(np.random.uniform(0, 1, [1]).astype(np.float16)) +out = topp_sampling(x, top_ps) +print(out) diff --git a/ppfleetx/ops/topp_sampling.cu b/ppfleetx/ops/topp_sampling.cu new file mode 100644 index 000000000..3dd1a8638 --- /dev/null +++ b/ppfleetx/ops/topp_sampling.cu @@ -0,0 +1,584 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cub/cub.cuh" +#include +#include + +#include "paddle/extension.h" + +#define CHECK_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") + +#define FINAL_MASK 0xFFFFFFFF + +#define FIXED_BLOCK_DIM_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kBlockDim = (dim); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) + +template class PDTraits; + +template <> class PDTraits { +public: + typedef float DataType; + typedef float data_t; +}; + +template <> class PDTraits { +public: + typedef half DataType; + typedef paddle::float16 data_t; +}; + +struct SegmentOffsetIter { + explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} + + __host__ __device__ __forceinline__ int operator()(int idx) const { + return idx * num_cols_; + } + + int num_cols_; +}; + +template struct Pair { + __device__ __forceinline__ Pair() {} + __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {} + + __device__ __forceinline__ void set(T value, int id) { + v = value; + id = id; + } + + __device__ __forceinline__ void operator=(const Pair &in) { + v = in.v; + id = in.id; + } + + __device__ __forceinline__ bool operator<(const T value) const { + return (static_cast(v) < static_cast(value)); + } + + __device__ __forceinline__ bool operator>(const T value) const { + return (static_cast(v) > static_cast(value)); + } + __device__ __forceinline__ bool operator<(const Pair &in) const { + return (static_cast(v) < static_cast(in.v)) || + ((static_cast(v) == static_cast(in.v)) && + (id > in.id)); + } + + __device__ __forceinline__ bool operator>(const Pair &in) const { + return (static_cast(v) > static_cast(in.v)) || + ((static_cast(v) == static_cast(in.v)) && + (id < in.id)); + } + + T v; + int id; +}; + +inline int div_up(int a, int n) { return (a + n - 1) / n; } + +__global__ void setup_kernel(curandState_t *state, const uint64_t seed, + const int bs) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { + curand_init(seed, 0, 0, &state[i]); + } +} + +template +__device__ __forceinline__ void AddTo(Pair topk[], const Pair &p, + int beam_size) { + for (int k = beam_size - 2; k >= 0; k--) { + if (topk[k] < p) { + topk[k + 1] = topk[k]; + } else { + topk[k + 1] = p; + return; + } + } + topk[0] = p; +} + +template +__device__ __forceinline__ void GetTopK(Pair topk[], const T *src, int idx, + int dim, int beam_size) { + while (idx < dim) { + if (topk[beam_size - 1] < src[idx]) { + Pair tmp(src[idx], idx); + AddTo(topk, tmp, beam_size); + } + idx += BlockSize; + } +} + +template +__device__ __forceinline__ void GetTopK(Pair topk[], const T *src, int idx, + int dim, const Pair &max, + int beam_size) { + while (idx < dim) { + if (topk[beam_size - 1] < src[idx]) { + Pair tmp(src[idx], idx); + if (tmp < max) { + AddTo(topk, tmp, beam_size); + } + } + idx += BlockSize; + } +} + +template +__device__ __forceinline__ void +ThreadGetTopK(Pair topk[], int *beam, int beam_size, const T *src, + bool *firstStep, bool *is_empty, Pair *max, int dim, + const int tid) { + if (*beam > 0) { + int length = (*beam) < beam_size ? *beam : beam_size; + if (*firstStep) { + *firstStep = false; + GetTopK(topk, src, tid, dim, length); + } else { + for (int k = 0; k < MaxLength; k++) { + if (k < MaxLength - (*beam)) { + topk[k] = topk[k + *beam]; + } else { + topk[k].set(std::numeric_limits::min(), -1); + } + } + if (!(*is_empty)) { + GetTopK(topk + MaxLength - *beam, src, tid, dim, *max, + length); + } + } + + *max = topk[MaxLength - 1]; + if ((*max).id == -1) + *is_empty = true; + *beam = 0; + } +} + +template +__forceinline__ __device__ Pair WarpReduce(Pair input) { +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + T tmp_val = __shfl_down_sync(FINAL_MASK, input.v, + static_cast(offset), 32); + int tmp_id = __shfl_down_sync(FINAL_MASK, input.id, + static_cast(offset), 32); + if (static_cast(input.v) < static_cast(tmp_val)) { + input.v = tmp_val; + input.id = tmp_id; + } + } + return input; +} + +template +__device__ __forceinline__ void +BlockReduce(Pair shared_max[], Pair topk[], Pair beam_max[], int *beam, + int *k, int *count, const int tid, const int wid, const int lane) { + while (true) { + __syncthreads(); + Pair input_now = topk[0]; + input_now = WarpReduce(input_now); + + if (lane == 0) { + shared_max[wid] = input_now; + } + __syncthreads(); + input_now = (tid < BlockSize / 32) + ? shared_max[lane] + : Pair(std::numeric_limits::min(), -1); + if (wid == 0) { + input_now = WarpReduce(input_now); + if (lane == 0) + shared_max[0] = input_now; + } + __syncthreads(); + if (tid == 0) { + beam_max[*count] = shared_max[0]; + (*count)++; + } + int tid_max = shared_max[0].id % BlockSize; + if (tid == tid_max) { + (*beam)++; + } + if (--(*k) == 0) + break; + __syncthreads(); + + if (tid == tid_max) { + if (*beam < MaxLength) { + topk[0] = topk[*beam]; + } + } + + if (MaxLength < 5) { + if (*beam >= MaxLength) + break; + } else { + unsigned mask = 0u; + mask = __ballot_sync(FINAL_MASK, true); + if (tid_max / 32 == wid) { + if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) == MaxLength) + break; + } + } + } +} + +template +__global__ void KeMatrixTopPBeamTopK(const T *src, T *top_ps, int64_t *out_id, + T *out_val, int vocab_size, + curandState_t *state, int *count_iter, + int *count_iter_begin) { + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane = tid % 32; + const int bid = blockIdx.x; + + int top_num = TopPBeamTopK; + float top_p_num = static_cast(top_ps[bid]); + + __shared__ Pair shared_max[BlockSize / 32]; + __shared__ Pair beam_max[TopPBeamTopK]; + + Pair topk[MaxLength]; + int beam = MaxLength; + Pair max; + bool is_empty = false; + bool firststep = true; + __shared__ int count; + + if (tid == 0) { + count = 0; + } + + for (int j = 0; j < MaxLength; j++) { + topk[j].set(std::numeric_limits::min(), -1); + } + + while (top_num) { + ThreadGetTopK(topk, &beam, TopPBeamTopK, + src + bid * vocab_size, &firststep, + &is_empty, &max, vocab_size, tid); + BlockReduce(shared_max, topk, beam_max, &beam, + &top_num, &count, tid, wid, lane); + } + if (tid == 0) { + count_iter_begin[bid] = count_iter[bid]; + float rand_top_p = curand_uniform(state + bid) * top_p_num; + top_ps[bid] = (T)rand_top_p; + float sum_prob = 0.0f; +#pragma unroll + for (int i = 0; i < TopPBeamTopK; i++) { + sum_prob += static_cast(beam_max[i].v); + if (sum_prob >= rand_top_p) { + count_iter_begin[bid] += 1; + out_id[bid] = static_cast(beam_max[i].id); + out_val[bid] = beam_max[i].v; + break; + } + } + } +} + +__global__ void SetCountIter(int *count_iter, int num) { + int tid = threadIdx.x; + int bid = blockIdx.x; + int idx = bid * blockDim.x + tid; + for (int i = idx; i < num; i += gridDim.x * blockDim.x) { + count_iter[i] = i; + } +} + +template +__global__ void FillIndex(T *indices, T num_rows, T num_cols) { + int col_id = threadIdx.x; + int row_id = blockIdx.x; + + for (T j = row_id; j < num_rows; j += gridDim.x) { + for (T i = col_id; i < num_cols; i += blockDim.x) { + indices[j * num_cols + i] = i; + } + } +} + +struct BlockPrefixCallbackOp { + // Running prefix + float running_total; + // Constructor + __device__ BlockPrefixCallbackOp(float running_total) + : running_total(running_total) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide + // scan. + __device__ float operator()(float block_aggregate) { + float old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + +template +__global__ void topp_sampling(T *sorted_probs, int64_t *sorted_id, T *out_val, + int64_t *out_id, const T *top_ps, int p_num, + int vocab_size, int *count_iter, + int *count_iter_begin) { + __shared__ int stop_shared; + __shared__ float rand_p; + const int tid = threadIdx.x; + const int bid = blockIdx.x; + constexpr int WARP_SIZE = 32; + constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE; + const int lane_id = tid % WARP_SIZE; + const int warp_id = tid / WARP_SIZE; + const float p_t = static_cast(top_ps[bid]); + if (tid == 0) { + stop_shared = 0; + rand_p = p_t; + } + if (count_iter_begin[bid] == count_iter[bid + 1]) { + // topk + return; + } + + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + __shared__ uint32_t selected_shared[NUM_WARPS]; + + // Initialize running total + BlockPrefixCallbackOp prefix_op(0); + + if (lane_id == 0) { + selected_shared[warp_id] = 0; + } + __syncthreads(); + + int offset = bid * vocab_size; + int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int i_activate = 0; + float thread_offset = 0; + for (int i = tid; i < end; i += BLOCK_SIZE) { + float thread_count = + (i < vocab_size) ? static_cast(sorted_probs[offset + i]) : 0.f; + BlockScan(temp_storage) + .InclusiveSum(thread_count, thread_offset, prefix_op); + + uint32_t activate_mask = __ballot_sync(FINAL_MASK, rand_p <= thread_offset); + + i_activate = i; + if (activate_mask != 0) { + if (lane_id == 0) { + atomicAdd(&stop_shared, 1); + selected_shared[warp_id] = activate_mask; + } + } + __syncthreads(); + if (stop_shared > 0) { + break; + } + } + + bool skip = (selected_shared[warp_id] > 0) ? false : true; + for (int i = 0; i < warp_id; i++) { + if (selected_shared[i] != 0) { + skip = true; + } + } + if (!skip) { + int active_lane_id = WARP_SIZE - __popc(selected_shared[warp_id]); + if (lane_id == active_lane_id) { + out_id[bid] = sorted_id[offset + i_activate]; + out_val[bid] = sorted_probs[offset + i_activate]; + } + } +} + +int GetBlockSize(int vocab_size) { + if (vocab_size > 512) { + return 1024; + } else if (vocab_size > 256) { + return 512; + } else if (vocab_size > 128) { + return 256; + } else if (vocab_size > 64) { + return 128; + } else { + return 64; + } +} + +template +std::vector +top_p_sampling_kernel(const paddle::Tensor &x, const paddle::Tensor &top_ps) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + std::vector shape = x.shape(); + auto cu_stream = x.stream(); + + int bs = shape[0]; + int p_num = top_ps.numel(); + PD_CHECK(bs == p_num, "PD_CHECK returns ", false, ", expected bs == p_num."); + int vocab_size = shape[1]; + auto topp_ids = paddle::full({bs, 1}, 1, paddle::DataType::INT64, x.place()); + auto topp_probs = paddle::full({bs, 1}, 1, x.dtype(), x.place()); + auto inds_input = + paddle::full({bs, vocab_size}, 1, paddle::DataType::INT64, x.place()); + auto sorted_out = paddle::full({bs, vocab_size}, 1, x.dtype(), x.place()); + auto sorted_id = + paddle::full({bs, vocab_size}, 1, paddle::DataType::INT64, x.place()); + + int BlockSize = GetBlockSize(vocab_size); + switch (BlockSize) { + FIXED_BLOCK_DIM(FillIndex<<>>( + inds_input.data(), bs, vocab_size)); + default: + PD_THROW("the input data shape has error in the FillIndex kernel."); + } + + static int count = 0; + constexpr int max_bs = 128; + static curandState_t *dev_curand_states; + if (count == 0) { +#if CUDA_VERSION >= 11020 + cudaMallocAsync(&dev_curand_states, max_bs * sizeof(curandState_t), + cu_stream); +#else + cudaMalloc(&dev_curand_states, max_bs * sizeof(curandState_t)); +#endif + setup_kernel<<<1, 256, 0, cu_stream>>>(dev_curand_states, 2022, max_bs); + } + count = 1; + PD_CHECK(bs == p_num, "PD_CHECK returns ", false, ", expected bs == p_num."); + + auto count_iter = paddle::empty({bs + 1}, paddle::DataType::INT32, x.place()); + auto count_iter_begin = + paddle::empty({bs}, paddle::DataType::INT32, x.place()); + SetCountIter<<<1, 256, 0, cu_stream>>>(count_iter.data(), bs + 1); + + constexpr int TopKMaxLength = 1; + constexpr int TopPBeamTopK = 1; + switch (BlockSize) { + FIXED_BLOCK_DIM( + KeMatrixTopPBeamTopK + <<>>( + reinterpret_cast( + const_cast(x.data())), + reinterpret_cast( + const_cast(top_ps.data())), + topp_ids.data(), + reinterpret_cast(topp_probs.data()), + vocab_size, dev_curand_states, count_iter.data(), + count_iter_begin.data())); + default: + PD_THROW("the input data shape has error in the topp_beam_topk kernel."); + } + + size_t temp_storage_bytes = 0; + + cub::TransformInputIterator + segment_offsets_t_begin(count_iter_begin.data(), + SegmentOffsetIter(vocab_size)); + + cub::TransformInputIterator + segment_offsets_t_end(count_iter.data(), + SegmentOffsetIter(vocab_size)); + + DataType_ *x_ptr = + reinterpret_cast(const_cast(x.data())); + DataType_ *sorted_out_ptr = reinterpret_cast( + const_cast(sorted_out.data())); + int64_t *in_id_ptr = inds_input.data(); + int64_t *out_id_ptr = sorted_id.data(); + + cub::DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, x_ptr, sorted_out_ptr, in_id_ptr, out_id_ptr, + vocab_size * bs, bs, segment_offsets_t_begin, segment_offsets_t_end + 1, + 0, sizeof(data_t) * 8, cu_stream); + + temp_storage_bytes = div_up(temp_storage_bytes, 256) * 256; + int64_t temp_size = temp_storage_bytes; + auto temp_storage = + paddle::empty({temp_size}, paddle::DataType::UINT8, x.place()); + + cub::DeviceSegmentedRadixSort::SortPairsDescending( + temp_storage.data(), temp_storage_bytes, x_ptr, sorted_out_ptr, + in_id_ptr, out_id_ptr, vocab_size * bs, bs, segment_offsets_t_begin, + segment_offsets_t_end + 1, 0, sizeof(data_t) * 8, cu_stream); + + switch (BlockSize) { + FIXED_BLOCK_DIM( + topp_sampling<<>>( + sorted_out_ptr, out_id_ptr, + reinterpret_cast(topp_probs.data()), + topp_ids.data(), + reinterpret_cast( + const_cast(top_ps.data())), + p_num, vocab_size, count_iter.data(), + count_iter_begin.data())); + default: + PD_THROW("the input data shape has error in the topp_sampling kernel."); + } + return {topp_ids}; +} + +std::vector TopPSampling(const paddle::Tensor &x, + const paddle::Tensor &top_ps) { + switch (x.type()) { + case paddle::DataType::FLOAT16: { + return top_p_sampling_kernel(x, top_ps); + } + case paddle::DataType::FLOAT32: { + return top_p_sampling_kernel(x, top_ps); + } + default: { + PD_THROW("NOT supported data type. " + "Only float16 and float32 are supported. "); + break; + } + } +} + +std::vector> +TopPSamplingInferShape(const std::vector &x_shape, + const std::vector &top_ps_shape) { + std::vector out_ids_shape = {x_shape[0], 1}; + return {out_ids_shape}; +} + +std::vector +TopPSamplingInferDtype(const paddle::DataType &x_dtype, + const paddle::DataType &top_ps_dtype) { + return {paddle::DataType::INT64}; +} + +PD_BUILD_OP(topp_sampling) + .Inputs({"x", "top_ps"}) + .Outputs({"topp_ids"}) + .SetKernelFn(PD_KERNEL(TopPSampling)) + .SetInferShapeFn(PD_INFER_SHAPE(TopPSamplingInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingInferDtype)); diff --git a/projects/gpt/auto_export_gpt_345M_single_card.sh b/projects/gpt/auto_export_gpt_345M_single_card.sh new file mode 100644 index 000000000..6e90c2652 --- /dev/null +++ b/projects/gpt/auto_export_gpt_345M_single_card.sh @@ -0,0 +1,22 @@ +#! /bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +log_dir=log_345m_mp1 +rm -rf $log_dir + +python -m paddle.distributed.launch --log_dir $log_dir --devices "1" \ + ./tools/auto_export.py \ + -c ./ppfleetx/configs/nlp/gpt/auto/generation_gpt_345M_single_card.yaml \