diff --git a/outlines/serve/vllm.py b/outlines/serve/vllm.py index ee9c7000e..7bfd051ae 100644 --- a/outlines/serve/vllm.py +++ b/outlines/serve/vllm.py @@ -1,12 +1,11 @@ """Make vLLM compatible with Outlines' guided generation.""" import json import math -from collections import defaultdict -from typing import DefaultDict, List +from typing import Dict, List, Tuple import torch -from outlines.fsm.fsm import RegexFSM +from outlines.fsm.fsm import FSMState, RegexFSM from outlines.fsm.json_schema import build_regex_from_object @@ -29,7 +28,7 @@ def _patched_apply_logits_processors( logits_row = logits[logits_row_idx] token_ids = sampling_metadata.seq_data[seq_id].output_token_ids for logits_processor in logits_processors: - logits_row = logits_processor(seq_id, token_ids, logits_row) + logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row logits_row_idx += 1 else: @@ -39,7 +38,58 @@ def _patched_apply_logits_processors( return logits +def adapt_tokenizer(tokenizer): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition, we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer + + +class CachedRegexFSM(RegexFSM): + def __init__(self, regex_string: str, adapted_tokenizer): + super().__init__(regex_string, adapted_tokenizer) + self.state_cache: Dict[int, FSMState] = {} + + def get_state_by_token_ids(self, input_ids: Tuple[int]) -> FSMState: + state_key = hash(input_ids) + + if not input_ids: + self.state_cache[state_key] = FSMState(0) + + elif state_key not in self.state_cache: + prev_state_key = hash(input_ids[:-1]) + prev_state = self.state_cache[prev_state_key] + + last_token = input_ids[-1] + new_state = self.next_state(prev_state, last_token) + self.state_cache[state_key] = new_state + + return self.state_cache[state_key] + + class RegexLogitsProcessor: + fsm_cache: Dict[str, CachedRegexFSM] = {} + def __init__(self, regex_string, llm): """Compile the FSM that drives the regex-guided generation. @@ -51,25 +101,19 @@ def __init__(self, regex_string, llm): An instance of `vllm.LLM` """ - tokenizer = self.adapt_tokenizer(llm.tokenizer.tokenizer) + adapted_tokenizer = adapt_tokenizer(llm.tokenizer.tokenizer) + + fsm = self.fsm_cache.get(regex_string) + if fsm is None: + fsm = CachedRegexFSM(regex_string, adapted_tokenizer) + self.fsm_cache[regex_string] = fsm - fsm = RegexFSM(regex_string, tokenizer) self.fsm = fsm - def __call__( - self, seq_id: int, input_ids: List[int], scores: torch.Tensor - ) -> torch.Tensor: + def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" - - if len(input_ids) == 0: # Initialize the fsm states - self.fsm_state: DefaultDict[int, int] = defaultdict(int) - else: - last_token = input_ids[-1] - self.fsm_state[seq_id] = self.fsm.next_state( - self.fsm_state[seq_id], last_token - ) - - allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) + state = self.fsm.get_state_by_token_ids(tuple(input_ids)) + allowed_tokens = self.fsm.allowed_token_ids(state) mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) mask[allowed_tokens] = 0 @@ -77,32 +121,6 @@ def __call__( return biased_scores - def adapt_tokenizer(self, tokenizer): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. In addition we need to handle the missing spaces to - Llama's tokenizer to be able to compile FSMs for this model. - - """ - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = tokenizer.convert_tokens_to_string([token]) - - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - tokenizer.convert_token_to_string = convert_token_to_string - - return tokenizer - class JSONLogitsProcessor(RegexLogitsProcessor): def __init__(self, schema, llm): @@ -118,5 +136,7 @@ def __init__(self, schema, llm): """ if isinstance(schema, dict): schema = json.dumps(schema) + regex_string = build_regex_from_object(schema) + super().__init__(regex_string, llm) diff --git a/tests/serve/test_vllm.py b/tests/serve/test_vllm.py new file mode 100644 index 000000000..77280f4e5 --- /dev/null +++ b/tests/serve/test_vllm.py @@ -0,0 +1,118 @@ +import re + +import torch + +from outlines.serve.vllm import RegexLogitsProcessor, _patched_apply_logits_processors + + +class MockTokenizer: + vocabulary = { + **{chr(i): i for i in range(256)}, + **{"eos": 256}, + } + special_tokens = {"eos"} + eos_token_id = 256 + + @property + def inverse_vocabulary(self): + return {v: k for k, v in self.vocabulary.items()} + + def decode(self, token_ids): + return "".join([self.inverse_vocabulary[t] for t in token_ids]) + + #### + # vLLM tokenizer features + #### + all_special_tokens = list(special_tokens) + + def convert_tokens_to_string(self, token): + return token[0] + + def get_vocab(self): + return MockTokenizer.vocabulary + + +class MockTokenizerGroup: + tokenizer = MockTokenizer() + + +class MockModel: + tokenizer = MockTokenizerGroup() + + +def sample_from_logits(logits): + probs = torch.exp(logits) / torch.sum(torch.exp(logits)) + return torch.multinomial(probs, 1).item() + + +def test_time_regexp(): + pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?" + llm = MockModel() + logits_processor = RegexLogitsProcessor(pattern, llm) + + token_ids = [] + while True: + random_scores = -10 + 20 * torch.rand(len(llm.tokenizer.vocabulary)) + logits = logits_processor( + input_ids=token_ids, + scores=random_scores, + ) + new_token_id = sample_from_logits(logits) + if new_token_id == llm.tokenizer.eos_token_id: + break + token_ids.append(new_token_id) + + assert re.fullmatch(pattern, llm.tokenizer.decode(token_ids)) is not None + + +def test_time_regexp_multiple_samples(): + num_seq = 64 + + pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\ ?(am|pm)?" + llm = MockModel() + + class MockSeqData: + def __init__(self): + self.output_token_ids = [] + + class MockSamplingParams: + logits_processors = [RegexLogitsProcessor(pattern, llm)] + + class MockSamplingMeta: + seq_groups = [[range(num_seq), MockSamplingParams()]] # seq_ids + seq_data = {seq_id: MockSeqData() for seq_id in range(num_seq)} + + sampling_meta = MockSamplingMeta() + + results = [] + while True: + complete_seq_ids = set() + + logits = torch.randn(len(sampling_meta.seq_data), len(llm.tokenizer.vocabulary)) + new_logits = _patched_apply_logits_processors(logits, sampling_meta) + seq_ids = sorted(sampling_meta.seq_groups[0][0]) + for logits_row, seq_id in zip(new_logits, seq_ids): + new_token_id = sample_from_logits(logits_row) + if new_token_id == llm.tokenizer.eos_token_id: + complete_seq_ids.add(seq_id) + results.append(sampling_meta.seq_data[seq_id].output_token_ids) + else: + sampling_meta.seq_data[seq_id].output_token_ids.append(new_token_id) + + if complete_seq_ids: + seq_datas = [ + sd + for seq_id, sd in sampling_meta.seq_data.items() + if seq_id not in complete_seq_ids + ] + sampling_meta.seq_data = { + i: seq_data for i, seq_data in enumerate(seq_datas) + } + sampling_meta.seq_groups[0][0] = range(len(sampling_meta.seq_data)) + + if not sampling_meta.seq_data: + break + + assert len(results) == num_seq + for result in results: + assert re.fullmatch(pattern, llm.tokenizer.decode(result)) is not None