From 26730211823ce36cd9a3f01479839c35dd648d84 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 6 Sep 2024 15:11:46 -0600 Subject: [PATCH 01/28] :recycle: refactor guided decoding construction Signed-off-by: Joe Runde --- .../guided_decoding/__init__.py | 65 ++++------------ .../guided_decoding/guided_fields.py | 1 + .../lm_format_enforcer_decoding.py | 74 +++---------------- .../guided_decoding/outlines_decoding.py | 71 ++++++------------ vllm/sampling_params.py | 35 +++++++++ 5 files changed, 84 insertions(+), 162 deletions(-) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 7161e83952a3d..260363d7e24d0 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -1,77 +1,44 @@ -from typing import Optional, Union +from typing import Optional -from vllm.entrypoints.openai.protocol import ( - ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, - CompletionRequest) -from vllm.model_executor.guided_decoding.guided_fields import ( - GuidedDecodingRequest) -from vllm.sampling_params import LogitsProcessor +from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor async def get_guided_decoding_logits_processor( - guided_decoding_backend: str, request: Union[CompletionRequest, - ChatCompletionRequest], + guided_params: GuidedDecodingParams, tokenizer) -> Optional[LogitsProcessor]: - request = _adapt_request_for_tool_use(request) - - if guided_decoding_backend == 'outlines': + if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( - request, tokenizer) - if guided_decoding_backend == 'lm-format-enforcer': + guided_params, tokenizer) + if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa - get_lm_format_enforcer_guided_decoding_logits_processor) - return await get_lm_format_enforcer_guided_decoding_logits_processor( - request, tokenizer) + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_params, tokenizer) raise ValueError( - f"Unknown guided decoding backend '{guided_decoding_backend}'. " + f"Unknown guided decoding backend '{guided_params.backend}'. " "Must be one of 'outlines, 'lm-format-enforcer'") def get_local_guided_decoding_logits_processor( - guided_decoding_backend: str, guided_options: GuidedDecodingRequest, + guided_params: GuidedDecodingParams, tokenizer) -> Optional[LogitsProcessor]: - # request = _adapt_request_for_tool_use(request) - - if guided_decoding_backend == 'outlines': + if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( - guided_options, tokenizer) - if guided_decoding_backend == 'lm-format-enforcer': + guided_params, tokenizer) + if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( - guided_options, tokenizer) + guided_params, tokenizer) raise ValueError( - f"Unknown guided decoding backend '{guided_decoding_backend}'. " + f"Unknown guided decoding backend '{guided_params.backend}'. " "Must be one of 'outlines, 'lm-format-enforcer'") - -def _adapt_request_for_tool_use(request: Union[CompletionRequest, - ChatCompletionRequest]): - # the legacy completion API does not support tool use - if type(request) is CompletionRequest: - return request - - # user has chosen to not use any tool, - # OR is allowing the model to choose a tool. - if request.tool_choice == "none" or request.tool_choice == "auto": - return request - - # user has chosen to use a named tool - if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam: - tool_name = request.tool_choice.function.name - tools = {tool.function.name: tool.function for tool in request.tools} - if tool_name not in tools: - raise ValueError( - f"Tool '{tool_name}' has not been passed in `tools`.") - tool = tools[tool_name] - request.guided_json = tool.parameters - - return request diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index 3082ac1510ccc..7ca4c21124c97 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -3,6 +3,7 @@ from pydantic import BaseModel +# Nick leans towards ripping out so we don't have duplication class LLMGuidedOptions(TypedDict, total=False): guided_json: Union[Dict, BaseModel, str] diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index 51f947981cac8..53a9b32006d36 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -10,63 +10,11 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest) -from vllm.model_executor.guided_decoding.guided_fields import ( - GuidedDecodingRequest) -from vllm.sampling_params import LogitsProcessor - - -async def get_lm_format_enforcer_guided_decoding_logits_processor( - request: Union[CompletionRequest, ChatCompletionRequest], - tokenizer) -> Optional[LogitsProcessor]: - """ - Given an OpenAI-compatible request, check for guided decoding parameters - and get the necessary logits processor for the given guide. - We cache logit processors by (guide, tokenizer), and on cache hit - we make a shallow copy to reuse the same underlying FSM. - """ - - tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( - tokenizer) - character_level_parser: CharacterLevelParser - if request.guided_json: - schema = _normalize_json_schema_object(request.guided_json) - character_level_parser = JsonSchemaParser(schema) - elif request.guided_choice: - character_level_parser = UnionParser( - [StringParser(choice) for choice in request.guided_choice]) - elif request.guided_regex: - character_level_parser = RegexParser(request.guided_regex) - elif request.guided_grammar: - # CFG grammar not supported by LMFE, revert to outlines - - # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 - from vllm.model_executor.guided_decoding.outlines_decoding import ( - get_outlines_guided_decoding_logits_processor) - return await get_outlines_guided_decoding_logits_processor( - request, tokenizer) - elif (request.response_format is not None - and request.response_format.type == "json_object"): - character_level_parser = JsonSchemaParser( - None) # None means any json object - elif (request.response_format is not None - and request.response_format.type == "json_schema" - and request.response_format.json_schema is not None - and request.response_format.json_schema.json_schema is not None): - schema = _normalize_json_schema_object( - request.response_format.json_schema.json_schema) - character_level_parser = JsonSchemaParser(schema) - else: - return None - - logits_processor = build_vllm_logits_processor(tokenizer_data, - character_level_parser) - return logits_processor +from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor def get_local_lm_format_enforcer_guided_decoding_logits_processor( - guided_options: GuidedDecodingRequest, + guided_params: GuidedDecodingParams, tokenizer) -> Optional[LogitsProcessor]: """ Given an OpenAI-compatible request, check for guided decoding parameters @@ -78,23 +26,23 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor( tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( tokenizer) character_level_parser: CharacterLevelParser - if guided_options.guided_json: - schema = _normalize_json_schema_object(guided_options.guided_json) + if guided_params.json: + schema = _normalize_json_schema_object(guided_params.json) character_level_parser = JsonSchemaParser(schema) - elif guided_options.guided_choice: + elif guided_params.choice: character_level_parser = UnionParser( - [StringParser(choice) for choice in guided_options.guided_choice]) - elif guided_options.guided_regex: - character_level_parser = RegexParser(guided_options.guided_regex) - elif guided_options.guided_grammar: + [StringParser(choice) for choice in guided_params.choice]) + elif guided_params.regex: + character_level_parser = RegexParser(guided_params.regex) + elif guided_params.grammar: # CFG grammar not supported by LMFE, revert to outlines # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( - guided_options, tokenizer) - elif guided_options.guided_json_object: + guided_params, tokenizer) + elif guided_params.json_object: # None means any json object character_level_parser = JsonSchemaParser(None) else: diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index e1f5b380120c5..2506aa5b76018 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -8,13 +8,9 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import ( - ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, - CompletionRequest) -from vllm.model_executor.guided_decoding.guided_fields import ( - GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.sampling_params import GuidedDecodingParams class GuidedDecodingMode(Enum): @@ -55,8 +51,7 @@ class GuidedDecodingMode(Enum): async def get_outlines_guided_decoding_logits_processor( - request: Union[CompletionRequest, - ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -66,7 +61,7 @@ async def get_outlines_guided_decoding_logits_processor( we make a shallow copy to reuse the same underlying FSM. """ global global_thread_pool - guide, mode = _get_guide_and_mode(request) + guide, mode = _get_guide_and_mode(guided_params) if not guide or not mode: return None @@ -77,11 +72,11 @@ async def get_outlines_guided_decoding_logits_processor( return await loop.run_in_executor(global_thread_pool, _get_logits_processor, guide, tokenizer, - mode, request.guided_whitespace_pattern) + mode, guided_params.whitespace_pattern) def get_local_outlines_guided_decoding_logits_processor( - guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -90,65 +85,41 @@ def get_local_outlines_guided_decoding_logits_processor( We cache logit processors by (guide, tokenizer), and on cache hit we make a shallow copy to reuse the same underlying FSM. """ - guide, mode = _get_guide_and_mode(guided_options) + guide, mode = _get_guide_and_mode(guided_params) if not guide or not mode: return None return _get_logits_processor(guide, tokenizer, mode, - guided_options.guided_whitespace_pattern) + guided_params.whitespace_pattern) def _get_guide_and_mode( - request: Union[CompletionRequest, ChatCompletionRequest, - GuidedDecodingRequest] + guided_params: GuidedDecodingParams ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: - # if the request is a chat completion request, AND the tool choice is a - # named tool choice, do guided decoding - # using that tool as the JSON schema - if isinstance(request, ChatCompletionRequest) and isinstance( - request.tool_choice, ChatCompletionNamedToolChoiceParam): - # Guided generation for tools/functions parameters - if request.tool_choice.type == "function": - for tool in request.tools: - if (tool.type == "function" and tool.function.name - == request.tool_choice.function.name): - json = json_dumps(tool.function.parameters, sort_keys=True) - return json, GuidedDecodingMode.JSON - return None, None - - elif request.guided_json: - if isinstance(request.guided_json, dict): + if guided_params.json: + if isinstance(guided_params.json, dict): # turn dict into hashable string - json = json_dumps(request.guided_json) - elif isinstance(request.guided_json, BaseModel): + json = json_dumps(guided_params.json) + elif isinstance(guided_params.json, BaseModel): # use pydantic signature so that different model classes # with the same fields will get hashed the same - json = str(request.guided_json.__signature__) + json = str(guided_params.json.__signature__) else: - json = request.guided_json + json = guided_params.json return json, GuidedDecodingMode.JSON - elif request.guided_regex: - return request.guided_regex, GuidedDecodingMode.REGEX - elif request.guided_choice: + elif guided_params.regex: + return guided_params.regex, GuidedDecodingMode.REGEX + elif guided_params.choice: # choice just uses regex choices = [ - regex_escape(str(choice)) for choice in request.guided_choice + regex_escape(str(choice)) for choice in guided_params.choice ] choices_regex = "(" + "|".join(choices) + ")" return choices_regex, GuidedDecodingMode.CHOICE - elif request.guided_grammar: - return request.guided_grammar, GuidedDecodingMode.GRAMMAR - elif (not isinstance(request, GuidedDecodingRequest) - and request.response_format is not None - and request.response_format.type == "json_object"): + elif guided_params.grammar: + return guided_params.grammar, GuidedDecodingMode.GRAMMAR + elif guided_params.json_object: return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR - elif (not isinstance(request, GuidedDecodingRequest) - and request.response_format is not None - and request.response_format.type == "json_schema" - and request.response_format.json_schema is not None - and request.response_format.json_schema.json_schema is not None): - json = json_dumps(request.response_format.json_schema.json_schema) - return json, GuidedDecodingMode.JSON else: return None, None diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index c83ed5cca6791..91a81ae068cee 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,5 +1,6 @@ """Sampling parameters for text generation.""" import copy +from dataclasses import dataclass from enum import IntEnum from functools import cached_property from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -33,6 +34,34 @@ class SamplingType(IntEnum): to sample from.""" +# maybe make msgspec? +@dataclass +class GuidedDecodingParams: + """One of these fields will be used to build a logit processor.""" + json: Optional[Union[Dict, BaseModel, str]] = None + regex: Optional[str] = None + choice: Optional[List[str]] = None + grammar: Optional[str] = None + json_object: Optional[bool] = None + + """These are other options that can be set""" + backend: Optional[str] = None + whitespace_pattern: Optional[str] = None + + def __post_init__(self): + """Validate that some fields are mutually exclusive.""" + guide_count = sum([ + self.json is not None, self.regex is not None, + self.choice is not None, self.grammar is not None, + self.json_object is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding but multiple are " + f"specified: {self.__dict__}") + + + class SamplingParams( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] @@ -153,6 +182,12 @@ class SamplingParams( output_text_buffer_length: int = 0 _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set) + # Fields used to construct logits processors + guided_decoding: Optional[GuidedDecodingParams] = None + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None + allowed_token_ids: Optional[List[int]] = None + + @staticmethod def from_optional( n: Optional[int] = 1, From 382a2409afcd142a1d0f225129ade63969809cf6 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 6 Sep 2024 15:35:44 -0600 Subject: [PATCH 02/28] :recycle: refactor to_sampling_params Signed-off-by: Joe Runde --- vllm/entrypoints/openai/protocol.py | 36 ++++++++--------------------- vllm/sampling_params.py | 6 +++++ 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 970262a4bd358..92b99de27af2c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -10,11 +10,9 @@ from typing_extensions import Annotated, Required, TypedDict from vllm.entrypoints.chat_utils import ChatCompletionMessageParam -from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams -from vllm.sampling_params import LogitsProcessor, SamplingParams +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sequence import Logprob -from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid # torch is mocked during docs generation, @@ -270,8 +268,8 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params def to_sampling_params( - self, tokenizer: AnyTokenizer, - guided_decode_logits_processor: Optional[LogitsProcessor], + self, + guided_decoding: Optional[GuidedDecodingParams], default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: @@ -281,15 +279,6 @@ def to_sampling_params( if prompt_logprobs is None and self.echo: prompt_logprobs = self.top_logprobs - # We now allow logprobs being true without top_logrobs. - logits_processors = get_logits_processors( - logit_bias=self.logit_bias, - allowed_token_ids=None, - tokenizer=tokenizer, - ) - if guided_decode_logits_processor: - logits_processors.append(guided_decode_logits_processor) - return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -314,8 +303,9 @@ def to_sampling_params( spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, - logits_processors=logits_processors, truncate_prompt_tokens=self.truncate_prompt_tokens, + guided_decoding=guided_decoding, + logit_bias=self.logit_bias ) @model_validator(mode="before") @@ -512,8 +502,8 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params def to_sampling_params( - self, tokenizer: AnyTokenizer, - guided_decode_logits_processor: Optional[LogitsProcessor], + self, + guided_decoding: Optional[GuidedDecodingParams], default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: @@ -525,14 +515,6 @@ def to_sampling_params( echo_without_generation = self.echo and self.max_tokens == 0 - logits_processors = get_logits_processors( - logit_bias=self.logit_bias, - allowed_token_ids=self.allowed_token_ids, - tokenizer=tokenizer, - ) - if guided_decode_logits_processor: - logits_processors.append(guided_decode_logits_processor) - return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -557,8 +539,10 @@ def to_sampling_params( spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, - logits_processors=logits_processors, truncate_prompt_tokens=self.truncate_prompt_tokens, + guided_decoding=guided_decoding, + logit_bias=self.logit_bias, + allowed_token_ids=self.allowed_token_ids ) @model_validator(mode="before") diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 91a81ae068cee..eb6557cb18884 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -217,6 +217,9 @@ def from_optional( logits_processors: Optional[List[LogitsProcessor]] = None, truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None, + guided_decoding: Optional[GuidedDecodingParams] = None, + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None, + allowed_token_ids: Optional[List[int]] = None, ) -> "SamplingParams": return SamplingParams( n=1 if n is None else n, @@ -248,6 +251,9 @@ def from_optional( spaces_between_special_tokens=spaces_between_special_tokens, logits_processors=logits_processors, truncate_prompt_tokens=truncate_prompt_tokens, + guided_decoding=guided_decoding, + logit_bias=logit_bias, + allowed_token_ids=allowed_token_ids, ) def __post_init__(self) -> None: From ff6c147419f56a362d375aa6d65ac536167bd6b2 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 6 Sep 2024 16:31:55 -0600 Subject: [PATCH 03/28] :recycle: update sampling params in openai servers Signed-off-by: Joe Runde --- vllm/entrypoints/openai/serving_chat.py | 36 ++++++++++++++++--- vllm/entrypoints/openai/serving_completion.py | 8 ++--- vllm/entrypoints/openai/serving_engine.py | 34 +++++++++++------- 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 78f355228012f..b52e45be937d6 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -156,9 +156,13 @@ async def create_chat_completion( request_id = f"chat-{random_uuid()}" try: - guided_decode_logits_processor = ( - await self._guided_decode_logits_processor(request, tokenizer)) - + guided_decoding_params = \ + self._create_guided_decoding_params(request) + # Some requests for tools will use guided decoding + if (guided_json := + self._get_guided_json_from_tool(request)) is not None: + guided_decoding_params.guided_json = guided_json + if isinstance(prompt, str): prompt_inputs = self._tokenize_prompt_input( request, @@ -177,8 +181,7 @@ async def create_chat_completion( assert prompt_inputs is not None sampling_params = request.to_sampling_params( - tokenizer, - guided_decode_logits_processor, + guided_decoding_params, default_max_tokens=self.max_model_len - len(prompt_inputs["prompt_token_ids"])) @@ -779,3 +782,26 @@ def _should_check_for_unstreamed_tool_arg_tokens( and delta_message.tool_calls[0].function.arguments is not None and output.finish_reason is not None ) + + @staticmethod + def _get_guided_json_from_tool( + request: ChatCompletionRequest + ) -> Optional[Union[str, dict, BaseModel]]: + # user has chosen to not use any tool + if request.tool_choice == "none" or request.tools is None: + return None + + # user has chosen to use a named tool + if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam: + tool_name = request.tool_choice.function.name + tools = { + tool.function.name: tool.function + for tool in request.tools + } + if tool_name not in tools: + raise ValueError( + f"Tool '{tool_name}' has not been passed in `tools`.") + tool = tools[tool_name] + return tool.parameters + + return None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 34f1200753f8d..b94f217169676 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -98,8 +98,9 @@ async def create_completion( tokenizer = await self.async_engine_client.get_tokenizer( lora_request) - guided_decode_logits_processor = ( - await self._guided_decode_logits_processor(request, tokenizer)) + guided_decoding_params = self._create_guided_decoding_params( + request + ) prompts = list( self._tokenize_prompt_input_or_inputs( request, @@ -111,8 +112,7 @@ async def create_completion( for i, prompt_inputs in enumerate(prompts): sampling_params = request.to_sampling_params( - tokenizer, - guided_decode_logits_processor, + guided_decoding_params, default_max_tokens=self.max_model_len - len(prompt_inputs["prompt_token_ids"])) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ac74527441cd9..60d7b339a8793 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -27,11 +27,9 @@ from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import LogitsProcessor, SamplingParams +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import AtomicCounter @@ -156,15 +154,6 @@ def create_streaming_error_response( }) return json_str - async def _guided_decode_logits_processor( - self, request: Union[ChatCompletionRequest, CompletionRequest], - tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: - decoding_config = await self.async_engine_client.get_decoding_config() - guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend - return await get_guided_decoding_logits_processor( - guided_decoding_backend, request, tokenizer) - async def _check_model( self, request: AnyRequest, @@ -480,3 +469,24 @@ async def unload_lora_adapter( if lora_request.lora_name != lora_name ] return f"Success: LoRA adapter '{lora_name}' removed successfully." + + @staticmethod + def _create_guided_decoding_params( + api_request: Union[CompletionRequest, ChatCompletionRequest] + ) -> GuidedDecodingParams: + """Extract all of the guided decoding parameters from a frontend api + request""" + guided_json_object = None + if (api_request.response_format is not None + and api_request.response_format.type == "json_object"): + guided_json_object = True + + return GuidedDecodingParams( + json=api_request.guided_json, + choice=api_request.guided_choice, + backend=api_request.guided_decoding_backend, + grammar=api_request.guided_grammar, + regex=api_request.guided_regex, + whitespace_pattern=api_request.guided_whitespace_pattern, + json_object=guided_json_object) + From ab0fea044d9afe4e61d3069c5ed9ab9634177a86 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 9 Sep 2024 16:55:30 -0600 Subject: [PATCH 04/28] :zap: build LPs in engine / client Signed-off-by: Joe Runde --- vllm/engine/async_llm_engine.py | 43 +++++++++++++++ vllm/engine/llm_engine.py | 53 +++++++++++++++++++ vllm/entrypoints/llm.py | 5 +- vllm/entrypoints/openai/protocol.py | 20 +++---- vllm/entrypoints/openai/rpc/client.py | 39 ++++++++++++++ vllm/entrypoints/openai/serving_chat.py | 8 +-- vllm/entrypoints/openai/serving_completion.py | 3 +- vllm/entrypoints/openai/serving_engine.py | 3 +- .../guided_decoding/__init__.py | 1 - .../guided_decoding/guided_fields.py | 1 + .../guided_decoding/outlines_decoding.py | 2 +- vllm/sampling_params.py | 9 ++-- 12 files changed, 158 insertions(+), 29 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7fe8053fffb7b..1a9e26d2ad1a6 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -22,6 +22,8 @@ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams @@ -561,6 +563,16 @@ async def add_request_async( prompt_adapter_request=prompt_adapter_request, ) + if isinstance(params, SamplingParams): + # Guided decoding has an async implementation for building logits + # processors in a separate threadpool. + # We want to invoke that here instead of using the blocking + # implementation in the LLMEngine + params = await self._build_guided_decoding_logits_processor_async( + params, + lora_request, + ) + self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, @@ -576,6 +588,37 @@ async def check_health_async(self) -> None: self.tokenizer.check_health() self.model_executor.check_health() + async def _build_guided_decoding_logits_processor_async( + self, sampling_params: SamplingParams, + lora_request: Optional[LoRARequest]) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Returns the modified sampling params.""" + if (guided_decoding := sampling_params.guided_decoding) is None: + return sampling_params + + logger.debug( + "Building guided decoding logits processor in " + "AsyncLLMEngine. Params: %v", guided_decoding) + + tokenizer = self.get_tokenizer(lora_request=lora_request) + guided_decoding.backend = guided_decoding.backend or \ + self.decoding_config.guided_decoding_backend + + processor = await get_guided_decoding_logits_processor( + guided_params=guided_decoding, tokenizer=tokenizer) + + if processor: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append(processor) + + # Unset guided decoding params after constructing the lp from them + sampling_params.guided_decoding = None + + return sampling_params + class AsyncLLMEngine: """An asynchronous wrapper for :class:`LLMEngine`. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 50dcb6937eb6f..cf27015ca21d5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -25,6 +25,7 @@ SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.util import create_output_by_sequence_group +from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, @@ -33,6 +34,8 @@ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_local_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalDataDict from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, @@ -1133,6 +1136,9 @@ def _create_sequence_group_with_sampling( raise ValueError(f"Cannot request more than " f"{max_logprobs} logprobs.") + sampling_params = self._build_logits_processors( + sampling_params, lora_request) + # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() @@ -2016,3 +2022,50 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs, # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def _build_logits_processors( + self, sampling_params: SamplingParams, + lora_request: Optional[LoRARequest]) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Returns the modified sampling params.""" + + logits_processors = [] + if (guided_decoding := sampling_params.guided_decoding) is not None: + + logger.debug( + "Building guided decoding logits processor in " + "LLMEngine. Params: %v", guided_decoding) + + tokenizer = self.get_tokenizer(lora_request=lora_request) + guided_decoding.backend = guided_decoding.backend or \ + self.decoding_config.guided_decoding_backend + + processor = get_local_guided_decoding_logits_processor( + guided_params=guided_decoding, tokenizer=tokenizer) + if processor: + logits_processors.append(processor) + + # Unset so this doesn't get passed down to the model + sampling_params.guided_decoding = None + + if (sampling_params.logit_bias or sampling_params.allowed_token_ids): + tokenizer = self.get_tokenizer(lora_request=lora_request) + + processors = get_logits_processors( + logit_bias=sampling_params.logit_bias, + allowed_token_ids=sampling_params.allowed_token_ids, + tokenizer=tokenizer) + logits_processors.extend(processors) + + # Unset so these don't get passed down to the model + sampling_params.logit_bias = None + sampling_params.allowed_token_ids = None + + if len(logits_processors) > 0: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.extend(logits_processors) + + return sampling_params diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b32c90a4df1aa..3758660cc88d6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -13,8 +13,9 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( - GuidedDecodingRequest, get_local_guided_decoding_logits_processor) -from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions + get_local_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest, LLMGuidedOptions) from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 92b99de27af2c..d9488a159540f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -267,10 +267,9 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params - def to_sampling_params( - self, - guided_decoding: Optional[GuidedDecodingParams], - default_max_tokens: int) -> SamplingParams: + def to_sampling_params(self, + guided_decoding: Optional[GuidedDecodingParams], + default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens @@ -305,8 +304,7 @@ def to_sampling_params( length_penalty=self.length_penalty, truncate_prompt_tokens=self.truncate_prompt_tokens, guided_decoding=guided_decoding, - logit_bias=self.logit_bias - ) + logit_bias=self.logit_bias) @model_validator(mode="before") @classmethod @@ -501,10 +499,9 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params - def to_sampling_params( - self, - guided_decoding: Optional[GuidedDecodingParams], - default_max_tokens: int) -> SamplingParams: + def to_sampling_params(self, + guided_decoding: Optional[GuidedDecodingParams], + default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens @@ -542,8 +539,7 @@ def to_sampling_params( truncate_prompt_tokens=self.truncate_prompt_tokens, guided_decoding=guided_decoding, logit_bias=self.logit_bias, - allowed_token_ids=self.allowed_token_ids - ) + allowed_token_ids=self.allowed_token_ids) @model_validator(mode="before") @classmethod diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 9b88db746be5c..e62c9b8dffd3c 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -23,6 +23,8 @@ from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -382,6 +384,13 @@ async def generate( ) -> AsyncGenerator[RequestOutput, None]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + # Constructing guided decoding logits processors is expensive, so we do + # it here to avoid contending with cpu resources and the GIL on the + # backend process. + sampling_params = await \ + self._build_guided_decoding_logits_processor_async( + sampling_params, lora_request) + finished = False try: with self.to_proxy_socket() as socket: @@ -449,3 +458,33 @@ async def stop_profile(self) -> None: await self._send_one_way_rpc_request( request=RPCUtilityRequest.STOP_PROFILE, error_message="RPCRequest STOP_PROFILE failed.") + + async def _build_guided_decoding_logits_processor_async( + self, sampling_params: SamplingParams, + lora_request: Optional[LoRARequest]) -> SamplingParams: + """Constructs a logits processor based on any provided guided + decoding parameters. Updates logits_processors, deletes the guided + decoding params, and returns the modified sampling params""" + if (guided_decoding := sampling_params.guided_decoding) is None: + return sampling_params + + logger.debug( + "Building guided decoding logits processor in " + "AsyncEngineRPCClient. Params: %v", guided_decoding) + + tokenizer = await self.get_tokenizer(lora_request=lora_request) + guided_decoding.backend = guided_decoding.backend or \ + self.decoding_config.guided_decoding_backend + + processor = await get_guided_decoding_logits_processor( + guided_params=guided_decoding, tokenizer=tokenizer) + + if processor: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append(processor) + + # Unset guided decoding params after constructing the lp from them + sampling_params.guided_decoding = None + + return sampling_params diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b52e45be937d6..76de155a1a25a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -7,6 +7,7 @@ from typing import Union from fastapi import Request +from pydantic import BaseModel from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient @@ -159,10 +160,11 @@ async def create_chat_completion( guided_decoding_params = \ self._create_guided_decoding_params(request) # Some requests for tools will use guided decoding + # TODO: validate against any conflicts here? if (guided_json := self._get_guided_json_from_tool(request)) is not None: - guided_decoding_params.guided_json = guided_json - + guided_decoding_params.json = guided_json + if isinstance(prompt, str): prompt_inputs = self._tokenize_prompt_input( request, @@ -782,7 +784,7 @@ def _should_check_for_unstreamed_tool_arg_tokens( and delta_message.tool_calls[0].function.arguments is not None and output.finish_reason is not None ) - + @staticmethod def _get_guided_json_from_tool( request: ChatCompletionRequest diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index b94f217169676..ce6b430973c8a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -99,8 +99,7 @@ async def create_completion( lora_request) guided_decoding_params = self._create_guided_decoding_params( - request - ) + request) prompts = list( self._tokenize_prompt_input_or_inputs( request, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 60d7b339a8793..a357a626618aa 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -469,7 +469,7 @@ async def unload_lora_adapter( if lora_request.lora_name != lora_name ] return f"Success: LoRA adapter '{lora_name}' removed successfully." - + @staticmethod def _create_guided_decoding_params( api_request: Union[CompletionRequest, ChatCompletionRequest] @@ -489,4 +489,3 @@ def _create_guided_decoding_params( regex=api_request.guided_regex, whitespace_pattern=api_request.guided_whitespace_pattern, json_object=guided_json_object) - diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 260363d7e24d0..763bd594c3044 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -41,4 +41,3 @@ def get_local_guided_decoding_logits_processor( raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " "Must be one of 'outlines, 'lm-format-enforcer'") - diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index 7ca4c21124c97..7676c0a29b1f8 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -5,6 +5,7 @@ # Nick leans towards ripping out so we don't have duplication + class LLMGuidedOptions(TypedDict, total=False): guided_json: Union[Dict, BaseModel, str] guided_regex: str diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 2506aa5b76018..bce0bad0a0d97 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -118,7 +118,7 @@ def _get_guide_and_mode( return choices_regex, GuidedDecodingMode.CHOICE elif guided_params.grammar: return guided_params.grammar, GuidedDecodingMode.GRAMMAR - elif guided_params.json_object: + elif guided_params.json_object: return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR else: return None, None diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index eb6557cb18884..40bbfe18173f6 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -7,6 +7,7 @@ import msgspec import torch +from pydantic import BaseModel from typing_extensions import Annotated from vllm.logger import init_logger @@ -43,7 +44,6 @@ class GuidedDecodingParams: choice: Optional[List[str]] = None grammar: Optional[str] = None json_object: Optional[bool] = None - """These are other options that can be set""" backend: Optional[str] = None whitespace_pattern: Optional[str] = None @@ -51,9 +51,8 @@ class GuidedDecodingParams: def __post_init__(self): """Validate that some fields are mutually exclusive.""" guide_count = sum([ - self.json is not None, self.regex is not None, - self.choice is not None, self.grammar is not None, - self.json_object is not None + self.json is not None, self.regex is not None, self.choice + is not None, self.grammar is not None, self.json_object is not None ]) if guide_count > 1: raise ValueError( @@ -61,7 +60,6 @@ def __post_init__(self): f"specified: {self.__dict__}") - class SamplingParams( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] @@ -187,7 +185,6 @@ class SamplingParams( logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None allowed_token_ids: Optional[List[int]] = None - @staticmethod def from_optional( n: Optional[int] = 1, From 41a18d554eb0e90b08ef9a1043c2f7af64922a1e Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 9 Sep 2024 16:58:06 -0600 Subject: [PATCH 05/28] :recycle: move test file Signed-off-by: Joe Runde --- .../openai => model_executor}/test_guided_processors.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{entrypoints/openai => model_executor}/test_guided_processors.py (100%) diff --git a/tests/entrypoints/openai/test_guided_processors.py b/tests/model_executor/test_guided_processors.py similarity index 100% rename from tests/entrypoints/openai/test_guided_processors.py rename to tests/model_executor/test_guided_processors.py From 40260eb76dfff2e58f09d08d6bb9cd7bd3630e3f Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 9 Sep 2024 17:03:39 -0600 Subject: [PATCH 06/28] :white_check_mark: fixup tests Signed-off-by: Joe Runde --- .../model_executor/test_guided_processors.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 85cb4d52200c3..d4ef031e960ac 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -1,14 +1,15 @@ -# This unit test should be moved to a new -# tests/test_guided_decoding directory. import pytest import torch from transformers import AutoTokenizer -from vllm.entrypoints.openai.protocol import CompletionRequest from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.sampling_params import GuidedDecodingParams + +# Yoinking fixtures from entrypoints tests, could duplicate them here +from tests.entrypoints.conftest import sample_regex, sample_json_schema def test_guided_logits_processors(sample_regex, sample_json_schema): @@ -44,11 +45,9 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') token_ids = tokenizer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") - regex_request = CompletionRequest(model='test', - prompt=token_ids, - guided_regex=sample_regex) + regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_lp = await get_guided_decoding_logits_processor( - backend, regex_request, tokenizer) + regex_request, tokenizer) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -59,11 +58,10 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, token_ids = tokenizer.encode( f"Give an employee profile that fits this schema: {sample_json_schema}" ) - json_request = CompletionRequest(model='test', - prompt=token_ids, - guided_json=sample_json_schema) + json_request = GuidedDecodingParams(json=sample_json_schema, + backend=backend) json_lp = await get_guided_decoding_logits_processor( - backend, json_request, tokenizer) + json_request, tokenizer) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) From 9d3b185d8630c0fbb813c30c1ab57eb9a64fa77e Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 10 Sep 2024 11:55:44 -0600 Subject: [PATCH 07/28] :wastebasket: deprecate GuidedDecodingRequest usage Signed-off-by: Joe Runde --- .../model_executor/test_guided_processors.py | 2 +- vllm/entrypoints/llm.py | 45 ++++++++++--------- vllm/entrypoints/openai/rpc/client.py | 2 +- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index d4ef031e960ac..4338ec5d8dc51 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -58,7 +58,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, token_ids = tokenizer.encode( f"Give an employee profile that fits this schema: {sample_json_schema}" ) - json_request = GuidedDecodingParams(json=sample_json_schema, + json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = await get_guided_decoding_logits_processor( json_request, tokenizer) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3758660cc88d6..5e6f147450897 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,3 +1,4 @@ +import warnings from contextlib import contextmanager from typing import ClassVar, List, Optional, Sequence, Union, cast, overload @@ -12,14 +13,12 @@ from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.guided_decoding import ( - get_local_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest, LLMGuidedOptions) from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.transformers_utils.tokenizer import (AnyTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -614,6 +613,14 @@ def _validate_and_add_requests( prompt_adapter_request: Optional[PromptAdapterRequest], guided_options: Optional[GuidedDecodingRequest] = None, ) -> None: + if guided_options is not None: + warnings.warn( + "guided_options is deprecated, use " + "SamplingParams.guided_decoding instead", + DeprecationWarning, + stacklevel=2, + ) + if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. inputs = [inputs] @@ -629,12 +636,11 @@ def _validate_and_add_requests( if isinstance(params, list): params = [ - self._add_guided_processor(param, guided_options) - if isinstance(param, SamplingParams) else param - for param in params + self._add_guided_params(param, guided_options) if isinstance( + param, SamplingParams) else param for param in params ] elif isinstance(params, SamplingParams): - params = self._add_guided_processor(params, guided_options) + params = self._add_guided_params(params, guided_options) # Add requests to the engine. for i, request_inputs in enumerate(inputs): @@ -662,22 +668,21 @@ def _add_request( prompt_adapter_request=prompt_adapter_request, ) - def _add_guided_processor( + def _add_guided_params( self, params: SamplingParams, guided_options: Optional[GuidedDecodingRequest] = None): - if guided_options: - if guided_options.guided_decoding_backend is None: - decoding_config = self.llm_engine.get_decoding_config() - guided_options.guided_decoding_backend = ( - decoding_config.guided_decoding_backend) - guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa - guided_options.guided_decoding_backend, guided_options, - self.get_tokenizer()) - if guided_logits_processor: - if params.logits_processors is None: - params.logits_processors = [] - params.logits_processors.append(guided_logits_processor) + if guided_options is None: + return params + + params.guided_decoding = GuidedDecodingParams( + json=guided_options.guided_json, + regex=guided_options.guided_regex, + choice=guided_options.guided_choice, + grammar=guided_options.guided_grammar, + json_object=guided_options.guided_json_object, + backend=guided_options.guided_decoding_backend, + whitespace_pattern=guided_options.guided_whitespace_pattern) return params def _run_engine( diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index e62c9b8dffd3c..846cd8f419683 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -269,7 +269,7 @@ async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): raise response raise ValueError(error_message) - async def get_tokenizer(self, lora_request: LoRARequest): + async def get_tokenizer(self, lora_request: Optional[LoRARequest]): return await self.tokenizer.get_lora_tokenizer_async(lora_request) async def get_decoding_config(self) -> DecodingConfig: From 9b6ab5567c226f9c89d9c474cb1358b406d26a56 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 10 Sep 2024 12:04:18 -0600 Subject: [PATCH 08/28] :goal_net: disallow both guided options Signed-off-by: Joe Runde --- vllm/entrypoints/llm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3f74aa65eee67..bcb1c0a4e92f1 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -631,7 +631,7 @@ def _validate_and_add_requests( ) -> None: if guided_options is not None: warnings.warn( - "guided_options is deprecated, use " + "guided_options_request is deprecated, use " "SamplingParams.guided_decoding instead", DeprecationWarning, stacklevel=2, @@ -691,6 +691,10 @@ def _add_guided_params( if guided_options is None: return params + if params.guided_decoding is not None: + raise ValueError("Cannot set both guided_options_request and" + "params.guided_decoding.") + params.guided_decoding = GuidedDecodingParams( json=guided_options.guided_json, regex=guided_options.guided_regex, From 429704d066e15d287f26d78d1e04b867b322363e Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 10 Sep 2024 14:34:15 -0600 Subject: [PATCH 09/28] :bug: fixup LLM tests Signed-off-by: Joe Runde --- tests/entrypoints/llm/test_guided_generate.py | 66 ++++++++++++------- tests/model_executor/conftest.py | 50 ++++++++++++++ .../model_executor/test_guided_processors.py | 5 +- .../guided_decoding/guided_fields.py | 3 +- 4 files changed, 96 insertions(+), 28 deletions(-) create mode 100644 tests/model_executor/conftest.py diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 873e115421257..2841dfc6bd9c2 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -7,7 +7,7 @@ from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...conftest import cleanup @@ -31,14 +31,12 @@ def test_guided_regex(sample_regex, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, - ) - outputs = llm.generate( - prompts=[ - f"Give an example IPv4 address with this regex: {sample_regex}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_regex=sample_regex)) + guided_decoding=GuidedDecodingParams(regex=sample_regex)) + outputs = llm.generate(prompts=[ + f"Give an example IPv4 address with this regex: {sample_regex}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) assert outputs is not None for output in outputs: @@ -57,15 +55,13 @@ def test_guided_json_completion(sample_json_schema, llm): sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, - ) - outputs = llm.generate( - prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_json=sample_json_schema)) + guided_decoding=GuidedDecodingParams(json=sample_json_schema)) + outputs = llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) assert outputs is not None @@ -86,12 +82,11 @@ def test_guided_choice_completion(sample_guided_choice, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, - ) + guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) outputs = llm.generate( prompts="The best language for type-safe systems programming is ", sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_choice=sample_guided_choice)) + use_tqdm=True) assert outputs is not None for output in outputs: @@ -112,13 +107,13 @@ def test_guided_grammar(sample_sql_statements, llm): temperature=0.8, top_p=0.95, max_tokens=1000, - ) + guided_decoding=GuidedDecodingParams(grammar=sample_sql_statements)) outputs = llm.generate( prompts=("Generate a sql state that select col_1 from " "table_1 where it is equals to 1"), sampling_params=sampling_params, use_tqdm=True, - guided_options_request=dict(guided_grammar=sample_sql_statements)) + ) assert outputs is not None for output in outputs: @@ -140,3 +135,28 @@ def test_guided_grammar(sample_sql_statements, llm): assert generated_text.strip() == ground_truth print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +def test_guided_options_request_deprecation_warning(sample_regex, llm): + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + with pytest.warns(DeprecationWarning, match="guided_options_request"): + llm.generate(prompts="This should fail", + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex)) + + +@pytest.mark.skip_global_cleanup +def test_validation_against_both_guided_decoding_options(sample_regex, llm): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams(regex=sample_regex)) + + with pytest.raises(ValueError, match="Cannot set both"): + llm.generate(prompts="This should fail", + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex)) diff --git a/tests/model_executor/conftest.py b/tests/model_executor/conftest.py new file mode 100644 index 0000000000000..4bf6852349a46 --- /dev/null +++ b/tests/model_executor/conftest.py @@ -0,0 +1,50 @@ +import pytest + + +@pytest.fixture +def sample_regex(): + return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + + +@pytest.fixture +def sample_json_schema(): + return { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work_history"] + } + diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 4338ec5d8dc51..f340467c171da 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -1,16 +1,15 @@ +import time import pytest import torch from transformers import AutoTokenizer +from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) from vllm.sampling_params import GuidedDecodingParams -# Yoinking fixtures from entrypoints tests, could duplicate them here -from tests.entrypoints.conftest import sample_regex, sample_json_schema - def test_guided_logits_processors(sample_regex, sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index 7676c0a29b1f8..8deb4c949824a 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -3,9 +3,8 @@ from pydantic import BaseModel -# Nick leans towards ripping out so we don't have duplication - +# These classes are deprecated, see SamplingParams class LLMGuidedOptions(TypedDict, total=False): guided_json: Union[Dict, BaseModel, str] guided_regex: str From 2df47831cca16875c188d371349f2ab02b824fd4 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 10 Sep 2024 14:56:27 -0600 Subject: [PATCH 10/28] :art: fmt Signed-off-by: Joe Runde --- tests/model_executor/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/model_executor/conftest.py b/tests/model_executor/conftest.py index 4bf6852349a46..10792b0a04999 100644 --- a/tests/model_executor/conftest.py +++ b/tests/model_executor/conftest.py @@ -47,4 +47,3 @@ def sample_json_schema(): }, "required": ["name", "age", "skills", "work_history"] } - From 97a1116e82cc4a99db47449a7020d352154231e7 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 10 Sep 2024 15:14:14 -0600 Subject: [PATCH 11/28] :art: more fmt Signed-off-by: Joe Runde --- tests/model_executor/test_guided_processors.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index f340467c171da..d41742d84f91c 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -1,9 +1,7 @@ -import time import pytest import torch from transformers import AutoTokenizer -from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( From 24d95febdf5a721570ef8ab3793138df89dd25c9 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 10 Sep 2024 15:25:59 -0600 Subject: [PATCH 12/28] :zap: ensure use of async outlines LP construction Signed-off-by: Joe Runde --- vllm/model_executor/guided_decoding/__init__.py | 6 ++++-- .../guided_decoding/lm_format_enforcer_decoding.py | 11 ++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 763bd594c3044..368436aa14613 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -6,7 +6,8 @@ async def get_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer) -> Optional[LogitsProcessor]: - if guided_params.backend == 'outlines': + # CFG grammar not supported by LMFE, so we use outlines instead + if guided_params.backend == 'outlines' or guided_params.grammar: # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) @@ -26,7 +27,8 @@ async def get_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer) -> Optional[LogitsProcessor]: - if guided_params.backend == 'outlines': + # CFG grammar not supported by LMFE, so we use outlines instead + if guided_params.backend == 'outlines' or guided_params.grammar: # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index 53a9b32006d36..c4120d4409b41 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -35,13 +35,10 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor( elif guided_params.regex: character_level_parser = RegexParser(guided_params.regex) elif guided_params.grammar: - # CFG grammar not supported by LMFE, revert to outlines - - # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 - from vllm.model_executor.guided_decoding.outlines_decoding import ( - get_local_outlines_guided_decoding_logits_processor) - return get_local_outlines_guided_decoding_logits_processor( - guided_params, tokenizer) + # CFG grammar not supported by LMFE + raise ValueError("Cannot construct a guided decoding logits processor" + " using the grammar option with the" + " lm_format_enforcer backend.") elif guided_params.json_object: # None means any json object character_level_parser = JsonSchemaParser(None) From 0fa508003dfdc54b4205d69e56b70f6093133dbc Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 10 Sep 2024 16:02:30 -0600 Subject: [PATCH 13/28] :recycle: move guided params construction Signed-off-by: Joe Runde --- vllm/entrypoints/openai/protocol.py | 54 ++++++++++++++++--- vllm/entrypoints/openai/serving_chat.py | 33 ------------ vllm/entrypoints/openai/serving_completion.py | 3 -- vllm/entrypoints/openai/serving_engine.py | 22 +------- 4 files changed, 49 insertions(+), 63 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 1576caae3a1f0..c53222556d36e 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -267,9 +267,7 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params - def to_sampling_params(self, - guided_decoding: Optional[GuidedDecodingParams], - default_max_tokens: int) -> SamplingParams: + def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens @@ -278,6 +276,20 @@ def to_sampling_params(self, if prompt_logprobs is None and self.echo: prompt_logprobs = self.top_logprobs + guided_json_object = None + if (self.response_format is not None + and self.response_format.type == "json_object"): + guided_json_object = True + + guided_decoding = GuidedDecodingParams( + json=self._get_guided_json_from_tool() or self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern) + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -306,6 +318,24 @@ def to_sampling_params(self, guided_decoding=guided_decoding, logit_bias=self.logit_bias) + def _get_guided_json_from_tool( + self) -> Optional[Union[str, dict, BaseModel]]: + # user has chosen to not use any tool + if self.tool_choice == "none" or self.tools is None: + return None + + # user has chosen to use a named tool + if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam: + tool_name = self.tool_choice.function.name + tools = {tool.function.name: tool.function for tool in self.tools} + if tool_name not in tools: + raise ValueError( + f"Tool '{tool_name}' has not been passed in `tools`.") + tool = tools[tool_name] + return tool.parameters + + return None + @model_validator(mode="before") @classmethod def validate_stream_options(cls, data): @@ -499,9 +529,7 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params - def to_sampling_params(self, - guided_decoding: Optional[GuidedDecodingParams], - default_max_tokens: int) -> SamplingParams: + def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens @@ -512,6 +540,20 @@ def to_sampling_params(self, echo_without_generation = self.echo and self.max_tokens == 0 + guided_json_object = None + if (self.response_format is not None + and self.response_format.type == "json_object"): + guided_json_object = True + + guided_decoding = GuidedDecodingParams( + json=self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern) + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ac5359595bff1..e2a2c56db4dae 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -7,7 +7,6 @@ from typing import Union from fastapi import Request -from pydantic import BaseModel from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient @@ -170,14 +169,6 @@ async def create_chat_completion( request_id = f"chat-{random_uuid()}" try: - guided_decoding_params = \ - self._create_guided_decoding_params(request) - # Some requests for tools will use guided decoding - # TODO: validate against any conflicts here? - if (guided_json := - self._get_guided_json_from_tool(request)) is not None: - guided_decoding_params.json = guided_json - if isinstance(prompt, str): prompt_inputs = self._tokenize_prompt_input( request, @@ -196,7 +187,6 @@ async def create_chat_completion( assert prompt_inputs is not None sampling_params = request.to_sampling_params( - guided_decoding_params, default_max_tokens=self.max_model_len - len(prompt_inputs["prompt_token_ids"])) @@ -800,26 +790,3 @@ def _should_check_for_unstreamed_tool_arg_tokens( and delta_message.tool_calls[0].function.arguments is not None and output.finish_reason is not None ) - - @staticmethod - def _get_guided_json_from_tool( - request: ChatCompletionRequest - ) -> Optional[Union[str, dict, BaseModel]]: - # user has chosen to not use any tool - if request.tool_choice == "none" or request.tools is None: - return None - - # user has chosen to use a named tool - if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam: - tool_name = request.tool_choice.function.name - tools = { - tool.function.name: tool.function - for tool in request.tools - } - if tool_name not in tools: - raise ValueError( - f"Tool '{tool_name}' has not been passed in `tools`.") - tool = tools[tool_name] - return tool.parameters - - return None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index ce6b430973c8a..01b5b987f3a25 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -98,8 +98,6 @@ async def create_completion( tokenizer = await self.async_engine_client.get_tokenizer( lora_request) - guided_decoding_params = self._create_guided_decoding_params( - request) prompts = list( self._tokenize_prompt_input_or_inputs( request, @@ -111,7 +109,6 @@ async def create_completion( for i, prompt_inputs in enumerate(prompts): sampling_params = request.to_sampling_params( - guided_decoding_params, default_max_tokens=self.max_model_len - len(prompt_inputs["prompt_token_ids"])) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index a357a626618aa..564d5feb5d208 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -29,7 +29,7 @@ from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import GuidedDecodingParams, SamplingParams +from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import AtomicCounter @@ -469,23 +469,3 @@ async def unload_lora_adapter( if lora_request.lora_name != lora_name ] return f"Success: LoRA adapter '{lora_name}' removed successfully." - - @staticmethod - def _create_guided_decoding_params( - api_request: Union[CompletionRequest, ChatCompletionRequest] - ) -> GuidedDecodingParams: - """Extract all of the guided decoding parameters from a frontend api - request""" - guided_json_object = None - if (api_request.response_format is not None - and api_request.response_format.type == "json_object"): - guided_json_object = True - - return GuidedDecodingParams( - json=api_request.guided_json, - choice=api_request.guided_choice, - backend=api_request.guided_decoding_backend, - grammar=api_request.guided_grammar, - regex=api_request.guided_regex, - whitespace_pattern=api_request.guided_whitespace_pattern, - json_object=guided_json_object) From 450c76eba7ebaf777da2d5c0df6afa78694618ef Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 12 Sep 2024 14:38:40 -0600 Subject: [PATCH 14/28] :art: fmt Signed-off-by: Joe Runde --- vllm/entrypoints/llm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9aa9554d42cdc..3660a5fa7b5d4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -19,7 +19,8 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind, SamplingParams) +from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind, + SamplingParams) from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -650,8 +651,7 @@ def _validate_and_add_requests( raise ValueError("The lengths of prompts and lora_request " "must be the same.") - - for sp in params if isinstance(params, list) else (self._add_guided_params(params, guided_options), ): + for sp in params if isinstance(params, list) else (params, ): if isinstance(sp, SamplingParams): self._add_guided_params(sp, guided_options) From d4540ca405c656823eebd3b3758da89713790cf0 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 17 Sep 2024 09:03:03 -0600 Subject: [PATCH 15/28] :recycle: refactor mistral unsupported errors Signed-off-by: Joe Runde --- .../guided_decoding/__init__.py | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 8f73f47e444c7..197a98fb8b40f 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -10,10 +10,7 @@ async def get_guided_decoding_logits_processor( # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend == 'outlines' or guided_params.grammar: if isinstance(tokenizer, MistralTokenizer): - raise NotImplementedError( - "Guided decoding with 'outlines' is currently not supported " - "for Mistral tokenizer. Please consider contributing to the " - "'outlines' project if you are interested in this feature.") + _mistral_unsupported('outlines') # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) @@ -21,11 +18,7 @@ async def get_guided_decoding_logits_processor( guided_params, tokenizer) if guided_params.backend == 'lm-format-enforcer': if isinstance(tokenizer, MistralTokenizer): - raise NotImplementedError( - "Guided decoding with 'lm-format-enforcer' is currently not " - "supported for Mistral tokenizer. Please consider contributing " - "to the 'lm-format-enforcer' project if you are interested " - "in this feature.") + _mistral_unsupported('lm-format-enforcer') from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( @@ -42,23 +35,15 @@ def get_local_guided_decoding_logits_processor( # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend == 'outlines' or guided_params.grammar: if isinstance(tokenizer, MistralTokenizer): - raise NotImplementedError( - "Guided decoding with 'outlines' is currently not supported " - "for Mistral tokenizer. Please consider contributing to the " - "'outlines' project if you are interested in this feature.") + _mistral_unsupported('outlines') # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_params, tokenizer) if guided_params.backend == 'lm-format-enforcer': - if isinstance(tokenizer, MistralTokenizer): - raise NotImplementedError( - "Guided decoding with 'lm-format-enforcer' is currently not " - "supported for Mistral tokenizer. Please consider contributing " - "to the 'lm-format-enforcer' project if you are interested " - "in this feature.") + _mistral_unsupported('lm-format-enforcer') from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( @@ -67,3 +52,12 @@ def get_local_guided_decoding_logits_processor( raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " "Must be one of 'outlines, 'lm-format-enforcer'") + + +def _mistral_unsupported(backend: str): + """Raise error that mistral tokenizers are unsupported""" + raise NotImplementedError( + f"Guided decoding with '{backend}' is currently not " + "supported for Mistral tokenizer. Please consider contributing " + f"to the '{backend}' project if you are interested " + "in this feature.") \ No newline at end of file From 1d580d274fc07ed291fe82d5de1c78b6a454d304 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 17 Sep 2024 10:16:33 -0600 Subject: [PATCH 16/28] :white_check_mark: test guided decoding validation Signed-off-by: Joe Runde --- tests/model_executor/test_guided_processors.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index d41742d84f91c..791a220be0a22 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -65,3 +65,16 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, tensor = json_lp(token_ids, tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) + +def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): + with pytest.raises(ValueError, match="You can only use one kind of guided"): + GuidedDecodingParams(json=sample_json_schema, regex=sample_regex) + + with pytest.raises(ValueError, match="You can only use one kind of guided"): + GuidedDecodingParams(json=sample_json_schema, json_object=True) + + with pytest.raises(ValueError, match="You can only use one kind of guided"): + GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"]) + + with pytest.raises(ValueError, match="You can only use one kind of guided"): + GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") From b984187d7602f0c8d70ad14d55afe0f988940796 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 17 Sep 2024 10:26:41 -0600 Subject: [PATCH 17/28] :art: fmt Signed-off-by: Joe Runde --- tests/model_executor/test_guided_processors.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 791a220be0a22..45fab8e96b968 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -66,15 +66,20 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) + def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): - with pytest.raises(ValueError, match="You can only use one kind of guided"): + with pytest.raises(ValueError, + match="You can only use one kind of guided"): GuidedDecodingParams(json=sample_json_schema, regex=sample_regex) - with pytest.raises(ValueError, match="You can only use one kind of guided"): + with pytest.raises(ValueError, + match="You can only use one kind of guided"): GuidedDecodingParams(json=sample_json_schema, json_object=True) - with pytest.raises(ValueError, match="You can only use one kind of guided"): + with pytest.raises(ValueError, + match="You can only use one kind of guided"): GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"]) - with pytest.raises(ValueError, match="You can only use one kind of guided"): + with pytest.raises(ValueError, + match="You can only use one kind of guided"): GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") From 1c2bbf1c070dd7508d390287ea14acdc0849e248 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 17 Sep 2024 15:58:56 -0600 Subject: [PATCH 18/28] :bug: fixup engine client test Signed-off-by: Joe Runde --- tests/entrypoints/openai/rpc/test_zmq_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/rpc/test_zmq_client.py b/tests/entrypoints/openai/rpc/test_zmq_client.py index cafd125c5a598..d4da72f2f7ae6 100644 --- a/tests/entrypoints/openai/rpc/test_zmq_client.py +++ b/tests/entrypoints/openai/rpc/test_zmq_client.py @@ -11,6 +11,7 @@ from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient, RPCClientClosedError) from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer +from vllm.sampling_params import SamplingParams @pytest.fixture(scope="function") @@ -112,7 +113,8 @@ async def test_client_errors_after_closing(monkeypatch, dummy_server, with pytest.raises(RPCClientClosedError): await client.check_health() with pytest.raises(RPCClientClosedError): - async for _ in client.generate(None, None, None): + async for _ in client.generate("sample prompt", SamplingParams(), + uuid.uuid4()): pass # But no-ops like aborting will pass From aa84827156bea435e6c63aa8c9a2221faa6c03da Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 18 Sep 2024 11:41:08 -0600 Subject: [PATCH 19/28] :bug: start to fixup for msgspec Signed-off-by: Joe Runde --- vllm/entrypoints/openai/protocol.py | 4 +-- vllm/sampling_params.py | 43 +++++++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0171b45ea0ed3..32066b604a0fa 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -282,7 +282,7 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: and self.response_format.type == "json_object"): guided_json_object = True - guided_decoding = GuidedDecodingParams( + guided_decoding = GuidedDecodingParams.from_optional( json=self._get_guided_json_from_tool() or self.guided_json, regex=self.guided_regex, choice=self.guided_choice, @@ -548,7 +548,7 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: and self.response_format.type == "json_object"): guided_json_object = True - guided_decoding = GuidedDecodingParams( + guided_decoding = GuidedDecodingParams.from_optional( json=self.guided_json, regex=self.guided_regex, choice=self.guided_choice, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 3e82a473f703e..c231c3739f869 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from enum import Enum, IntEnum from functools import cached_property +from json import dumps from typing import Any, Callable, Dict, List, Optional, Set, Union import msgspec @@ -39,7 +40,7 @@ class SamplingType(IntEnum): @dataclass class GuidedDecodingParams: """One of these fields will be used to build a logit processor.""" - json: Optional[Union[Dict, BaseModel, str]] = None + json: Optional[Union[BaseModel, str]] = None regex: Optional[str] = None choice: Optional[List[str]] = None grammar: Optional[str] = None @@ -48,6 +49,31 @@ class GuidedDecodingParams: backend: Optional[str] = None whitespace_pattern: Optional[str] = None + @staticmethod + def from_optional( + json: Optional[Union[Dict, BaseModel, str]], + regex: Optional[str] = None, + choice: Optional[List[str]] = None, + grammar: Optional[str] = None, + json_object: Optional[bool] = None, + backend: Optional[str] = None, + whitespace_pattern: Optional[str] = None, + ) -> "GuidedDecodingParams": + if isinstance(json, dict): + # Serialize dicts to json strings up-front + # msgspec decoders will complain about a type union with both + # Dict and BaseModel in it + json = dumps(json) + return GuidedDecodingParams( + json=json, + regex=regex, + choice=choice, + grammar=grammar, + json_object=json_object, + backend=backend, + whitespace_pattern=whitespace_pattern, + ) + def __post_init__(self): """Validate that some fields are mutually exclusive.""" guide_count = sum([ @@ -150,6 +176,13 @@ class SamplingParams( truncate_prompt_tokens: If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). Defaults to None (i.e., no truncation). + guided_decoding: If provided, the engine will construct a guided + decoding logits processor from these parameters. Defaults to None. + logit_bias: If provided, the engine will construct a logits processor + that applies these logit biases. Defaults to None. + allowed_token_ids: If provided, the engine will construct a logits + processor which only retains scores for the given token ids. + Defaults to None. """ n: int = 1 @@ -192,7 +225,7 @@ class SamplingParams( # Fields used to construct logits processors guided_decoding: Optional[GuidedDecodingParams] = None - logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None + logit_bias: Optional[Dict[int, float]] = None allowed_token_ids: Optional[List[int]] = None @staticmethod @@ -229,6 +262,12 @@ def from_optional( logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None, allowed_token_ids: Optional[List[int]] = None, ) -> "SamplingParams": + if logit_bias is not None: + logit_bias = { + int(token): bias + for token, bias in logit_bias.items() + } + return SamplingParams( n=1 if n is None else n, best_of=best_of, From 891c5262f29a02c3c272ac6f681a7531de6d550e Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 18 Sep 2024 11:50:52 -0600 Subject: [PATCH 20/28] :zap: add LP construction to mqllmengine client Signed-off-by: Joe Runde --- vllm/engine/multiprocessing/client.py | 39 +++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 18b620c74ddf9..fd38c688fdb56 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -27,6 +27,8 @@ from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -395,6 +397,13 @@ async def generate( if self._errored_with is not None: raise ENGINE_DEAD_ERROR(self._errored_with) + # Constructing guided decoding logits processors is expensive, so we do + # it here to avoid contending with cpu resources and the GIL on the + # backend process. + sampling_params = await \ + self._build_guided_decoding_logits_processor_async( + sampling_params, lora_request) + # 1) Create output queue for this requests. queue: asyncio.Queue[Union[RequestOutput, BaseException]] = asyncio.Queue() @@ -450,3 +459,33 @@ async def encode(self, *args, **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: raise NotImplementedError( "Embeddings not supported with multiprocessing backend") + + async def _build_guided_decoding_logits_processor_async( + self, sampling_params: SamplingParams, + lora_request: Optional[LoRARequest]) -> SamplingParams: + """Constructs a logits processor based on any provided guided + decoding parameters. Updates logits_processors, deletes the guided + decoding params, and returns the modified sampling params""" + if (guided_decoding := sampling_params.guided_decoding) is None: + return sampling_params + + logger.debug( + "Building guided decoding logits processor in " + "AsyncEngineRPCClient. Params: %v", guided_decoding) + + tokenizer = await self.get_tokenizer(lora_request=lora_request) + guided_decoding.backend = guided_decoding.backend or \ + self.decoding_config.guided_decoding_backend + + processor = await get_guided_decoding_logits_processor( + guided_params=guided_decoding, tokenizer=tokenizer) + + if processor: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append(processor) + + # Unset guided decoding params after constructing the lp from them + sampling_params.guided_decoding = None + + return sampling_params From 8f079e5f0575eddcba7605fab52a69069950e923 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 25 Sep 2024 12:09:02 -0600 Subject: [PATCH 21/28] :art: fmt Signed-off-by: Joe Runde --- vllm/engine/multiprocessing/client.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 5802647b16539..26316fbe438b1 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -517,9 +517,10 @@ async def _process_request( # Constructing guided decoding logits processors is expensive, so we do # it here to avoid contending with cpu resources and the GIL on the # backend process. - sampling_params = await \ - self._build_guided_decoding_logits_processor_async( - sampling_params, lora_request) + if isinstance(params, SamplingParams): + params = await \ + self._build_guided_decoding_logits_processor_async( + params, lora_request) # 1) Create output queue for this requests. queue: asyncio.Queue[Union[RequestOutput, @@ -601,7 +602,7 @@ async def _build_guided_decoding_logits_processor_async( sampling_params.guided_decoding = None return sampling_params - + async def start_profile(self) -> None: """Start profiling the engine""" From ed29ec25db2a2a58d1938635528343843ead57da Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 25 Sep 2024 17:01:37 -0600 Subject: [PATCH 22/28] :recycle: extract BaseModel schema up-front Signed-off-by: Joe Runde --- .../guided_decoding/lm_format_enforcer_decoding.py | 9 +++------ .../guided_decoding/outlines_decoding.py | 5 ----- vllm/sampling_params.py | 11 ++++------- 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index c4120d4409b41..cf2162ed7720d 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -7,7 +7,6 @@ TokenEnforcerTokenizerData, UnionParser) from lmformatenforcer.integrations.vllm import ( build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) -from pydantic import BaseModel from transformers import PreTrainedTokenizerBase from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor @@ -27,8 +26,8 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor( tokenizer) character_level_parser: CharacterLevelParser if guided_params.json: - schema = _normalize_json_schema_object(guided_params.json) - character_level_parser = JsonSchemaParser(schema) + schema_dict = _normalize_json_schema_object(guided_params.json) + character_level_parser = JsonSchemaParser(schema_dict) elif guided_params.choice: character_level_parser = UnionParser( [StringParser(choice) for choice in guided_params.choice]) @@ -50,13 +49,11 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor( return logits_processor -def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: +def _normalize_json_schema_object(schema: Union[str, dict]) -> dict: if isinstance(schema, str): return json_loads(schema) if isinstance(schema, dict): return schema - if isinstance(schema, BaseModel): - return schema.model_json_schema() raise AssertionError(f"Unsupported schema type {schema}") diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index bce0bad0a0d97..8a7ff38bfeb1a 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -5,7 +5,6 @@ from re import escape as regex_escape from typing import Tuple, Union -from pydantic import BaseModel from transformers import PreTrainedTokenizerBase from vllm.model_executor.guided_decoding.outlines_logits_processors import ( @@ -100,10 +99,6 @@ def _get_guide_and_mode( if isinstance(guided_params.json, dict): # turn dict into hashable string json = json_dumps(guided_params.json) - elif isinstance(guided_params.json, BaseModel): - # use pydantic signature so that different model classes - # with the same fields will get hashed the same - json = str(guided_params.json.__signature__) else: json = guided_params.json return json, GuidedDecodingMode.JSON diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 0f978c23292ab..8374bcc050a16 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from enum import Enum, IntEnum from functools import cached_property -from json import dumps from typing import Any, Callable, Dict, List, Optional, Set, Union import msgspec @@ -41,7 +40,7 @@ class SamplingType(IntEnum): @dataclass class GuidedDecodingParams: """One of these fields will be used to build a logit processor.""" - json: Optional[Union[BaseModel, str]] = None + json: Optional[Union[str, Dict]] = None regex: Optional[str] = None choice: Optional[List[str]] = None grammar: Optional[str] = None @@ -60,11 +59,9 @@ def from_optional( backend: Optional[str] = None, whitespace_pattern: Optional[str] = None, ) -> "GuidedDecodingParams": - if isinstance(json, dict): - # Serialize dicts to json strings up-front - # msgspec decoders will complain about a type union with both - # Dict and BaseModel in it - json = dumps(json) + # Extract json schemas from pydantic models + if isinstance(json, (BaseModel, type(BaseModel))): + json = json.model_json_schema() return GuidedDecodingParams( json=json, regex=regex, From 6281044096d7d93c6149f88814758824785b66be Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 27 Sep 2024 15:12:35 -0600 Subject: [PATCH 23/28] :recycle: v -> s Signed-off-by: Joe Runde --- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 2 +- vllm/engine/multiprocessing/client.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index e6f761b58c09a..b910213e23468 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -503,7 +503,7 @@ async def _build_guided_decoding_logits_processor_async( logger.debug( "Building guided decoding logits processor in " - "AsyncLLMEngine. Params: %v", guided_decoding) + "AsyncLLMEngine. Params: %s", guided_decoding) tokenizer = self.get_tokenizer(lora_request=lora_request) guided_decoding.backend = guided_decoding.backend or \ diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 64f8374e39990..4dae2490f6b89 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1819,7 +1819,7 @@ def _build_logits_processors( logger.debug( "Building guided decoding logits processor in " - "LLMEngine. Params: %v", guided_decoding) + "LLMEngine. Params: %s", guided_decoding) tokenizer = self.get_tokenizer(lora_request=lora_request) guided_decoding.backend = guided_decoding.backend or \ diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 26316fbe438b1..9b1dc3e419149 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -584,7 +584,7 @@ async def _build_guided_decoding_logits_processor_async( logger.debug( "Building guided decoding logits processor in " - "AsyncEngineRPCClient. Params: %v", guided_decoding) + "AsyncEngineRPCClient. Params: %s", guided_decoding) tokenizer = await self.get_tokenizer(lora_request=lora_request) guided_decoding.backend = guided_decoding.backend or \ From a4e25236379ae473a451be2a69ff3cf3aa321e45 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 27 Sep 2024 16:48:11 -0600 Subject: [PATCH 24/28] Update vllm/engine/async_llm_engine.py Co-authored-by: Nick Hill --- vllm/engine/async_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index b910213e23468..275250a6fb8b5 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -505,7 +505,7 @@ async def _build_guided_decoding_logits_processor_async( "Building guided decoding logits processor in " "AsyncLLMEngine. Params: %s", guided_decoding) - tokenizer = self.get_tokenizer(lora_request=lora_request) + tokenizer = await self.get_tokenizer(lora_request=lora_request) guided_decoding.backend = guided_decoding.backend or \ self.decoding_config.guided_decoding_backend From 8acc98eba3e20196c2de7d0a049c9f53b5c09d51 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 27 Sep 2024 16:49:06 -0600 Subject: [PATCH 25/28] Update vllm/engine/llm_engine.py Co-authored-by: Nick Hill --- vllm/engine/llm_engine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4dae2490f6b89..e013bcb479e0c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1846,9 +1846,10 @@ def _build_logits_processors( sampling_params.logit_bias = None sampling_params.allowed_token_ids = None - if len(logits_processors) > 0: + if logits_processors: if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.extend(logits_processors) + sampling_params.logits_processors = logits_processors + else: + sampling_params.logits_processors.extend(logits_processors) return sampling_params From 485ce1de5001adcecdfb4f89828432c0c2f9625c Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 30 Sep 2024 11:09:49 -0600 Subject: [PATCH 26/28] Revert "Update vllm/engine/async_llm_engine.py" This reverts commit a4e25236379ae473a451be2a69ff3cf3aa321e45. --- vllm/engine/async_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 275250a6fb8b5..b910213e23468 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -505,7 +505,7 @@ async def _build_guided_decoding_logits_processor_async( "Building guided decoding logits processor in " "AsyncLLMEngine. Params: %s", guided_decoding) - tokenizer = await self.get_tokenizer(lora_request=lora_request) + tokenizer = self.get_tokenizer(lora_request=lora_request) guided_decoding.backend = guided_decoding.backend or \ self.decoding_config.guided_decoding_backend From f0a3f9d41030eae58736c42bc916f1403f73a10d Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 30 Sep 2024 11:48:43 -0600 Subject: [PATCH 27/28] :recycle: refactor guided decoding lp construction Signed-off-by: Joe Runde --- vllm/engine/async_llm_engine.py | 59 ++++++++++++++------------- vllm/engine/multiprocessing/client.py | 44 ++++---------------- vllm/sampling_params.py | 3 +- 3 files changed, 41 insertions(+), 65 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index b910213e23468..aa6f0ee188bdd 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -466,15 +466,17 @@ async def add_request_async( ) processed_inputs = self.input_processor(preprocessed_inputs) - if isinstance(params, SamplingParams): + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: # Guided decoding has an async implementation for building logits # processors in a separate threadpool. # We want to invoke that here instead of using the blocking # implementation in the LLMEngine - params = await self._build_guided_decoding_logits_processor_async( - params, - lora_request, - ) + params = await build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=self.get_tokenizer(lora_request), + default_guided_backend=self.decoding_config. + guided_decoding_backend) self._add_processed_request( request_id=request_id, @@ -491,36 +493,35 @@ async def check_health_async(self) -> None: self.tokenizer.check_health() self.model_executor.check_health() - async def _build_guided_decoding_logits_processor_async( - self, sampling_params: SamplingParams, - lora_request: Optional[LoRARequest]) -> SamplingParams: - """Constructs logits processors based on the guided_decoding, - logits_bias, and allowed_token_ids fields in sampling_params. Deletes - those fields and adds the constructed logits processors to the - logits_processors field. Returns the modified sampling params.""" - if (guided_decoding := sampling_params.guided_decoding) is None: - return sampling_params - logger.debug( - "Building guided decoding logits processor in " - "AsyncLLMEngine. Params: %s", guided_decoding) +async def build_guided_decoding_logits_processor_async( + sampling_params: SamplingParams, tokenizer: AnyTokenizer, + default_guided_backend: str) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Modifies sampling params in-place and returns + the modified sampling params.""" + if (guided_decoding := sampling_params.guided_decoding) is None: + return sampling_params - tokenizer = self.get_tokenizer(lora_request=lora_request) - guided_decoding.backend = guided_decoding.backend or \ - self.decoding_config.guided_decoding_backend + logger.debug("Building guided decoding logits processor. " + "Params: %s", guided_decoding) - processor = await get_guided_decoding_logits_processor( - guided_params=guided_decoding, tokenizer=tokenizer) + guided_decoding.backend = guided_decoding.backend or default_guided_backend - if processor: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append(processor) + processor = await get_guided_decoding_logits_processor( + guided_params=guided_decoding, tokenizer=tokenizer) - # Unset guided decoding params after constructing the lp from them - sampling_params.guided_decoding = None + if processor: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append(processor) - return sampling_params + # Unset guided decoding params after constructing the lp from them + sampling_params.guided_decoding = None + + return sampling_params class AsyncLLMEngine: diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 9b1dc3e419149..79da0be97fdbf 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -16,6 +16,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block # yapf: disable +from vllm.engine.async_llm_engine import ( + build_guided_decoding_logits_processor_async) from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, @@ -28,8 +30,6 @@ from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -517,10 +517,14 @@ async def _process_request( # Constructing guided decoding logits processors is expensive, so we do # it here to avoid contending with cpu resources and the GIL on the # backend process. - if isinstance(params, SamplingParams): + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: params = await \ - self._build_guided_decoding_logits_processor_async( - params, lora_request) + build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=await self.get_tokenizer(lora_request), + default_guided_backend=self.decoding_config.guided_decoding_backend + ) # 1) Create output queue for this requests. queue: asyncio.Queue[Union[RequestOutput, @@ -573,36 +577,6 @@ async def _process_request( finally: self.output_queues.pop(request_id) - async def _build_guided_decoding_logits_processor_async( - self, sampling_params: SamplingParams, - lora_request: Optional[LoRARequest]) -> SamplingParams: - """Constructs a logits processor based on any provided guided - decoding parameters. Updates logits_processors, deletes the guided - decoding params, and returns the modified sampling params""" - if (guided_decoding := sampling_params.guided_decoding) is None: - return sampling_params - - logger.debug( - "Building guided decoding logits processor in " - "AsyncEngineRPCClient. Params: %s", guided_decoding) - - tokenizer = await self.get_tokenizer(lora_request=lora_request) - guided_decoding.backend = guided_decoding.backend or \ - self.decoding_config.guided_decoding_backend - - processor = await get_guided_decoding_logits_processor( - guided_params=guided_decoding, tokenizer=tokenizer) - - if processor: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append(processor) - - # Unset guided decoding params after constructing the lp from them - sampling_params.guided_decoding = None - - return sampling_params - async def start_profile(self) -> None: """Start profiling the engine""" diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 8374bcc050a16..83f76410882de 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -528,4 +528,5 @@ def __repr__(self) -> str: f"skip_special_tokens={self.skip_special_tokens}, " "spaces_between_special_tokens=" f"{self.spaces_between_special_tokens}, " - f"truncate_prompt_tokens={self.truncate_prompt_tokens})") + f"truncate_prompt_tokens={self.truncate_prompt_tokens}), " + f"guided_decoding={self.guided_decoding}") From 538517515027c496263b8d1846d7bc16c67e923f Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 30 Sep 2024 16:34:45 -0600 Subject: [PATCH 28/28] retrigger CI