Skip to content

Commit 2b8830a

Browse files
committed
examples : do not eval prompt 2 times (close #3348)
1 parent a207561 commit 2b8830a

File tree

2 files changed

+24
-21
lines changed

2 files changed

+24
-21
lines changed

Diff for: examples/batched/batched.cpp

+16-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "common.h"
22
#include "llama.h"
33

4+
#include <algorithm>
45
#include <cmath>
56
#include <cstdio>
67
#include <string>
@@ -42,7 +43,9 @@ int main(int argc, char ** argv) {
4243
llama_context_params ctx_params = llama_context_default_params();
4344

4445
ctx_params.seed = 1234;
45-
ctx_params.n_ctx = 2048;
46+
ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
47+
ctx_params.n_batch = std::max(n_len, n_parallel);
48+
// ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
4649

4750
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
4851

@@ -66,11 +69,11 @@ int main(int argc, char ** argv) {
6669
const int n_ctx = llama_n_ctx(ctx);
6770
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
6871

69-
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req);
72+
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
7073

7174
// make sure the KV cache is big enough to hold all the prompt and generated tokens
7275
if (n_kv_req > n_ctx) {
73-
LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
76+
LOG_TEE("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req);
7477
LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
7578
return 1;
7679
}
@@ -88,7 +91,7 @@ int main(int argc, char ** argv) {
8891
// create a llama_batch with size 512
8992
// we use this object to submit token data for decoding
9093

91-
llama_batch batch = llama_batch_init(512, 0);
94+
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
9295

9396
// evaluate the initial prompt
9497
batch.n_tokens = tokens_list.size();
@@ -133,12 +136,6 @@ int main(int argc, char ** argv) {
133136
const auto t_main_start = ggml_time_us();
134137

135138
while (n_cur <= n_len) {
136-
// evaluate the current batch with the transformer model
137-
if (llama_decode(ctx, batch, params.n_threads)) {
138-
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
139-
return 1;
140-
}
141-
142139
// prepare the next batch
143140
batch.n_tokens = 0;
144141

@@ -149,8 +146,8 @@ int main(int argc, char ** argv) {
149146
continue;
150147
}
151148

152-
auto n_vocab = llama_n_vocab(ctx);
153-
auto logits = llama_get_logits_ith(ctx, i_batch[i]);
149+
auto n_vocab = llama_n_vocab(ctx);
150+
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
154151

155152
std::vector<llama_token_data> candidates;
156153
candidates.reserve(n_vocab);
@@ -178,7 +175,7 @@ int main(int argc, char ** argv) {
178175
i_batch[i] = -1;
179176
LOG_TEE("\n");
180177
if (n_parallel > 1) {
181-
LOG_TEE("%s: stream %d finished", __func__, i);
178+
LOG_TEE("%s: stream %d finished at n_cur = %d", __func__, i, n_cur);
182179
}
183180

184181
continue;
@@ -211,6 +208,12 @@ int main(int argc, char ** argv) {
211208
}
212209

213210
n_cur += 1;
211+
212+
// evaluate the current batch with the transformer model
213+
if (llama_decode(ctx, batch, params.n_threads)) {
214+
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
215+
return 1;
216+
}
214217
}
215218

216219
LOG_TEE("\n");

Diff for: examples/simple/simple.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,10 @@ int main(int argc, char ** argv) {
110110
const auto t_main_start = ggml_time_us();
111111

112112
while (n_cur <= n_len) {
113-
// evaluate the current batch with the transformer model
114-
if (llama_decode(ctx, batch, params.n_threads)) {
115-
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
116-
return 1;
117-
}
118-
119113
// sample the next token
120114
{
121-
auto n_vocab = llama_n_vocab(ctx);
122-
auto logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
115+
auto n_vocab = llama_n_vocab(ctx);
116+
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
123117

124118
std::vector<llama_token_data> candidates;
125119
candidates.reserve(n_vocab);
@@ -158,6 +152,12 @@ int main(int argc, char ** argv) {
158152
}
159153

160154
n_cur += 1;
155+
156+
// evaluate the current batch with the transformer model
157+
if (llama_decode(ctx, batch, params.n_threads)) {
158+
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
159+
return 1;
160+
}
161161
}
162162

163163
LOG_TEE("\n");

0 commit comments

Comments
 (0)