Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,74 +3,92 @@
import pytest
from transformers import AutoTokenizer

from tests.entrypoints.openai.reasoning_parsers.utils import (
run_reasoning_extraction)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager

parser_name = "deepseek_r1"
start_token = "<think>"
end_token = "</think>"

REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"


@pytest.fixture(scope="module")
def deepseek_r1_qwen_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)


SIMPLE_REASONING = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING = {
"output": "This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
NO_CONTENT = {
"output": "This is content",
"reasoning_content": "This is content",
"content": None,
"is_reasoning_end": False,
}
NO_REASONING_STREAMING = {
"output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
MULTIPLE_LINES = {
"output": "This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING = {
"output": "</think>This is the rest",
"reasoning_content": "",
"content": "This is the rest",
"is_reasoning_end": True,
}
SHORTEST_REASONING = {
"output": "</think>This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
MULTIPLE_LINES_WITH_THINK = {
"output": "<think>This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
"output": "</think>This is the rest",
"reasoning_content": "",
"content": "This is the rest",
"is_reasoning_end": True,
}
SHORTEST_REASONING_WITH_THINK = {
"output": "</think>This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": True,
}

TEST_CASES = [
Expand Down Expand Up @@ -166,27 +184,39 @@
),
]

# Global tokenizer initialization to avoid repeated loading
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
tokenizer.add_tokens([start_token, end_token])


@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
deepseek_r1_qwen_tokenizer,
):
output = tokenizer.tokenize(param_dict["output"])
output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"])
# decode everything to tokens
output_tokens: list[str] = [
tokenizer.convert_tokens_to_string([token]) for token in output
deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token])
for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
parser_name)(tokenizer)
parser_name)(deepseek_r1_qwen_tokenizer)

reasoning, content = run_reasoning_extraction(parser,
output_tokens,
streaming=streaming)

assert reasoning == param_dict["reasoning_content"]
assert content == param_dict["content"]

# Test is_reasoning_end
output_ids = deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]

# Test extract_content
if param_dict["content"] is not None:
content = parser.extract_content_ids(output_ids)
assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(
deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"]))
else:
content = parser.extract_content_ids(output)
assert content == []
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import pytest
from transformers import AutoTokenizer

from tests.entrypoints.openai.reasoning_parsers.utils import (
DeltaMessage, run_reasoning_extraction)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from tests.reasoning.utils import DeltaMessage, run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager

parser_name = "granite"
START_REASONING = "Here is my thought process:"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParser
from vllm.reasoning import ReasoningParser


class StreamingReasoningReconstructor:
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -1099,7 +1100,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"--reasoning-parser",
type=str,
choices=["deepseek_r1", "granite"],
choices=list(ReasoningParserManager.reasoning_parsers),
default=None,
help=
"Select the reasoning parser depending on the model that you're "
Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,8 +2084,9 @@ def _build_logits_processors(
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend

logger.debug("Reasoning backend: %s",
self.decoding_config.reasoning_backend)
if self.decoding_config.reasoning_backend is not None:
logger.debug("Building with reasoning backend %s",
self.decoding_config.reasoning_backend)

processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding,
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
TranscriptionRequest,
TranscriptionResponse,
UnloadLoRAAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
Expand All @@ -84,6 +83,7 @@
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import MistralTokenizer
Expand Down
3 changes: 1 addition & 2 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
Expand All @@ -33,6 +31,7 @@
MistralToolCall)
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
Expand Down
15 changes: 11 additions & 4 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from typing import TYPE_CHECKING

from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.reasoning import ReasoningParserManager

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
Expand Down Expand Up @@ -107,7 +107,11 @@ async def get_guided_decoding_logits_processor(
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:

reasoner = get_reasoner(tokenizer, reasoning_backend)
reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)

guided_params = maybe_backend_fallback(guided_params)

Expand Down Expand Up @@ -146,8 +150,11 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)

# Get the reasoner if needed, it will be None if reasoning_
reasoner = get_reasoner(tokenizer, reasoning_backend)
reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)

# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/guided_decoding/outlines_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams


Expand Down Expand Up @@ -61,7 +61,7 @@ class GuidedDecodingMode(Enum):
async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
Expand Down Expand Up @@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
Expand Down Expand Up @@ -141,7 +141,7 @@ def _get_logits_processor(
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.platforms import current_platform
from vllm.reasoning import ReasoningParser

logger = init_logger(__name__)

Expand All @@ -49,9 +49,9 @@

class BaseLogitsProcessor:

def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
self._guide: Guide = guide
self._reasoner: Optional[Reasoner] = reasoner
self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int)
Expand All @@ -69,7 +69,7 @@ def __call__(self, input_ids: List[int],
# Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the
# hash of the input_ids to store the FSM state.
input_ids = self._reasoner.extract_content(input_ids)
input_ids = self._reasoner.extract_content_ids(input_ids)

seq_id = hash(tuple(input_ids))

Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(
self,
regex_string: str,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
):
"""Compile the FSM that drives the regex-structured generation.

Expand All @@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner]):
reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the JSON-guided generation.

Parameters
Expand Down Expand Up @@ -203,7 +203,7 @@ def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
return CFGGuide(cfg, tokenizer)

def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner]):
reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the context free grammar generation.

Parameters
Expand Down
38 changes: 0 additions & 38 deletions vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py

This file was deleted.

Loading