Skip to content

Commit

Permalink
[Frontend][Core] Move guided decoding params into sampling params (#8252
Browse files Browse the repository at this point in the history
)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
  • Loading branch information
joerunde and njhill authored Oct 1, 2024
1 parent bce3244 commit 062c89e
Show file tree
Hide file tree
Showing 16 changed files with 441 additions and 281 deletions.
66 changes: 43 additions & 23 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

from ...conftest import cleanup

Expand All @@ -31,14 +31,12 @@ def test_guided_regex(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
guided_decoding=GuidedDecodingParams(regex=sample_regex))
outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)

assert outputs is not None
for output in outputs:
Expand All @@ -57,15 +55,13 @@ def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
)
outputs = llm.generate(
prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_json=sample_json_schema))
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)

assert outputs is not None

Expand All @@ -86,12 +82,11 @@ def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_choice=sample_guided_choice))
use_tqdm=True)

assert outputs is not None
for output in outputs:
Expand All @@ -112,13 +107,13 @@ def test_guided_grammar(sample_sql_statements, llm):
temperature=0.8,
top_p=0.95,
max_tokens=1000,
)
guided_decoding=GuidedDecodingParams(grammar=sample_sql_statements))
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_grammar=sample_sql_statements))
)

assert outputs is not None
for output in outputs:
Expand All @@ -140,3 +135,28 @@ def test_guided_grammar(sample_sql_statements, llm):
assert generated_text.strip() == ground_truth

print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_options_request_deprecation_warning(sample_regex, llm):
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

with pytest.warns(DeprecationWarning, match="guided_options_request"):
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))


@pytest.mark.skip_global_cleanup
def test_validation_against_both_guided_decoding_options(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex))

with pytest.raises(ValueError, match="Cannot set both"):
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
49 changes: 49 additions & 0 deletions tests/model_executor/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest


@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")


@pytest.fixture
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# This unit test should be moved to a new
# tests/test_guided_decoding directory.
import pytest
import torch
from transformers import AutoTokenizer

from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams


def test_guided_logits_processors(sample_regex, sample_json_schema):
Expand Down Expand Up @@ -44,11 +42,9 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = CompletionRequest(model='test',
prompt=token_ids,
guided_regex=sample_regex)
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = await get_guided_decoding_logits_processor(
backend, regex_request, tokenizer)
regex_request, tokenizer)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand All @@ -59,14 +55,31 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = CompletionRequest(model='test',
prompt=token_ids,
guided_json=sample_json_schema)
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
backend, json_request, tokenizer)
json_request, tokenizer)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)


def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, regex=sample_regex)

with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, json_object=True)

with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"])

with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
44 changes: 44 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -477,6 +479,18 @@ async def add_request_async(
)
processed_inputs = self.input_processor(preprocessed_inputs)

if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
# Guided decoding has an async implementation for building logits
# processors in a separate threadpool.
# We want to invoke that here instead of using the blocking
# implementation in the LLMEngine
params = await build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=self.get_tokenizer(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)

self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
Expand All @@ -494,6 +508,36 @@ async def check_health_async(self) -> None:
self.model_executor.check_health()


async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if (guided_decoding := sampling_params.guided_decoding) is None:
return sampling_params

logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding)

guided_decoding.backend = guided_decoding.backend or default_guided_backend

processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)

if processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(processor)

# Unset guided decoding params after constructing the lp from them
sampling_params.guided_decoding = None

return sampling_params


class AsyncLLMEngine:
"""An asynchronous wrapper for :class:`LLMEngine`.
Expand Down
54 changes: 54 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
Expand All @@ -33,6 +34,8 @@
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
Expand Down Expand Up @@ -843,6 +846,9 @@ def _create_sequence_group_with_sampling(
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")

sampling_params = self._build_logits_processors(
sampling_params, lora_request)

# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
Expand Down Expand Up @@ -1895,3 +1901,51 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs,
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens

def _build_logits_processors(
self, sampling_params: SamplingParams,
lora_request: Optional[LoRARequest]) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Returns the modified sampling params."""

logits_processors = []
if (guided_decoding := sampling_params.guided_decoding) is not None:

logger.debug(
"Building guided decoding logits processor in "
"LLMEngine. Params: %s", guided_decoding)

tokenizer = self.get_tokenizer(lora_request=lora_request)
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend

processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
if processor:
logits_processors.append(processor)

# Unset so this doesn't get passed down to the model
sampling_params.guided_decoding = None

if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)

processors = get_logits_processors(
logit_bias=sampling_params.logit_bias,
allowed_token_ids=sampling_params.allowed_token_ids,
tokenizer=tokenizer)
logits_processors.extend(processors)

# Unset so these don't get passed down to the model
sampling_params.logit_bias = None
sampling_params.allowed_token_ids = None

if logits_processors:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = logits_processors
else:
sampling_params.logits_processors.extend(logits_processors)

return sampling_params
Loading

0 comments on commit 062c89e

Please sign in to comment.