99from typing import TYPE_CHECKING , Any , List
1010
1111import torch
12- from transformers import PreTrainedTokenizerFast
1312
1413from vllm .logger import init_logger
1514
1615try :
1716 import xgrammar as xgr
18- from xgrammar .base import _core as xgr_core
1917 xgr_installed = True
2018except ImportError :
2119 xgr_installed = False
3533logger = init_logger (__name__ )
3634
3735
38- # TODO: passing batch size to max threads here
3936def get_local_xgrammar_guided_decoding_logits_processor (
4037 guided_params : GuidedDecodingParams ,
4138 tokenizer : PreTrainedTokenizer ,
@@ -52,65 +49,61 @@ def get_local_xgrammar_guided_decoding_logits_processor(
5249@dataclass (frozen = True )
5350class TokenizerData :
5451 """Immutable container for cached tokenizer data."""
52+ metadata : str
5553 encoded_vocab : list [str ] = field (default_factory = list )
56- stop_token_ids : list [int ] | None = None
57- # These fields are mutually exclusive: `backend_str` is used to create a
58- # TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is
59- # used within the constructor of TokenizeInfo
60- backend_str : str | None = None
61- vocab_type : xgr .VocabType | None = None
62-
63- def __post_init__ (self ):
64- # Check for mutual exclusive
65- assert not (self .backend_str and self .vocab_type ), \
66- "backend_str and vocab_type are mutual exclusive"
6754
6855
6956class TokenizerDataCache :
7057 """Cache manager for tokenizer data to avoid repeated processing."""
7158 _cache : dict [int , TokenizerData ] = {}
7259
7360 @classmethod
74- def get_tokenizer_data (cls ,
75- tokenizer : PreTrainedTokenizer ) -> TokenizerData :
76- tokenizer_hash = hash (tokenizer )
61+ def get_tokenizer_data (
62+ cls ,
63+ tokenizer : PreTrainedTokenizer ,
64+ / ,
65+ * ,
66+ tokenizer_hash : int ,
67+ vocab_size : int ,
68+ ) -> TokenizerData :
7769
7870 if tokenizer_hash not in cls ._cache :
79- # Vendored from xgrammar logic since we cannot pickle the tokenizer
80- # https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501
71+ tokenizer_info = xgr .TokenizerInfo .from_huggingface (
72+ tokenizer ,
73+ # NOTE: We will need to use lm_head's vocab_size
74+ # to determine correct special_token_ids for this tokenizer.
75+ # See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92 # noqa: E501
76+ vocab_size = vocab_size ,
77+ )
78+ metadata = json .loads (tokenizer_info .dump_metadata ())
79+
80+ # Vendored from xgrammar logic to get encoded_vocab
81+ # https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501
8182 try :
82- encoded_vocab = [
83- token for token , _ in sorted (tokenizer .get_vocab ().items (),
84- key = lambda x : x [1 ])
85- ]
83+ vocab_dict = tokenizer .get_vocab ()
8684 except AttributeError as e :
8785 raise ValueError (
8886 f"Cannot get the vocabulary of the tokenizer "
8987 f"{ type (tokenizer )} . The tokenizer should have a "
9088 "get_vocab method." ) from e
9189
92- stop_token_ids = None
93- backend_str = ""
94- vocab_type = xgr .VocabType .RAW
95-
96- if stop_token_ids is None and hasattr (
97- tokenizer ,
98- "eos_token_id" ) and tokenizer .eos_token_id is not None :
99- stop_token_ids = [tokenizer .eos_token_id ]
100-
101- if isinstance (tokenizer , PreTrainedTokenizerFast ):
102- backend_str = tokenizer .backend_tokenizer .to_str ()
103- vocab_type = None
90+ # maintain tokenizer's indexing
91+ encoded_vocab = ["" ] * tokenizer_info .vocab_size
92+ for token , idx in vocab_dict .items ():
93+ if idx < tokenizer_info .vocab_size :
94+ encoded_vocab [idx ] = token
10495
105- elif isinstance (tokenizer , MistralTokenizer ):
96+ if isinstance (tokenizer , MistralTokenizer ):
10697 # REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
107- vocab_type = xgr .VocabType .BYTE_FALLBACK
98+ metadata .update ({
99+ "vocab_type" : xgr .VocabType .BYTE_FALLBACK ,
100+ "add_prefix_space" : True
101+ })
108102
109103 cls ._cache [tokenizer_hash ] = TokenizerData (
110104 encoded_vocab = encoded_vocab ,
111- stop_token_ids = stop_token_ids ,
112- backend_str = backend_str ,
113- vocab_type = vocab_type )
105+ metadata = json .dumps (metadata ),
106+ )
114107
115108 return cls ._cache [tokenizer_hash ]
116109
@@ -129,30 +122,15 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
129122 cache_key = str (config .tokenizer_hash )
130123
131124 if cache_key not in cls ._cache :
132- assert config .tokenizer_data is not None
133- assert config .tokenizer_data .encoded_vocab is not None
134-
135125 config_data = config .tokenizer_data
136126
137127 # In TokenizerDataCache.get_tokenizer_data, a serializable
138128 # tokenizer_data is created and cached. This data is used to build
139129 # a tokenizer_info and create an xgrammar compiler.
140- # - If tokenizer_data has backend_str set, use
141- # xgr_core.TokenizerInfo.from_huggingface (a C++ bind).
142- # - Otherwise, use the default constructor with vocab_type.
143- # - xgr_core.TokenizerInfo.from_huggingface !=
144- # xgr.TokenizerInfo.from_huggingface.
145- if config_data .backend_str :
146- tokenizer_info = xgr .TokenizerInfo ._create_from_handle (
147- xgr_core .TokenizerInfo .from_huggingface (
148- config_data .encoded_vocab , config_data .backend_str ,
149- config .vocab_size , config_data .stop_token_ids ))
150- else :
151- tokenizer_info = xgr .TokenizerInfo (
152- config_data .encoded_vocab ,
153- config_data .vocab_type ,
154- vocab_size = config .vocab_size ,
155- stop_token_ids = config_data .stop_token_ids )
130+ tokenizer_info = xgr .TokenizerInfo .from_vocab_and_metadata (
131+ encoded_vocab = config_data .encoded_vocab ,
132+ metadata = config_data .metadata ,
133+ )
156134 cls ._cache [cache_key ] = xgr .GrammarCompiler (
157135 tokenizer_info , max_threads = config .max_threads )
158136
@@ -163,13 +141,12 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
163141class GrammarConfig :
164142 """Serializable configuration for grammar compilation"""
165143 tokenizer_hash : int
166- vocab_size : int
144+ tokenizer_data : TokenizerData
167145 json_str : str | None = None
168146 grammar_str : str | None = None
169147 json_object : bool | None = None
170148 any_whitespace : bool = True
171149 max_threads : int = 8
172- tokenizer_data : TokenizerData | None = None
173150
174151 @classmethod
175152 def from_guided_params (cls ,
@@ -179,7 +156,11 @@ def from_guided_params(cls,
179156 max_threads : int = 8 ) -> GrammarConfig :
180157
181158 tokenizer_hash = hash (tokenizer )
182- tokenizer_data = TokenizerDataCache .get_tokenizer_data (tokenizer )
159+ tokenizer_data = TokenizerDataCache .get_tokenizer_data (
160+ tokenizer ,
161+ tokenizer_hash = tokenizer_hash ,
162+ vocab_size = model_config .hf_text_config .vocab_size ,
163+ )
183164
184165 if guided_params .json :
185166 if not isinstance (guided_params .json , str ):
@@ -218,7 +199,6 @@ def from_guided_params(cls,
218199 raise ValueError (str (err )) from err
219200
220201 return cls (json_str = json_str ,
221- vocab_size = model_config .hf_text_config .vocab_size ,
222202 tokenizer_hash = tokenizer_hash ,
223203 max_threads = max_threads ,
224204 tokenizer_data = tokenizer_data ,
@@ -246,14 +226,12 @@ def from_guided_params(cls,
246226 raise ValueError (str (err )) from err
247227
248228 return cls (grammar_str = grammar_str ,
249- vocab_size = model_config .hf_text_config .vocab_size ,
250229 tokenizer_hash = tokenizer_hash ,
251230 max_threads = max_threads ,
252231 tokenizer_data = tokenizer_data )
253232 elif guided_params .json_object :
254233 return cls (
255234 json_object = True ,
256- vocab_size = model_config .hf_text_config .vocab_size ,
257235 tokenizer_hash = tokenizer_hash ,
258236 max_threads = max_threads ,
259237 tokenizer_data = tokenizer_data ,
@@ -267,7 +245,6 @@ def from_guided_params(cls,
267245
268246 return cls (
269247 grammar_str = choice_str ,
270- vocab_size = model_config .hf_text_config .vocab_size ,
271248 tokenizer_hash = tokenizer_hash ,
272249 max_threads = max_threads ,
273250 tokenizer_data = tokenizer_data ,
@@ -291,6 +268,13 @@ def choice_as_grammar(choice: List[str] | None) -> str:
291268 grammar = ('root ::= ' + ' | ' .join (f'"{ c } "' for c in escaped_choices ))
292269 return grammar
293270
271+ @staticmethod
272+ def tokenizer_info (tokenizer_data : TokenizerData ) -> xgr .TokenizerInfo :
273+ return xgr .TokenizerInfo .from_vocab_and_metadata (
274+ encoded_vocab = tokenizer_data .encoded_vocab ,
275+ metadata = tokenizer_data .metadata ,
276+ )
277+
294278
295279@dataclass
296280class XGrammarLogitsProcessor :
@@ -299,18 +283,25 @@ class XGrammarLogitsProcessor:
299283 reasoner : Reasoner | None = None
300284
301285 ctx : xgr .CompiledGrammar | None = None
286+ tokenizer_info : xgr .TokenizerInfo = None # type: ignore[assignment]
302287 token_bitmask : torch .Tensor = None # type: ignore[assignment]
303288 matchers : list [xgr .GrammarMatcher ] = field (default_factory = list )
304289 batch_size : int = field (default = 1 )
305290 prefilled : bool = field (default = False )
306291
292+ def __post_init__ (self ):
293+ self .tokenizer_info = self .config .tokenizer_info (
294+ self .config .tokenizer_data )
295+
307296 def __getstate__ (self ) -> dict [str , Any ]:
308297 return {'config' : self .config , 'reasoner' : self .reasoner }
309298
310299 def __setstate__ (self , state : dict [str , Any ]):
311300 self .config = state ['config' ]
312301 self .reasoner = state ['reasoner' ]
313302
303+ self .tokenizer_info = GrammarConfig .tokenizer_info (
304+ self .config .tokenizer_data )
314305 self .ctx = None
315306 self .matchers = []
316307 self .batch_size = 1
@@ -352,7 +343,7 @@ def __call__(self, input_ids: list[int],
352343 xgr .GrammarMatcher (self .ctx ) for _ in range (self .batch_size )
353344 ]
354345 self .token_bitmask = xgr .allocate_token_bitmask (
355- self .batch_size , self .config .vocab_size )
346+ self .batch_size , self .tokenizer_info .vocab_size )
356347
357348 if not self .prefilled :
358349 # Have not sampled a token yet
0 commit comments