Skip to content

Commit 8b1ff9a

Browse files
author
Andrew Lapp
committed
use tuples as fsm state key
1 parent 0355ab4 commit 8b1ff9a

File tree

2 files changed

+98
-7
lines changed

2 files changed

+98
-7
lines changed

outlines/serve/vllm.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _patched_apply_logits_processors(
2929
logits_row = logits[logits_row_idx]
3030
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
3131
for logits_processor in logits_processors:
32-
logits_row = logits_processor(seq_id, token_ids, logits_row)
32+
logits_row = logits_processor(token_ids, logits_row)
3333
logits[logits_row_idx] = logits_row
3434
logits_row_idx += 1
3535
else:
@@ -56,20 +56,21 @@ def __init__(self, regex_string, llm):
5656
fsm = RegexFSM(regex_string, tokenizer)
5757
self.fsm = fsm
5858

59-
def __call__(
60-
self, seq_id: int, input_ids: List[int], scores: torch.Tensor
61-
) -> torch.Tensor:
59+
def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
6260
"""Use the FSM to bias the logits before sampling the next token."""
6361

62+
state_id = hash(tuple(input_ids))
63+
6464
if len(input_ids) == 0: # Initialize the fsm states
6565
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
6666
else:
67+
prev_state_id = hash(tuple(input_ids[:-1]))
6768
last_token = input_ids[-1]
68-
self.fsm_state[seq_id] = self.fsm.next_state(
69-
self.fsm_state[seq_id], last_token
69+
self.fsm_state[state_id] = self.fsm.next_state(
70+
self.fsm_state[prev_state_id], last_token
7071
)
7172

72-
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
73+
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[state_id])
7374

7475
mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
7576
mask[allowed_tokens] = 0
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import re
2+
3+
import torch
4+
from transformers import AutoTokenizer
5+
6+
from outlines.serve.vllm import RegexLogitsProcessor, _patched_apply_logits_processors
7+
8+
9+
class MockModel:
10+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
11+
12+
13+
def sample_from_logits(logits):
14+
probs = torch.exp(logits) / torch.sum(torch.exp(logits))
15+
return torch.multinomial(probs, 1).item()
16+
17+
18+
def test_time_regexp():
19+
pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?"
20+
llm = MockModel()
21+
logits_processor = RegexLogitsProcessor(pattern, llm)
22+
23+
token_ids = []
24+
while True:
25+
random_scores = -10 + 20 * torch.rand(len(llm.tokenizer.vocab))
26+
logits = logits_processor(
27+
seq_id=0,
28+
input_ids=token_ids,
29+
scores=random_scores,
30+
)
31+
new_token_id = sample_from_logits(logits)
32+
if new_token_id == llm.tokenizer.eos_token_id:
33+
break
34+
token_ids.append(new_token_id)
35+
36+
assert re.fullmatch(pattern, llm.tokenizer.decode(token_ids)) is not None
37+
38+
39+
def test_time_regexp_multiple_samples():
40+
num_seq = 64
41+
42+
pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\ ?(am|pm)?"
43+
llm = MockModel()
44+
45+
class MockSeqData:
46+
def __init__(self):
47+
self.output_token_ids = []
48+
49+
class MockSamplingParams:
50+
logits_processors = [RegexLogitsProcessor(pattern, llm)]
51+
52+
class MockSamplingMeta:
53+
seq_groups = [[range(num_seq), MockSamplingParams()]] # seq_ids
54+
seq_data = {seq_id: MockSeqData() for seq_id in range(num_seq)}
55+
56+
sampling_meta = MockSamplingMeta()
57+
58+
results = []
59+
while True:
60+
complete_seq_ids = set()
61+
62+
logits = torch.randn(len(sampling_meta.seq_data), len(llm.tokenizer.vocab))
63+
new_logits = _patched_apply_logits_processors(logits, sampling_meta)
64+
seq_ids = sorted(sampling_meta.seq_groups[0][0])
65+
for logits_row, seq_id in zip(new_logits, seq_ids):
66+
new_token_id = sample_from_logits(logits_row)
67+
if new_token_id == llm.tokenizer.eos_token_id:
68+
complete_seq_ids.add(seq_id)
69+
results.append(sampling_meta.seq_data[seq_id].output_token_ids)
70+
else:
71+
sampling_meta.seq_data[seq_id].output_token_ids.append(new_token_id)
72+
73+
if complete_seq_ids:
74+
seq_datas = [
75+
sd
76+
for seq_id, sd in sampling_meta.seq_data.items()
77+
if seq_id not in complete_seq_ids
78+
]
79+
sampling_meta.seq_data = {
80+
i: seq_data for i, seq_data in enumerate(seq_datas)
81+
}
82+
sampling_meta.seq_groups[0][0] = range(len(sampling_meta.seq_data))
83+
84+
if not sampling_meta.seq_data:
85+
break
86+
87+
assert len(results) == num_seq
88+
for result in results:
89+
print(llm.tokenizer.decode(result))
90+
assert re.fullmatch(pattern, llm.tokenizer.decode(result)) is not None

0 commit comments

Comments
 (0)