-
Notifications
You must be signed in to change notification settings - Fork 644
Add CFG-guided generation to the vLLM integration #541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,11 +2,11 @@ | |
| import json | ||
| import math | ||
| from collections import defaultdict | ||
| from typing import DefaultDict, List | ||
| from typing import Callable, DefaultDict, List | ||
|
|
||
| import torch | ||
|
|
||
| from outlines.fsm.fsm import RegexFSM | ||
| from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM | ||
| from outlines.fsm.json_schema import build_regex_from_object | ||
|
|
||
|
|
||
|
|
@@ -39,21 +39,54 @@ def _patched_apply_logits_processors( | |
| return logits | ||
|
|
||
|
|
||
| class RegexLogitsProcessor: | ||
| def __init__(self, regex_string, llm): | ||
| """Compile the FSM that drives the regex-guided generation. | ||
| def _adapt_tokenizer(tokenizer): | ||
| """Adapt vLLM's tokenizer to use to compile the FSM. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| regex_string | ||
| A string that represents a regular expression | ||
| llm | ||
| An instance of `vllm.LLM` | ||
| The API of Outlines tokenizers is slightly different to that of | ||
| `transformers`. The decoder of outlines, returns a list whereas | ||
| the decode of vLLM returns an str. To sync the vLLM decoder with | ||
| outlines internal api, the decoder should be adapted. In addition | ||
| we need to handle the missing spaces to Llama's tokenizer to be | ||
| able to compile FSMs for this model. | ||
|
|
||
| """ | ||
| tokenizer = self.adapt_tokenizer(llm.tokenizer.tokenizer) | ||
| """ | ||
| if getattr(tokenizer, "_outlines_adapted", False): | ||
| return tokenizer | ||
|
|
||
| tokenizer.vocabulary = tokenizer.get_vocab() | ||
| tokenizer.special_tokens = set(tokenizer.all_special_tokens) | ||
|
|
||
| def convert_token_to_string(token: str) -> str: | ||
| from transformers.file_utils import SPIECE_UNDERLINE | ||
|
|
||
| string = tokenizer.convert_tokens_to_string([token]) | ||
|
|
||
| # A hack to handle missing spaces to HF's Llama tokenizers | ||
| if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": | ||
| return " " + string | ||
|
|
||
| return string | ||
|
|
||
| def change_decoder( | ||
| decoder: Callable[[List[int]], str] | ||
| ) -> Callable[[List[int]], List[str]]: | ||
| """Sync vLLM's decoder with the outlines expectations by returning list""" | ||
|
|
||
| def new_decoder(inp_tokens: List[int]) -> List[str]: | ||
| return [decoder(inp_tokens)] | ||
|
|
||
| return new_decoder | ||
|
|
||
| tokenizer.convert_token_to_string = convert_token_to_string | ||
| tokenizer.decode = change_decoder(tokenizer.decode) | ||
| setattr(tokenizer, "_outlines_adapted", True) | ||
|
|
||
| return tokenizer | ||
|
|
||
| fsm = RegexFSM(regex_string, tokenizer) | ||
|
|
||
| class FSMLogitsProcessor: | ||
| def __init__(self): | ||
| fsm = FSM() | ||
| self.fsm = fsm | ||
|
|
||
| def __call__( | ||
|
|
@@ -77,31 +110,39 @@ def __call__( | |
|
|
||
| return biased_scores | ||
|
|
||
| def adapt_tokenizer(self, tokenizer): | ||
| """Adapt vLLM's tokenizer to use to compile the FSM. | ||
|
|
||
| The API of Outlines tokenizers is slightly different to that of | ||
| `transformers`. In addition we need to handle the missing spaces to | ||
| Llama's tokenizer to be able to compile FSMs for this model. | ||
|
|
||
| """ | ||
| tokenizer.vocabulary = tokenizer.get_vocab() | ||
| tokenizer.special_tokens = set(tokenizer.all_special_tokens) | ||
| class RegexLogitsProcessor(FSMLogitsProcessor): | ||
| def __init__(self, regex_string, llm): | ||
| """Compile the FSM that drives the regex-guided generation. | ||
|
|
||
| def convert_token_to_string(token: str) -> str: | ||
| from transformers.file_utils import SPIECE_UNDERLINE | ||
| Parameters | ||
| ---------- | ||
| regex_string | ||
| A string that represents a regular expression | ||
| llm | ||
| An instance of `vllm.LLMEngine` | ||
|
|
||
| string = tokenizer.convert_tokens_to_string([token]) | ||
| """ | ||
| adapted_tokenizer = _adapt_tokenizer(llm.tokenizer) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be |
||
| fsm = RegexFSM(regex_string, adapted_tokenizer) | ||
| self.fsm = fsm | ||
|
|
||
| # A hack to handle missing spaces to HF's Llama tokenizers | ||
| if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": | ||
| return " " + string | ||
|
|
||
| return string | ||
| class CFGLogitsProcessor(FSMLogitsProcessor): | ||
| def __init__(self, cfg_string, llm): | ||
| """Compile the FSM that drives the cfg-guided generation. | ||
|
|
||
| tokenizer.convert_token_to_string = convert_token_to_string | ||
| Parameters | ||
| ---------- | ||
| regex_string | ||
| A string that represents a regular expression | ||
| llm | ||
| An instance of `vllm.LLMEngine` | ||
|
|
||
| return tokenizer | ||
| """ | ||
| adapted_tokenizer = _adapt_tokenizer(llm.tokenizer) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above, |
||
| fsm = CFGFSM(cfg_string, adapted_tokenizer) | ||
| self.fsm = fsm | ||
|
|
||
|
|
||
| class JSONLogitsProcessor(RegexLogitsProcessor): | ||
|
|
@@ -113,7 +154,7 @@ def __init__(self, schema, llm): | |
| schema | ||
| A JSON schema that encodes the structure we want the model to generate | ||
| llm | ||
| An instance of `vllm.LLM` | ||
| An instance of `vllm.LLMEngine` | ||
|
|
||
| """ | ||
| if isinstance(schema, dict): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| import pytest | ||
| import torch | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from outlines.serve.vllm import ( | ||
| CFGLogitsProcessor, | ||
| JSONLogitsProcessor, | ||
| RegexLogitsProcessor, | ||
| ) | ||
|
|
||
| TEST_REGEX = r"(-)?(0|[1-9][0-9]*)(.[0-9]+)?([eE][+-][0-9]+)?" | ||
| TEST_CFG = """ | ||
| start: DECIMAL | ||
| DIGIT: "0".."9" | ||
| INT: DIGIT+ | ||
| DECIMAL: INT "." INT? | "." INT | ||
| """ | ||
| TEST_SCHEMA = '{"type": "string", "maxLength": 5}' | ||
|
|
||
| LOGIT_PROCESSORS = ( | ||
| (CFGLogitsProcessor, TEST_CFG), | ||
| (RegexLogitsProcessor, TEST_REGEX), | ||
| (JSONLogitsProcessor, TEST_SCHEMA), | ||
| ) | ||
|
|
||
| TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("logit_processor, fsm_str", LOGIT_PROCESSORS) | ||
| def test_logit_processor(logit_processor, fsm_str: str): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this test doing? |
||
| class MockvLLMEngine: | ||
| def __init__(self, tokenizer): | ||
| self.tokenizer = tokenizer | ||
|
|
||
| def __call__(*_): | ||
| return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL) | ||
| engine = MockvLLMEngine(tokenizer) | ||
| logit_processor(fsm_str, engine) | ||
| assert isinstance(engine.tokenizer.decode([0, 1, 2, 3]), list) | ||
| logit_processor(fsm_str, engine) | ||
| assert isinstance(engine.tokenizer.decode([0, 1, 2, 3]), list) | ||
Uh oh!
There was an error while loading. Please reload this page.