Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server : remove self-extend features #9860

Merged
merged 2 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1163,14 +1163,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, int value) {
params.grp_attn_n = value;
}
).set_env("LLAMA_ARG_GRP_ATTN_N"));
).set_env("LLAMA_ARG_GRP_ATTN_N").set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_PASSKEY}));
add_opt(common_arg(
{"-gaw", "--grp-attn-w"}, "N",
string_format("group-attention width (default: %.1f)", (double)params.grp_attn_w),
string_format("group-attention width (default: %d)", params.grp_attn_w),
[](common_params & params, int value) {
params.grp_attn_w = value;
}
).set_env("LLAMA_ARG_GRP_ATTN_W"));
).set_env("LLAMA_ARG_GRP_ATTN_W").set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg(
{"-dkvc", "--dump-kv-cache"},
"verbose print of the KV cache",
Expand Down
2 changes: 0 additions & 2 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ The project is under active development, and we are [looking for feedback and co
| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0)<br/>(env: LLAMA_ARG_YARN_ATTN_FACTOR) |
| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)<br/>(env: LLAMA_ARG_YARN_BETA_SLOW) |
| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)<br/>(env: LLAMA_ARG_YARN_BETA_FAST) |
| `-gan, --grp-attn-n N` | group-attention factor (default: 1)<br/>(env: LLAMA_ARG_GRP_ATTN_N) |
| `-gaw, --grp-attn-w N` | group-attention width (default: 512.0)<br/>(env: LLAMA_ARG_GRP_ATTN_W) |
| `-dkvc, --dump-kv-cache` | verbose print of the KV cache |
| `-nkvo, --no-kv-offload` | disable KV offload<br/>(env: LLAMA_ARG_NO_KV_OFFLOAD) |
| `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |
Expand Down
175 changes: 44 additions & 131 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,21 +193,15 @@ struct server_slot {

llama_token sampled;

int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1; // group-attention factor
int32_t ga_w = 512; // group-attention width

int32_t n_past_se = 0; // self-extend

// stats
size_t n_sent_text = 0; // number of sent text character
size_t n_sent_text = 0; // number of sent text character
size_t n_sent_token_probs = 0;

int64_t t_start_process_prompt;
int64_t t_start_generation;

double t_prompt_processing; // ms
double t_token_generation; // ms
double t_token_generation; // ms

std::function<void(int)> callback_on_release;

Expand All @@ -225,8 +219,6 @@ struct server_slot {
n_sent_text = 0;
n_sent_token_probs = 0;
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
ga_i = 0;
n_past_se = 0;

generated_token_probs.clear();
}
Expand Down Expand Up @@ -705,22 +697,6 @@ struct server_context {

SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);

const int ga_n = params.grp_attn_n;
const int ga_w = params.grp_attn_w;

if (ga_n != 1) {
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT

SLT_INF(slot, "slot self-extend: ga_n = %d, ga_w = %d\n", ga_n, ga_w);
}

slot.ga_i = 0;
slot.ga_n = ga_n;
slot.ga_w = ga_w;

slot.sparams = params.sparams;

slot.callback_on_release = [this](int) {
Expand Down Expand Up @@ -906,19 +882,14 @@ struct server_context {
}
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
slot.sparams.grammar = json_schema_to_grammar(schema);
auto schema = json_value(data, "json_schema", json::object());
slot.sparams.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) {
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
return false;
}
} else {
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
}

if (slot.params.cache_prompt && slot.ga_n != 1) {
slot.params.cache_prompt = false;
SLT_WRN(slot, "%s", "group-attention is not supported with prompt caching. disabling cache\n");
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
}

if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
Expand Down Expand Up @@ -1131,12 +1102,13 @@ struct server_context {
}

// if context shift is disabled, we stop when it reaches the context limit
if (slot.n_decoded >= slot.n_ctx) {
if (slot.n_past >= slot.n_ctx) {
slot.truncated = true;
slot.stopped_limit = true;
slot.has_next_token = false;

SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
}

if (llama_token_is_eog(model, result.tok)) {
Expand All @@ -1148,13 +1120,13 @@ struct server_context {

const auto n_ctx_train = llama_n_ctx_train(model);

if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
slot.truncated = true;
slot.stopped_limit = true;
slot.has_next_token = false; // stop prediction

SLT_WRN(slot,
"n_predict (%d) is not set and self-context extend is disabled. "
"n_predict (%d) is set for infinite generation. "
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
slot.params.n_predict, n_ctx_train);
}
Expand Down Expand Up @@ -1826,38 +1798,36 @@ struct server_context {
// apply context-shift if needed
// TODO: simplify and improve
for (server_slot & slot : slots) {
if (slot.ga_n == 1) {
if (slot.is_processing() && slot.n_past >= slot.n_ctx - 1) {
if (!params.ctx_shift) {
// this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token()
slot.release();
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
continue;
}

// Shift context
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_left = slot.n_past - n_keep;
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
if (!params.ctx_shift) {
// this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token()
slot.release();
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
continue;
}
Comment on lines +1802 to +1808
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson I think the comment is not entirely correct because in process_token() we check agains the training context length (n_ctx_train), while the slot's context slot.n_ctx could be smaller. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@ggerganov ggerganov Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I missed that, thanks.

Shouldn't we check this actually:

    if (slot.n_prompt_tokens + slot.n_decoded >= n_ctx) {

Hmm, or maybe:

    if (slot.n_past + slot.n_decoded >= n_ctx) {

Anyway, I will figure it out as I'm looking into this logic currently.

Copy link
Collaborator

@ngxson ngxson Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah I misunderstood n_decoded. Yeah, maybe we even need (int) system_tokens.size() + slot.n_prompt_tokens because system_tokens is already in KV cache before the first decode.

Thanks for looking into this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No sorry I haven't see #9811


SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
// Shift context
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_left = slot.n_past - n_keep;
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);

llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past, -n_discard);
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);

if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
}
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past, -n_discard);

slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
}

slot.n_past -= n_discard;

slot.truncated = true;
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
}

slot.n_past -= n_discard;

slot.truncated = true;
}
}

Expand All @@ -1872,9 +1842,7 @@ struct server_context {

slot.i_batch = batch.n_tokens;

const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;

common_batch_add(batch, slot.sampled, slot_npast, { slot.id + 1 }, true);
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id + 1 }, true);

slot.n_past += 1;

Expand Down Expand Up @@ -1993,6 +1961,8 @@ struct server_context {
} else {
if (!params.ctx_shift) {
// if context shift is disabled, we make sure prompt size is smaller than KV size
// TODO: there should be a separate parameter that control prompt truncation
// context shift should be applied only during the generation phase
if (slot.n_prompt_tokens >= slot.n_ctx) {
slot.release();
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
Expand All @@ -2005,7 +1975,7 @@ struct server_context {
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);

// if input prompt is too big, truncate it (if group attention self-extend is disabled)
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
if (slot.n_prompt_tokens >= slot.n_ctx) {
const int n_left = slot.n_ctx - slot.params.n_keep;

const int n_block_size = n_left / 2;
Expand All @@ -2032,12 +2002,7 @@ struct server_context {

common_sampler_reset(slot.smpl);

if (!slot.params.cache_prompt) {
slot.n_past_se = 0;
slot.ga_i = 0;
} else {
GGML_ASSERT(slot.ga_n == 1);

if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);

Expand All @@ -2053,9 +2018,6 @@ struct server_context {
SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);

slot.n_past--;
if (slot.ga_i > 0) {
slot.n_past_se--;
}
}

slot.n_prompt_tokens_processed = 0;
Expand All @@ -2081,52 +2043,31 @@ struct server_context {
}

// keep only the common part
int p0 = slot.n_past;

if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);

p0 = 0;

// there is no common part left
slot.n_past = 0;
slot.n_past_se = 0;
slot.ga_i = 0;

common_sampler_reset(slot.smpl);
}

SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);

// remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past);

SLT_INF(slot, "kv cache rm [%d, end)\n", p0);

int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;

int32_t ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;

// add prompt tokens for processing in the current batch
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) {
if (slot.ga_n != 1) {
while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1);
slot_npast -= bd;
ga_i += ga_w/ga_n;
}
}

common_batch_add(batch, prompt_tokens[slot.n_past], slot_npast, { slot.id + 1 }, false);
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);

if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
}

slot.n_prompt_tokens_processed++;
slot_npast++;
slot.n_past++;
}

SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
Expand Down Expand Up @@ -2167,34 +2108,6 @@ struct server_context {
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);

for (auto & slot : slots) {
if (slot.ga_n != 1) {
// context extension via Self-Extend
// TODO: simplify and/or abstract this
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;

SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
SLT_DBG(slot, "div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);

llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);

slot.n_past_se -= bd;

slot.ga_i += slot.ga_w / slot.ga_n;

SLT_DBG(slot, "\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
}

slot.n_past_se += n_tokens;
}
}

llama_batch batch_view = {
n_tokens,
batch.token + i,
Expand Down
4 changes: 4 additions & 0 deletions examples/server/tests/features/ctx_shift.feature
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ Feature: llama.cpp server
And 32 as batch size
And 2 slots

# the prompt is 301 tokens
# the slot context is 256/2 = 128 tokens
# the prompt is truncated to keep the last 109 tokens
# 64 tokens are generated thanks to shifting the context when it gets full
Scenario: Inference with context shift
And 64 server max tokens to predict
Then the server is starting
Expand Down
Loading