Skip to content

Commit 77a318b

Browse files
authored
[V1][Core] Support MistralTokenizer for Structured Output (#14625)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
1 parent 80e78d0 commit 77a318b

File tree

2 files changed

+102
-26
lines changed

2 files changed

+102
-26
lines changed

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
from __future__ import annotations
4+
35
import json
46
import re
7+
from typing import Any
58

69
import jsonschema
710
import pytest
@@ -10,17 +13,27 @@
1013
from vllm.outputs import RequestOutput
1114
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1215

13-
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
1416
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
1517

1618

19+
@pytest.fixture
20+
def model_name():
21+
return [
22+
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
23+
]
24+
25+
1726
@pytest.mark.skip_global_cleanup
1827
@pytest.mark.parametrize("guided_decoding_backend",
1928
GUIDED_DECODING_BACKENDS_V1)
20-
def test_guided_json_completion(monkeypatch, sample_json_schema,
21-
guided_decoding_backend: str):
29+
def test_guided_json_completion(
30+
monkeypatch: pytest.MonkeyPatch,
31+
sample_json_schema: dict[str, Any],
32+
guided_decoding_backend: str,
33+
model_name: str,
34+
):
2235
monkeypatch.setenv("VLLM_USE_V1", "1")
23-
llm = LLM(model=MODEL_NAME, max_model_len=1024)
36+
llm = LLM(model=model_name, max_model_len=1024)
2437
sampling_params = SamplingParams(temperature=1.0,
2538
max_tokens=1000,
2639
guided_decoding=GuidedDecodingParams(
@@ -50,9 +63,13 @@ def test_guided_json_completion(monkeypatch, sample_json_schema,
5063
@pytest.mark.skip_global_cleanup
5164
@pytest.mark.parametrize("guided_decoding_backend",
5265
GUIDED_DECODING_BACKENDS_V1)
53-
def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
66+
def test_guided_json_object(
67+
monkeypatch: pytest.MonkeyPatch,
68+
guided_decoding_backend: str,
69+
model_name: str,
70+
):
5471
monkeypatch.setenv("VLLM_USE_V1", "1")
55-
llm = LLM(model=MODEL_NAME, max_model_len=1024)
72+
llm = LLM(model=model_name, max_model_len=1024)
5673
sampling_params = SamplingParams(temperature=1.0,
5774
max_tokens=100,
5875
n=2,
@@ -84,10 +101,14 @@ def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
84101
@pytest.mark.skip_global_cleanup
85102
@pytest.mark.parametrize("guided_decoding_backend",
86103
GUIDED_DECODING_BACKENDS_V1)
87-
def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
88-
guided_decoding_backend: str):
104+
def test_guided_json_unsupported_schema(
105+
monkeypatch: pytest.MonkeyPatch,
106+
unsupported_json_schema: dict[str, Any],
107+
guided_decoding_backend: str,
108+
model_name: str,
109+
):
89110
monkeypatch.setenv("VLLM_USE_V1", "1")
90-
llm = LLM(model=MODEL_NAME, max_model_len=1024)
111+
llm = LLM(model=model_name, max_model_len=1024)
91112
sampling_params = SamplingParams(temperature=1.0,
92113
max_tokens=1000,
93114
guided_decoding=GuidedDecodingParams(
@@ -107,10 +128,14 @@ def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
107128
@pytest.mark.skip_global_cleanup
108129
@pytest.mark.parametrize("guided_decoding_backend",
109130
GUIDED_DECODING_BACKENDS_V1)
110-
def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
111-
guided_decoding_backend: str):
131+
def test_guided_grammar_ebnf(
132+
monkeypatch: pytest.MonkeyPatch,
133+
sample_sql_ebnf: str,
134+
guided_decoding_backend: str,
135+
model_name: str,
136+
):
112137
monkeypatch.setenv("VLLM_USE_V1", "1")
113-
llm = LLM(model=MODEL_NAME, max_model_len=1024)
138+
llm = LLM(model=model_name, max_model_len=1024)
114139
sampling_params = SamplingParams(temperature=0.8,
115140
top_p=0.95,
116141
max_tokens=1000,
@@ -145,10 +170,14 @@ def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
145170
@pytest.mark.skip_global_cleanup
146171
@pytest.mark.parametrize("guided_decoding_backend",
147172
GUIDED_DECODING_BACKENDS_V1)
148-
def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
149-
guided_decoding_backend: str):
173+
def test_guided_grammar_lark(
174+
monkeypatch: pytest.MonkeyPatch,
175+
sample_sql_lark: str,
176+
guided_decoding_backend: str,
177+
model_name: str,
178+
):
150179
monkeypatch.setenv("VLLM_USE_V1", "1")
151-
llm = LLM(model=MODEL_NAME, max_model_len=1024)
180+
llm = LLM(model=model_name, max_model_len=1024)
152181
sampling_params = SamplingParams(temperature=0.8,
153182
top_p=0.95,
154183
max_tokens=1000,
@@ -188,10 +217,13 @@ def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
188217
@pytest.mark.skip_global_cleanup
189218
@pytest.mark.parametrize("guided_decoding_backend",
190219
GUIDED_DECODING_BACKENDS_V1)
191-
def test_guided_grammar_ebnf_invalid(monkeypatch,
192-
guided_decoding_backend: str):
220+
def test_guided_grammar_ebnf_invalid(
221+
monkeypatch: pytest.MonkeyPatch,
222+
guided_decoding_backend: str,
223+
model_name: str,
224+
):
193225
monkeypatch.setenv("VLLM_USE_V1", "1")
194-
llm = LLM(model=MODEL_NAME, max_model_len=1024)
226+
llm = LLM(model=model_name, max_model_len=1024)
195227
sampling_params = SamplingParams(temperature=0.8,
196228
top_p=0.95,
197229
max_tokens=1000,
@@ -212,9 +244,14 @@ def test_guided_grammar_ebnf_invalid(monkeypatch,
212244
@pytest.mark.skip_global_cleanup
213245
@pytest.mark.parametrize("guided_decoding_backend",
214246
GUIDED_DECODING_BACKENDS_V1)
215-
def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
247+
def test_guided_regex(
248+
monkeypatch: pytest.MonkeyPatch,
249+
sample_regex: str,
250+
guided_decoding_backend: str,
251+
model_name: str,
252+
):
216253
monkeypatch.setenv("VLLM_USE_V1", "1")
217-
llm = LLM(model=MODEL_NAME, max_model_len=1024)
254+
llm = LLM(model=model_name, max_model_len=1024)
218255
sampling_params = SamplingParams(temperature=0.8,
219256
top_p=0.95,
220257
guided_decoding=GuidedDecodingParams(
@@ -243,10 +280,14 @@ def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
243280
@pytest.mark.skip_global_cleanup
244281
@pytest.mark.parametrize("guided_decoding_backend",
245282
GUIDED_DECODING_BACKENDS_V1)
246-
def test_guided_choice_completion(monkeypatch, sample_guided_choice,
247-
guided_decoding_backend: str):
283+
def test_guided_choice_completion(
284+
monkeypatch: pytest.MonkeyPatch,
285+
sample_guided_choice: str,
286+
guided_decoding_backend: str,
287+
model_name: str,
288+
):
248289
monkeypatch.setenv("VLLM_USE_V1", "1")
249-
llm = LLM(model=MODEL_NAME, max_model_len=1024)
290+
llm = LLM(model=model_name, max_model_len=1024)
250291
sampling_params = SamplingParams(temperature=0.8,
251292
top_p=0.95,
252293
guided_decoding=GuidedDecodingParams(

vllm/v1/structured_output/__init__.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm.config import VllmConfig
99
from vllm.logger import init_logger
1010
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
11+
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
1112
from vllm.utils import LazyLoader
1213
from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
1314

@@ -40,8 +41,40 @@ def _delayed_init(self):
4041
tokenizer_group.ping()
4142

4243
tokenizer = tokenizer_group.get_lora_tokenizer(None)
43-
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
44-
tokenizer, vocab_size=self.vocab_size)
44+
if isinstance(tokenizer, MistralTokenizer):
45+
# NOTE: ideally, xgrammar should handle this accordingly.
46+
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
47+
try:
48+
encoded_vocab = [
49+
token for token, _ in sorted(
50+
tokenizer.get_vocab().items(),
51+
key=lambda x: x[1],
52+
)
53+
]
54+
stop_token_ids = None
55+
if hasattr(
56+
tokenizer,
57+
"eos_token_id",
58+
) and tokenizer.eos_token_id is not None:
59+
stop_token_ids = [tokenizer.eos_token_id]
60+
except AttributeError as e:
61+
raise ValueError(
62+
f"Cannot get the vocabulary of the tokenizer "
63+
f"{type(tokenizer)}. The tokenizer should have a "
64+
"get_vocab method.") from e
65+
tokenizer_info = xgr.TokenizerInfo(
66+
encoded_vocab=encoded_vocab,
67+
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
68+
vocab_type=xgr.VocabType.BYTE_FALLBACK,
69+
vocab_size=self.vocab_size,
70+
stop_token_ids=stop_token_ids,
71+
add_prefix_space=True,
72+
)
73+
else:
74+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
75+
tokenizer,
76+
vocab_size=self.vocab_size,
77+
)
4578
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
4679

4780
# The default max_workers if not specified is the number of CPUs * 5,
@@ -51,7 +84,9 @@ def _delayed_init(self):
5184
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
5285
self.executor = ThreadPoolExecutor(max_workers=max_workers)
5386
self._grammar_bitmask = xgr.allocate_token_bitmask(
54-
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
87+
self.vllm_config.scheduler_config.max_num_seqs,
88+
self.vocab_size,
89+
)
5590

5691
self.init_complete = True
5792

0 commit comments

Comments
 (0)