Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Core] Move guided decoding params into sampling params #8252

Merged
merged 39 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
2673021
:recycle: refactor guided decoding construction
joerunde Sep 6, 2024
382a240
:recycle: refactor to_sampling_params
joerunde Sep 6, 2024
ff6c147
:recycle: update sampling params in openai servers
joerunde Sep 6, 2024
ab0fea0
:zap: build LPs in engine / client
joerunde Sep 9, 2024
41a18d5
:recycle: move test file
joerunde Sep 9, 2024
40260eb
:white_check_mark: fixup tests
joerunde Sep 9, 2024
9d3b185
:wastebasket: deprecate GuidedDecodingRequest usage
joerunde Sep 10, 2024
610423e
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 10, 2024
9b6ab55
:goal_net: disallow both guided options
joerunde Sep 10, 2024
429704d
:bug: fixup LLM tests
joerunde Sep 10, 2024
2df4783
:art: fmt
joerunde Sep 10, 2024
97a1116
:art: more fmt
joerunde Sep 10, 2024
24d95fe
:zap: ensure use of async outlines LP construction
joerunde Sep 10, 2024
0fa5080
:recycle: move guided params construction
joerunde Sep 10, 2024
2920459
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 12, 2024
450c76e
:art: fmt
joerunde Sep 12, 2024
6bfa8a8
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 16, 2024
07a17ef
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 17, 2024
d4540ca
:recycle: refactor mistral unsupported errors
joerunde Sep 17, 2024
1d580d2
:white_check_mark: test guided decoding validation
joerunde Sep 17, 2024
b984187
:art: fmt
joerunde Sep 17, 2024
1c2bbf1
:bug: fixup engine client test
joerunde Sep 17, 2024
aa84827
:bug: start to fixup for msgspec
joerunde Sep 18, 2024
c52075b
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 18, 2024
891c526
:zap: add LP construction to mqllmengine client
joerunde Sep 18, 2024
2b59d03
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 25, 2024
8f079e5
:art: fmt
joerunde Sep 25, 2024
ed29ec2
:recycle: extract BaseModel schema up-front
joerunde Sep 25, 2024
bea1716
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 25, 2024
74e9187
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 26, 2024
7b17aba
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 26, 2024
76182e4
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 27, 2024
6281044
:recycle: v -> s
joerunde Sep 27, 2024
a4e2523
Update vllm/engine/async_llm_engine.py
joerunde Sep 27, 2024
8acc98e
Update vllm/engine/llm_engine.py
joerunde Sep 27, 2024
485ce1d
Revert "Update vllm/engine/async_llm_engine.py"
joerunde Sep 30, 2024
f0a3f9d
:recycle: refactor guided decoding lp construction
joerunde Sep 30, 2024
ec92fd5
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 30, 2024
5385175
retrigger CI
joerunde Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 10 additions & 26 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
from typing_extensions import Annotated, Required, TypedDict

from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

# torch is mocked during docs generation,
Expand Down Expand Up @@ -270,8 +268,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params

def to_sampling_params(
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
self,
guided_decoding: Optional[GuidedDecodingParams],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
Expand All @@ -281,15 +279,6 @@ def to_sampling_params(
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs

# We now allow logprobs being true without top_logrobs.
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=None,
tokenizer=tokenizer,
)
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)

return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
Expand All @@ -314,8 +303,9 @@ def to_sampling_params(
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias
)

@model_validator(mode="before")
Expand Down Expand Up @@ -512,8 +502,8 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params

def to_sampling_params(
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
self,
guided_decoding: Optional[GuidedDecodingParams],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
Expand All @@ -525,14 +515,6 @@ def to_sampling_params(

echo_without_generation = self.echo and self.max_tokens == 0

logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
tokenizer=tokenizer,
)
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)

return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
Expand All @@ -557,8 +539,10 @@ def to_sampling_params(
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids
)

@model_validator(mode="before")
Expand Down
36 changes: 31 additions & 5 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,13 @@ async def create_chat_completion(

request_id = f"chat-{random_uuid()}"
try:
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))

guided_decoding_params = \
self._create_guided_decoding_params(request)
# Some requests for tools will use guided decoding
if (guided_json :=
self._get_guided_json_from_tool(request)) is not None:
guided_decoding_params.guided_json = guided_json
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to move these calls into (Chat)CompletionRequest.to_sampling_params?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you right 👍


if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
Expand All @@ -177,8 +181,7 @@ async def create_chat_completion(
assert prompt_inputs is not None

sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
guided_decoding_params,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))

Expand Down Expand Up @@ -779,3 +782,26 @@ def _should_check_for_unstreamed_tool_arg_tokens(
and delta_message.tool_calls[0].function.arguments is not None
and output.finish_reason is not None
)

@staticmethod
def _get_guided_json_from_tool(
request: ChatCompletionRequest
) -> Optional[Union[str, dict, BaseModel]]:
# user has chosen to not use any tool
if request.tool_choice == "none" or request.tools is None:
return None

# user has chosen to use a named tool
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
tool_name = request.tool_choice.function.name
tools = {
tool.function.name: tool.function
for tool in request.tools
}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
return tool.parameters

return None
8 changes: 4 additions & 4 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ async def create_completion(
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)

guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
guided_decoding_params = self._create_guided_decoding_params(
request
)
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
Expand All @@ -111,8 +112,7 @@ async def create_completion(

for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
guided_decoding_params,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))

Expand Down
34 changes: 22 additions & 12 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AtomicCounter
Expand Down Expand Up @@ -156,15 +154,6 @@ def create_streaming_error_response(
})
return json_str

async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.async_engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(
guided_decoding_backend, request, tokenizer)

async def _check_model(
self,
request: AnyRequest,
Expand Down Expand Up @@ -480,3 +469,24 @@ async def unload_lora_adapter(
if lora_request.lora_name != lora_name
]
return f"Success: LoRA adapter '{lora_name}' removed successfully."

@staticmethod
def _create_guided_decoding_params(
api_request: Union[CompletionRequest, ChatCompletionRequest]
) -> GuidedDecodingParams:
"""Extract all of the guided decoding parameters from a frontend api
request"""
guided_json_object = None
if (api_request.response_format is not None
and api_request.response_format.type == "json_object"):
guided_json_object = True

return GuidedDecodingParams(
json=api_request.guided_json,
choice=api_request.guided_choice,
backend=api_request.guided_decoding_backend,
grammar=api_request.guided_grammar,
regex=api_request.guided_regex,
whitespace_pattern=api_request.guided_whitespace_pattern,
json_object=guided_json_object)

65 changes: 16 additions & 49 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,44 @@
from typing import Optional, Union
from typing import Optional

from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.sampling_params import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor


async def get_guided_decoding_logits_processor(
guided_decoding_backend: str, request: Union[CompletionRequest,
ChatCompletionRequest],
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
request = _adapt_request_for_tool_use(request)

if guided_decoding_backend == 'outlines':
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_lm_format_enforcer_guided_decoding_logits_processor)
return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer)
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the difference here is that in the case of request.guided_grammar the behavior is to revert to "outlines" and in that case it uses the async version of get_outlines_guided_decoding_logits_processor?

But the de-duplication is good .. maybe we could just add or request.guided_grammar to the if guided_params.backend == 'outlines' above? (with appropriate comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, right yeah we would want to use the async outlines construction in that case 🤔
I'll toss that in with a comment, good catch!


raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")


def get_local_guided_decoding_logits_processor(
guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
# request = _adapt_request_for_tool_use(request)

if guided_decoding_backend == 'outlines':
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options, tokenizer)
guided_params, tokenizer)

raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")


def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
if type(request) is CompletionRequest:
return request

# user has chosen to not use any tool,
# OR is allowing the model to choose a tool.
if request.tool_choice == "none" or request.tool_choice == "auto":
return request

# user has chosen to use a named tool
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = request.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in request.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
request.guided_json = tool.parameters

return request
1 change: 1 addition & 0 deletions vllm/model_executor/guided_decoding/guided_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pydantic import BaseModel

# Nick leans towards ripping out so we don't have duplication
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a minimum we should deprecate for a version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I'll stick a warning on!


class LLMGuidedOptions(TypedDict, total=False):
guided_json: Union[Dict, BaseModel, str]
Expand Down
Loading
Loading