@@ -818,4 +818,88 @@ def sample(
818
818
def accept (self , ctx_main : _LlamaContext , id : int , apply_grammar : bool ):
819
819
if apply_grammar and self .grammar is not None :
820
820
ctx_main .grammar_accept_token (self .grammar , id )
821
- self .prev .append (id )
821
+ self .prev .append (id )
822
+
823
+ class _TokenTextQueue :
824
+ def __init__ (self , detokenize , stop_sequences : List [int ] = None ):
825
+ # settings
826
+ self .detokenize = detokenize
827
+ self .stop_sequences = stop_sequences or []
828
+
829
+ # current state
830
+ self .tokens : List [int ] = []
831
+
832
+ def __len__ (self ):
833
+ return len (self .tokens )
834
+
835
+ @staticmethod
836
+ def decode_robust (bstr ):
837
+ try :
838
+ return bstr .decode ("utf-8" )
839
+ except UnicodeError :
840
+ return
841
+
842
+ def detect_stop_token (self ):
843
+ text = self .detokenize (self .tokens )
844
+ stop_idxs = [text .index (s ) for s in self .stop_sequences if s in text ]
845
+ if len (stop_idxs ) > 0 :
846
+ return text [:min (stop_idxs )]
847
+
848
+ # detect first index of partial stop sequence
849
+ def first_stop_position (self ):
850
+ text = self .detokenize (self .tokens )
851
+ length = len (text )
852
+ first_stop_len = 0
853
+ for s in self .stop_sequences :
854
+ for i in range (min (len (s ), length ), 0 , - 1 ):
855
+ if text .endswith (s [:i ]):
856
+ if i > first_stop_len :
857
+ first_stop_len = i
858
+ break
859
+ return length - first_stop_len
860
+
861
+ def push_token (self , token : int ):
862
+ self .tokens .append (token )
863
+
864
+ def pop_text (self ) -> bytes :
865
+ if len (self ) == 0 :
866
+ return
867
+
868
+ # attempt decode on substrings
869
+ for i in range (1 , len (self .tokens ) + 1 ):
870
+ bstr = self .detokenize (self .tokens [:i ])
871
+ text = self .decode_robust (bstr )
872
+ if text is not None :
873
+ break
874
+
875
+ # all remaining tokens cannot be decoded to a UTF-8 character
876
+ if text is None :
877
+ return
878
+
879
+ # avoid yield if possible stop sequence in progress
880
+ if len (bstr ) > self .first_stop_position ():
881
+ return
882
+
883
+ # trim token list
884
+ self .tokens = self .tokens [i :]
885
+
886
+ return i , bstr , text
887
+
888
+ def empty_text (self ):
889
+ text = ""
890
+ position = 0
891
+ end_position = self .first_stop_position ()
892
+
893
+ for token in self .tokens :
894
+ last_text = self .detokenize ([token ])
895
+ position += len (last_text )
896
+
897
+ if position >= end_position :
898
+ text += last_text [
899
+ : len (last_text ) - (position - end_position )
900
+ ].decode ("utf-8" , errors = "ignore" )
901
+ break
902
+
903
+ text += last_text .decode ("utf-8" , errors = "ignore" )
904
+
905
+ return text
0 commit comments