Skip to content

Commit e5f74fe

Browse files
committed
server : remove self-extend
ggml-ci
1 parent 0db72b6 commit e5f74fe

File tree

3 files changed

+42
-134
lines changed

3 files changed

+42
-134
lines changed

common/arg.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1163,14 +1163,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
11631163
[](common_params & params, int value) {
11641164
params.grp_attn_n = value;
11651165
}
1166-
).set_env("LLAMA_ARG_GRP_ATTN_N"));
1166+
).set_env("LLAMA_ARG_GRP_ATTN_N").set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_PASSKEY}));
11671167
add_opt(common_arg(
11681168
{"-gaw", "--grp-attn-w"}, "N",
1169-
string_format("group-attention width (default: %.1f)", (double)params.grp_attn_w),
1169+
string_format("group-attention width (default: %d)", params.grp_attn_w),
11701170
[](common_params & params, int value) {
11711171
params.grp_attn_w = value;
11721172
}
1173-
).set_env("LLAMA_ARG_GRP_ATTN_W"));
1173+
).set_env("LLAMA_ARG_GRP_ATTN_W").set_examples({LLAMA_EXAMPLE_MAIN}));
11741174
add_opt(common_arg(
11751175
{"-dkvc", "--dump-kv-cache"},
11761176
"verbose print of the KV cache",

examples/server/README.md

-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ The project is under active development, and we are [looking for feedback and co
6060
| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0)<br/>(env: LLAMA_ARG_YARN_ATTN_FACTOR) |
6161
| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)<br/>(env: LLAMA_ARG_YARN_BETA_SLOW) |
6262
| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)<br/>(env: LLAMA_ARG_YARN_BETA_FAST) |
63-
| `-gan, --grp-attn-n N` | group-attention factor (default: 1)<br/>(env: LLAMA_ARG_GRP_ATTN_N) |
64-
| `-gaw, --grp-attn-w N` | group-attention width (default: 512.0)<br/>(env: LLAMA_ARG_GRP_ATTN_W) |
6563
| `-dkvc, --dump-kv-cache` | verbose print of the KV cache |
6664
| `-nkvo, --no-kv-offload` | disable KV offload<br/>(env: LLAMA_ARG_NO_KV_OFFLOAD) |
6765
| `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |

examples/server/server.cpp

+39-129
Original file line numberDiff line numberDiff line change
@@ -193,21 +193,15 @@ struct server_slot {
193193

194194
llama_token sampled;
195195

196-
int32_t ga_i = 0; // group-attention state
197-
int32_t ga_n = 1; // group-attention factor
198-
int32_t ga_w = 512; // group-attention width
199-
200-
int32_t n_past_se = 0; // self-extend
201-
202196
// stats
203-
size_t n_sent_text = 0; // number of sent text character
197+
size_t n_sent_text = 0; // number of sent text character
204198
size_t n_sent_token_probs = 0;
205199

206200
int64_t t_start_process_prompt;
207201
int64_t t_start_generation;
208202

209203
double t_prompt_processing; // ms
210-
double t_token_generation; // ms
204+
double t_token_generation; // ms
211205

212206
std::function<void(int)> callback_on_release;
213207

@@ -225,8 +219,6 @@ struct server_slot {
225219
n_sent_text = 0;
226220
n_sent_token_probs = 0;
227221
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
228-
ga_i = 0;
229-
n_past_se = 0;
230222

231223
generated_token_probs.clear();
232224
}
@@ -705,22 +697,6 @@ struct server_context {
705697

706698
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
707699

708-
const int ga_n = params.grp_attn_n;
709-
const int ga_w = params.grp_attn_w;
710-
711-
if (ga_n != 1) {
712-
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
713-
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
714-
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
715-
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
716-
717-
SLT_INF(slot, "slot self-extend: ga_n = %d, ga_w = %d\n", ga_n, ga_w);
718-
}
719-
720-
slot.ga_i = 0;
721-
slot.ga_n = ga_n;
722-
slot.ga_w = ga_w;
723-
724700
slot.sparams = params.sparams;
725701

726702
slot.callback_on_release = [this](int) {
@@ -906,19 +882,14 @@ struct server_context {
906882
}
907883
if (data.contains("json_schema") && !data.contains("grammar")) {
908884
try {
909-
auto schema = json_value(data, "json_schema", json::object());
910-
slot.sparams.grammar = json_schema_to_grammar(schema);
885+
auto schema = json_value(data, "json_schema", json::object());
886+
slot.sparams.grammar = json_schema_to_grammar(schema);
911887
} catch (const std::exception & e) {
912888
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
913889
return false;
914890
}
915891
} else {
916-
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
917-
}
918-
919-
if (slot.params.cache_prompt && slot.ga_n != 1) {
920-
slot.params.cache_prompt = false;
921-
SLT_WRN(slot, "%s", "group-attention is not supported with prompt caching. disabling cache\n");
892+
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
922893
}
923894

924895
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
@@ -1148,13 +1119,13 @@ struct server_context {
11481119

11491120
const auto n_ctx_train = llama_n_ctx_train(model);
11501121

1151-
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1122+
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
11521123
slot.truncated = true;
11531124
slot.stopped_limit = true;
11541125
slot.has_next_token = false; // stop prediction
11551126

11561127
SLT_WRN(slot,
1157-
"n_predict (%d) is not set and self-context extend is disabled. "
1128+
"n_predict (%d) is set for infinite generation. "
11581129
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
11591130
slot.params.n_predict, n_ctx_train);
11601131
}
@@ -1826,38 +1797,36 @@ struct server_context {
18261797
// apply context-shift if needed
18271798
// TODO: simplify and improve
18281799
for (server_slot & slot : slots) {
1829-
if (slot.ga_n == 1) {
1830-
if (slot.is_processing() && slot.n_past >= slot.n_ctx - 1) {
1831-
if (!params.ctx_shift) {
1832-
// this check is redundant (for good)
1833-
// we should never get here, because generation should already stopped in process_token()
1834-
slot.release();
1835-
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
1836-
continue;
1837-
}
1838-
1839-
// Shift context
1840-
const int n_keep = slot.params.n_keep + add_bos_token;
1841-
const int n_left = slot.n_past - n_keep;
1842-
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
1800+
if (slot.is_processing() && slot.n_past >= slot.n_ctx - 1) {
1801+
if (!params.ctx_shift) {
1802+
// this check is redundant (for good)
1803+
// we should never get here, because generation should already stopped in process_token()
1804+
slot.release();
1805+
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
1806+
continue;
1807+
}
18431808

1844-
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
1809+
// Shift context
1810+
const int n_keep = slot.params.n_keep + add_bos_token;
1811+
const int n_left = slot.n_past - n_keep;
1812+
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
18451813

1846-
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
1847-
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past, -n_discard);
1814+
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
18481815

1849-
if (slot.params.cache_prompt) {
1850-
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
1851-
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
1852-
}
1816+
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
1817+
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past, -n_discard);
18531818

1854-
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
1819+
if (slot.params.cache_prompt) {
1820+
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
1821+
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
18551822
}
18561823

1857-
slot.n_past -= n_discard;
1858-
1859-
slot.truncated = true;
1824+
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
18601825
}
1826+
1827+
slot.n_past -= n_discard;
1828+
1829+
slot.truncated = true;
18611830
}
18621831
}
18631832

@@ -1872,9 +1841,7 @@ struct server_context {
18721841

18731842
slot.i_batch = batch.n_tokens;
18741843

1875-
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
1876-
1877-
common_batch_add(batch, slot.sampled, slot_npast, { slot.id + 1 }, true);
1844+
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id + 1 }, true);
18781845

18791846
slot.n_past += 1;
18801847

@@ -2005,7 +1972,7 @@ struct server_context {
20051972
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
20061973

20071974
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
2008-
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
1975+
if (slot.n_prompt_tokens >= slot.n_ctx) {
20091976
const int n_left = slot.n_ctx - slot.params.n_keep;
20101977

20111978
const int n_block_size = n_left / 2;
@@ -2032,12 +1999,7 @@ struct server_context {
20321999

20332000
common_sampler_reset(slot.smpl);
20342001

2035-
if (!slot.params.cache_prompt) {
2036-
slot.n_past_se = 0;
2037-
slot.ga_i = 0;
2038-
} else {
2039-
GGML_ASSERT(slot.ga_n == 1);
2040-
2002+
if (slot.params.cache_prompt) {
20412003
// reuse any previously computed tokens that are common with the new prompt
20422004
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
20432005

@@ -2053,9 +2015,6 @@ struct server_context {
20532015
SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
20542016

20552017
slot.n_past--;
2056-
if (slot.ga_i > 0) {
2057-
slot.n_past_se--;
2058-
}
20592018
}
20602019

20612020
slot.n_prompt_tokens_processed = 0;
@@ -2081,52 +2040,31 @@ struct server_context {
20812040
}
20822041

20832042
// keep only the common part
2084-
int p0 = slot.n_past;
2085-
2086-
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
2043+
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
20872044
// could not partially delete (likely using a non-Transformer model)
20882045
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
20892046

2090-
p0 = 0;
2091-
20922047
// there is no common part left
20932048
slot.n_past = 0;
2094-
slot.n_past_se = 0;
2095-
slot.ga_i = 0;
20962049

20972050
common_sampler_reset(slot.smpl);
20982051
}
20992052

2053+
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
2054+
21002055
// remove the non-common part from the cache
21012056
slot.cache_tokens.resize(slot.n_past);
21022057

2103-
SLT_INF(slot, "kv cache rm [%d, end)\n", p0);
2104-
2105-
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
2106-
2107-
int32_t ga_i = slot.ga_i;
2108-
int32_t ga_n = slot.ga_n;
2109-
int32_t ga_w = slot.ga_w;
2110-
21112058
// add prompt tokens for processing in the current batch
2112-
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
2113-
for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) {
2114-
if (slot.ga_n != 1) {
2115-
while (slot_npast >= ga_i + ga_w) {
2116-
const int bd = (ga_w/ga_n)*(ga_n - 1);
2117-
slot_npast -= bd;
2118-
ga_i += ga_w/ga_n;
2119-
}
2120-
}
2121-
2122-
common_batch_add(batch, prompt_tokens[slot.n_past], slot_npast, { slot.id + 1 }, false);
2059+
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
2060+
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
21232061

21242062
if (slot.params.cache_prompt) {
21252063
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
21262064
}
21272065

21282066
slot.n_prompt_tokens_processed++;
2129-
slot_npast++;
2067+
slot.n_past++;
21302068
}
21312069

21322070
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
@@ -2167,34 +2105,6 @@ struct server_context {
21672105
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
21682106
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
21692107

2170-
for (auto & slot : slots) {
2171-
if (slot.ga_n != 1) {
2172-
// context extension via Self-Extend
2173-
// TODO: simplify and/or abstract this
2174-
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
2175-
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
2176-
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
2177-
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
2178-
2179-
SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
2180-
SLT_DBG(slot, "div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
2181-
SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
2182-
2183-
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
2184-
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
2185-
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
2186-
2187-
slot.n_past_se -= bd;
2188-
2189-
slot.ga_i += slot.ga_w / slot.ga_n;
2190-
2191-
SLT_DBG(slot, "\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
2192-
}
2193-
2194-
slot.n_past_se += n_tokens;
2195-
}
2196-
}
2197-
21982108
llama_batch batch_view = {
21992109
n_tokens,
22002110
batch.token + i,

0 commit comments

Comments
 (0)