Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 6c36f54

Browse files
authored
resubmit "Implement the YaRN rop scaling feature" (#147)
1 parent 96dc559 commit 6c36f54

File tree

2 files changed

+109
-19
lines changed

2 files changed

+109
-19
lines changed

neural_speed/core/ne_layers.c

Lines changed: 95 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3107,7 +3107,8 @@ struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor*
31073107

31083108
struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
31093109
int prompt_size, bool inplace, int n_keep, struct ne_tensor* cossin, int* n_padding,
3110-
bool padding_left, float freq_base, float freq_scale) {
3110+
bool padding_left, float freq_base, float freq_scale, int yarn_orig_ctx,
3111+
float ext_factor, float attn_factor, float beta_fast, float beta_slow) {
31113112
NE_ASSERT(n_past >= 0 || n_keep >= 0);
31123113
NE_ASSERT(padding_left);
31133114
bool is_node = false;
@@ -3147,7 +3148,9 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
31473148

31483149
ne_scratch_load(ctx);
31493150

3150-
float params[] = {freq_base, freq_scale};
3151+
/* what the difference of setting parameters in b->data and in op_parameters */
3152+
/* float and int are in different data ?? */
3153+
float params[] = {freq_base, freq_scale, (float)yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow};
31513154
ne_set_op_params(result, &params, sizeof(params));
31523155

31533156
result->op = NE_OP_ROPE;
@@ -3161,19 +3164,36 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
31613164

31623165
struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
31633166
int prompt_size, float freq_base, float freq_scale) {
3164-
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base, freq_scale);
3167+
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base, freq_scale, 0,
3168+
0.0f, 1.0f, 0.0f, 0.0f);
31653169
}
31663170

31673171
struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
31683172
int prompt_size, float freq_base, float freq_scale) {
3169-
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale);
3173+
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale, 0,
3174+
0.0f, 1.0f, 0.0f, 0.0f);
31703175
}
31713176

31723177
struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims, int mode,
31733178
int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base,
31743179
float freq_scale) {
31753180
return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base,
3176-
freq_scale);
3181+
freq_scale, 0, 0.0f, 1.0f, 0.0f, 0.0f);
3182+
}
3183+
3184+
struct ne_tensor* ne_rope_custom_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
3185+
int prompt_size, float freq_base, float freq_scale, int yarn_orig_ctx,
3186+
float ext_factor, float attn_factor, float beta_fast, float beta_slow) {
3187+
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale,
3188+
yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow);
3189+
}
3190+
3191+
struct ne_tensor* ne_rope_custom_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims,
3192+
int mode, int prompt_size, int n_keep, struct ne_tensor* cossin,
3193+
float freq_base, float freq_scale, int yarn_orig_ctx, float ext_factor,
3194+
float attn_factor, float beta_fast, float beta_slow) {
3195+
return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base,
3196+
freq_scale, yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow);
31773197
}
31783198

31793199
// ne_rope_back
@@ -3211,14 +3231,14 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int
32113231
struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
32123232
int prompt_size, int* n_padding, float freq_base, float freq_scale) {
32133233
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, n_padding, true, freq_base,
3214-
freq_scale);
3234+
freq_scale, 0, 0.0f, 1.0f, 0.0f, 0.0f);
32153235
}
32163236

32173237
struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
32183238
int mode, int prompt_size, int* n_padding, float freq_base,
32193239
float freq_scale) {
3220-
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base,
3221-
freq_scale);
3240+
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base, freq_scale,
3241+
0, 0.0f, 1.0f, 0.0f, 0.0f);
32223242
}
32233243

32243244
// ne_alibi
@@ -8709,6 +8729,45 @@ static void ne_compute_forward_clamp(const struct ne_compute_params* params, con
87098729
}
87108730
}
87118731

8732+
static float rope_yarn_ramp(const float low, const float high, const int i0) {
8733+
const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
8734+
return 1.0 - MIN(1.0, MAX(0.0, y));
8735+
}
8736+
8737+
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
8738+
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
8739+
static void rope_yarn(float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor,
8740+
float mscale, float* cos_theta, float* sin_theta) {
8741+
// Get n-d rotational scaling corrected for extrapolation
8742+
float theta_interp = freq_scale * theta_extrap;
8743+
float theta = theta_interp;
8744+
if (ext_factor != 0.0f) {
8745+
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
8746+
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
8747+
8748+
// Get n-d magnitude scaling corrected for interpolation
8749+
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
8750+
}
8751+
*cos_theta = cosf(theta) * mscale;
8752+
*sin_theta = sinf(theta) * mscale;
8753+
}
8754+
8755+
#ifndef NE_PI
8756+
#define NE_PI (3.14159265358979323846)
8757+
#endif
8758+
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
8759+
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
8760+
static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
8761+
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)NE_PI)) / (2 * logf(base));
8762+
}
8763+
8764+
void ggml_rope_yarn_corr_dims(int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow,
8765+
float dims[2]) {
8766+
// start and end correction dims
8767+
dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
8768+
dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
8769+
}
8770+
87128771
// ne_compute_forward_rope
87138772
#define NE_TENSOR_UNARY_OP_LOCALS \
87148773
NE_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
@@ -8721,12 +8780,18 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
87218780
if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) {
87228781
return;
87238782
}
8783+
87248784
const int bs = src0->ne[3];
87258785
NE_ASSERT(src1->type == NE_TYPE_I32);
87268786
NE_ASSERT(ne_nelements(src1) == 5 + bs); // 5 + bs params
87278787

87288788
const float freq_base = ((float*)(dst->op_params))[0];
87298789
const float freq_scale = 1 / ((float*)(dst->op_params))[1];
8790+
const int n_orig_ctx = (int)((float*)(dst->op_params))[2];
8791+
const float ext_factor = ((float*)(dst->op_params))[3];
8792+
const float attn_factor = ((float*)(dst->op_params))[4];
8793+
const float beta_fast = ((float*)(dst->op_params))[5];
8794+
const float beta_slow = ((float*)(dst->op_params))[6];
87308795

87318796
const int64_t n_past = ((int32_t*)src1->data)[ROPE_NPAST_IDX];
87328797
const int64_t n_dims = ((int32_t*)src1->data)[ROPE_NDIMS_IDX];
@@ -8759,11 +8824,15 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
87598824
int ir = 0;
87608825

87618826
const float theta_scale = powf(freq_base, -2.0f / n_dims);
8827+
const float inv_ndims = -1.f / n_dims;
8828+
float corr_dims[2];
8829+
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
87628830

87638831
const bool skip = mode & 1;
87648832
const bool is_neox = mode & 2;
87658833
const bool is_glm = mode & 4;
87668834
const bool is_shift = n_keep >= 0;
8835+
const bool use_yarn = ((mode & 0x8) != 0);
87678836
NE_ASSERT(("RoPE shift not supported!", !is_shift));
87688837

87698838
NE_ASSERT(ne3 == bs);
@@ -8774,21 +8843,21 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
87748843
if (ir++ < ir0) continue;
87758844
if (ir > ir1) break;
87768845

8777-
float theta = freq_scale * (float)p;
8846+
float theta_base = (float)p;
87788847

87798848
// only for glm when mode == 4
87808849
if (is_glm) {
87818850
const int64_t n_padding = ((int32_t*)src1->data)[ROPE_PARAMS_NUM + i3];
87828851
// position ids
8783-
theta = MIN(MAX(p - n_padding, 0), prompt_size - 2 - n_padding);
8852+
theta_base = MIN(MAX(p - n_padding, 0), prompt_size - 2 - n_padding);
87848853
float block_theta = MAX(p - (prompt_size - 2), 0);
87858854
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
8786-
const float cos_theta = cosf(theta);
8787-
const float sin_theta = sinf(theta);
8855+
const float cos_theta = cosf(theta_base);
8856+
const float sin_theta = sinf(theta_base);
87888857
const float cos_block_theta = cosf(block_theta);
87898858
const float sin_block_theta = sinf(block_theta);
87908859

8791-
theta *= theta_scale;
8860+
theta_base *= theta_scale;
87928861
block_theta *= theta_scale;
87938862

87948863
const float* const src = (float*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00);
@@ -8805,11 +8874,12 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
88058874
dst_data[n_dims / 2 * 3] = x2 * sin_block_theta + x3 * cos_block_theta;
88068875
}
88078876
} else if (!is_neox) {
8877+
// printf("theta_base = %ld, freq_scale %.4f, ne0 %d\n", p, freq_scale, ne0);
88088878
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
8809-
const float cos_theta = cosf(theta);
8810-
const float sin_theta = sinf(theta);
8879+
float cos_theta, sin_theta;
8880+
rope_yarn(theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
88118881

8812-
theta *= theta_scale; // theta = i2 * theta_scale^(i0/2)
8882+
theta_base *= theta_scale;
88138883

88148884
const float* const src = (float*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00);
88158885
float* dst_data = (float*)((char*)dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
@@ -8824,12 +8894,18 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
88248894
// TODO: this is probably wrong, but I can't figure it out ..
88258895
// ref:
88268896
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
8897+
theta_base = theta_base * freq_scale;
8898+
88278899
for (int64_t ib = 0; ib < ne0 / n_dims; ++ib) {
88288900
for (int64_t ic = 0; ic < n_dims; ic += 2) {
8829-
const float cos_theta = cosf(theta);
8830-
const float sin_theta = sinf(theta);
8901+
// simplified from `(ib * n_dims + ic) * inv_ndims`
8902+
float cur_rot = inv_ndims * ic - ib;
88318903

8832-
theta *= theta_scale;
8904+
float cos_theta, sin_theta;
8905+
rope_yarn(theta_base, freq_scale, corr_dims, (int)cur_rot, ext_factor, attn_factor, &cos_theta,
8906+
&sin_theta);
8907+
8908+
theta_base *= theta_scale;
88338909

88348910
const int64_t i0 = ib * n_dims + ic / 2;
88358911

neural_speed/core/ne_layers.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,20 @@ NE_API struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne
414414
int mode, int prompt_size, int n_keep, struct ne_tensor* cossin,
415415
float freq_base, float freq_scale);
416416

417+
// in-place, returns view(a)
418+
NE_API struct ne_tensor* ne_rope_custom_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
419+
int mode, int prompt_size, float freq_base, float freq_scale,
420+
int yarn_orig_ctx, float ext_factor, float attn_factor, float beta_fast,
421+
float beta_slow);
422+
423+
// shift all tokens by a give p (n_shift)
424+
// Optionally give a 1d tensor of precomputed interleaved cos/sin value of n_shift*scale^k for k \in [0, n_dims)
425+
NE_API struct ne_tensor* ne_rope_custom_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift,
426+
int n_dims, int mode, int prompt_size, int n_keep,
427+
struct ne_tensor* cossin, float freq_base, float freq_scale,
428+
int yarn_orig_ctx, float ext_factor, float attn_factor,
429+
float beta_fast, float beta_slow);
430+
417431
// rotary position embedding backward, i.e compute dx from dy
418432
// a - dy
419433
NE_API struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode);

0 commit comments

Comments
 (0)