@@ -72,22 +72,29 @@ int main(int argc, char ** argv) {
72
72
fprintf (stderr, " \n " );
73
73
}
74
74
75
- if (params.embedding ){
76
- if (embd_inp.size () > 0 ) {
77
- if (llama_eval (ctx, embd_inp.data (), embd_inp.size (), n_past, params.n_threads )) {
78
- fprintf (stderr, " %s : failed to eval\n " , __func__);
79
- return 1 ;
80
- }
75
+ if (embd_inp.size () > (size_t )params.n_ctx ) {
76
+ fprintf (stderr, " %s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n " ,
77
+ __func__, embd_inp.size (), params.n_ctx );
78
+ return 1 ;
79
+ }
80
+
81
+ while (!embd_inp.empty ()) {
82
+ int n_tokens = std::min (params.n_batch , (int ) embd_inp.size ());
83
+ if (llama_eval (ctx, embd_inp.data (), n_tokens, n_past, params.n_threads )) {
84
+ fprintf (stderr, " %s : failed to eval\n " , __func__);
85
+ return 1 ;
81
86
}
87
+ n_past += n_tokens;
88
+ embd_inp.erase (embd_inp.begin (), embd_inp.begin () + n_tokens);
89
+ }
82
90
83
- const int n_embd = llama_n_embd (ctx);
84
- const auto embeddings = llama_get_embeddings (ctx);
91
+ const int n_embd = llama_n_embd (ctx);
92
+ const auto embeddings = llama_get_embeddings (ctx);
85
93
86
- for (int i = 0 ; i < n_embd; i++) {
87
- printf (" %f " , embeddings[i]);
88
- }
89
- printf (" \n " );
94
+ for (int i = 0 ; i < n_embd; i++) {
95
+ printf (" %f " , embeddings[i]);
90
96
}
97
+ printf (" \n " );
91
98
92
99
llama_print_timings (ctx);
93
100
llama_free (ctx);
0 commit comments