Skip to content

Commit

Permalink
fix: change int of some kernels to int64_t to prevent overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
byshiue committed Mar 14, 2023
1 parent 72d3dce commit bb94e2d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
48 changes: 24 additions & 24 deletions src/fastertransformer/kernels/decoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,19 @@ template void invokeDecodingInitialize(bool* finished,

// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts
template<typename T>
__global__ void embeddingLookupPosEncoding(T* from_tensor,
const T* embedding_table,
const T* position_encoding,
const int* all_ids,
const int* padding_count,
const int* input_lengths,
const int local_token_num,
const int hidden_units,
const int step,
const int max_input_length,
const int token_num,
const int ite,
const T scale)
__global__ void embeddingLookupPosEncoding(T* from_tensor,
const T* embedding_table,
const T* position_encoding,
const int* all_ids,
const int* padding_count,
const int* input_lengths,
const int local_token_num,
const int64_t hidden_units,
const int step,
const int max_input_length,
const int token_num,
const int ite,
const T scale)
{
// 1. lookup from embedding table
// 2. multiply scale
Expand All @@ -120,7 +120,7 @@ __global__ void embeddingLookupPosEncoding(T* from_tensor,
const bool use_padding_count = padding_count != nullptr;
const bool use_input_len = input_lengths != nullptr;

for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;
for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;
index += blockDim.x * gridDim.x) {
const int row_index = index / hidden_units;
const int col_index = index % hidden_units;
Expand Down Expand Up @@ -148,7 +148,7 @@ __global__ void embeddingLookup(T* from_tensor,
const int* all_ids,
pPromptTuningParam<T> prompt_param,
const int local_token_num,
const int hidden_units,
const int64_t hidden_units,
const int step,
const int token_num,
const int ite,
Expand All @@ -159,7 +159,7 @@ __global__ void embeddingLookup(T* from_tensor,
// 2. multiply scale
const int id_offset = step * token_num + ite * local_token_num;

for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;
for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;
index += blockDim.x * gridDim.x) {

const int word_index = index / hidden_units;
Expand Down Expand Up @@ -313,15 +313,15 @@ INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(__nv_bfloat16);
#undef INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT

template<typename T>
__global__ void paddingEmbedding(T* padded_embedding_kernel,
T* padded_embedding_bias,
const T* embedding_kernel,
const T* embedding_bias,
const int hidden_unit,
const int vocab_size,
const int vocab_size_padded)
__global__ void paddingEmbedding(T* padded_embedding_kernel,
T* padded_embedding_bias,
const T* embedding_kernel,
const T* embedding_bias,
const int64_t hidden_unit,
const int64_t vocab_size,
const int64_t vocab_size_padded)
{
for (int id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded;
for (int64_t id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded;
id += blockDim.x * gridDim.x) {
int row_id = id / vocab_size_padded;
int col_id = id % vocab_size_padded;
Expand Down
14 changes: 7 additions & 7 deletions src/fastertransformer/kernels/gpt_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ __global__ void start_id_embedding_position_lookups_kernel(T*
const int length,
const int max_length,
const int batch_size,
const int hidden_units)
const int64_t hidden_units)
{
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * length * hidden_units;
index += blockDim.x * gridDim.x) {
Expand Down Expand Up @@ -250,20 +250,20 @@ __global__ void inputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLo
const int beam_id = tmp_index % param.beam_width;
tmp_index = (tmp_index - beam_id) / param.beam_width;
const int batch_id = tmp_index % param.batch_size;
const int64_t hidden_units = param.hidden_units;
T embedding =
(seq_id < param.prefix_soft_prompt_lengths[batch_id]) ?
(T)param
.prefix_soft_prompt_embedding[batch_id * param.max_prefix_soft_prompt_length * param.hidden_units
+ seq_id * param.hidden_units + hidden_id] :
param.embedding_table[param.input_ids[batch_id * param.beam_width * param.max_input_length
(T)param.prefix_soft_prompt_embedding[batch_id * param.max_prefix_soft_prompt_length * hidden_units
+ seq_id * hidden_units + hidden_id] :
param.embedding_table[param.input_ids[batch_id * param.beam_width * param.max_input_length
+ beam_id * param.max_input_length
+ (seq_id - param.prefix_soft_prompt_lengths[batch_id])]
* param.hidden_units
* hidden_units
+ hidden_id];

T pos_embed = param.pos_table == nullptr ?
(T)0.0f :
param.pos_table[(param.start_step + seq_id - 1) * param.hidden_units + hidden_id];
param.pos_table[(param.start_step + seq_id - 1) * hidden_units + hidden_id];
param.from_tensor[index] = embedding + pos_embed;

if (seq_id == 0 && hidden_id == 0) {
Expand Down

0 comments on commit bb94e2d

Please sign in to comment.