Skip to content

Commit b3805bb

Browse files
committed
Implement logprobs parameter for text completion. Closes #2
1 parent 2a60eb8 commit b3805bb

File tree

2 files changed

+111
-16
lines changed

2 files changed

+111
-16
lines changed

llama_cpp/llama.py

+109-16
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import uuid
44
import time
5+
import math
56
import multiprocessing
67
from typing import List, Optional, Union, Generator, Sequence, Iterator
78
from collections import deque
@@ -76,6 +77,9 @@ def __init__(
7677
)
7778
self.tokens_consumed = 0
7879
self.n_batch = min(n_ctx, n_batch)
80+
self.n_tokens = 0
81+
self.n_past = 0
82+
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
7983

8084
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
8185

@@ -136,6 +140,9 @@ def reset(self):
136140
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
137141
)
138142
self.tokens_consumed = 0
143+
self.n_tokens = 0
144+
self.n_past = 0
145+
self.all_logits = []
139146

140147
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
141148
"""Evaluate a list of tokens.
@@ -147,18 +154,31 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
147154
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
148155
for i in range(0, len(tokens), self.n_batch):
149156
batch = tokens[i : min(len(tokens), i + self.n_batch)]
150-
n_past = min(n_ctx - len(batch), self.tokens_consumed)
157+
self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
158+
self.n_tokens = len(batch)
151159
return_code = llama_cpp.llama_eval(
152160
ctx=self.ctx,
153161
tokens=(llama_cpp.llama_token * len(batch))(*batch),
154-
n_tokens=llama_cpp.c_int(len(batch)),
155-
n_past=llama_cpp.c_int(n_past),
162+
n_tokens=llama_cpp.c_int(self.n_tokens),
163+
n_past=llama_cpp.c_int(self.n_past),
156164
n_threads=llama_cpp.c_int(self.n_threads),
157165
)
158166
if int(return_code) != 0:
159167
raise RuntimeError(f"llama_eval returned {return_code}")
160168
self.last_n_tokens_data.extend(batch)
161169
self.tokens_consumed += len(batch)
170+
if self.params.logits_all:
171+
self.all_logits.extend(self._logits())
172+
173+
def _logits(self) -> List[List[float]]:
174+
"""Return the logits from the last call to llama_eval."""
175+
assert self.ctx is not None
176+
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
177+
cols = int(n_vocab)
178+
rows = self.n_tokens if self.params.logits_all else 1
179+
logits_view = llama_cpp.llama_get_logits(self.ctx)
180+
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
181+
return logits
162182

163183
def sample(
164184
self,
@@ -327,14 +347,55 @@ def _create_completion(
327347
else:
328348
stop_sequences = []
329349

330-
finish_reason = None
331-
for token in self.generate(
332-
prompt_tokens,
333-
top_k=top_k,
334-
top_p=top_p,
335-
temp=temperature,
336-
repeat_penalty=repeat_penalty,
337-
):
350+
text_offset = 0
351+
text_offsets: List[int] = []
352+
token_logprobs: List[float] = []
353+
tokens: List[str] = []
354+
top_logprobs: List[Dict[str, float]] = []
355+
356+
self.reset()
357+
self.eval(prompt_tokens)
358+
359+
if logprobs is not None and self.params.logits_all is False:
360+
raise ValueError(
361+
"logprobs is not supported for models created with logits_all=False"
362+
)
363+
364+
if logprobs is not None:
365+
token_strs = [
366+
self.detokenize([token]).decode("utf-8") for token in prompt_tokens
367+
]
368+
logprobs_all = [
369+
[Llama.logit_to_logprob(logit) for logit in row]
370+
for row in self.all_logits
371+
]
372+
for token, token_str, logprobs_token in zip(
373+
prompt_tokens, token_strs, logprobs_all
374+
):
375+
text_offsets.append(text_offset)
376+
text_offset += len(token_str)
377+
tokens.append(token_str)
378+
sorted_logprobs = list(
379+
sorted(
380+
zip(logprobs_token, range(len(logprobs_token))), reverse=True
381+
)
382+
)
383+
token_logprobs.append(sorted_logprobs[int(token)][0])
384+
top_logprob = {
385+
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
386+
for logprob, i in sorted_logprobs[:logprobs]
387+
}
388+
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
389+
top_logprobs.append(top_logprob)
390+
391+
finish_reason = "length"
392+
while True:
393+
token = self.sample(
394+
top_k=top_k,
395+
top_p=top_p,
396+
temp=temperature,
397+
repeat_penalty=repeat_penalty,
398+
)
338399
if token == llama_cpp.llama_token_eos():
339400
text = self.detokenize(completion_tokens)
340401
finish_reason = "stop"
@@ -377,13 +438,35 @@ def _create_completion(
377438
}
378439
],
379440
}
441+
442+
if logprobs is not None:
443+
# TODO: Confirm wether this should happen before or after
444+
# next eval.
445+
token_str = self.detokenize([token]).decode("utf-8")
446+
text_offsets.append(text_offset)
447+
text_offset += len(token_str)
448+
tokens.append(token_str)
449+
logprobs_token = [
450+
Llama.logit_to_logprob(logit) for logit in self.all_logits[-1]
451+
]
452+
sorted_logprobs = list(
453+
sorted(
454+
zip(logprobs_token, range(len(logprobs_token))), reverse=True
455+
)
456+
)
457+
token_logprobs.append(sorted_logprobs[int(token)][0])
458+
top_logprob = {
459+
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
460+
for logprob, i in sorted_logprobs[:logprobs]
461+
}
462+
top_logprob.update({token_str: logprobs_token[int(token)]})
463+
top_logprobs.append(top_logprob)
464+
380465
if len(completion_tokens) >= max_tokens:
381466
text = self.detokenize(completion_tokens)
382467
finish_reason = "length"
383468
break
384-
385-
if finish_reason is None:
386-
finish_reason = "length"
469+
self.eval([token])
387470

388471
if stream:
389472
yield {
@@ -410,8 +493,14 @@ def _create_completion(
410493
if suffix is not None:
411494
text = text + suffix
412495

496+
logprobs_or_none: Optional[CompletionLogprobs] = None
413497
if logprobs is not None:
414-
raise NotImplementedError("logprobs not implemented")
498+
logprobs_or_none = {
499+
"tokens": tokens,
500+
"text_offset": text_offsets,
501+
"token_logprobs": token_logprobs,
502+
"top_logprobs": top_logprobs,
503+
}
415504

416505
if self.verbose:
417506
llama_cpp.llama_print_timings(self.ctx)
@@ -425,7 +514,7 @@ def _create_completion(
425514
{
426515
"text": text,
427516
"index": 0,
428-
"logprobs": None,
517+
"logprobs": logprobs_or_none,
429518
"finish_reason": finish_reason,
430519
}
431520
],
@@ -704,3 +793,7 @@ def token_eos() -> llama_cpp.llama_token:
704793
def token_bos() -> llama_cpp.llama_token:
705794
"""Return the beginning-of-sequence token."""
706795
return llama_cpp.llama_token_bos()
796+
797+
@staticmethod
798+
def logit_to_logprob(x: float) -> float:
799+
return math.log(1.0 + math.exp(x))

llama_cpp/server/__main__.py

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class Settings(BaseSettings):
3333
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
3434
embedding: bool = True
3535
last_n_tokens_size: int = 64
36+
logits_all: bool = False
3637

3738

3839
app = FastAPI(
@@ -52,6 +53,7 @@ class Settings(BaseSettings):
5253
f16_kv=settings.f16_kv,
5354
use_mlock=settings.use_mlock,
5455
embedding=settings.embedding,
56+
logits_all=settings.logits_all,
5557
n_threads=settings.n_threads,
5658
n_batch=settings.n_batch,
5759
n_ctx=settings.n_ctx,

0 commit comments

Comments
 (0)