11"""Make vLLM compatible with Outlines' guided generation."""
22import json
33import math
4- from typing import Dict , List
4+ from typing import Dict , List , Tuple
55
66import torch
77
@@ -38,7 +38,59 @@ def _patched_apply_logits_processors(
3838 return logits
3939
4040
41+ def adapt_tokenizer (tokenizer ):
42+ """Adapt vLLM's tokenizer to use to compile the FSM.
43+
44+ The API of Outlines tokenizers is slightly different to that of
45+ `transformers`. In addition, we need to handle the missing spaces to
46+ Llama's tokenizer to be able to compile FSMs for this model.
47+
48+ """
49+ tokenizer .vocabulary = tokenizer .get_vocab ()
50+ tokenizer .special_tokens = set (tokenizer .all_special_tokens )
51+
52+ def convert_token_to_string (token : str ) -> str :
53+ from transformers .file_utils import SPIECE_UNDERLINE
54+
55+ string = tokenizer .convert_tokens_to_string ([token ])
56+
57+ # A hack to handle missing spaces to HF's Llama tokenizers
58+ if token .startswith (SPIECE_UNDERLINE ) or token == "<0x20>" :
59+ return " " + string
60+
61+ return string
62+
63+ tokenizer .convert_token_to_string = convert_token_to_string
64+
65+ return tokenizer
66+
67+
68+ class CachedRegexFSM (RegexFSM ):
69+ def __init__ (self , regex_string : str , adapted_tokenizer ):
70+ super ().__init__ (regex_string , adapted_tokenizer )
71+ self .state_cache : Dict [int , FSMState ] = {}
72+
73+ def get_state_by_token_ids (self , input_ids : Tuple [int ]) -> FSMState :
74+ state_key = hash (input_ids )
75+
76+ if not input_ids :
77+ self .state_cache [state_key ] = FSMState (0 )
78+
79+ elif state_key not in self .state_cache :
80+ prev_state_key = hash (input_ids [:- 1 ])
81+ prev_state = self .state_cache [prev_state_key ]
82+
83+ last_token = input_ids [- 1 ]
84+ new_state = self .next_state (prev_state , last_token )
85+ self .state_cache [state_key ] = new_state
86+
87+ return self .state_cache [state_key ]
88+
89+
4190class RegexLogitsProcessor :
91+ fsm_cache : Dict [str , CachedRegexFSM ] = {}
92+ adapted_tokenizer = None
93+
4294 def __init__ (self , regex_string , llm ):
4395 """Compile the FSM that drives the regex-guided generation.
4496
@@ -50,15 +102,21 @@ def __init__(self, regex_string, llm):
50102 An instance of `vllm.LLM`
51103
52104 """
53- tokenizer = self .adapt_tokenizer (llm .tokenizer )
105+ cls = self .__class__
106+
107+ if cls .adapted_tokenizer is None :
108+ cls .adapted_tokenizer = adapt_tokenizer (llm .tokenizer )
109+
110+ fsm = self .fsm_cache .get (regex_string )
111+ if fsm is None :
112+ fsm = CachedRegexFSM (regex_string , cls .adapted_tokenizer )
113+ self .fsm_cache [regex_string ] = fsm
54114
55- fsm = RegexFSM (regex_string , tokenizer )
56115 self .fsm = fsm
57- self .fsm_state_cache : Dict [int , FSMState ] = {}
58116
59117 def __call__ (self , input_ids : List [int ], scores : torch .Tensor ) -> torch .Tensor :
60118 """Use the FSM to bias the logits before sampling the next token."""
61- state = self .get_fsm_state ( input_ids )
119+ state = self .fsm . get_state_by_token_ids ( tuple ( input_ids ) )
62120 allowed_tokens = self .fsm .allowed_token_ids (state )
63121
64122 mask = torch .full ((scores .shape [- 1 ],), - math .inf , device = scores .device )
@@ -67,48 +125,6 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
67125
68126 return biased_scores
69127
70- def get_fsm_state (self , input_ids : List [int ]) -> FSMState :
71- state_key = hash (tuple (input_ids ))
72-
73- if not input_ids :
74- self .fsm_state_cache [state_key ] = FSMState (0 )
75-
76- elif state_key not in self .fsm_state_cache :
77- prev_state_key = hash (tuple (input_ids [:- 1 ]))
78- prev_state = self .fsm_state_cache [prev_state_key ]
79- last_token = input_ids [- 1 ]
80- self .fsm_state_cache [state_key ] = self .fsm .next_state (
81- prev_state , last_token
82- )
83-
84- return self .fsm_state_cache [state_key ]
85-
86- def adapt_tokenizer (self , tokenizer ):
87- """Adapt vLLM's tokenizer to use to compile the FSM.
88-
89- The API of Outlines tokenizers is slightly different to that of
90- `transformers`. In addition we need to handle the missing spaces to
91- Llama's tokenizer to be able to compile FSMs for this model.
92-
93- """
94- tokenizer .vocabulary = tokenizer .get_vocab ()
95- tokenizer .special_tokens = set (tokenizer .all_special_tokens )
96-
97- def convert_token_to_string (token : str ) -> str :
98- from transformers .file_utils import SPIECE_UNDERLINE
99-
100- string = tokenizer .convert_tokens_to_string ([token ])
101-
102- # A hack to handle missing spaces to HF's Llama tokenizers
103- if token .startswith (SPIECE_UNDERLINE ) or token == "<0x20>" :
104- return " " + string
105-
106- return string
107-
108- tokenizer .convert_token_to_string = convert_token_to_string
109-
110- return tokenizer
111-
112128
113129class JSONLogitsProcessor (RegexLogitsProcessor ):
114130 def __init__ (self , schema , llm ):
@@ -124,5 +140,7 @@ def __init__(self, schema, llm):
124140 """
125141 if isinstance (schema , dict ):
126142 schema = json .dumps (schema )
143+
127144 regex_string = build_regex_from_object (schema )
145+
128146 super ().__init__ (regex_string , llm )
0 commit comments