-
-
Notifications
You must be signed in to change notification settings - Fork 11k
Frontend: Adding LM Format Enforcer support to V1 engine #22564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
0eb426c
Frontend: Adding LM Format Enforcer support to V1 engine
noamgat b16edef
Adding missing file from previous commit
noamgat 75bcc07
Linting
noamgat b60421a
Linting
noamgat 619802b
Merge branch 'main' into lmfe-v1-dco
noamgat 9515293
Merge branch 'main' into lmfe-v1-dco
russellb 2bc72cf
Merge branch 'main' into lmfe-v1-dco
DarkLight1337 3884a24
Merge branch 'main' into lmfe-v1-dco
russellb e4b7ca6
Merge branch 'main' into lmfe-v1-dco
russellb d941eba
Merge branch 'main' into lmfe-v1-dco
russellb ee6b7f1
Merge branch 'vllm-project:main' into lmfe-v1-dco
noamgat 6d4fcbf
bumping requirement
noamgat File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
167 changes: 167 additions & 0 deletions
167
vllm/v1/structured_output/backend_lm_format_enforcer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from __future__ import annotations | ||
|
|
||
| import ast | ||
| import json | ||
| from dataclasses import dataclass, field | ||
| from functools import lru_cache | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import torch | ||
| from transformers import PreTrainedTokenizerBase | ||
|
|
||
| from vllm.sampling_params import SamplingParams | ||
| from vllm.utils import LazyLoader | ||
| from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, | ||
| StructuredOutputGrammar, | ||
| StructuredOutputOptions) | ||
|
|
||
| if TYPE_CHECKING: | ||
| import lmformatenforcer | ||
| import lmformatenforcer.integrations.vllm as lmfe_vllm | ||
| else: | ||
| lmformatenforcer = LazyLoader("lmformatenforcer", globals(), | ||
| "lmformatenforcer") | ||
| lmfe_vllm = LazyLoader("lmformatenforcer.integrations.vllm", globals(), | ||
| "lmformatenforcer.integrations.vllm") | ||
|
|
||
|
|
||
| @lru_cache | ||
| def _cached_build_vllm_token_enforcer_tokenizer_data( | ||
| tokenizer: PreTrainedTokenizerBase, | ||
| vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData: | ||
| return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data( | ||
| tokenizer, use_bitmask=True, vocab_size=vocab_size) | ||
|
|
||
|
|
||
| @dataclass | ||
| class LMFormatEnforcerGrammar(StructuredOutputGrammar): | ||
| token_enforcer: lmformatenforcer.TokenEnforcer | ||
| current_tokens_prefix: list[int] = field(default_factory=list) | ||
|
|
||
| def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: | ||
| original_len = len(self.current_tokens_prefix) | ||
| for token in tokens: | ||
| if not self.token_enforcer.get_allowed_tokens( | ||
| self.current_tokens_prefix).is_token_allowed(token): | ||
| # Rollback partial updates to ensure atomicity. | ||
| del self.current_tokens_prefix[original_len:] | ||
| return False | ||
| self.current_tokens_prefix.append(token) | ||
| return True | ||
|
|
||
| def validate_tokens(self, tokens: list[int]) -> list[int]: | ||
| for prefix_length in range(len(tokens)): | ||
| prefix = tokens[:prefix_length] | ||
| next_token = tokens[prefix_length] | ||
| if not self.token_enforcer.get_allowed_tokens( | ||
| self.current_tokens_prefix + | ||
| prefix).is_token_allowed(next_token): | ||
| break | ||
| else: | ||
| return tokens | ||
|
|
||
| return tokens[:prefix_length] | ||
|
|
||
| def rollback(self, num_tokens: int) -> None: | ||
| self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens] | ||
|
|
||
| def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: | ||
| allowed_tokens = self.token_enforcer.get_allowed_tokens( | ||
| self.current_tokens_prefix) | ||
| bitmask[batch_index] = allowed_tokens.allowed_tokens | ||
|
|
||
| def is_terminated(self) -> bool: | ||
| # We are considered terminated if the prefix ends with eos_token_id | ||
| return_value = len( | ||
| self.current_tokens_prefix) > 0 and self.current_tokens_prefix[ | ||
| -1] == self.token_enforcer.eos_token_id | ||
| return return_value | ||
|
|
||
| def reset(self): | ||
| self.current_tokens_prefix = [] | ||
|
|
||
|
|
||
| @dataclass | ||
| class LMFormatEnforcerBackend(StructuredOutputBackend): | ||
|
|
||
| def __post_init__(self): | ||
| self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( | ||
| self.tokenizer, self.vocab_size) | ||
|
|
||
| def compile_grammar(self, request_type: StructuredOutputOptions, | ||
| grammar_spec: str) -> StructuredOutputGrammar: | ||
| character_level_parser: lmformatenforcer.CharacterLevelParser | ||
| if request_type == StructuredOutputOptions.JSON: | ||
| spec_dict = json.loads(grammar_spec) | ||
| character_level_parser = lmformatenforcer.JsonSchemaParser( | ||
| spec_dict) | ||
| elif request_type == StructuredOutputOptions.JSON_OBJECT: | ||
| character_level_parser = lmformatenforcer.JsonSchemaParser(None) | ||
| elif request_type == StructuredOutputOptions.REGEX: | ||
| character_level_parser = lmformatenforcer.RegexParser(grammar_spec) | ||
| elif request_type == StructuredOutputOptions.CHOICE: | ||
| choices = ast.literal_eval(grammar_spec) | ||
| character_level_parser = lmformatenforcer.UnionParser( | ||
| [lmformatenforcer.StringParser(choice) for choice in choices]) | ||
| else: | ||
| raise ValueError( | ||
| "Invalid request type for LM Format Enforcer backend" | ||
| f"({request_type!s})") | ||
| max_rollback_tokens = ( | ||
| self.vllm_config.speculative_config.num_speculative_tokens | ||
| if self.vllm_config.speculative_config is not None else 0) | ||
|
|
||
| if max_rollback_tokens > 0: | ||
| raise ValueError( | ||
| "LM Format Enforcer backend does not support speculative tokens" | ||
| ) | ||
|
|
||
| token_enforcer = lmformatenforcer.TokenEnforcer( | ||
| tokenizer_data=self.tokenizer_data, | ||
| parser=character_level_parser, | ||
| ) | ||
| return LMFormatEnforcerGrammar(token_enforcer) | ||
|
|
||
| def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: | ||
| return torch.full( | ||
| (max_num_seqs, (self.vocab_size + 31) // 32), | ||
| -1, | ||
| dtype=torch.int32, | ||
| pin_memory=torch.cuda.is_available(), | ||
| ) | ||
|
|
||
| def destroy(self): | ||
| pass | ||
|
|
||
|
|
||
| def validate_structured_output_request_lm_format_enforcer( | ||
| params: SamplingParams): | ||
| if params.guided_decoding is None: | ||
| return | ||
|
|
||
| gd_params = params.guided_decoding | ||
|
|
||
| if gd_params.regex: | ||
| return | ||
| elif gd_params.json: | ||
| if isinstance(gd_params.json, str): | ||
| try: | ||
| # make sure schema is valid json | ||
| json.loads(gd_params.json) | ||
| except json.JSONDecodeError as e: | ||
| raise ValueError("Invalid JSON grammar specification.") from e | ||
| else: | ||
| try: | ||
| json.dumps(gd_params.json) | ||
| except Exception as e: | ||
| raise ValueError( | ||
| f"Error serializing guided decoding jsonschema: {e}" | ||
| ) from e | ||
| return | ||
| elif gd_params.choice: | ||
| return | ||
| elif gd_params.grammar: | ||
| raise ValueError("LM Format Enforcer guided decoding backend " | ||
| "does not support grammar specifications") | ||
russellb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.