Skip to content

Commit 0ce2269

Browse files
committed
first rough draft
1 parent 4787ec3 commit 0ce2269

File tree

2 files changed

+117
-328
lines changed

2 files changed

+117
-328
lines changed

llama_cpp/_internals.py

+85-1
Original file line numberDiff line numberDiff line change
@@ -818,4 +818,88 @@ def sample(
818818
def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool):
819819
if apply_grammar and self.grammar is not None:
820820
ctx_main.grammar_accept_token(self.grammar, id)
821-
self.prev.append(id)
821+
self.prev.append(id)
822+
823+
class _TokenTextQueue:
824+
def __init__(self, detokenize, stop_sequences: List[int] = None):
825+
# settings
826+
self.detokenize = detokenize
827+
self.stop_sequences = stop_sequences or []
828+
829+
# current state
830+
self.tokens: List[int] = []
831+
832+
def __len__(self):
833+
return len(self.tokens)
834+
835+
@staticmethod
836+
def decode_robust(bstr):
837+
try:
838+
return bstr.decode("utf-8")
839+
except UnicodeError:
840+
return
841+
842+
def detect_stop_token(self):
843+
text = self.detokenize(self.tokens)
844+
stop_idxs = [text.index(s) for s in self.stop_sequences if s in text]
845+
if len(stop_idxs) > 0:
846+
return text[:min(stop_idxs)]
847+
848+
# detect first index of partial stop sequence
849+
def first_stop_position(self):
850+
text = self.detokenize(self.tokens)
851+
length = len(text)
852+
first_stop_len = 0
853+
for s in self.stop_sequences:
854+
for i in range(min(len(s), length), 0, -1):
855+
if text.endswith(s[:i]):
856+
if i > first_stop_len:
857+
first_stop_len = i
858+
break
859+
return length - first_stop_len
860+
861+
def push_token(self, token: int):
862+
self.tokens.append(token)
863+
864+
def pop_text(self) -> bytes:
865+
if len(self) == 0:
866+
return
867+
868+
# attempt decode on substrings
869+
for i in range(1, len(self.tokens) + 1):
870+
bstr = self.detokenize(self.tokens[:i])
871+
text = self.decode_robust(bstr)
872+
if text is not None:
873+
break
874+
875+
# all remaining tokens cannot be decoded to a UTF-8 character
876+
if text is None:
877+
return
878+
879+
# avoid yield if possible stop sequence in progress
880+
if len(bstr) > self.first_stop_position():
881+
return
882+
883+
# trim token list
884+
self.tokens = self.tokens[i:]
885+
886+
return i, bstr, text
887+
888+
def empty_text(self):
889+
text = ""
890+
position = 0
891+
end_position = self.first_stop_position()
892+
893+
for token in self.tokens:
894+
last_text = self.detokenize([token])
895+
position += len(last_text)
896+
897+
if position >= end_position:
898+
text += last_text[
899+
: len(last_text) - (position - end_position)
900+
].decode("utf-8", errors="ignore")
901+
break
902+
903+
text += last_text.decode("utf-8", errors="ignore")
904+
905+
return text

0 commit comments

Comments
 (0)