From 7b0a0dfb22907505441f8a4a5eb882cbca4d2acf Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Thu, 6 Jun 2024 01:49:12 +0200 Subject: [PATCH] [Frontend][Core] Update Outlines Integration from `FSM` to `Guide` (#4109) Co-authored-by: Simon Mo Co-authored-by: Breno Faria --- requirements-common.txt | 2 +- tests/entrypoints/test_guided_processors.py | 2 - .../guided_decoding/outlines_decoding.py | 31 ++++------ .../outlines_logits_processors.py | 62 +++++++++++-------- 4 files changed, 49 insertions(+), 48 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index f41873570aa67..bf9987e3af014 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -17,6 +17,6 @@ prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.10.1 -outlines == 0.0.34 # Requires torch >= 2.1.0 +outlines >= 0.0.43 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 5d4163e96fd87..fb32a9d155bc0 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -63,7 +63,6 @@ def test_guided_logits_processors(): tokenizer, whitespace_pattern=None) - regex_LP.init_state() token_ids = tokenizer.encode( f"Give an example IPv4 address with this regex: {TEST_REGEX}") tensor = torch.rand(32000) @@ -72,7 +71,6 @@ def test_guided_logits_processors(): assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - json_LP.init_state() token_ids = tokenizer.encode( f"Give an employee profile that fits this schema: {TEST_SCHEMA}") tensor = torch.rand(32000) diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 8403604286903..721f7e0530cb7 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -1,8 +1,6 @@ import asyncio import concurrent.futures -from copy import copy from enum import Enum -from functools import lru_cache from json import dumps as json_dumps from re import escape as regex_escape from typing import Tuple, Union @@ -54,8 +52,10 @@ class GuidedDecodingMode(Enum): async def get_outlines_guided_decoding_logits_processor( - request: Union[CompletionRequest, ChatCompletionRequest], - tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]: + request: Union[CompletionRequest, + ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, + None]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. @@ -64,7 +64,7 @@ async def get_outlines_guided_decoding_logits_processor( """ global global_thread_pool guide, mode = _get_guide_and_mode(request) - if not guide: + if not guide or not mode: return None if global_thread_pool is None: @@ -72,15 +72,9 @@ async def get_outlines_guided_decoding_logits_processor( max_workers=2) loop = asyncio.get_running_loop() - result = await loop.run_in_executor(global_thread_pool, - _get_cached_logits_processor, guide, - tokenizer, mode, - request.guided_whitespace_pattern) - - logits_processor = copy(result) - # reset logits processor's internal state - logits_processor.init_state() - return logits_processor + return await loop.run_in_executor(global_thread_pool, + _get_logits_processor, guide, tokenizer, + mode, request.guided_whitespace_pattern) def _get_guide_and_mode( @@ -115,11 +109,10 @@ def _get_guide_and_mode( return None, None -@lru_cache(maxsize=32) -def _get_cached_logits_processor(guide: str, - tokenizer: PreTrainedTokenizerBase, - mode: GuidedDecodingMode, - whitespace_pattern: Union[str, None]): +def _get_logits_processor( + guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode, + whitespace_pattern: Union[str, None] +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: if mode == GuidedDecodingMode.JSON: return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index a131c6a1b92b4..1618705ff2983 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -21,7 +21,7 @@ from typing import Callable, DefaultDict, Dict, List, Union import torch -from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM +from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write from outlines.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel from transformers import PreTrainedTokenizerBase @@ -29,28 +29,32 @@ class BaseLogitsProcessor: - def __init__(self): - # Child class should use initialize in their init. - self.fsm: FSM - - def init_state(self): - """Initialize the FSM states.""" - self.fsm_state: DefaultDict[int, int] = defaultdict(int) + def __init__(self, guide: Guide): + self._guide: Guide = guide + self._fsm_state: DefaultDict[int, int] = defaultdict(int) def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" seq_id = hash(tuple(input_ids)) - if len(input_ids) == 0: - self.init_state() - else: + if len(input_ids) > 0: last_token = input_ids[-1] last_seq_id = hash(tuple(input_ids[:-1])) - self.fsm_state[seq_id] = self.fsm.next_state( - self.fsm_state[last_seq_id], last_token) + self._fsm_state[seq_id] = self._guide.get_next_state( + state=self._fsm_state[last_seq_id], token_id=last_token) + + instruction = self._guide.get_next_instruction( + state=self._fsm_state[seq_id]) - allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) + if type(instruction) == Generate: + allowed_tokens = instruction.tokens + elif type(instruction) == Write: + # TODO: support fast forward tokens + allowed_tokens = [instruction.tokens[0]] + else: + raise TypeError( + f"Unsupported instruction type {type(instruction)}") mask = torch.full((scores.shape[-1], ), -math.inf, @@ -62,6 +66,13 @@ def __call__(self, input_ids: List[int], class RegexLogitsProcessor(BaseLogitsProcessor): + @classmethod + @lru_cache(maxsize=32) + def _get_guide(cls, regex_string: str, + tokenizer: PreTrainedTokenizerBase) -> Guide: + tokenizer = _adapt_tokenizer(tokenizer) + return RegexGuide(regex_string, tokenizer) + def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): """Compile the FSM that drives the regex-structured generation. @@ -73,9 +84,8 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - tokenizer = _adapt_tokenizer(tokenizer) - fsm = RegexFSM(regex_string, tokenizer) - self.fsm = fsm + super().__init__( + RegexLogitsProcessor._get_guide(regex_string, tokenizer)) class JSONLogitsProcessor(RegexLogitsProcessor): @@ -115,6 +125,12 @@ def __init__(self, schema: Union[str, Dict, BaseModel], class CFGLogitsProcessor(BaseLogitsProcessor): + @classmethod + @lru_cache(maxsize=32) + def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: + tokenizer = _adapt_tokenizer(tokenizer) + return CFGGuide(cfg, tokenizer) + def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): """Compile the FSM that drives the context free grammar generation. @@ -126,17 +142,11 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - tokenizer = _adapt_tokenizer(tokenizer) - fsm = CFGFSM(cfg, tokenizer) - self.fsm = fsm - - def init_state(self): - """Initialize state with a CFGFSM copy.""" - super().init_state() - self.fsm = self.fsm.copy() + super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer)) + self._guide = self._guide.copy() -@lru_cache +@lru_cache(maxsize=32) def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): """Adapt vLLM's tokenizer to use to compile the FSM.