diff --git a/src/fastertransformer/kernels/decoding_kernels.cu b/src/fastertransformer/kernels/decoding_kernels.cu index ff28bae8b..89f0d5011 100644 --- a/src/fastertransformer/kernels/decoding_kernels.cu +++ b/src/fastertransformer/kernels/decoding_kernels.cu @@ -98,19 +98,19 @@ template void invokeDecodingInitialize(bool* finished, // PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts template -__global__ void embeddingLookupPosEncoding(T* from_tensor, - const T* embedding_table, - const T* position_encoding, - const int* all_ids, - const int* padding_count, - const int* input_lengths, - const int local_token_num, - const int hidden_units, - const int step, - const int max_input_length, - const int token_num, - const int ite, - const T scale) +__global__ void embeddingLookupPosEncoding(T* from_tensor, + const T* embedding_table, + const T* position_encoding, + const int* all_ids, + const int* padding_count, + const int* input_lengths, + const int local_token_num, + const int64_t hidden_units, + const int step, + const int max_input_length, + const int token_num, + const int ite, + const T scale) { // 1. lookup from embedding table // 2. multiply scale @@ -120,7 +120,7 @@ __global__ void embeddingLookupPosEncoding(T* from_tensor, const bool use_padding_count = padding_count != nullptr; const bool use_input_len = input_lengths != nullptr; - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units; + for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units; index += blockDim.x * gridDim.x) { const int row_index = index / hidden_units; const int col_index = index % hidden_units; @@ -148,7 +148,7 @@ __global__ void embeddingLookup(T* from_tensor, const int* all_ids, pPromptTuningParam prompt_param, const int local_token_num, - const int hidden_units, + const int64_t hidden_units, const int step, const int token_num, const int ite, @@ -159,7 +159,7 @@ __global__ void embeddingLookup(T* from_tensor, // 2. multiply scale const int id_offset = step * token_num + ite * local_token_num; - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units; + for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units; index += blockDim.x * gridDim.x) { const int word_index = index / hidden_units; @@ -313,15 +313,15 @@ INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(__nv_bfloat16); #undef INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT template -__global__ void paddingEmbedding(T* padded_embedding_kernel, - T* padded_embedding_bias, - const T* embedding_kernel, - const T* embedding_bias, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded) +__global__ void paddingEmbedding(T* padded_embedding_kernel, + T* padded_embedding_bias, + const T* embedding_kernel, + const T* embedding_bias, + const int64_t hidden_unit, + const int64_t vocab_size, + const int64_t vocab_size_padded) { - for (int id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded; + for (int64_t id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded; id += blockDim.x * gridDim.x) { int row_id = id / vocab_size_padded; int col_id = id % vocab_size_padded; diff --git a/src/fastertransformer/kernels/gpt_kernels.cu b/src/fastertransformer/kernels/gpt_kernels.cu index abb3b5db4..9402b57fa 100644 --- a/src/fastertransformer/kernels/gpt_kernels.cu +++ b/src/fastertransformer/kernels/gpt_kernels.cu @@ -39,7 +39,7 @@ __global__ void start_id_embedding_position_lookups_kernel(T* const int length, const int max_length, const int batch_size, - const int hidden_units) + const int64_t hidden_units) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * length * hidden_units; index += blockDim.x * gridDim.x) { @@ -250,20 +250,20 @@ __global__ void inputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLo const int beam_id = tmp_index % param.beam_width; tmp_index = (tmp_index - beam_id) / param.beam_width; const int batch_id = tmp_index % param.batch_size; + const int64_t hidden_units = param.hidden_units; T embedding = (seq_id < param.prefix_soft_prompt_lengths[batch_id]) ? - (T)param - .prefix_soft_prompt_embedding[batch_id * param.max_prefix_soft_prompt_length * param.hidden_units - + seq_id * param.hidden_units + hidden_id] : - param.embedding_table[param.input_ids[batch_id * param.beam_width * param.max_input_length + (T)param.prefix_soft_prompt_embedding[batch_id * param.max_prefix_soft_prompt_length * hidden_units + + seq_id * hidden_units + hidden_id] : + param.embedding_table[param.input_ids[batch_id * param.beam_width * param.max_input_length + beam_id * param.max_input_length + (seq_id - param.prefix_soft_prompt_lengths[batch_id])] - * param.hidden_units + * hidden_units + hidden_id]; T pos_embed = param.pos_table == nullptr ? (T)0.0f : - param.pos_table[(param.start_step + seq_id - 1) * param.hidden_units + hidden_id]; + param.pos_table[(param.start_step + seq_id - 1) * hidden_units + hidden_id]; param.from_tensor[index] = embedding + pos_embed; if (seq_id == 0 && hidden_id == 0) {