|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
| 3 | +from __future__ import annotations |
| 4 | + |
3 | 5 | from dataclasses import dataclass, field |
4 | 6 | from typing import TYPE_CHECKING |
5 | 7 |
|
6 | 8 | import torch |
7 | 9 |
|
8 | 10 | import vllm.envs |
9 | | -from vllm.config import VllmConfig |
10 | 11 | from vllm.logger import init_logger |
11 | | -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs |
12 | 12 | from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer |
13 | 13 | from vllm.utils import LazyLoader |
14 | 14 | from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, |
|
23 | 23 | logger = init_logger(__name__) |
24 | 24 |
|
25 | 25 |
|
| 26 | +@dataclass |
26 | 27 | class XgrammarBackend(StructuredOutputBackend): |
27 | 28 |
|
28 | | - def __init__(self, vllm_config: VllmConfig): |
29 | | - self.vllm_config = vllm_config |
| 29 | + def __post_init__(self): |
30 | 30 | self.disable_any_whitespace = ( |
31 | 31 | "disable-any-whitespace" |
32 | | - in vllm_config.decoding_config.guided_decoding_backend) |
33 | | - tokenizer_group = init_tokenizer_from_configs( |
34 | | - model_config=vllm_config.model_config, |
35 | | - scheduler_config=vllm_config.scheduler_config, |
36 | | - parallel_config=vllm_config.parallel_config, |
37 | | - lora_config=vllm_config.lora_config) # type: ignore[arg-type] |
38 | | - tokenizer_group.ping() |
39 | | - |
40 | | - tokenizer = tokenizer_group.get_lora_tokenizer(None) |
41 | | - self.vocab_size = vllm_config.model_config.get_vocab_size() |
42 | | - if isinstance(tokenizer, MistralTokenizer): |
| 32 | + in self.vllm_config.decoding_config.guided_decoding_backend) |
| 33 | + if isinstance(self.tokenizer, MistralTokenizer): |
43 | 34 | # NOTE: ideally, xgrammar should handle this accordingly. |
44 | 35 | # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 |
45 | 36 | try: |
46 | | - if tokenizer.is_tekken: |
47 | | - encoded_vocab = tokenizer._vocab |
| 37 | + if self.tokenizer.is_tekken: |
| 38 | + encoded_vocab = self.tokenizer._vocab |
48 | 39 | else: |
49 | 40 | encoded_vocab = [ |
50 | 41 | token for token, _ in sorted( |
51 | | - tokenizer.get_vocab().items(), |
| 42 | + self.tokenizer.get_vocab().items(), |
52 | 43 | key=lambda x: x[1], |
53 | 44 | ) |
54 | 45 | ] |
55 | 46 | stop_token_ids = None |
56 | 47 | if hasattr( |
57 | | - tokenizer, |
| 48 | + self.tokenizer, |
58 | 49 | "eos_token_id", |
59 | | - ) and tokenizer.eos_token_id is not None: |
60 | | - stop_token_ids = [tokenizer.eos_token_id] |
| 50 | + ) and self.tokenizer.eos_token_id is not None: |
| 51 | + stop_token_ids = [self.tokenizer.eos_token_id] |
61 | 52 | except AttributeError as e: |
62 | 53 | raise ValueError( |
63 | 54 | f"Cannot get the vocabulary of the tokenizer " |
64 | | - f"{type(tokenizer)}. The tokenizer should have a " |
| 55 | + f"{type(self.tokenizer)}. The tokenizer should have a " |
65 | 56 | "get_vocab method.") from e |
66 | 57 | tokenizer_info = xgr.TokenizerInfo( # type: ignore |
67 | 58 | encoded_vocab=encoded_vocab, |
68 | 59 | # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 |
69 | 60 | vocab_type=xgr.VocabType.RAW |
70 | | - if tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, |
| 61 | + if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, |
71 | 62 | vocab_size=self.vocab_size, |
72 | 63 | stop_token_ids=stop_token_ids, |
73 | 64 | add_prefix_space=True, |
74 | 65 | ) |
75 | 66 | else: |
76 | 67 | tokenizer_info = xgr.TokenizerInfo.from_huggingface( |
77 | | - tokenizer, |
| 68 | + self.tokenizer, |
78 | 69 | vocab_size=self.vocab_size, |
79 | 70 | ) |
80 | 71 | self.compiler = xgr.GrammarCompiler( |
|
0 commit comments