Skip to content

Commit 6b2035e

Browse files
author
Andrew Lapp
committed
integrate CachedRegexFSM from @viktor-ferenczi
1 parent 0443d43 commit 6b2035e

File tree

1 file changed

+65
-47
lines changed

1 file changed

+65
-47
lines changed

outlines/serve/vllm.py

Lines changed: 65 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Make vLLM compatible with Outlines' guided generation."""
22
import json
33
import math
4-
from typing import Dict, List
4+
from typing import Dict, List, Tuple
55

66
import torch
77

@@ -38,7 +38,59 @@ def _patched_apply_logits_processors(
3838
return logits
3939

4040

41+
def adapt_tokenizer(tokenizer):
42+
"""Adapt vLLM's tokenizer to use to compile the FSM.
43+
44+
The API of Outlines tokenizers is slightly different to that of
45+
`transformers`. In addition, we need to handle the missing spaces to
46+
Llama's tokenizer to be able to compile FSMs for this model.
47+
48+
"""
49+
tokenizer.vocabulary = tokenizer.get_vocab()
50+
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
51+
52+
def convert_token_to_string(token: str) -> str:
53+
from transformers.file_utils import SPIECE_UNDERLINE
54+
55+
string = tokenizer.convert_tokens_to_string([token])
56+
57+
# A hack to handle missing spaces to HF's Llama tokenizers
58+
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
59+
return " " + string
60+
61+
return string
62+
63+
tokenizer.convert_token_to_string = convert_token_to_string
64+
65+
return tokenizer
66+
67+
68+
class CachedRegexFSM(RegexFSM):
69+
def __init__(self, regex_string: str, adapted_tokenizer):
70+
super().__init__(regex_string, adapted_tokenizer)
71+
self.state_cache: Dict[int, FSMState] = {}
72+
73+
def get_state_by_token_ids(self, input_ids: Tuple[int]) -> FSMState:
74+
state_key = hash(input_ids)
75+
76+
if not input_ids:
77+
self.state_cache[state_key] = FSMState(0)
78+
79+
elif state_key not in self.state_cache:
80+
prev_state_key = hash(input_ids[:-1])
81+
prev_state = self.state_cache[prev_state_key]
82+
83+
last_token = input_ids[-1]
84+
new_state = self.next_state(prev_state, last_token)
85+
self.state_cache[state_key] = new_state
86+
87+
return self.state_cache[state_key]
88+
89+
4190
class RegexLogitsProcessor:
91+
fsm_cache: Dict[str, CachedRegexFSM] = {}
92+
adapted_tokenizer = None
93+
4294
def __init__(self, regex_string, llm):
4395
"""Compile the FSM that drives the regex-guided generation.
4496
@@ -50,15 +102,21 @@ def __init__(self, regex_string, llm):
50102
An instance of `vllm.LLM`
51103
52104
"""
53-
tokenizer = self.adapt_tokenizer(llm.tokenizer)
105+
cls = self.__class__
106+
107+
if cls.adapted_tokenizer is None:
108+
cls.adapted_tokenizer = adapt_tokenizer(llm.tokenizer)
109+
110+
fsm = self.fsm_cache.get(regex_string)
111+
if fsm is None:
112+
fsm = CachedRegexFSM(regex_string, cls.adapted_tokenizer)
113+
self.fsm_cache[regex_string] = fsm
54114

55-
fsm = RegexFSM(regex_string, tokenizer)
56115
self.fsm = fsm
57-
self.fsm_state_cache: Dict[int, FSMState] = {}
58116

59117
def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
60118
"""Use the FSM to bias the logits before sampling the next token."""
61-
state = self.get_fsm_state(input_ids)
119+
state = self.fsm.get_state_by_token_ids(tuple(input_ids))
62120
allowed_tokens = self.fsm.allowed_token_ids(state)
63121

64122
mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
@@ -67,48 +125,6 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
67125

68126
return biased_scores
69127

70-
def get_fsm_state(self, input_ids: List[int]) -> FSMState:
71-
state_key = hash(tuple(input_ids))
72-
73-
if not input_ids:
74-
self.fsm_state_cache[state_key] = FSMState(0)
75-
76-
elif state_key not in self.fsm_state_cache:
77-
prev_state_key = hash(tuple(input_ids[:-1]))
78-
prev_state = self.fsm_state_cache[prev_state_key]
79-
last_token = input_ids[-1]
80-
self.fsm_state_cache[state_key] = self.fsm.next_state(
81-
prev_state, last_token
82-
)
83-
84-
return self.fsm_state_cache[state_key]
85-
86-
def adapt_tokenizer(self, tokenizer):
87-
"""Adapt vLLM's tokenizer to use to compile the FSM.
88-
89-
The API of Outlines tokenizers is slightly different to that of
90-
`transformers`. In addition we need to handle the missing spaces to
91-
Llama's tokenizer to be able to compile FSMs for this model.
92-
93-
"""
94-
tokenizer.vocabulary = tokenizer.get_vocab()
95-
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
96-
97-
def convert_token_to_string(token: str) -> str:
98-
from transformers.file_utils import SPIECE_UNDERLINE
99-
100-
string = tokenizer.convert_tokens_to_string([token])
101-
102-
# A hack to handle missing spaces to HF's Llama tokenizers
103-
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
104-
return " " + string
105-
106-
return string
107-
108-
tokenizer.convert_token_to_string = convert_token_to_string
109-
110-
return tokenizer
111-
112128

113129
class JSONLogitsProcessor(RegexLogitsProcessor):
114130
def __init__(self, schema, llm):
@@ -124,5 +140,7 @@ def __init__(self, schema, llm):
124140
"""
125141
if isinstance(schema, dict):
126142
schema = json.dumps(schema)
143+
127144
regex_string = build_regex_from_object(schema)
145+
128146
super().__init__(regex_string, llm)

0 commit comments

Comments
 (0)