-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][Core] Move guided decoding params into sampling params #8252
Changes from 3 commits
2673021
382a240
ff6c147
ab0fea0
41a18d5
40260eb
9d3b185
610423e
9b6ab55
429704d
2df4783
97a1116
24d95fe
0fa5080
2920459
450c76e
6bfa8a8
07a17ef
d4540ca
1d580d2
b984187
1c2bbf1
aa84827
c52075b
891c526
2b59d03
8f079e5
ed29ec2
bea1716
74e9187
7b17aba
76182e4
6281044
a4e2523
8acc98e
485ce1d
f0a3f9d
ec92fd5
5385175
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,77 +1,44 @@ | ||
from typing import Optional, Union | ||
from typing import Optional | ||
|
||
from vllm.entrypoints.openai.protocol import ( | ||
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, | ||
CompletionRequest) | ||
from vllm.model_executor.guided_decoding.guided_fields import ( | ||
GuidedDecodingRequest) | ||
from vllm.sampling_params import LogitsProcessor | ||
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor | ||
|
||
|
||
async def get_guided_decoding_logits_processor( | ||
guided_decoding_backend: str, request: Union[CompletionRequest, | ||
ChatCompletionRequest], | ||
guided_params: GuidedDecodingParams, | ||
tokenizer) -> Optional[LogitsProcessor]: | ||
request = _adapt_request_for_tool_use(request) | ||
|
||
if guided_decoding_backend == 'outlines': | ||
if guided_params.backend == 'outlines': | ||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 | ||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa | ||
get_outlines_guided_decoding_logits_processor) | ||
return await get_outlines_guided_decoding_logits_processor( | ||
request, tokenizer) | ||
if guided_decoding_backend == 'lm-format-enforcer': | ||
guided_params, tokenizer) | ||
if guided_params.backend == 'lm-format-enforcer': | ||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa | ||
get_lm_format_enforcer_guided_decoding_logits_processor) | ||
return await get_lm_format_enforcer_guided_decoding_logits_processor( | ||
request, tokenizer) | ||
get_local_lm_format_enforcer_guided_decoding_logits_processor) | ||
return get_local_lm_format_enforcer_guided_decoding_logits_processor( | ||
guided_params, tokenizer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the difference here is that in the case of But the de-duplication is good .. maybe we could just add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, right yeah we would want to use the async outlines construction in that case 🤔 |
||
|
||
raise ValueError( | ||
f"Unknown guided decoding backend '{guided_decoding_backend}'. " | ||
f"Unknown guided decoding backend '{guided_params.backend}'. " | ||
"Must be one of 'outlines, 'lm-format-enforcer'") | ||
|
||
|
||
def get_local_guided_decoding_logits_processor( | ||
guided_decoding_backend: str, guided_options: GuidedDecodingRequest, | ||
guided_params: GuidedDecodingParams, | ||
tokenizer) -> Optional[LogitsProcessor]: | ||
# request = _adapt_request_for_tool_use(request) | ||
|
||
if guided_decoding_backend == 'outlines': | ||
if guided_params.backend == 'outlines': | ||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 | ||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa | ||
get_local_outlines_guided_decoding_logits_processor) | ||
return get_local_outlines_guided_decoding_logits_processor( | ||
guided_options, tokenizer) | ||
if guided_decoding_backend == 'lm-format-enforcer': | ||
guided_params, tokenizer) | ||
if guided_params.backend == 'lm-format-enforcer': | ||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa | ||
get_local_lm_format_enforcer_guided_decoding_logits_processor) | ||
return get_local_lm_format_enforcer_guided_decoding_logits_processor( | ||
guided_options, tokenizer) | ||
guided_params, tokenizer) | ||
|
||
raise ValueError( | ||
f"Unknown guided decoding backend '{guided_decoding_backend}'. " | ||
f"Unknown guided decoding backend '{guided_params.backend}'. " | ||
"Must be one of 'outlines, 'lm-format-enforcer'") | ||
|
||
|
||
def _adapt_request_for_tool_use(request: Union[CompletionRequest, | ||
ChatCompletionRequest]): | ||
# the legacy completion API does not support tool use | ||
if type(request) is CompletionRequest: | ||
return request | ||
|
||
# user has chosen to not use any tool, | ||
# OR is allowing the model to choose a tool. | ||
if request.tool_choice == "none" or request.tool_choice == "auto": | ||
return request | ||
|
||
# user has chosen to use a named tool | ||
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam: | ||
tool_name = request.tool_choice.function.name | ||
tools = {tool.function.name: tool.function for tool in request.tools} | ||
if tool_name not in tools: | ||
raise ValueError( | ||
f"Tool '{tool_name}' has not been passed in `tools`.") | ||
tool = tools[tool_name] | ||
request.guided_json = tool.parameters | ||
|
||
return request |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
|
||
from pydantic import BaseModel | ||
|
||
# Nick leans towards ripping out so we don't have duplication | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At a minimum we should deprecate for a version. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, I'll stick a warning on! |
||
|
||
class LLMGuidedOptions(TypedDict, total=False): | ||
guided_json: Union[Dict, BaseModel, str] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to move these calls into
(Chat)CompletionRequest.to_sampling_params
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you right 👍