From 58444931dc1f4a885b6e0e6ddb54dc4656a406bb Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 21 Nov 2023 17:48:00 +0100 Subject: [PATCH 1/4] ggml-cuda : support stablelm rope --- ggml-cuda.cu | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 50e03de500747..d7b2c3394f590 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4609,8 +4609,8 @@ static __global__ void rope( template static __global__ void rope_neox( - const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, - float ext_factor, float attn_factor, rope_corr_dims corr_dims + const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -4619,23 +4619,25 @@ static __global__ void rope_neox( } const int row = blockDim.x*blockIdx.x + threadIdx.x; - const int i = row*ncols + col/2; + const int ib = col / n_dims; + const int ic = col % n_dims; + + const int i = row*ncols + ib*n_dims + ic/2; const int i2 = row/p_delta_rows; - // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero - const float cur_rot = -float(col)/ncols; + float cur_rot = inv_ndims * ic - ib; const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*powf(freq_base, cur_rot); + const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); float cos_theta, sin_theta; rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; - const float x1 = x[i + ncols/2]; + const float x1 = x[i + n_dims/2]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; } static __global__ void rope_glm_f32( @@ -5738,20 +5740,26 @@ static void rope_cuda( template static void rope_neox_cuda( - const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream ) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float inv_ndims = -1.0f / n_dims; + if (pos == nullptr) { rope_neox<<>>( - x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims ); } else { rope_neox<<>>( - x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims ); } } @@ -6706,15 +6714,14 @@ inline void ggml_cuda_op_rope( GGML_ASSERT(false); rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream); } else if (is_neox) { - GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); if (src0->type == GGML_TYPE_F32) { rope_neox_cuda( - (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, main_stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_cuda( - (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + (const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, main_stream ); } else { From 4a3469f20ef4341711e6596894fe472d9ec233d5 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 21 Nov 2023 17:53:54 +0100 Subject: [PATCH 2/4] remove unused freq_base kernel parameter --- ggml-cuda.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d7b2c3394f590..c12c30f73f69b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4609,7 +4609,7 @@ static __global__ void rope( template static __global__ void rope_neox( - const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, + const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -5753,12 +5753,12 @@ static void rope_neox_cuda( if (pos == nullptr) { rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, inv_ndims ); } else { rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, inv_ndims ); } From 84adb5412cd60c029fc3dc6695462ad9d86c1db3 Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 22 Nov 2023 00:45:43 +0100 Subject: [PATCH 3/4] add n_dims parameter to llm_build_k_shift, default to n_rot via overload --- llama.cpp | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index c2ad048699472..370ca6b7a6b5e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3460,6 +3460,7 @@ static void llm_build_k_shift( llm_rope_type type, int64_t n_ctx, int64_t n_rot, + int n_dims, float freq_base, float freq_scale, const llm_build_cb & cb) { @@ -3495,13 +3496,28 @@ static void llm_build_k_shift( ggml_element_size(kv.k)*n_embd_head, ggml_element_size(kv.k)*n_embd_gqa, ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il), - K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + K_shift, n_dims, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(tmp, "K_shifted", il); ggml_build_forward_expand(graph, tmp); } } +static void llm_build_k_shift( + struct ggml_context * ctx, + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache & kv, + struct ggml_cgraph * graph, + llm_rope_type type, + int64_t n_ctx, + int64_t n_rot, + float freq_base, + float freq_scale, + const llm_build_cb & cb) { + llm_build_k_shift(ctx, hparams, cparams, kv, graph, type, n_ctx, n_rot, n_rot, freq_base, freq_scale, cb); +} + static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, @@ -4798,7 +4814,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, hparams.n_rot, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { From 5ed3e1a8f27212bc09c174d11ce5c9c96cc48ec7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 24 Nov 2023 18:58:03 +0200 Subject: [PATCH 4/4] llama : fix llm_build_k_shift args --- llama.cpp | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/llama.cpp b/llama.cpp index 370ca6b7a6b5e..85c0ee0c1c6e8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3459,8 +3459,7 @@ static void llm_build_k_shift( struct ggml_cgraph * graph, llm_rope_type type, int64_t n_ctx, - int64_t n_rot, - int n_dims, + int n_rot, float freq_base, float freq_scale, const llm_build_cb & cb) { @@ -3492,32 +3491,17 @@ static void llm_build_k_shift( // we rotate only the first n_rot dimensions ggml_rope_custom_inplace(ctx, ggml_view_3d(ctx, kv.k, - n_rot, n_head_kv, n_ctx, + n_embd_head, n_head_kv, n_ctx, ggml_element_size(kv.k)*n_embd_head, ggml_element_size(kv.k)*n_embd_gqa, ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il), - K_shift, n_dims, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(tmp, "K_shifted", il); ggml_build_forward_expand(graph, tmp); } } -static void llm_build_k_shift( - struct ggml_context * ctx, - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_kv_cache & kv, - struct ggml_cgraph * graph, - llm_rope_type type, - int64_t n_ctx, - int64_t n_rot, - float freq_base, - float freq_scale, - const llm_build_cb & cb) { - llm_build_k_shift(ctx, hparams, cparams, kv, graph, type, n_ctx, n_rot, n_rot, freq_base, freq_scale, cb); -} - static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, @@ -4814,7 +4798,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, hparams.n_rot, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) {