Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ft llama 06 48 #2

Merged
merged 19 commits into from
Oct 1, 2023
342 changes: 321 additions & 21 deletions src/fastertransformer/kernels/llama_kernels.cu

Large diffs are not rendered by default.

54 changes: 50 additions & 4 deletions src/fastertransformer/kernels/llama_kernels.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,61 @@
#pragma once


#include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include "src/fastertransformer/utils/memory_utils.h"
namespace fastertransformer {

void invokeLLaMAGetPaddingOffsetAndCuSeqLens(int* padding_offset,
int* cu_seqlens,
const int* input_lengths,
const int batch_size,
const int seq_len,
cudaStream_t stream);

template<typename T>
void invokeLLaMABuildDecoderAttentionMask(T* attention_mask,
const int* sequence_lengths,
const int* sequence_length,
const int* context_lengths,
const int batch_size,
const int seq_len,
const int start_pos,
const int max_length,
cudaStream_t stream);
} // namespace fastertransformer
template<typename T>
void invokeLLaMAInputIdsEmbeddingLookup(T* from_tensor,
const T* embedding_table,
const int* input_ids,
const int num_tokens,
const int hidden_units,
cudaStream_t stream);

template<typename T>
void invokeLLaMACopyKernel(T* dst, T* src, const int count, cudaStream_t stream);
template<typename T>
void invokeLLaMAMemset0(T* dst, const int count, cudaStream_t stream);

void invokeLLaMAGatherTokens(float* out,
const float* probs,
const int* input_lengths,
const int* target_ids,
const int* cu_seqlens,
const int batch_size,
const int vocab_size,
const int num_tokens,
cudaStream_t stream);

void invokeLLaMALogSoftmax(
float* out, const float* logits, const int num_tokens, const int vocab_size, cudaStream_t stream);

template<typename T>
void invokeLLaMAGetLastTokens(
T* out, T* in, const int* cu_seqlens, int batch_size, int hidden_size, cudaStream_t stream);

void invokeLLaMAExtractTargets(float* out,
float* in,
const int* target_ids,
const int* cu_seqlens,
int beam_width,
int batch_size,
int vocab_size,
int num_tokens,
cudaStream_t stream);
} // namespace fastertransformer
109 changes: 53 additions & 56 deletions src/fastertransformer/kernels/unfused_attention_kernels.cu

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions src/fastertransformer/kernels/unfused_attention_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ void invokeLLaMAAddFusedQKVBiasTranspose(T* q_buf,
const int head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int start_pos,
const int* start_pos,
cudaStream_t stream);

template<typename T>
Expand Down Expand Up @@ -209,24 +209,24 @@ void invokeLLaMASaveToCache(T* k_dst,
T* v_dst,
const T* k_src,
const T* v_src,
const int local_batch_size,
const int* context_lengths,
const int batch_size,
const int head_num,
const int size_per_head,
const int seq_len,
const int max_seq_len,
const int size_per_head,
const int local_head_num,
const int start_pos,
cudaStream_t stream);
template<typename T>
void invokeLLaMALoadFromCache(T* k_dst,
T* v_dst,
const T* k_src,
const T* v_src,
const int local_batch_size,
const int batch_size,
const int head_num,
const int size_per_head,
const int seq_len,
const int attn_len,
const int max_seq_len,
const int size_per_head,
const int local_head_num,
const int start_pos,
cudaStream_t stream);

template<typename T>
Expand Down
Loading