Skip to content

Commit

Permalink
fix YaRN ramp, make mscale conditional, add --yarn-orig-ctx (ggerganov#2
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jquesnelle authored Oct 20, 2023
1 parent 9ae10b3 commit 14cf93b
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 24 deletions.
8 changes: 8 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.rope_freq_scale = 1.0f/std::stof(argv[i]);
} else if (arg == "--yarn-orig-ctx") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_orig_ctx = std::stoi(argv[i]);
} else if (arg == "--yarn-ext-factor") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -737,6 +743,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n");
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n");
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
Expand Down Expand Up @@ -861,6 +868,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;

return cparams;
}
Expand Down
5 changes: 3 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ struct gpt_params {
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
float yarn_ext_factor = NAN; // YaRN extrapolation mix factor
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
float yarn_beta_fast = 32.0f;// YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;

// // sampling parameters
Expand Down
7 changes: 3 additions & 4 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4406,7 +4406,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
}

static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / min(0.001f, high - low);
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}

Expand All @@ -4426,11 +4426,10 @@ static __device__ void rope_yarn(
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
}

// Get n-d magnitude scaling corrected for interpolation
if (freq_scale < 1.0f)
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}
Expand Down
7 changes: 3 additions & 4 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ kernel void kernel_alibi_f32(
}

static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / min(0.001f, high - low);
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}

Expand All @@ -896,11 +896,10 @@ static void rope_yarn(
if (ext_factor != 0.0f) {
ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
}

// Get n-d magnitude scaling corrected for interpolation
if (freq_scale < 1.0f)
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}
Expand Down
7 changes: 3 additions & 4 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -13345,7 +13345,7 @@ static void ggml_compute_forward_clamp(
// ggml_compute_forward_rope

static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / MIN(0.001f, high - low);
const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
return 1 - MIN(1, MAX(0, y));
}

Expand All @@ -13361,11 +13361,10 @@ static void rope_yarn(
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
}

// Get n-d magnitude scaling corrected for interpolation
if (freq_scale < 1.0f)
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}
Expand Down
10 changes: 6 additions & 4 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,7 @@ struct llama_cparams {
float rope_freq_base;
float rope_freq_scale;

uint32_t n_yarn_orig_ctx;
// These hyperparameters are not exposed in GGUF, because all
// existing YaRN models use the same values for them.
float yarn_ext_factor;
Expand Down Expand Up @@ -3028,7 +3029,7 @@ static struct ggml_cgraph * llm_build_llama(
const int32_t n_embd = hparams.n_embd;
const int32_t n_layer = hparams.n_layer;
const int32_t n_ctx = cparams.n_ctx;
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
const int32_t n_head = hparams.n_head;
const int32_t n_head_kv = hparams.n_head_kv;
const int32_t n_embd_head = hparams.n_embd_head();
Expand Down Expand Up @@ -3430,7 +3431,7 @@ static struct ggml_cgraph * llm_build_baichaun(
const int32_t n_embd = hparams.n_embd;
const int32_t n_layer = hparams.n_layer;
const int32_t n_ctx = cparams.n_ctx;
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
const int32_t n_head = hparams.n_head;
const int32_t n_head_kv = hparams.n_head_kv;
const int32_t n_embd_head = hparams.n_embd_head();
Expand Down Expand Up @@ -4194,7 +4195,7 @@ static struct ggml_cgraph * llm_build_falcon(
const int32_t n_embd = hparams.n_embd;
const int32_t n_layer = hparams.n_layer;
const int32_t n_ctx = cparams.n_ctx;
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
const int32_t n_head = hparams.n_head;
const int32_t n_head_kv = hparams.n_head_kv;
const int32_t n_embd_head = hparams.n_embd_head();
Expand Down Expand Up @@ -4818,7 +4819,7 @@ static struct ggml_cgraph * llm_build_persimmon(
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = cparams.n_ctx;
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_head = hparams.n_head;
const int64_t n_embd_head = hparams.n_embd_head();
Expand Down Expand Up @@ -8676,6 +8677,7 @@ struct llama_context * llama_new_context_with_model(
cparams.mul_mat_q = params.mul_mat_q;

cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.n_yarn_orig_ctx = params.yarn_orig_ctx == 0 ? hparams.n_ctx_train : params.yarn_orig_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;

Expand Down
13 changes: 7 additions & 6 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@ extern "C" {
int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`

// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
float yarn_attn_factor; // YaRN magnitude scaling factor
float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim
float rope_freq_base; // RoPE base frequency, 0 = from model
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
float yarn_attn_factor; // YaRN magnitude scaling factor
float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size

// Keep the booleans together to avoid misalignment during copy-by-value.
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
Expand Down

0 comments on commit 14cf93b

Please sign in to comment.