Skip to content

Commit 0e94a70

Browse files
committed
Add in-memory longest prefix cache. Closes ggml-org#158
1 parent 8dfde63 commit 0e94a70

File tree

1 file changed

+64
-27
lines changed

1 file changed

+64
-27
lines changed

Diff for: llama_cpp/llama.py

+64-27
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
import multiprocessing
77
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
8-
from collections import deque
8+
from collections import deque, OrderedDict
99

1010
from . import llama_cpp
1111
from .llama_types import *
@@ -14,37 +14,50 @@
1414
class LlamaCache:
1515
"""Cache for a llama.cpp model."""
1616

17-
def __init__(self):
18-
self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict()
17+
def __init__(self, capacity_bytes: int = (2 << 30)):
18+
self.cache_state: OrderedDict[
19+
Tuple[llama_cpp.llama_token, ...], "LlamaState"
20+
] = OrderedDict()
21+
self.capacity_bytes = capacity_bytes
1922

20-
def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]:
21-
return [
22-
key
23-
for _, key in sorted(
24-
((len(key), key) for key in self.cache_state.keys()), reverse=True
25-
)
26-
]
23+
@property
24+
def cache_size(self):
25+
return sum([state.llama_state_size for state in self.cache_state.values()])
2726

28-
def _find_key(
29-
self, key: Tuple[llama_cpp.llama_token, ...]
27+
def _find_longest_prefix_key(
28+
self,
29+
key: Tuple[llama_cpp.llama_token, ...],
3030
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
31-
for k in self._sorted_keys():
32-
if key[: len(k)] == k:
33-
return k
34-
return None
31+
min_len = 0
32+
min_key = None
33+
keys = (
34+
(k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys()
35+
)
36+
for k, prefix_len in keys:
37+
if prefix_len > min_len:
38+
min_len = prefix_len
39+
min_key = k
40+
return min_key
3541

3642
def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
37-
_key = self._find_key(tuple(key))
43+
key = tuple(key)
44+
_key = self._find_longest_prefix_key(key)
3845
if _key is None:
39-
raise KeyError(f"Key not found: {key}")
40-
return self.cache_state[_key]
46+
raise KeyError(f"Key not found")
47+
value = self.cache_state[_key]
48+
self.cache_state.move_to_end(_key)
49+
return value
4150

4251
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
43-
return self._find_key(tuple(key)) is not None
52+
return self._find_longest_prefix_key(tuple(key)) is not None
4453

4554
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
46-
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
47-
self.cache_state[tuple(key)] = value
55+
key = tuple(key)
56+
if key in self.cache_state:
57+
del self.cache_state[key]
58+
self.cache_state[key] = value
59+
while self.cache_size > self.capacity_bytes:
60+
self.cache_state.popitem(last=False)
4861

4962

5063
class LlamaState:
@@ -53,7 +66,7 @@ def __init__(
5366
eval_tokens: Deque[llama_cpp.llama_token],
5467
eval_logits: Deque[List[float]],
5568
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
56-
llama_state_size: llama_cpp.c_size_t,
69+
llama_state_size: int,
5770
):
5871
self.eval_tokens = eval_tokens
5972
self.eval_logits = eval_logits
@@ -526,10 +539,22 @@ def _create_completion(
526539
"logprobs is not supported for models created with logits_all=False"
527540
)
528541

529-
if self.cache and prompt_tokens in self.cache:
530-
if self.verbose:
531-
print("Llama._create_completion: cache hit", file=sys.stderr)
532-
self.load_state(self.cache[prompt_tokens])
542+
if self.cache:
543+
try:
544+
cache_item = self.cache[prompt_tokens]
545+
cache_prefix_len = Llama.longest_token_prefix(
546+
cache_item.eval_tokens, prompt_tokens
547+
)
548+
eval_prefix_len = Llama.longest_token_prefix(
549+
self.eval_tokens, prompt_tokens
550+
)
551+
if cache_prefix_len > eval_prefix_len:
552+
self.load_state(cache_item)
553+
if self.verbose:
554+
print("Llama._create_completion: cache hit", file=sys.stderr)
555+
except KeyError:
556+
if self.verbose:
557+
print("Llama._create_completion: cache miss", file=sys.stderr)
533558

534559
finish_reason = "length"
535560
multibyte_fix = 0
@@ -1004,3 +1029,15 @@ def logits_to_logprobs(logits: List[float]) -> List[float]:
10041029
exps = [math.exp(float(x)) for x in logits]
10051030
sum_exps = sum(exps)
10061031
return [math.log(x / sum_exps) for x in exps]
1032+
1033+
@staticmethod
1034+
def longest_token_prefix(
1035+
a: Sequence[llama_cpp.llama_token], b: Sequence[llama_cpp.llama_token]
1036+
):
1037+
longest_prefix = 0
1038+
for _a, _b in zip(a, b):
1039+
if _a == _b:
1040+
longest_prefix += 1
1041+
else:
1042+
break
1043+
return longest_prefix

0 commit comments

Comments
 (0)