Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Core] Move guided decoding params into sampling params #8252

Merged
merged 39 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
2673021
:recycle: refactor guided decoding construction
joerunde Sep 6, 2024
382a240
:recycle: refactor to_sampling_params
joerunde Sep 6, 2024
ff6c147
:recycle: update sampling params in openai servers
joerunde Sep 6, 2024
ab0fea0
:zap: build LPs in engine / client
joerunde Sep 9, 2024
41a18d5
:recycle: move test file
joerunde Sep 9, 2024
40260eb
:white_check_mark: fixup tests
joerunde Sep 9, 2024
9d3b185
:wastebasket: deprecate GuidedDecodingRequest usage
joerunde Sep 10, 2024
610423e
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 10, 2024
9b6ab55
:goal_net: disallow both guided options
joerunde Sep 10, 2024
429704d
:bug: fixup LLM tests
joerunde Sep 10, 2024
2df4783
:art: fmt
joerunde Sep 10, 2024
97a1116
:art: more fmt
joerunde Sep 10, 2024
24d95fe
:zap: ensure use of async outlines LP construction
joerunde Sep 10, 2024
0fa5080
:recycle: move guided params construction
joerunde Sep 10, 2024
2920459
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 12, 2024
450c76e
:art: fmt
joerunde Sep 12, 2024
6bfa8a8
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 16, 2024
07a17ef
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 17, 2024
d4540ca
:recycle: refactor mistral unsupported errors
joerunde Sep 17, 2024
1d580d2
:white_check_mark: test guided decoding validation
joerunde Sep 17, 2024
b984187
:art: fmt
joerunde Sep 17, 2024
1c2bbf1
:bug: fixup engine client test
joerunde Sep 17, 2024
aa84827
:bug: start to fixup for msgspec
joerunde Sep 18, 2024
c52075b
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 18, 2024
891c526
:zap: add LP construction to mqllmengine client
joerunde Sep 18, 2024
2b59d03
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 25, 2024
8f079e5
:art: fmt
joerunde Sep 25, 2024
ed29ec2
:recycle: extract BaseModel schema up-front
joerunde Sep 25, 2024
bea1716
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 25, 2024
74e9187
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 26, 2024
7b17aba
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 26, 2024
76182e4
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 27, 2024
6281044
:recycle: v -> s
joerunde Sep 27, 2024
a4e2523
Update vllm/engine/async_llm_engine.py
joerunde Sep 27, 2024
8acc98e
Update vllm/engine/llm_engine.py
joerunde Sep 27, 2024
485ce1d
Revert "Update vllm/engine/async_llm_engine.py"
joerunde Sep 30, 2024
f0a3f9d
:recycle: refactor guided decoding lp construction
joerunde Sep 30, 2024
ec92fd5
Merge remote-tracking branch 'upstream/main' into lp-scratch
joerunde Sep 30, 2024
5385175
retrigger CI
joerunde Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

from ...conftest import cleanup

Expand All @@ -31,14 +31,12 @@ def test_guided_regex(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
guided_decoding=GuidedDecodingParams(regex=sample_regex))
outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)

assert outputs is not None
for output in outputs:
Expand All @@ -57,15 +55,13 @@ def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
)
outputs = llm.generate(
prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_json=sample_json_schema))
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)

assert outputs is not None

Expand All @@ -86,12 +82,11 @@ def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_choice=sample_guided_choice))
use_tqdm=True)

assert outputs is not None
for output in outputs:
Expand All @@ -112,13 +107,13 @@ def test_guided_grammar(sample_sql_statements, llm):
temperature=0.8,
top_p=0.95,
max_tokens=1000,
)
guided_decoding=GuidedDecodingParams(grammar=sample_sql_statements))
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_grammar=sample_sql_statements))
)

assert outputs is not None
for output in outputs:
Expand All @@ -140,3 +135,28 @@ def test_guided_grammar(sample_sql_statements, llm):
assert generated_text.strip() == ground_truth

print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_options_request_deprecation_warning(sample_regex, llm):
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

with pytest.warns(DeprecationWarning, match="guided_options_request"):
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))


@pytest.mark.skip_global_cleanup
def test_validation_against_both_guided_decoding_options(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex))

with pytest.raises(ValueError, match="Cannot set both"):
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
49 changes: 49 additions & 0 deletions tests/model_executor/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest


@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")


@pytest.fixture
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# This unit test should be moved to a new
# tests/test_guided_decoding directory.
import pytest
import torch
from transformers import AutoTokenizer

from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams


def test_guided_logits_processors(sample_regex, sample_json_schema):
Expand Down Expand Up @@ -44,11 +42,9 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = CompletionRequest(model='test',
prompt=token_ids,
guided_regex=sample_regex)
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = await get_guided_decoding_logits_processor(
backend, regex_request, tokenizer)
regex_request, tokenizer)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand All @@ -59,14 +55,31 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = CompletionRequest(model='test',
prompt=token_ids,
guided_json=sample_json_schema)
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
backend, json_request, tokenizer)
json_request, tokenizer)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)


def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, regex=sample_regex)

with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, json_object=True)

with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"])

with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
44 changes: 44 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -464,6 +466,18 @@ async def add_request_async(
)
processed_inputs = self.input_processor(preprocessed_inputs)

if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
# Guided decoding has an async implementation for building logits
# processors in a separate threadpool.
# We want to invoke that here instead of using the blocking
# implementation in the LLMEngine
params = await build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=self.get_tokenizer(lora_request),
joerunde marked this conversation as resolved.
Show resolved Hide resolved
default_guided_backend=self.decoding_config.
guided_decoding_backend)

self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
Expand All @@ -480,6 +494,36 @@ async def check_health_async(self) -> None:
self.model_executor.check_health()


async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if (guided_decoding := sampling_params.guided_decoding) is None:
return sampling_params

logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding)

guided_decoding.backend = guided_decoding.backend or default_guided_backend

processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)

if processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(processor)

# Unset guided decoding params after constructing the lp from them
sampling_params.guided_decoding = None

return sampling_params


class AsyncLLMEngine:
"""An asynchronous wrapper for :class:`LLMEngine`.

Expand Down
54 changes: 54 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
Expand All @@ -33,6 +34,8 @@
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
Expand Down Expand Up @@ -833,6 +836,9 @@ def _create_sequence_group_with_sampling(
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")

sampling_params = self._build_logits_processors(
sampling_params, lora_request)

# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
Expand Down Expand Up @@ -1799,3 +1805,51 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs,
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens

def _build_logits_processors(
self, sampling_params: SamplingParams,
lora_request: Optional[LoRARequest]) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Returns the modified sampling params."""

logits_processors = []
if (guided_decoding := sampling_params.guided_decoding) is not None:

logger.debug(
"Building guided decoding logits processor in "
"LLMEngine. Params: %s", guided_decoding)

tokenizer = self.get_tokenizer(lora_request=lora_request)
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend

processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
if processor:
logits_processors.append(processor)

# Unset so this doesn't get passed down to the model
sampling_params.guided_decoding = None

if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)

processors = get_logits_processors(
logit_bias=sampling_params.logit_bias,
allowed_token_ids=sampling_params.allowed_token_ids,
tokenizer=tokenizer)
logits_processors.extend(processors)

# Unset so these don't get passed down to the model
sampling_params.logit_bias = None
sampling_params.allowed_token_ids = None

if logits_processors:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = logits_processors
else:
sampling_params.logits_processors.extend(logits_processors)

return sampling_params
Loading
Loading