Skip to content

Commit 519c981

Browse files
authoredAug 22, 2023
embedding : evaluate prompt in batches (#2713)
1 parent 1123f7f commit 519c981

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed
 

‎examples/embedding/embedding.cpp

+19-12
Original file line numberDiff line numberDiff line change
@@ -72,22 +72,29 @@ int main(int argc, char ** argv) {
7272
fprintf(stderr, "\n");
7373
}
7474

75-
if (params.embedding){
76-
if (embd_inp.size() > 0) {
77-
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads)) {
78-
fprintf(stderr, "%s : failed to eval\n", __func__);
79-
return 1;
80-
}
75+
if (embd_inp.size() > (size_t)params.n_ctx) {
76+
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
77+
__func__, embd_inp.size(), params.n_ctx);
78+
return 1;
79+
}
80+
81+
while (!embd_inp.empty()) {
82+
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
83+
if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) {
84+
fprintf(stderr, "%s : failed to eval\n", __func__);
85+
return 1;
8186
}
87+
n_past += n_tokens;
88+
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
89+
}
8290

83-
const int n_embd = llama_n_embd(ctx);
84-
const auto embeddings = llama_get_embeddings(ctx);
91+
const int n_embd = llama_n_embd(ctx);
92+
const auto embeddings = llama_get_embeddings(ctx);
8593

86-
for (int i = 0; i < n_embd; i++) {
87-
printf("%f ", embeddings[i]);
88-
}
89-
printf("\n");
94+
for (int i = 0; i < n_embd; i++) {
95+
printf("%f ", embeddings[i]);
9096
}
97+
printf("\n");
9198

9299
llama_print_timings(ctx);
93100
llama_free(ctx);

0 commit comments

Comments
 (0)
Please sign in to comment.