From 2870e1e2c31b4f5c67b5486cbbe8c90769811329 Mon Sep 17 00:00:00 2001 From: Morteza Date: Mon, 15 Jan 2024 14:01:20 -0700 Subject: [PATCH 1/4] Expect vllm.LLMEngine as processor's argument --- docs/reference/vllm.md | 4 ++ examples/vllm_integration.py | 2 +- outlines/serve/serve.py | 5 ++ outlines/serve/vllm.py | 110 ++++++++++++++++++++++++----------- 4 files changed, 86 insertions(+), 35 deletions(-) diff --git a/docs/reference/vllm.md b/docs/reference/vllm.md index 063699a6d..9fee80292 100644 --- a/docs/reference/vllm.md +++ b/docs/reference/vllm.md @@ -23,7 +23,11 @@ You can then query the model in shell by passing a prompt and either 1. a [JSON Schema][jsonschema]{:target="_blank"} specification or 2. a [Regex][regex]{:target="_blank"} pattern +<<<<<<< HEAD with the `schema` or `regex` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained. +======= +with the `schema`, `regex` or `cfg` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained. +>>>>>>> 43ff5c5 (Expect vllm.LLMEngine as processor's argument) For example, to generate a string that matches the schema `{"type": "string"}` (any string): diff --git a/examples/vllm_integration.py b/examples/vllm_integration.py index c2d38883a..bd7d0afb2 100644 --- a/examples/vllm_integration.py +++ b/examples/vllm_integration.py @@ -14,7 +14,7 @@ class User(BaseModel): llm = vllm.LLM(model="gpt2") -logits_processor = JSONLogitsProcessor(User, llm) +logits_processor = JSONLogitsProcessor(User, llm.llm_engine) result = llm.generate( ["A prompt", "Another prompt"], sampling_params=vllm.SamplingParams( diff --git a/outlines/serve/serve.py b/outlines/serve/serve.py index a669c5b50..8ccbb4250 100644 --- a/outlines/serve/serve.py +++ b/outlines/serve/serve.py @@ -27,6 +27,7 @@ from .vllm import ( JSONLogitsProcessor, RegexLogitsProcessor, + CFGLogitsProcessor, _patched_apply_logits_processors, ) @@ -65,10 +66,14 @@ async def generate(request: Request) -> Response: json_schema = request_dict.pop("schema", None) regex_string = request_dict.pop("regex", None) + grammar_string = request_dict.pop("grammar", None) + if json_schema is not None: logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)] elif regex_string is not None: logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)] + elif regex_string is not None: + logits_processors = [CFGLogitsProcessor(grammar_string, engine.engine)] else: logits_processors = [] diff --git a/outlines/serve/vllm.py b/outlines/serve/vllm.py index ee9c7000e..d4f339735 100644 --- a/outlines/serve/vllm.py +++ b/outlines/serve/vllm.py @@ -2,11 +2,12 @@ import json import math from collections import defaultdict -from typing import DefaultDict, List +from typing import DefaultDict, List, Callable import torch +from vllm import LLMEngine -from outlines.fsm.fsm import RegexFSM +from outlines.fsm.fsm import RegexFSM, CFGFSM, FSM from outlines.fsm.json_schema import build_regex_from_object @@ -39,21 +40,54 @@ def _patched_apply_logits_processors( return logits -class RegexLogitsProcessor: - def __init__(self, regex_string, llm): - """Compile the FSM that drives the regex-guided generation. +def _adapt_tokenizer(tokenizer): + """Adapt vLLM's tokenizer to use to compile the FSM. - Parameters - ---------- - regex_string - A string that represents a regular expression - llm - An instance of `vllm.LLM` + The API of Outlines tokenizers is slightly different to that of + `transformers`. The decoder of outlines, returns a list whereas + the decode of vLLM returns an str. To sync the vLLM decoder with + outlines internal api, the decoder should be adapted. In addition + we need to handle the missing spaces to Llama's tokenizer to be + able to compile FSMs for this model. - """ - tokenizer = self.adapt_tokenizer(llm.tokenizer.tokenizer) + """ + if getattr(tokenizer, "_outlines_adapted", False): + return tokenizer + + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def change_decoder( + decoder: Callable[[List[int]], str] + ) -> Callable[[List[int]], List[str]]: + """Sync vLLM's decoder with the outlines expectations by returning list""" + + def new_decoder(inp_tokens: List[int]) -> List[str]: + return [decoder(inp_tokens)] + + return new_decoder + + tokenizer.convert_token_to_string = convert_token_to_string + tokenizer.decode = change_decoder(tokenizer.decode) + setattr(tokenizer, "_outlines_adapted", True) + + return tokenizer - fsm = RegexFSM(regex_string, tokenizer) + +class FSMLogitsProcessor: + def __init__(self): + fsm = FSM() self.fsm = fsm def __call__( @@ -77,35 +111,43 @@ def __call__( return biased_scores - def adapt_tokenizer(self, tokenizer): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. In addition we need to handle the missing spaces to - Llama's tokenizer to be able to compile FSMs for this model. - """ - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) +class RegexLogitsProcessor(FSMLogitsProcessor): + def __init__(self, regex_string, llm: LLMEngine): + """Compile the FSM that drives the regex-guided generation. - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE + Parameters + ---------- + regex_string + A string that represents a regular expression + llm + An instance of `vllm.LLMEngine` - string = tokenizer.convert_tokens_to_string([token]) + """ + adapted_tokenizer = _adapt_tokenizer(llm.tokenizer) + fsm = RegexFSM(regex_string, adapted_tokenizer) + self.fsm = fsm - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - return string +class CFGLogitsProcessor(FSMLogitsProcessor): + def __init__(self, cfg_string, llm: LLMEngine): + """Compile the FSM that drives the cfg-guided generation. - tokenizer.convert_token_to_string = convert_token_to_string + Parameters + ---------- + regex_string + A string that represents a regular expression + llm + An instance of `vllm.LLMEngine` - return tokenizer + """ + adapted_tokenizer = _adapt_tokenizer(llm.tokenizer) + fsm = CFGFSM(cfg_string, adapted_tokenizer) + self.fsm = fsm class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, schema, llm): + def __init__(self, schema, llm: LLMEngine): """Compile the FSM that drives the JSON-guided generation. Parameters @@ -113,7 +155,7 @@ def __init__(self, schema, llm): schema A JSON schema that encodes the structure we want the model to generate llm - An instance of `vllm.LLM` + An instance of `vllm.LLMEngine` """ if isinstance(schema, dict): From 7e5ec9e5bbb0f90a64eef422fcb49c0a3cfa8248 Mon Sep 17 00:00:00 2001 From: Morteza Date: Thu, 18 Jan 2024 15:02:48 -0700 Subject: [PATCH 2/4] Change cfg key in vllm serving --- docs/reference/vllm.md | 16 +++++++++++----- examples/vllm_integration.py | 8 ++++++-- outlines/serve/serve.py | 7 ++++--- outlines/serve/vllm.py | 4 ++-- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/docs/reference/vllm.md b/docs/reference/vllm.md index 9fee80292..25bc3c894 100644 --- a/docs/reference/vllm.md +++ b/docs/reference/vllm.md @@ -23,11 +23,7 @@ You can then query the model in shell by passing a prompt and either 1. a [JSON Schema][jsonschema]{:target="_blank"} specification or 2. a [Regex][regex]{:target="_blank"} pattern -<<<<<<< HEAD -with the `schema` or `regex` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained. -======= -with the `schema`, `regex` or `cfg` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained. ->>>>>>> 43ff5c5 (Expect vllm.LLMEngine as processor's argument) +with the `schema`, `regex` or `grammar` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained. For example, to generate a string that matches the schema `{"type": "string"}` (any string): @@ -49,6 +45,16 @@ curl http://127.0.0.1:8000/generate \ }' ``` +To generate a string that matches the grammar ``: + +```bash +curl http://127.0.0.1:8000/generate \ + -d '{ + "prompt": "What is Pi? Give me the first 15 digits: ", + "grammar": "start: DECIMAL \r\nDIGIT: \"0\"..\"9\"\r\nINT: DIGIT+\r\nDECIMAL: INT \".\" INT? | \".\" INT" + }' +``` + Instead of `curl`, you can also use the [requests][requests]{:target="_blank"} library from another python program. Please consult the [vLLM documentation][vllm]{:target="_blank"} for details on additional request parameters. You can also [read the code](https://github.com/outlines-dev/outlines/blob/main/outlines/serve/serve.py) in case you need to customize the solution to your needs. diff --git a/examples/vllm_integration.py b/examples/vllm_integration.py index bd7d0afb2..e57775633 100644 --- a/examples/vllm_integration.py +++ b/examples/vllm_integration.py @@ -15,10 +15,14 @@ class User(BaseModel): llm = vllm.LLM(model="gpt2") logits_processor = JSONLogitsProcessor(User, llm.llm_engine) -result = llm.generate( +outputs = llm.generate( ["A prompt", "Another prompt"], sampling_params=vllm.SamplingParams( max_tokens=100, logits_processors=[logits_processor] ), ) -print(result) + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/outlines/serve/serve.py b/outlines/serve/serve.py index 8ccbb4250..919be0eab 100644 --- a/outlines/serve/serve.py +++ b/outlines/serve/serve.py @@ -25,6 +25,7 @@ from vllm.utils import random_uuid from .vllm import ( + CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, @@ -66,14 +67,14 @@ async def generate(request: Request) -> Response: json_schema = request_dict.pop("schema", None) regex_string = request_dict.pop("regex", None) - grammar_string = request_dict.pop("grammar", None) + cfg_string = request_dict.pop("grammar", None) if json_schema is not None: logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)] elif regex_string is not None: logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)] - elif regex_string is not None: - logits_processors = [CFGLogitsProcessor(grammar_string, engine.engine)] + elif cfg_string is not None: + logits_processors = [CFGLogitsProcessor(cfg_string, engine.engine)] else: logits_processors = [] diff --git a/outlines/serve/vllm.py b/outlines/serve/vllm.py index d4f339735..ded66223d 100644 --- a/outlines/serve/vllm.py +++ b/outlines/serve/vllm.py @@ -2,12 +2,12 @@ import json import math from collections import defaultdict -from typing import DefaultDict, List, Callable +from typing import Callable, DefaultDict, List import torch from vllm import LLMEngine -from outlines.fsm.fsm import RegexFSM, CFGFSM, FSM +from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.json_schema import build_regex_from_object From 9b092e5909708c7785ac8d6ba709978af99ec665 Mon Sep 17 00:00:00 2001 From: Morteza Date: Thu, 25 Jan 2024 03:40:57 -0700 Subject: [PATCH 3/4] Save the vllm tokenizer adapted state --- docs/reference/vllm.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/reference/vllm.md b/docs/reference/vllm.md index 25bc3c894..44c3a1541 100644 --- a/docs/reference/vllm.md +++ b/docs/reference/vllm.md @@ -45,7 +45,7 @@ curl http://127.0.0.1:8000/generate \ }' ``` -To generate a string that matches the grammar ``: +To generate a string that matches a given grammar ``: ```bash curl http://127.0.0.1:8000/generate \ From 609fc5d003bb11e7936539866dee76090ff5ea3f Mon Sep 17 00:00:00 2001 From: Morteza Date: Thu, 25 Jan 2024 20:42:21 -0700 Subject: [PATCH 4/4] Add test for logit processor --- outlines/serve/serve.py | 1 - outlines/serve/vllm.py | 7 +++---- tests/test_vllm.py | 43 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 5 deletions(-) create mode 100644 tests/test_vllm.py diff --git a/outlines/serve/serve.py b/outlines/serve/serve.py index 919be0eab..aebebdd35 100644 --- a/outlines/serve/serve.py +++ b/outlines/serve/serve.py @@ -28,7 +28,6 @@ CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor, - CFGLogitsProcessor, _patched_apply_logits_processors, ) diff --git a/outlines/serve/vllm.py b/outlines/serve/vllm.py index ded66223d..bf12e2c56 100644 --- a/outlines/serve/vllm.py +++ b/outlines/serve/vllm.py @@ -5,7 +5,6 @@ from typing import Callable, DefaultDict, List import torch -from vllm import LLMEngine from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.json_schema import build_regex_from_object @@ -113,7 +112,7 @@ def __call__( class RegexLogitsProcessor(FSMLogitsProcessor): - def __init__(self, regex_string, llm: LLMEngine): + def __init__(self, regex_string, llm): """Compile the FSM that drives the regex-guided generation. Parameters @@ -130,7 +129,7 @@ def __init__(self, regex_string, llm: LLMEngine): class CFGLogitsProcessor(FSMLogitsProcessor): - def __init__(self, cfg_string, llm: LLMEngine): + def __init__(self, cfg_string, llm): """Compile the FSM that drives the cfg-guided generation. Parameters @@ -147,7 +146,7 @@ def __init__(self, cfg_string, llm: LLMEngine): class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, schema, llm: LLMEngine): + def __init__(self, schema, llm): """Compile the FSM that drives the JSON-guided generation. Parameters diff --git a/tests/test_vllm.py b/tests/test_vllm.py new file mode 100644 index 000000000..6715ce45f --- /dev/null +++ b/tests/test_vllm.py @@ -0,0 +1,43 @@ +import pytest +import torch +from transformers import AutoTokenizer + +from outlines.serve.vllm import ( + CFGLogitsProcessor, + JSONLogitsProcessor, + RegexLogitsProcessor, +) + +TEST_REGEX = r"(-)?(0|[1-9][0-9]*)(.[0-9]+)?([eE][+-][0-9]+)?" +TEST_CFG = """ +start: DECIMAL +DIGIT: "0".."9" +INT: DIGIT+ +DECIMAL: INT "." INT? | "." INT +""" +TEST_SCHEMA = '{"type": "string", "maxLength": 5}' + +LOGIT_PROCESSORS = ( + (CFGLogitsProcessor, TEST_CFG), + (RegexLogitsProcessor, TEST_REGEX), + (JSONLogitsProcessor, TEST_SCHEMA), +) + +TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM" + + +@pytest.mark.parametrize("logit_processor, fsm_str", LOGIT_PROCESSORS) +def test_logit_processor(logit_processor, fsm_str: str): + class MockvLLMEngine: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __call__(*_): + return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None + + tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL) + engine = MockvLLMEngine(tokenizer) + logit_processor(fsm_str, engine) + assert isinstance(engine.tokenizer.decode([0, 1, 2, 3]), list) + logit_processor(fsm_str, engine) + assert isinstance(engine.tokenizer.decode([0, 1, 2, 3]), list)