Skip to content

Commit

Permalink
Removes duplicate outlines processors
Browse files Browse the repository at this point in the history
  • Loading branch information
lynkz-matt-psaltis committed Jul 29, 2024
1 parent 7cbd9ec commit 8d9b6de
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 214 deletions.
14 changes: 6 additions & 8 deletions tests/entrypoints/openai/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
from transformers import AutoTokenizer

from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
from outlines.processors import JSONLogitsProcessor, RegexLogitsProcessor


def test_guided_logits_processors(sample_regex, sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta", )
regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
json_LP = JSONLogitsProcessor(sample_json_schema,
tokenizer,
Expand Down Expand Up @@ -41,10 +39,10 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = CompletionRequest(model='test',
regex_request = CompletionRequest(model="test",
prompt=token_ids,
guided_regex=sample_regex)
regex_lp = await get_guided_decoding_logits_processor(
Expand All @@ -59,7 +57,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = CompletionRequest(model='test',
json_request = CompletionRequest(model="test",
prompt=token_ids,
guided_json=sample_json_schema)
json_lp = await get_guided_decoding_logits_processor(
Expand Down
31 changes: 20 additions & 11 deletions vllm/model_executor/guided_decoding/outlines_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from outlines.processors import (
CFGLogitsProcessor,
JSONLogitsProcessor,
RegexLogitsProcessor,
)


class GuidedDecodingMode(Enum):
Expand Down Expand Up @@ -52,8 +54,8 @@ class GuidedDecodingMode(Enum):


async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest,
ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer: PreTrainedTokenizerBase,
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
Expand All @@ -72,9 +74,14 @@ async def get_outlines_guided_decoding_logits_processor(
max_workers=2)
loop = asyncio.get_running_loop()

return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer,
mode, request.guided_whitespace_pattern)
return await loop.run_in_executor(
global_thread_pool,
_get_logits_processor,
guide,
tokenizer,
mode,
request.guided_whitespace_pattern,
)


def _get_guide_and_mode(
Expand Down Expand Up @@ -110,8 +117,10 @@ def _get_guide_and_mode(


def _get_logits_processor(
guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None]
guide: str,
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
Expand Down
195 changes: 0 additions & 195 deletions vllm/model_executor/guided_decoding/outlines_logits_processors.py

This file was deleted.

0 comments on commit 8d9b6de

Please sign in to comment.