2
2
import sys
3
3
import uuid
4
4
import time
5
+ import math
5
6
import multiprocessing
6
7
from typing import List , Optional , Union , Generator , Sequence , Iterator
7
8
from collections import deque
@@ -76,6 +77,9 @@ def __init__(
76
77
)
77
78
self .tokens_consumed = 0
78
79
self .n_batch = min (n_ctx , n_batch )
80
+ self .n_tokens = 0
81
+ self .n_past = 0
82
+ self .all_logits : List [List [float ]] = [] # TODO: Use an array instead of a list.
79
83
80
84
self .n_threads = n_threads or max (multiprocessing .cpu_count () // 2 , 1 )
81
85
@@ -136,6 +140,9 @@ def reset(self):
136
140
[llama_cpp .llama_token (0 )] * self .last_n_tokens_size
137
141
)
138
142
self .tokens_consumed = 0
143
+ self .n_tokens = 0
144
+ self .n_past = 0
145
+ self .all_logits = []
139
146
140
147
def eval (self , tokens : Sequence [llama_cpp .llama_token ]):
141
148
"""Evaluate a list of tokens.
@@ -147,18 +154,31 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
147
154
n_ctx = int (llama_cpp .llama_n_ctx (self .ctx ))
148
155
for i in range (0 , len (tokens ), self .n_batch ):
149
156
batch = tokens [i : min (len (tokens ), i + self .n_batch )]
150
- n_past = min (n_ctx - len (batch ), self .tokens_consumed )
157
+ self .n_past = min (n_ctx - len (batch ), self .tokens_consumed )
158
+ self .n_tokens = len (batch )
151
159
return_code = llama_cpp .llama_eval (
152
160
ctx = self .ctx ,
153
161
tokens = (llama_cpp .llama_token * len (batch ))(* batch ),
154
- n_tokens = llama_cpp .c_int (len ( batch ) ),
155
- n_past = llama_cpp .c_int (n_past ),
162
+ n_tokens = llama_cpp .c_int (self . n_tokens ),
163
+ n_past = llama_cpp .c_int (self . n_past ),
156
164
n_threads = llama_cpp .c_int (self .n_threads ),
157
165
)
158
166
if int (return_code ) != 0 :
159
167
raise RuntimeError (f"llama_eval returned { return_code } " )
160
168
self .last_n_tokens_data .extend (batch )
161
169
self .tokens_consumed += len (batch )
170
+ if self .params .logits_all :
171
+ self .all_logits .extend (self ._logits ())
172
+
173
+ def _logits (self ) -> List [List [float ]]:
174
+ """Return the logits from the last call to llama_eval."""
175
+ assert self .ctx is not None
176
+ n_vocab = llama_cpp .llama_n_vocab (self .ctx )
177
+ cols = int (n_vocab )
178
+ rows = self .n_tokens if self .params .logits_all else 1
179
+ logits_view = llama_cpp .llama_get_logits (self .ctx )
180
+ logits = [[logits_view [i * cols + j ] for j in range (cols )] for i in range (rows )]
181
+ return logits
162
182
163
183
def sample (
164
184
self ,
@@ -327,14 +347,55 @@ def _create_completion(
327
347
else :
328
348
stop_sequences = []
329
349
330
- finish_reason = None
331
- for token in self .generate (
332
- prompt_tokens ,
333
- top_k = top_k ,
334
- top_p = top_p ,
335
- temp = temperature ,
336
- repeat_penalty = repeat_penalty ,
337
- ):
350
+ text_offset = 0
351
+ text_offsets : List [int ] = []
352
+ token_logprobs : List [float ] = []
353
+ tokens : List [str ] = []
354
+ top_logprobs : List [Dict [str , float ]] = []
355
+
356
+ self .reset ()
357
+ self .eval (prompt_tokens )
358
+
359
+ if logprobs is not None and self .params .logits_all is False :
360
+ raise ValueError (
361
+ "logprobs is not supported for models created with logits_all=False"
362
+ )
363
+
364
+ if logprobs is not None :
365
+ token_strs = [
366
+ self .detokenize ([token ]).decode ("utf-8" ) for token in prompt_tokens
367
+ ]
368
+ logprobs_all = [
369
+ [Llama .logit_to_logprob (logit ) for logit in row ]
370
+ for row in self .all_logits
371
+ ]
372
+ for token , token_str , logprobs_token in zip (
373
+ prompt_tokens , token_strs , logprobs_all
374
+ ):
375
+ text_offsets .append (text_offset )
376
+ text_offset += len (token_str )
377
+ tokens .append (token_str )
378
+ sorted_logprobs = list (
379
+ sorted (
380
+ zip (logprobs_token , range (len (logprobs_token ))), reverse = True
381
+ )
382
+ )
383
+ token_logprobs .append (sorted_logprobs [int (token )][0 ])
384
+ top_logprob = {
385
+ self .detokenize ([llama_cpp .llama_token (i )]).decode ("utf-8" ): logprob
386
+ for logprob , i in sorted_logprobs [:logprobs ]
387
+ }
388
+ top_logprob .update ({token_str : sorted_logprobs [int (token )][0 ]})
389
+ top_logprobs .append (top_logprob )
390
+
391
+ finish_reason = "length"
392
+ while True :
393
+ token = self .sample (
394
+ top_k = top_k ,
395
+ top_p = top_p ,
396
+ temp = temperature ,
397
+ repeat_penalty = repeat_penalty ,
398
+ )
338
399
if token == llama_cpp .llama_token_eos ():
339
400
text = self .detokenize (completion_tokens )
340
401
finish_reason = "stop"
@@ -377,13 +438,35 @@ def _create_completion(
377
438
}
378
439
],
379
440
}
441
+
442
+ if logprobs is not None :
443
+ # TODO: Confirm wether this should happen before or after
444
+ # next eval.
445
+ token_str = self .detokenize ([token ]).decode ("utf-8" )
446
+ text_offsets .append (text_offset )
447
+ text_offset += len (token_str )
448
+ tokens .append (token_str )
449
+ logprobs_token = [
450
+ Llama .logit_to_logprob (logit ) for logit in self .all_logits [- 1 ]
451
+ ]
452
+ sorted_logprobs = list (
453
+ sorted (
454
+ zip (logprobs_token , range (len (logprobs_token ))), reverse = True
455
+ )
456
+ )
457
+ token_logprobs .append (sorted_logprobs [int (token )][0 ])
458
+ top_logprob = {
459
+ self .detokenize ([llama_cpp .llama_token (i )]).decode ("utf-8" ): logprob
460
+ for logprob , i in sorted_logprobs [:logprobs ]
461
+ }
462
+ top_logprob .update ({token_str : logprobs_token [int (token )]})
463
+ top_logprobs .append (top_logprob )
464
+
380
465
if len (completion_tokens ) >= max_tokens :
381
466
text = self .detokenize (completion_tokens )
382
467
finish_reason = "length"
383
468
break
384
-
385
- if finish_reason is None :
386
- finish_reason = "length"
469
+ self .eval ([token ])
387
470
388
471
if stream :
389
472
yield {
@@ -410,8 +493,14 @@ def _create_completion(
410
493
if suffix is not None :
411
494
text = text + suffix
412
495
496
+ logprobs_or_none : Optional [CompletionLogprobs ] = None
413
497
if logprobs is not None :
414
- raise NotImplementedError ("logprobs not implemented" )
498
+ logprobs_or_none = {
499
+ "tokens" : tokens ,
500
+ "text_offset" : text_offsets ,
501
+ "token_logprobs" : token_logprobs ,
502
+ "top_logprobs" : top_logprobs ,
503
+ }
415
504
416
505
if self .verbose :
417
506
llama_cpp .llama_print_timings (self .ctx )
@@ -425,7 +514,7 @@ def _create_completion(
425
514
{
426
515
"text" : text ,
427
516
"index" : 0 ,
428
- "logprobs" : None ,
517
+ "logprobs" : logprobs_or_none ,
429
518
"finish_reason" : finish_reason ,
430
519
}
431
520
],
@@ -704,3 +793,7 @@ def token_eos() -> llama_cpp.llama_token:
704
793
def token_bos () -> llama_cpp .llama_token :
705
794
"""Return the beginning-of-sequence token."""
706
795
return llama_cpp .llama_token_bos ()
796
+
797
+ @staticmethod
798
+ def logit_to_logprob (x : float ) -> float :
799
+ return math .log (1.0 + math .exp (x ))
0 commit comments