@@ -193,21 +193,15 @@ struct server_slot {
193
193
194
194
llama_token sampled;
195
195
196
- int32_t ga_i = 0 ; // group-attention state
197
- int32_t ga_n = 1 ; // group-attention factor
198
- int32_t ga_w = 512 ; // group-attention width
199
-
200
- int32_t n_past_se = 0 ; // self-extend
201
-
202
196
// stats
203
- size_t n_sent_text = 0 ; // number of sent text character
197
+ size_t n_sent_text = 0 ; // number of sent text character
204
198
size_t n_sent_token_probs = 0 ;
205
199
206
200
int64_t t_start_process_prompt;
207
201
int64_t t_start_generation;
208
202
209
203
double t_prompt_processing; // ms
210
- double t_token_generation; // ms
204
+ double t_token_generation; // ms
211
205
212
206
std::function<void (int )> callback_on_release;
213
207
@@ -225,8 +219,6 @@ struct server_slot {
225
219
n_sent_text = 0 ;
226
220
n_sent_token_probs = 0 ;
227
221
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
228
- ga_i = 0 ;
229
- n_past_se = 0 ;
230
222
231
223
generated_token_probs.clear ();
232
224
}
@@ -705,22 +697,6 @@ struct server_context {
705
697
706
698
SLT_INF (slot, " new slot n_ctx_slot = %d\n " , slot.n_ctx );
707
699
708
- const int ga_n = params.grp_attn_n ;
709
- const int ga_w = params.grp_attn_w ;
710
-
711
- if (ga_n != 1 ) {
712
- GGML_ASSERT (ga_n > 0 && " ga_n must be positive" ); // NOLINT
713
- GGML_ASSERT (ga_w % ga_n == 0 && " ga_w must be a multiple of ga_n" ); // NOLINT
714
- // GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
715
- // GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
716
-
717
- SLT_INF (slot, " slot self-extend: ga_n = %d, ga_w = %d\n " , ga_n, ga_w);
718
- }
719
-
720
- slot.ga_i = 0 ;
721
- slot.ga_n = ga_n;
722
- slot.ga_w = ga_w;
723
-
724
700
slot.sparams = params.sparams ;
725
701
726
702
slot.callback_on_release = [this ](int ) {
@@ -906,19 +882,14 @@ struct server_context {
906
882
}
907
883
if (data.contains (" json_schema" ) && !data.contains (" grammar" )) {
908
884
try {
909
- auto schema = json_value (data, " json_schema" , json::object ());
910
- slot.sparams .grammar = json_schema_to_grammar (schema);
885
+ auto schema = json_value (data, " json_schema" , json::object ());
886
+ slot.sparams .grammar = json_schema_to_grammar (schema);
911
887
} catch (const std::exception & e) {
912
888
send_error (task, std::string (" \" json_schema\" : " ) + e.what (), ERROR_TYPE_INVALID_REQUEST);
913
889
return false ;
914
890
}
915
891
} else {
916
- slot.sparams .grammar = json_value (data, " grammar" , default_sparams.grammar );
917
- }
918
-
919
- if (slot.params .cache_prompt && slot.ga_n != 1 ) {
920
- slot.params .cache_prompt = false ;
921
- SLT_WRN (slot, " %s" , " group-attention is not supported with prompt caching. disabling cache\n " );
892
+ slot.sparams .grammar = json_value (data, " grammar" , default_sparams.grammar );
922
893
}
923
894
924
895
if (slot.n_predict > 0 && slot.params .n_predict > slot.n_predict ) {
@@ -1148,13 +1119,13 @@ struct server_context {
1148
1119
1149
1120
const auto n_ctx_train = llama_n_ctx_train (model);
1150
1121
1151
- if (slot.params .n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && slot. n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1122
+ if (slot.params .n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1152
1123
slot.truncated = true ;
1153
1124
slot.stopped_limit = true ;
1154
1125
slot.has_next_token = false ; // stop prediction
1155
1126
1156
1127
SLT_WRN (slot,
1157
- " n_predict (%d) is not set and self-context extend is disabled . "
1128
+ " n_predict (%d) is set for infinite generation . "
1158
1129
" Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n " ,
1159
1130
slot.params .n_predict , n_ctx_train);
1160
1131
}
@@ -1826,38 +1797,36 @@ struct server_context {
1826
1797
// apply context-shift if needed
1827
1798
// TODO: simplify and improve
1828
1799
for (server_slot & slot : slots) {
1829
- if (slot.ga_n == 1 ) {
1830
- if (slot.is_processing () && slot.n_past >= slot.n_ctx - 1 ) {
1831
- if (!params.ctx_shift ) {
1832
- // this check is redundant (for good)
1833
- // we should never get here, because generation should already stopped in process_token()
1834
- slot.release ();
1835
- send_error (slot, " context shift is disabled" , ERROR_TYPE_SERVER);
1836
- continue ;
1837
- }
1838
-
1839
- // Shift context
1840
- const int n_keep = slot.params .n_keep + add_bos_token;
1841
- const int n_left = slot.n_past - n_keep;
1842
- const int n_discard = slot.params .n_discard ? slot.params .n_discard : (n_left / 2 );
1800
+ if (slot.is_processing () && slot.n_past >= slot.n_ctx - 1 ) {
1801
+ if (!params.ctx_shift ) {
1802
+ // this check is redundant (for good)
1803
+ // we should never get here, because generation should already stopped in process_token()
1804
+ slot.release ();
1805
+ send_error (slot, " context shift is disabled" , ERROR_TYPE_SERVER);
1806
+ continue ;
1807
+ }
1843
1808
1844
- SLT_WRN (slot, " slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n " , n_keep, n_left, n_discard);
1809
+ // Shift context
1810
+ const int n_keep = slot.params .n_keep + add_bos_token;
1811
+ const int n_left = slot.n_past - n_keep;
1812
+ const int n_discard = slot.params .n_discard ? slot.params .n_discard : (n_left / 2 );
1845
1813
1846
- llama_kv_cache_seq_rm (ctx, slot.id + 1 , n_keep , n_keep + n_discard);
1847
- llama_kv_cache_seq_add (ctx, slot.id + 1 , n_keep + n_discard, slot.n_past , -n_discard);
1814
+ SLT_WRN (slot, " slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n " , n_keep, n_left, n_discard);
1848
1815
1849
- if (slot.params .cache_prompt ) {
1850
- for (size_t i = n_keep + n_discard; i < slot.cache_tokens .size (); i++) {
1851
- slot.cache_tokens [i - n_discard] = slot.cache_tokens [i];
1852
- }
1816
+ llama_kv_cache_seq_rm (ctx, slot.id + 1 , n_keep , n_keep + n_discard);
1817
+ llama_kv_cache_seq_add (ctx, slot.id + 1 , n_keep + n_discard, slot.n_past , -n_discard);
1853
1818
1854
- slot.cache_tokens .resize (slot.cache_tokens .size () - n_discard);
1819
+ if (slot.params .cache_prompt ) {
1820
+ for (size_t i = n_keep + n_discard; i < slot.cache_tokens .size (); i++) {
1821
+ slot.cache_tokens [i - n_discard] = slot.cache_tokens [i];
1855
1822
}
1856
1823
1857
- slot.n_past -= n_discard;
1858
-
1859
- slot.truncated = true ;
1824
+ slot.cache_tokens .resize (slot.cache_tokens .size () - n_discard);
1860
1825
}
1826
+
1827
+ slot.n_past -= n_discard;
1828
+
1829
+ slot.truncated = true ;
1861
1830
}
1862
1831
}
1863
1832
@@ -1872,9 +1841,7 @@ struct server_context {
1872
1841
1873
1842
slot.i_batch = batch.n_tokens ;
1874
1843
1875
- const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past ;
1876
-
1877
- common_batch_add (batch, slot.sampled , slot_npast, { slot.id + 1 }, true );
1844
+ common_batch_add (batch, slot.sampled , slot.n_past , { slot.id + 1 }, true );
1878
1845
1879
1846
slot.n_past += 1 ;
1880
1847
@@ -2005,7 +1972,7 @@ struct server_context {
2005
1972
slot.params .n_keep = std::min (slot.n_ctx - 4 , slot.params .n_keep );
2006
1973
2007
1974
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
2008
- if (slot.ga_n == 1 && slot. n_prompt_tokens >= slot.n_ctx ) {
1975
+ if (slot.n_prompt_tokens >= slot.n_ctx ) {
2009
1976
const int n_left = slot.n_ctx - slot.params .n_keep ;
2010
1977
2011
1978
const int n_block_size = n_left / 2 ;
@@ -2032,12 +1999,7 @@ struct server_context {
2032
1999
2033
2000
common_sampler_reset (slot.smpl );
2034
2001
2035
- if (!slot.params .cache_prompt ) {
2036
- slot.n_past_se = 0 ;
2037
- slot.ga_i = 0 ;
2038
- } else {
2039
- GGML_ASSERT (slot.ga_n == 1 );
2040
-
2002
+ if (slot.params .cache_prompt ) {
2041
2003
// reuse any previously computed tokens that are common with the new prompt
2042
2004
slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
2043
2005
@@ -2053,9 +2015,6 @@ struct server_context {
2053
2015
SLT_WRN (slot, " need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n " , slot.n_past , slot.n_prompt_tokens );
2054
2016
2055
2017
slot.n_past --;
2056
- if (slot.ga_i > 0 ) {
2057
- slot.n_past_se --;
2058
- }
2059
2018
}
2060
2019
2061
2020
slot.n_prompt_tokens_processed = 0 ;
@@ -2081,52 +2040,31 @@ struct server_context {
2081
2040
}
2082
2041
2083
2042
// keep only the common part
2084
- int p0 = slot.n_past ;
2085
-
2086
- if (!llama_kv_cache_seq_rm (ctx, slot.id + 1 , p0, -1 )) {
2043
+ if (!llama_kv_cache_seq_rm (ctx, slot.id + 1 , slot.n_past , -1 )) {
2087
2044
// could not partially delete (likely using a non-Transformer model)
2088
2045
llama_kv_cache_seq_rm (ctx, slot.id + 1 , -1 , -1 );
2089
2046
2090
- p0 = 0 ;
2091
-
2092
2047
// there is no common part left
2093
2048
slot.n_past = 0 ;
2094
- slot.n_past_se = 0 ;
2095
- slot.ga_i = 0 ;
2096
2049
2097
2050
common_sampler_reset (slot.smpl );
2098
2051
}
2099
2052
2053
+ SLT_INF (slot, " kv cache rm [%d, end)\n " , slot.n_past );
2054
+
2100
2055
// remove the non-common part from the cache
2101
2056
slot.cache_tokens .resize (slot.n_past );
2102
2057
2103
- SLT_INF (slot, " kv cache rm [%d, end)\n " , p0);
2104
-
2105
- int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past ;
2106
-
2107
- int32_t ga_i = slot.ga_i ;
2108
- int32_t ga_n = slot.ga_n ;
2109
- int32_t ga_w = slot.ga_w ;
2110
-
2111
2058
// add prompt tokens for processing in the current batch
2112
- // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
2113
- for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past ) {
2114
- if (slot.ga_n != 1 ) {
2115
- while (slot_npast >= ga_i + ga_w) {
2116
- const int bd = (ga_w/ga_n)*(ga_n - 1 );
2117
- slot_npast -= bd;
2118
- ga_i += ga_w/ga_n;
2119
- }
2120
- }
2121
-
2122
- common_batch_add (batch, prompt_tokens[slot.n_past ], slot_npast, { slot.id + 1 }, false );
2059
+ while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
2060
+ common_batch_add (batch, prompt_tokens[slot.n_past ], slot.n_past , { slot.id + 1 }, false );
2123
2061
2124
2062
if (slot.params .cache_prompt ) {
2125
2063
slot.cache_tokens .push_back (prompt_tokens[slot.n_past ]);
2126
2064
}
2127
2065
2128
2066
slot.n_prompt_tokens_processed ++;
2129
- slot_npast ++;
2067
+ slot. n_past ++;
2130
2068
}
2131
2069
2132
2070
SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , batch.n_tokens , (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
@@ -2167,34 +2105,6 @@ struct server_context {
2167
2105
for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
2168
2106
const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
2169
2107
2170
- for (auto & slot : slots) {
2171
- if (slot.ga_n != 1 ) {
2172
- // context extension via Self-Extend
2173
- // TODO: simplify and/or abstract this
2174
- while (slot.n_past_se >= slot.ga_i + slot.ga_w ) {
2175
- const int ib = (slot.ga_n * slot.ga_i ) / slot.ga_w ;
2176
- const int bd = (slot.ga_w / slot.ga_n ) * (slot.ga_n - 1 );
2177
- const int dd = (slot.ga_w / slot.ga_n ) - ib * bd - slot.ga_w ;
2178
-
2179
- SLT_DBG (slot, " shift: [%6d, %6d] + %6d -> [%6d, %6d]\n " , slot.ga_i , slot.n_past_se , ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
2180
- SLT_DBG (slot, " div: [%6d, %6d] / %6d -> [%6d, %6d]\n " , slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w , slot.ga_n , (slot.ga_i + ib * bd) / slot.ga_n , (slot.ga_i + ib * bd + slot.ga_w ) / slot.ga_n );
2181
- SLT_DBG (slot, " shift: [%6d, %6d] + %6d -> [%6d, %6d]\n " , slot.ga_i + ib * bd + slot.ga_w , slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
2182
-
2183
- llama_kv_cache_seq_add (ctx, slot.id + 1 , slot.ga_i , slot.n_past_se , ib * bd);
2184
- llama_kv_cache_seq_div (ctx, slot.id + 1 , slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w , slot.ga_n );
2185
- llama_kv_cache_seq_add (ctx, slot.id + 1 , slot.ga_i + ib * bd + slot.ga_w , slot.n_past_se + ib * bd, dd);
2186
-
2187
- slot.n_past_se -= bd;
2188
-
2189
- slot.ga_i += slot.ga_w / slot.ga_n ;
2190
-
2191
- SLT_DBG (slot, " \n n_past_old = %d, n_past = %d, ga_i = %d\n\n " , slot.n_past_se + bd, slot.n_past_se , slot.ga_i );
2192
- }
2193
-
2194
- slot.n_past_se += n_tokens;
2195
- }
2196
- }
2197
-
2198
2108
llama_batch batch_view = {
2199
2109
n_tokens,
2200
2110
batch.token + i,
0 commit comments