5
5
import math
6
6
import multiprocessing
7
7
from typing import List , Optional , Union , Generator , Sequence , Iterator , Deque , Tuple
8
- from collections import deque
8
+ from collections import deque , OrderedDict
9
9
10
10
from . import llama_cpp
11
11
from .llama_types import *
14
14
class LlamaCache :
15
15
"""Cache for a llama.cpp model."""
16
16
17
- def __init__ (self ):
18
- self .cache_state : Dict [Tuple [llama_cpp .llama_token , ...], "LlamaState" ] = dict ()
17
+ def __init__ (self , capacity_bytes : int = (2 << 30 )):
18
+ self .cache_state : OrderedDict [
19
+ Tuple [llama_cpp .llama_token , ...], "LlamaState"
20
+ ] = OrderedDict ()
21
+ self .capacity_bytes = capacity_bytes
19
22
20
- def _sorted_keys (self ) -> List [Tuple [llama_cpp .llama_token , ...]]:
21
- return [
22
- key
23
- for _ , key in sorted (
24
- ((len (key ), key ) for key in self .cache_state .keys ()), reverse = True
25
- )
26
- ]
23
+ @property
24
+ def cache_size (self ):
25
+ return sum ([state .llama_state_size for state in self .cache_state .values ()])
27
26
28
- def _find_key (
29
- self , key : Tuple [llama_cpp .llama_token , ...]
27
+ def _find_longest_prefix_key (
28
+ self ,
29
+ key : Tuple [llama_cpp .llama_token , ...],
30
30
) -> Optional [Tuple [llama_cpp .llama_token , ...]]:
31
- for k in self ._sorted_keys ():
32
- if key [: len (k )] == k :
33
- return k
34
- return None
31
+ min_len = 0
32
+ min_key = None
33
+ keys = (
34
+ (k , Llama .longest_token_prefix (k , key )) for k in self .cache_state .keys ()
35
+ )
36
+ for k , prefix_len in keys :
37
+ if prefix_len > min_len :
38
+ min_len = prefix_len
39
+ min_key = k
40
+ return min_key
35
41
36
42
def __getitem__ (self , key : Sequence [llama_cpp .llama_token ]) -> "LlamaState" :
37
- _key = self ._find_key (tuple (key ))
43
+ key = tuple (key )
44
+ _key = self ._find_longest_prefix_key (key )
38
45
if _key is None :
39
- raise KeyError (f"Key not found: { key } " )
40
- return self .cache_state [_key ]
46
+ raise KeyError (f"Key not found" )
47
+ value = self .cache_state [_key ]
48
+ self .cache_state .move_to_end (_key )
49
+ return value
41
50
42
51
def __contains__ (self , key : Sequence [llama_cpp .llama_token ]) -> bool :
43
- return self ._find_key (tuple (key )) is not None
52
+ return self ._find_longest_prefix_key (tuple (key )) is not None
44
53
45
54
def __setitem__ (self , key : Sequence [llama_cpp .llama_token ], value : "LlamaState" ):
46
- self .cache_state = dict () # NOTE: Currently limit to one cache entry.
47
- self .cache_state [tuple (key )] = value
55
+ key = tuple (key )
56
+ if key in self .cache_state :
57
+ del self .cache_state [key ]
58
+ self .cache_state [key ] = value
59
+ while self .cache_size > self .capacity_bytes :
60
+ self .cache_state .popitem (last = False )
48
61
49
62
50
63
class LlamaState :
@@ -53,7 +66,7 @@ def __init__(
53
66
eval_tokens : Deque [llama_cpp .llama_token ],
54
67
eval_logits : Deque [List [float ]],
55
68
llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
56
- llama_state_size : llama_cpp . c_size_t ,
69
+ llama_state_size : int ,
57
70
):
58
71
self .eval_tokens = eval_tokens
59
72
self .eval_logits = eval_logits
@@ -526,10 +539,22 @@ def _create_completion(
526
539
"logprobs is not supported for models created with logits_all=False"
527
540
)
528
541
529
- if self .cache and prompt_tokens in self .cache :
530
- if self .verbose :
531
- print ("Llama._create_completion: cache hit" , file = sys .stderr )
532
- self .load_state (self .cache [prompt_tokens ])
542
+ if self .cache :
543
+ try :
544
+ cache_item = self .cache [prompt_tokens ]
545
+ cache_prefix_len = Llama .longest_token_prefix (
546
+ cache_item .eval_tokens , prompt_tokens
547
+ )
548
+ eval_prefix_len = Llama .longest_token_prefix (
549
+ self .eval_tokens , prompt_tokens
550
+ )
551
+ if cache_prefix_len > eval_prefix_len :
552
+ self .load_state (cache_item )
553
+ if self .verbose :
554
+ print ("Llama._create_completion: cache hit" , file = sys .stderr )
555
+ except KeyError :
556
+ if self .verbose :
557
+ print ("Llama._create_completion: cache miss" , file = sys .stderr )
533
558
534
559
finish_reason = "length"
535
560
multibyte_fix = 0
@@ -1004,3 +1029,15 @@ def logits_to_logprobs(logits: List[float]) -> List[float]:
1004
1029
exps = [math .exp (float (x )) for x in logits ]
1005
1030
sum_exps = sum (exps )
1006
1031
return [math .log (x / sum_exps ) for x in exps ]
1032
+
1033
+ @staticmethod
1034
+ def longest_token_prefix (
1035
+ a : Sequence [llama_cpp .llama_token ], b : Sequence [llama_cpp .llama_token ]
1036
+ ):
1037
+ longest_prefix = 0
1038
+ for _a , _b in zip (a , b ):
1039
+ if _a == _b :
1040
+ longest_prefix += 1
1041
+ else :
1042
+ break
1043
+ return longest_prefix
0 commit comments