1
1
#include " common.h"
2
2
#include " llama.h"
3
3
4
+ #include < algorithm>
4
5
#include < cmath>
5
6
#include < cstdio>
6
7
#include < string>
@@ -42,7 +43,9 @@ int main(int argc, char ** argv) {
42
43
llama_context_params ctx_params = llama_context_default_params ();
43
44
44
45
ctx_params.seed = 1234 ;
45
- ctx_params.n_ctx = 2048 ;
46
+ ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
47
+ ctx_params.n_batch = std::max (n_len, n_parallel);
48
+ // ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
46
49
47
50
llama_model * model = llama_load_model_from_file (params.model .c_str (), ctx_params);
48
51
@@ -66,11 +69,11 @@ int main(int argc, char ** argv) {
66
69
const int n_ctx = llama_n_ctx (ctx);
67
70
const int n_kv_req = tokens_list.size () + (n_len - tokens_list.size ())*n_parallel;
68
71
69
- LOG_TEE (" \n %s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n " , __func__, n_len, n_ctx, n_parallel, n_kv_req);
72
+ LOG_TEE (" \n %s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n " , __func__, n_len, n_ctx, ctx_params. n_batch , n_parallel, n_kv_req);
70
73
71
74
// make sure the KV cache is big enough to hold all the prompt and generated tokens
72
75
if (n_kv_req > n_ctx) {
73
- LOG_TEE (" %s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n " , __func__);
76
+ LOG_TEE (" %s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n " , __func__, n_kv_req );
74
77
LOG_TEE (" %s: either reduce n_parallel or increase n_ctx\n " , __func__);
75
78
return 1 ;
76
79
}
@@ -88,7 +91,7 @@ int main(int argc, char ** argv) {
88
91
// create a llama_batch with size 512
89
92
// we use this object to submit token data for decoding
90
93
91
- llama_batch batch = llama_batch_init (512 , 0 );
94
+ llama_batch batch = llama_batch_init (std::max (tokens_list. size (), ( size_t )n_parallel) , 0 );
92
95
93
96
// evaluate the initial prompt
94
97
batch.n_tokens = tokens_list.size ();
@@ -133,12 +136,6 @@ int main(int argc, char ** argv) {
133
136
const auto t_main_start = ggml_time_us ();
134
137
135
138
while (n_cur <= n_len) {
136
- // evaluate the current batch with the transformer model
137
- if (llama_decode (ctx, batch, params.n_threads )) {
138
- fprintf (stderr, " %s : failed to eval, return code %d\n " , __func__, 1 );
139
- return 1 ;
140
- }
141
-
142
139
// prepare the next batch
143
140
batch.n_tokens = 0 ;
144
141
@@ -149,8 +146,8 @@ int main(int argc, char ** argv) {
149
146
continue ;
150
147
}
151
148
152
- auto n_vocab = llama_n_vocab (ctx);
153
- auto logits = llama_get_logits_ith (ctx, i_batch[i]);
149
+ auto n_vocab = llama_n_vocab (ctx);
150
+ auto * logits = llama_get_logits_ith (ctx, i_batch[i]);
154
151
155
152
std::vector<llama_token_data> candidates;
156
153
candidates.reserve (n_vocab);
@@ -178,7 +175,7 @@ int main(int argc, char ** argv) {
178
175
i_batch[i] = -1 ;
179
176
LOG_TEE (" \n " );
180
177
if (n_parallel > 1 ) {
181
- LOG_TEE (" %s: stream %d finished" , __func__, i);
178
+ LOG_TEE (" %s: stream %d finished at n_cur = %d " , __func__, i, n_cur );
182
179
}
183
180
184
181
continue ;
@@ -211,6 +208,12 @@ int main(int argc, char ** argv) {
211
208
}
212
209
213
210
n_cur += 1 ;
211
+
212
+ // evaluate the current batch with the transformer model
213
+ if (llama_decode (ctx, batch, params.n_threads )) {
214
+ fprintf (stderr, " %s : failed to eval, return code %d\n " , __func__, 1 );
215
+ return 1 ;
216
+ }
214
217
}
215
218
216
219
LOG_TEE (" \n " );
0 commit comments