Skip to content

Commit 9e18659

Browse files
noamgatrussellbDarkLight1337
authored andcommitted
Frontend: Adding LM Format Enforcer support to V1 engine (vllm-project#22564)
Signed-off-by: Noam Gat <noamgat@gmail.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
1 parent 16ef192 commit 9e18659

File tree

6 files changed

+190
-5
lines changed

6 files changed

+190
-5
lines changed

requirements/common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ prometheus_client >= 0.18.0
1818
pillow # Required for image processing
1919
prometheus-fastapi-instrumentator >= 7.0.0
2020
tiktoken >= 0.6.0 # Required for DBRX tokenizer
21-
lm-format-enforcer >= 0.10.11, < 0.11
21+
lm-format-enforcer == 0.11.3
2222
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
2323
outlines_core == 0.2.10 ; platform_machine != "s390x"
2424
outlines == 0.1.11 ; platform_machine == "s390x"

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@
4141
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
4242
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
4343
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
44+
("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto",
45+
None),
4446
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
4547
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
48+
("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None),
4649
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
4750
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
4851
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
@@ -148,7 +151,8 @@ def test_structured_output(
148151

149152
generated_text = output.outputs[0].text
150153
assert generated_text is not None
151-
assert "\n" not in generated_text
154+
if guided_decoding_backend != 'lm-format-enforcer':
155+
assert "\n" not in generated_text
152156
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
153157
output_json = json.loads(generated_text)
154158
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@@ -225,7 +229,7 @@ def test_structured_output(
225229
parsed_json = json.loads(generated_text)
226230
assert isinstance(parsed_json, dict)
227231

228-
if guided_decoding_backend != "outlines":
232+
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
229233
#
230234
# Test 4: Generate SQL statement using EBNF grammar
231235
#
@@ -439,7 +443,7 @@ def test_structured_output(
439443
output_json = json.loads(generated_text)
440444
jsonschema.validate(instance=output_json, schema=json_schema)
441445

442-
if guided_decoding_backend != "outlines":
446+
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
443447
#
444448
# Test 11: Generate structured output using structural_tag format
445449
#

vllm/config/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3057,7 +3057,8 @@ def get_served_model_name(model: str,
30573057
return served_model_name
30583058

30593059

3060-
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"]
3060+
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines",
3061+
"lm-format-enforcer"]
30613062

30623063

30633064
@config

vllm/v1/engine/processor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
2222
from vllm.v1.structured_output.backend_guidance import (
2323
validate_guidance_grammar)
24+
from vllm.v1.structured_output.backend_lm_format_enforcer import (
25+
validate_structured_output_request_lm_format_enforcer)
2426
from vllm.v1.structured_output.backend_outlines import (
2527
validate_structured_output_request_outlines)
2628
from vllm.v1.structured_output.backend_xgrammar import (
@@ -200,6 +202,9 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
200202
elif engine_level_backend == "outlines":
201203
# outlines backend
202204
validate_structured_output_request_outlines(params)
205+
elif engine_level_backend == "lm-format-enforcer":
206+
# lm format enforcer backend
207+
validate_structured_output_request_lm_format_enforcer(params)
203208
else:
204209
# NOTE: engine_level_backend must be "auto" here, because we have
205210
# checked supported_backends above.

vllm/v1/structured_output/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ def grammar_init(self, request: Request) -> None:
108108
tokenizer=self.tokenizer,
109109
vocab_size=vocab_size,
110110
)
111+
elif backend == "lm-format-enforcer":
112+
from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501
113+
LMFormatEnforcerBackend)
114+
self.backend = LMFormatEnforcerBackend(
115+
self.vllm_config,
116+
tokenizer=self.tokenizer,
117+
vocab_size=vocab_size,
118+
)
111119
else:
112120
raise ValueError(
113121
f"Unsupported structured output backend: {backend}")
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from __future__ import annotations
4+
5+
import ast
6+
import json
7+
from dataclasses import dataclass, field
8+
from functools import lru_cache
9+
from typing import TYPE_CHECKING
10+
11+
import torch
12+
from transformers import PreTrainedTokenizerBase
13+
14+
from vllm.sampling_params import SamplingParams
15+
from vllm.utils import LazyLoader
16+
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
17+
StructuredOutputGrammar,
18+
StructuredOutputOptions)
19+
20+
if TYPE_CHECKING:
21+
import lmformatenforcer
22+
import lmformatenforcer.integrations.vllm as lmfe_vllm
23+
else:
24+
lmformatenforcer = LazyLoader("lmformatenforcer", globals(),
25+
"lmformatenforcer")
26+
lmfe_vllm = LazyLoader("lmformatenforcer.integrations.vllm", globals(),
27+
"lmformatenforcer.integrations.vllm")
28+
29+
30+
@lru_cache
31+
def _cached_build_vllm_token_enforcer_tokenizer_data(
32+
tokenizer: PreTrainedTokenizerBase,
33+
vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData:
34+
return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data(
35+
tokenizer, use_bitmask=True, vocab_size=vocab_size)
36+
37+
38+
@dataclass
39+
class LMFormatEnforcerGrammar(StructuredOutputGrammar):
40+
token_enforcer: lmformatenforcer.TokenEnforcer
41+
current_tokens_prefix: list[int] = field(default_factory=list)
42+
43+
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
44+
original_len = len(self.current_tokens_prefix)
45+
for token in tokens:
46+
if not self.token_enforcer.get_allowed_tokens(
47+
self.current_tokens_prefix).is_token_allowed(token):
48+
# Rollback partial updates to ensure atomicity.
49+
del self.current_tokens_prefix[original_len:]
50+
return False
51+
self.current_tokens_prefix.append(token)
52+
return True
53+
54+
def validate_tokens(self, tokens: list[int]) -> list[int]:
55+
for prefix_length in range(len(tokens)):
56+
prefix = tokens[:prefix_length]
57+
next_token = tokens[prefix_length]
58+
if not self.token_enforcer.get_allowed_tokens(
59+
self.current_tokens_prefix +
60+
prefix).is_token_allowed(next_token):
61+
break
62+
else:
63+
return tokens
64+
65+
return tokens[:prefix_length]
66+
67+
def rollback(self, num_tokens: int) -> None:
68+
self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens]
69+
70+
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
71+
allowed_tokens = self.token_enforcer.get_allowed_tokens(
72+
self.current_tokens_prefix)
73+
bitmask[batch_index] = allowed_tokens.allowed_tokens
74+
75+
def is_terminated(self) -> bool:
76+
# We are considered terminated if the prefix ends with eos_token_id
77+
return_value = len(
78+
self.current_tokens_prefix) > 0 and self.current_tokens_prefix[
79+
-1] == self.token_enforcer.eos_token_id
80+
return return_value
81+
82+
def reset(self):
83+
self.current_tokens_prefix = []
84+
85+
86+
@dataclass
87+
class LMFormatEnforcerBackend(StructuredOutputBackend):
88+
89+
def __post_init__(self):
90+
self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
91+
self.tokenizer, self.vocab_size)
92+
93+
def compile_grammar(self, request_type: StructuredOutputOptions,
94+
grammar_spec: str) -> StructuredOutputGrammar:
95+
character_level_parser: lmformatenforcer.CharacterLevelParser
96+
if request_type == StructuredOutputOptions.JSON:
97+
spec_dict = json.loads(grammar_spec)
98+
character_level_parser = lmformatenforcer.JsonSchemaParser(
99+
spec_dict)
100+
elif request_type == StructuredOutputOptions.JSON_OBJECT:
101+
character_level_parser = lmformatenforcer.JsonSchemaParser(None)
102+
elif request_type == StructuredOutputOptions.REGEX:
103+
character_level_parser = lmformatenforcer.RegexParser(grammar_spec)
104+
elif request_type == StructuredOutputOptions.CHOICE:
105+
choices = ast.literal_eval(grammar_spec)
106+
character_level_parser = lmformatenforcer.UnionParser(
107+
[lmformatenforcer.StringParser(choice) for choice in choices])
108+
else:
109+
raise ValueError(
110+
"Invalid request type for LM Format Enforcer backend"
111+
f"({request_type!s})")
112+
max_rollback_tokens = (
113+
self.vllm_config.speculative_config.num_speculative_tokens
114+
if self.vllm_config.speculative_config is not None else 0)
115+
116+
if max_rollback_tokens > 0:
117+
raise ValueError(
118+
"LM Format Enforcer backend does not support speculative tokens"
119+
)
120+
121+
token_enforcer = lmformatenforcer.TokenEnforcer(
122+
tokenizer_data=self.tokenizer_data,
123+
parser=character_level_parser,
124+
)
125+
return LMFormatEnforcerGrammar(token_enforcer)
126+
127+
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
128+
return torch.full(
129+
(max_num_seqs, (self.vocab_size + 31) // 32),
130+
-1,
131+
dtype=torch.int32,
132+
pin_memory=torch.cuda.is_available(),
133+
)
134+
135+
def destroy(self):
136+
pass
137+
138+
139+
def validate_structured_output_request_lm_format_enforcer(
140+
params: SamplingParams):
141+
if params.guided_decoding is None:
142+
return
143+
144+
gd_params = params.guided_decoding
145+
146+
if gd_params.regex:
147+
return
148+
elif gd_params.json:
149+
if isinstance(gd_params.json, str):
150+
try:
151+
# make sure schema is valid json
152+
json.loads(gd_params.json)
153+
except json.JSONDecodeError as e:
154+
raise ValueError("Invalid JSON grammar specification.") from e
155+
else:
156+
try:
157+
json.dumps(gd_params.json)
158+
except Exception as e:
159+
raise ValueError(
160+
f"Error serializing guided decoding jsonschema: {e}"
161+
) from e
162+
return
163+
elif gd_params.choice:
164+
return
165+
elif gd_params.grammar:
166+
raise ValueError("LM Format Enforcer guided decoding backend "
167+
"does not support grammar specifications")

0 commit comments

Comments
 (0)