@@ -3107,7 +3107,8 @@ struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor*
3107
3107
3108
3108
struct ne_tensor * ne_rope_impl (struct ne_context * ctx , struct ne_tensor * a , int n_past , int n_dims , int mode ,
3109
3109
int prompt_size , bool inplace , int n_keep , struct ne_tensor * cossin , int * n_padding ,
3110
- bool padding_left , float freq_base , float freq_scale ) {
3110
+ bool padding_left , float freq_base , float freq_scale , int yarn_orig_ctx ,
3111
+ float ext_factor , float attn_factor , float beta_fast , float beta_slow ) {
3111
3112
NE_ASSERT (n_past >= 0 || n_keep >= 0 );
3112
3113
NE_ASSERT (padding_left );
3113
3114
bool is_node = false;
@@ -3147,7 +3148,9 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
3147
3148
3148
3149
ne_scratch_load (ctx );
3149
3150
3150
- float params [] = {freq_base , freq_scale };
3151
+ /* what the difference of setting parameters in b->data and in op_parameters */
3152
+ /* float and int are in different data ?? */
3153
+ float params [] = {freq_base , freq_scale , (float )yarn_orig_ctx , ext_factor , attn_factor , beta_fast , beta_slow };
3151
3154
ne_set_op_params (result , & params , sizeof (params ));
3152
3155
3153
3156
result -> op = NE_OP_ROPE ;
@@ -3161,19 +3164,36 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
3161
3164
3162
3165
struct ne_tensor * ne_rope (struct ne_context * ctx , struct ne_tensor * a , int n_past , int n_dims , int mode ,
3163
3166
int prompt_size , float freq_base , float freq_scale ) {
3164
- return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , false, -1 , NULL , NULL , true, freq_base , freq_scale );
3167
+ return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , false, -1 , NULL , NULL , true, freq_base , freq_scale , 0 ,
3168
+ 0.0f , 1.0f , 0.0f , 0.0f );
3165
3169
}
3166
3170
3167
3171
struct ne_tensor * ne_rope_inplace (struct ne_context * ctx , struct ne_tensor * a , int n_past , int n_dims , int mode ,
3168
3172
int prompt_size , float freq_base , float freq_scale ) {
3169
- return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , true, -1 , NULL , NULL , true, freq_base , freq_scale );
3173
+ return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , true, -1 , NULL , NULL , true, freq_base , freq_scale , 0 ,
3174
+ 0.0f , 1.0f , 0.0f , 0.0f );
3170
3175
}
3171
3176
3172
3177
struct ne_tensor * ne_rope_shift_inplace (struct ne_context * ctx , struct ne_tensor * a , int n_shift , int n_dims , int mode ,
3173
3178
int prompt_size , int n_keep , struct ne_tensor * cossin , float freq_base ,
3174
3179
float freq_scale ) {
3175
3180
return ne_rope_impl (ctx , a , n_shift , n_dims , mode , prompt_size , true, n_keep , cossin , NULL , true, freq_base ,
3176
- freq_scale );
3181
+ freq_scale , 0 , 0.0f , 1.0f , 0.0f , 0.0f );
3182
+ }
3183
+
3184
+ struct ne_tensor * ne_rope_custom_inplace (struct ne_context * ctx , struct ne_tensor * a , int n_past , int n_dims , int mode ,
3185
+ int prompt_size , float freq_base , float freq_scale , int yarn_orig_ctx ,
3186
+ float ext_factor , float attn_factor , float beta_fast , float beta_slow ) {
3187
+ return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , true, -1 , NULL , NULL , true, freq_base , freq_scale ,
3188
+ yarn_orig_ctx , ext_factor , attn_factor , beta_fast , beta_slow );
3189
+ }
3190
+
3191
+ struct ne_tensor * ne_rope_custom_shift_inplace (struct ne_context * ctx , struct ne_tensor * a , int n_shift , int n_dims ,
3192
+ int mode , int prompt_size , int n_keep , struct ne_tensor * cossin ,
3193
+ float freq_base , float freq_scale , int yarn_orig_ctx , float ext_factor ,
3194
+ float attn_factor , float beta_fast , float beta_slow ) {
3195
+ return ne_rope_impl (ctx , a , n_shift , n_dims , mode , prompt_size , true, n_keep , cossin , NULL , true, freq_base ,
3196
+ freq_scale , yarn_orig_ctx , ext_factor , attn_factor , beta_fast , beta_slow );
3177
3197
}
3178
3198
3179
3199
// ne_rope_back
@@ -3211,14 +3231,14 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int
3211
3231
struct ne_tensor * ne_rope_with_padding (struct ne_context * ctx , struct ne_tensor * a , int n_past , int n_dims , int mode ,
3212
3232
int prompt_size , int * n_padding , float freq_base , float freq_scale ) {
3213
3233
return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , false, -1 , NULL , n_padding , true, freq_base ,
3214
- freq_scale );
3234
+ freq_scale , 0 , 0.0f , 1.0f , 0.0f , 0.0f );
3215
3235
}
3216
3236
3217
3237
struct ne_tensor * ne_rope_with_padding_inplace (struct ne_context * ctx , struct ne_tensor * a , int n_past , int n_dims ,
3218
3238
int mode , int prompt_size , int * n_padding , float freq_base ,
3219
3239
float freq_scale ) {
3220
- return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , true, -1 , NULL , n_padding , true, freq_base ,
3221
- freq_scale );
3240
+ return ne_rope_impl (ctx , a , n_past , n_dims , mode , prompt_size , true, -1 , NULL , n_padding , true, freq_base , freq_scale ,
3241
+ 0 , 0.0f , 1.0f , 0.0f , 0.0f );
3222
3242
}
3223
3243
3224
3244
// ne_alibi
@@ -8709,6 +8729,45 @@ static void ne_compute_forward_clamp(const struct ne_compute_params* params, con
8709
8729
}
8710
8730
}
8711
8731
8732
+ static float rope_yarn_ramp (const float low , const float high , const int i0 ) {
8733
+ const float y = (i0 / 2 - low ) / MAX (0.001f , high - low );
8734
+ return 1.0 - MIN (1.0 , MAX (0.0 , y ));
8735
+ }
8736
+
8737
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
8738
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
8739
+ static void rope_yarn (float theta_extrap , float freq_scale , float corr_dims [2 ], int64_t i0 , float ext_factor ,
8740
+ float mscale , float * cos_theta , float * sin_theta ) {
8741
+ // Get n-d rotational scaling corrected for extrapolation
8742
+ float theta_interp = freq_scale * theta_extrap ;
8743
+ float theta = theta_interp ;
8744
+ if (ext_factor != 0.0f ) {
8745
+ float ramp_mix = rope_yarn_ramp (corr_dims [0 ], corr_dims [1 ], i0 ) * ext_factor ;
8746
+ theta = theta_interp * (1 - ramp_mix ) + theta_extrap * ramp_mix ;
8747
+
8748
+ // Get n-d magnitude scaling corrected for interpolation
8749
+ mscale *= 1.0f + 0.1f * logf (1.0f / freq_scale );
8750
+ }
8751
+ * cos_theta = cosf (theta ) * mscale ;
8752
+ * sin_theta = sinf (theta ) * mscale ;
8753
+ }
8754
+
8755
+ #ifndef NE_PI
8756
+ #define NE_PI (3.14159265358979323846)
8757
+ #endif
8758
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
8759
+ // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
8760
+ static float ggml_rope_yarn_corr_dim (int n_dims , int n_orig_ctx , float n_rot , float base ) {
8761
+ return n_dims * logf (n_orig_ctx / (n_rot * 2 * (float )NE_PI )) / (2 * logf (base ));
8762
+ }
8763
+
8764
+ void ggml_rope_yarn_corr_dims (int n_dims , int n_orig_ctx , float freq_base , float beta_fast , float beta_slow ,
8765
+ float dims [2 ]) {
8766
+ // start and end correction dims
8767
+ dims [0 ] = MAX (0 , floorf (ggml_rope_yarn_corr_dim (n_dims , n_orig_ctx , beta_fast , freq_base )));
8768
+ dims [1 ] = MIN (n_dims - 1 , ceilf (ggml_rope_yarn_corr_dim (n_dims , n_orig_ctx , beta_slow , freq_base )));
8769
+ }
8770
+
8712
8771
// ne_compute_forward_rope
8713
8772
#define NE_TENSOR_UNARY_OP_LOCALS \
8714
8773
NE_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
@@ -8721,12 +8780,18 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
8721
8780
if (params -> type == NE_TASK_INIT || params -> type == NE_TASK_FINALIZE ) {
8722
8781
return ;
8723
8782
}
8783
+
8724
8784
const int bs = src0 -> ne [3 ];
8725
8785
NE_ASSERT (src1 -> type == NE_TYPE_I32 );
8726
8786
NE_ASSERT (ne_nelements (src1 ) == 5 + bs ); // 5 + bs params
8727
8787
8728
8788
const float freq_base = ((float * )(dst -> op_params ))[0 ];
8729
8789
const float freq_scale = 1 / ((float * )(dst -> op_params ))[1 ];
8790
+ const int n_orig_ctx = (int )((float * )(dst -> op_params ))[2 ];
8791
+ const float ext_factor = ((float * )(dst -> op_params ))[3 ];
8792
+ const float attn_factor = ((float * )(dst -> op_params ))[4 ];
8793
+ const float beta_fast = ((float * )(dst -> op_params ))[5 ];
8794
+ const float beta_slow = ((float * )(dst -> op_params ))[6 ];
8730
8795
8731
8796
const int64_t n_past = ((int32_t * )src1 -> data )[ROPE_NPAST_IDX ];
8732
8797
const int64_t n_dims = ((int32_t * )src1 -> data )[ROPE_NDIMS_IDX ];
@@ -8759,11 +8824,15 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
8759
8824
int ir = 0 ;
8760
8825
8761
8826
const float theta_scale = powf (freq_base , -2.0f / n_dims );
8827
+ const float inv_ndims = -1.f / n_dims ;
8828
+ float corr_dims [2 ];
8829
+ ggml_rope_yarn_corr_dims (n_dims , n_orig_ctx , freq_base , beta_fast , beta_slow , corr_dims );
8762
8830
8763
8831
const bool skip = mode & 1 ;
8764
8832
const bool is_neox = mode & 2 ;
8765
8833
const bool is_glm = mode & 4 ;
8766
8834
const bool is_shift = n_keep >= 0 ;
8835
+ const bool use_yarn = ((mode & 0x8 ) != 0 );
8767
8836
NE_ASSERT (("RoPE shift not supported!" , !is_shift ));
8768
8837
8769
8838
NE_ASSERT (ne3 == bs );
@@ -8774,21 +8843,21 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
8774
8843
if (ir ++ < ir0 ) continue ;
8775
8844
if (ir > ir1 ) break ;
8776
8845
8777
- float theta = freq_scale * (float )p ;
8846
+ float theta_base = (float )p ;
8778
8847
8779
8848
// only for glm when mode == 4
8780
8849
if (is_glm ) {
8781
8850
const int64_t n_padding = ((int32_t * )src1 -> data )[ROPE_PARAMS_NUM + i3 ];
8782
8851
// position ids
8783
- theta = MIN (MAX (p - n_padding , 0 ), prompt_size - 2 - n_padding );
8852
+ theta_base = MIN (MAX (p - n_padding , 0 ), prompt_size - 2 - n_padding );
8784
8853
float block_theta = MAX (p - (prompt_size - 2 ), 0 );
8785
8854
for (int64_t i0 = 0 ; i0 < ne0 / 4 ; i0 ++ ) {
8786
- const float cos_theta = cosf (theta );
8787
- const float sin_theta = sinf (theta );
8855
+ const float cos_theta = cosf (theta_base );
8856
+ const float sin_theta = sinf (theta_base );
8788
8857
const float cos_block_theta = cosf (block_theta );
8789
8858
const float sin_block_theta = sinf (block_theta );
8790
8859
8791
- theta *= theta_scale ;
8860
+ theta_base *= theta_scale ;
8792
8861
block_theta *= theta_scale ;
8793
8862
8794
8863
const float * const src = (float * )((char * )src0 -> data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00 );
@@ -8805,11 +8874,12 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
8805
8874
dst_data [n_dims / 2 * 3 ] = x2 * sin_block_theta + x3 * cos_block_theta ;
8806
8875
}
8807
8876
} else if (!is_neox ) {
8877
+ // printf("theta_base = %ld, freq_scale %.4f, ne0 %d\n", p, freq_scale, ne0);
8808
8878
for (int64_t i0 = 0 ; i0 < ne0 ; i0 += 2 ) {
8809
- const float cos_theta = cosf ( theta ) ;
8810
- const float sin_theta = sinf ( theta );
8879
+ float cos_theta , sin_theta ;
8880
+ rope_yarn ( theta_base , freq_scale , corr_dims , i0 , ext_factor , attn_factor , & cos_theta , & sin_theta );
8811
8881
8812
- theta *= theta_scale ; // theta = i2 * theta_scale^(i0/2)
8882
+ theta_base *= theta_scale ;
8813
8883
8814
8884
const float * const src = (float * )((char * )src0 -> data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00 );
8815
8885
float * dst_data = (float * )((char * )dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
@@ -8824,12 +8894,18 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
8824
8894
// TODO: this is probably wrong, but I can't figure it out ..
8825
8895
// ref:
8826
8896
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
8897
+ theta_base = theta_base * freq_scale ;
8898
+
8827
8899
for (int64_t ib = 0 ; ib < ne0 / n_dims ; ++ ib ) {
8828
8900
for (int64_t ic = 0 ; ic < n_dims ; ic += 2 ) {
8829
- const float cos_theta = cosf ( theta );
8830
- const float sin_theta = sinf ( theta ) ;
8901
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
8902
+ float cur_rot = inv_ndims * ic - ib ;
8831
8903
8832
- theta *= theta_scale ;
8904
+ float cos_theta , sin_theta ;
8905
+ rope_yarn (theta_base , freq_scale , corr_dims , (int )cur_rot , ext_factor , attn_factor , & cos_theta ,
8906
+ & sin_theta );
8907
+
8908
+ theta_base *= theta_scale ;
8833
8909
8834
8910
const int64_t i0 = ib * n_dims + ic / 2 ;
8835
8911
0 commit comments