From 205a2b214990a82ffacb18c27de452f1fd3a912b Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Thu, 27 Mar 2025 00:08:04 +0800 Subject: [PATCH 1/3] refactor reasoning Signed-off-by: Ce Gao --- .../__init__.py | 0 .../test_deepseekr1_reasoning_parser.py | 52 +++++++-- .../test_granite_reasoning_parser.py | 6 +- .../reasoning_parsers => reasoning}/utils.py | 2 +- vllm/engine/arg_utils.py | 3 +- vllm/engine/llm_engine.py | 5 +- vllm/entrypoints/openai/api_server.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 3 +- .../guided_decoding/__init__.py | 15 ++- .../guided_decoding/outlines_decoding.py | 8 +- .../outlines_logits_processors.py | 14 +-- .../reasoner/deepseek_reasoner.py | 38 ------- .../guided_decoding/reasoner/reasoner.py | 23 ---- .../guided_decoding/xgrammar_decoding.py | 6 +- .../__init__.py | 0 .../abs_reasoning_parsers.py | 105 +++++++++--------- .../deepseek_r1_reasoning_parser.py | 90 ++++++++------- .../granite_reasoning_parser.py | 57 ++++++---- 18 files changed, 207 insertions(+), 222 deletions(-) rename tests/{entrypoints/openai/reasoning_parsers => reasoning}/__init__.py (100%) rename tests/{entrypoints/openai/reasoning_parsers => reasoning}/test_deepseekr1_reasoning_parser.py (75%) rename tests/{entrypoints/openai/reasoning_parsers => reasoning}/test_granite_reasoning_parser.py (97%) rename tests/{entrypoints/openai/reasoning_parsers => reasoning}/utils.py (97%) delete mode 100644 vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py delete mode 100644 vllm/model_executor/guided_decoding/reasoner/reasoner.py rename vllm/{entrypoints/openai/reasoning_parsers => reasoning}/__init__.py (100%) rename vllm/{entrypoints/openai/reasoning_parsers => reasoning}/abs_reasoning_parsers.py (82%) rename vllm/{entrypoints/openai/reasoning_parsers => reasoning}/deepseek_r1_reasoning_parser.py (64%) rename vllm/{entrypoints/openai/reasoning_parsers => reasoning}/granite_reasoning_parser.py (92%) diff --git a/tests/entrypoints/openai/reasoning_parsers/__init__.py b/tests/reasoning/__init__.py similarity index 100% rename from tests/entrypoints/openai/reasoning_parsers/__init__.py rename to tests/reasoning/__init__.py diff --git a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py b/tests/reasoning/test_deepseekr1_reasoning_parser.py similarity index 75% rename from tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py rename to tests/reasoning/test_deepseekr1_reasoning_parser.py index 5ce5d9280f3e..7b6af183a86a 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py +++ b/tests/reasoning/test_deepseekr1_reasoning_parser.py @@ -3,74 +3,92 @@ import pytest from transformers import AutoTokenizer -from tests.entrypoints.openai.reasoning_parsers.utils import ( - run_reasoning_extraction) -from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, - ReasoningParserManager) +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager parser_name = "deepseek_r1" start_token = "" end_token = "" +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + +@pytest.fixture(scope="module") +def deepseek_r1_qwen_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + SIMPLE_REASONING = { "output": "This is a reasoning sectionThis is the rest", "reasoning_content": "This is a reasoning section", "content": "This is the rest", + "is_reasoning_end": True, } COMPLETE_REASONING = { "output": "This is a reasoning section", "reasoning_content": "This is a reasoning section", "content": None, + "is_reasoning_end": True, } NO_CONTENT = { "output": "This is content", "reasoning_content": "This is content", "content": None, + "is_reasoning_end": False, } NO_REASONING_STREAMING = { "output": "This is a reasoning section", "reasoning_content": "This is a reasoning section", "content": None, + "is_reasoning_end": False, } MULTIPLE_LINES = { "output": "This\nThatThis is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", + "is_reasoning_end": True, } SHORTEST_REASONING_NO_STREAMING = { "output": "This is the rest", "reasoning_content": "", "content": "This is the rest", + "is_reasoning_end": True, } SHORTEST_REASONING = { "output": "This is the rest", "reasoning_content": None, "content": "This is the rest", + "is_reasoning_end": True, } REASONING_WITH_THINK = { "output": "This is a reasoning sectionThis is the rest", "reasoning_content": "This is a reasoning section", "content": "This is the rest", + "is_reasoning_end": True, } COMPLETE_REASONING_WITH_THINK = { "output": "This is a reasoning section", "reasoning_content": "This is a reasoning section", "content": None, + "is_reasoning_end": True, } MULTIPLE_LINES_WITH_THINK = { "output": "This\nThatThis is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", + "is_reasoning_end": True, } SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { "output": "This is the rest", "reasoning_content": "", "content": "This is the rest", + "is_reasoning_end": True, } SHORTEST_REASONING_WITH_THINK = { "output": "This is the rest", "reasoning_content": None, "content": "This is the rest", + "is_reasoning_end": True, } TEST_CASES = [ @@ -166,23 +184,21 @@ ), ] -# Global tokenizer initialization to avoid repeated loading -tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") -tokenizer.add_tokens([start_token, end_token]) - @pytest.mark.parametrize("streaming, param_dict", TEST_CASES) def test_reasoning( streaming: bool, param_dict: dict, + deepseek_r1_qwen_tokenizer, ): - output = tokenizer.tokenize(param_dict["output"]) + output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"]) # decode everything to tokens output_tokens: list[str] = [ - tokenizer.convert_tokens_to_string([token]) for token in output + deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token]) + for token in output ] parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(tokenizer) + parser_name)(deepseek_r1_qwen_tokenizer) reasoning, content = run_reasoning_extraction(parser, output_tokens, @@ -190,3 +206,17 @@ def test_reasoning( assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] + + # Test is_reasoning_end + output_ids = deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(output) + is_reasoning_end = parser.is_reasoning_end(output_ids) + assert is_reasoning_end == param_dict["is_reasoning_end"] + + # Test extract_content + if param_dict["content"] is not None: + content = parser.extract_content_ids(output_ids) + assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids( + deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"])) + else: + content = parser.extract_content_ids(output) + assert content == [] diff --git a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py b/tests/reasoning/test_granite_reasoning_parser.py similarity index 97% rename from tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py rename to tests/reasoning/test_granite_reasoning_parser.py index 84ac6600498b..48fb8c2f8d1b 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py +++ b/tests/reasoning/test_granite_reasoning_parser.py @@ -2,10 +2,8 @@ import pytest from transformers import AutoTokenizer -from tests.entrypoints.openai.reasoning_parsers.utils import ( - DeltaMessage, run_reasoning_extraction) -from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, - ReasoningParserManager) +from tests.reasoning.utils import DeltaMessage, run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager parser_name = "granite" START_REASONING = "Here is my thought process:" diff --git a/tests/entrypoints/openai/reasoning_parsers/utils.py b/tests/reasoning/utils.py similarity index 97% rename from tests/entrypoints/openai/reasoning_parsers/utils.py rename to tests/reasoning/utils.py index 01e43130bc6e..0f894ed800c6 100644 --- a/tests/entrypoints/openai/reasoning_parsers/utils.py +++ b/tests/reasoning/utils.py @@ -4,7 +4,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) -from vllm.entrypoints.openai.reasoning_parsers import ReasoningParser +from vllm.reasoning import ReasoningParser class StreamingReasoningReconstructor: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index be00689f2b55..6c26b2a96e8b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -23,6 +23,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.plugins import load_general_plugins +from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext @@ -1099,7 +1100,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--reasoning-parser", type=str, - choices=["deepseek_r1", "granite"], + choices=list(ReasoningParserManager.reasoning_parsers), default=None, help= "Select the reasoning parser depending on the model that you're " diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3d019ea58c5e..1474c420bd18 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2084,8 +2084,9 @@ def _build_logits_processors( guided_decoding.backend = guided_decoding.backend or \ self.decoding_config.guided_decoding_backend - logger.debug("Reasoning backend: %s", - self.decoding_config.reasoning_backend) + if self.decoding_config.reasoning_backend is not None: + logger.debug("Building with reasoning backend %s", + self.decoding_config.reasoning_backend) processor = get_local_guided_decoding_logits_processor( guided_params=guided_decoding, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 374e43fb1534..29b0e9c78196 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -67,7 +67,6 @@ TranscriptionRequest, TranscriptionResponse, UnloadLoRAAdapterRequest) -from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -84,6 +83,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import load_aware_call, with_cancellation from vllm.logger import init_logger +from vllm.reasoning import ReasoningParserManager from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.transformers_utils.tokenizer import MistralTokenizer diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 3c35a848ea3a..c3db21d66947 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -23,8 +23,6 @@ ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo) -from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, - ReasoningParserManager) from vllm.entrypoints.openai.serving_engine import (OpenAIServing, clamp_prompt_logprobs) from vllm.entrypoints.openai.serving_models import OpenAIServingModels @@ -33,6 +31,7 @@ MistralToolCall) from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput +from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 0c26a60588c8..cecb3a8a1d4a 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -5,10 +5,10 @@ from typing import TYPE_CHECKING from vllm.logger import init_logger -from vllm.model_executor.guided_decoding.reasoner import get_reasoner from vllm.model_executor.guided_decoding.utils import ( convert_lark_to_gbnf, grammar_is_likely_lark, has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) +from vllm.reasoning import ReasoningParserManager if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -107,7 +107,11 @@ async def get_guided_decoding_logits_processor( model_config: ModelConfig, reasoning_backend: str | None = None) -> LogitsProcessor | None: - reasoner = get_reasoner(tokenizer, reasoning_backend) + reasoner = None + if reasoning_backend is not None: + reasoner_class = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + reasoner = reasoner_class(tokenizer) guided_params = maybe_backend_fallback(guided_params) @@ -146,8 +150,11 @@ def get_local_guided_decoding_logits_processor( reasoning_backend: str | None = None) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) - # Get the reasoner if needed, it will be None if reasoning_ - reasoner = get_reasoner(tokenizer, reasoning_backend) + reasoner = None + if reasoning_backend is not None: + reasoner_class = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + reasoner = reasoner_class(tokenizer) # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend_name == 'outlines': diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 97f63ae11f45..564f9277a83c 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -12,7 +12,7 @@ from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) -from vllm.model_executor.guided_decoding.reasoner import Reasoner +from vllm.reasoning import ReasoningParser from vllm.sampling_params import GuidedDecodingParams @@ -61,7 +61,7 @@ class GuidedDecodingMode(Enum): async def get_outlines_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[Reasoner], + reasoner: Optional[ReasoningParser], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor( def get_local_outlines_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[Reasoner], + reasoner: Optional[ReasoningParser], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -141,7 +141,7 @@ def _get_logits_processor( tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode, whitespace_pattern: Union[str, None], - reasoner: Optional[Reasoner], + reasoner: Optional[ReasoningParser], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: if mode == GuidedDecodingMode.JSON: return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 8b2a0f4cfe64..31af4593f112 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -34,8 +34,8 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.platforms import current_platform +from vllm.reasoning import ReasoningParser logger = init_logger(__name__) @@ -49,9 +49,9 @@ class BaseLogitsProcessor: - def __init__(self, guide: Guide, reasoner: Optional[Reasoner]): + def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]): self._guide: Guide = guide - self._reasoner: Optional[Reasoner] = reasoner + self._reasoner: Optional[ReasoningParser] = reasoner # CFGState is used for the FSM state for CFGGuide self._fsm_state: DefaultDict[int, Union[int, CFGState]] = defaultdict(int) @@ -69,7 +69,7 @@ def __call__(self, input_ids: List[int], # Remove the reasoning tokens from the input_ids # We need this because our implementation relies on the # hash of the input_ids to store the FSM state. - input_ids = self._reasoner.extract_content(input_ids) + input_ids = self._reasoner.extract_content_ids(input_ids) seq_id = hash(tuple(input_ids)) @@ -142,7 +142,7 @@ def __init__( self, regex_string: str, tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[Reasoner], + reasoner: Optional[ReasoningParser], ): """Compile the FSM that drives the regex-structured generation. @@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor): def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer: PreTrainedTokenizerBase, whitespace_pattern: Union[str, None], - reasoner: Optional[Reasoner]): + reasoner: Optional[ReasoningParser]): """Compile the FSM that drives the JSON-guided generation. Parameters @@ -203,7 +203,7 @@ def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: return CFGGuide(cfg, tokenizer) def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[Reasoner]): + reasoner: Optional[ReasoningParser]): """Compile the FSM that drives the context free grammar generation. Parameters diff --git a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py deleted file mode 100644 index 7e61e6a9620c..000000000000 --- a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py +++ /dev/null @@ -1,38 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass - -from transformers import PreTrainedTokenizer - -from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner - - -@dataclass -class DeepSeekReasoner(Reasoner): - """ - Reasoner for DeepSeek R series models. - """ - start_token_id: int - end_token_id: int - - start_token: str = "" - end_token: str = "" - - @classmethod - def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: - return cls(start_token_id=tokenizer.encode( - "", add_special_tokens=False)[0], - end_token_id=tokenizer.encode("", - add_special_tokens=False)[0]) - - def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.end_token_id in input_ids - - def extract_content(self, input_ids: list[int]) -> list[int]: - """ - Extract the content after the end tokens - """ - if self.end_token_id not in input_ids or \ - input_ids.index(self.end_token_id) + 1 == len(input_ids): - return [] - else: - return input_ids[input_ids.index(self.end_token_id) + 1:] diff --git a/vllm/model_executor/guided_decoding/reasoner/reasoner.py b/vllm/model_executor/guided_decoding/reasoner/reasoner.py deleted file mode 100644 index df21b1db6221..000000000000 --- a/vllm/model_executor/guided_decoding/reasoner/reasoner.py +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass - -from transformers import PreTrainedTokenizer - - -@dataclass -class Reasoner(ABC): - - @abstractmethod - def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: - pass - - @abstractmethod - def is_reasoning_end(self, input_ids: list[int]) -> bool: - pass - - @abstractmethod - def extract_content(self, input_ids: list[int]) -> list[int]: - pass diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index bc156223953e..47b1e7e3f981 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -27,7 +27,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig - from vllm.model_executor.guided_decoding.reasoner import Reasoner + from vllm.reasoning import ReasoningParser from vllm.sampling_params import GuidedDecodingParams logger = init_logger(__name__) @@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, model_config: ModelConfig, - reasoner: Reasoner | None, + reasoner: ReasoningParser | None, max_threads: int = 8): config = GrammarConfig.from_guided_params(guided_params=guided_params, model_config=model_config, @@ -280,7 +280,7 @@ def tokenizer_info(tokenizer_data: TokenizerData) -> xgr.TokenizerInfo: class XGrammarLogitsProcessor: """Wrapper class to support pickle protocol""" config: GrammarConfig - reasoner: Reasoner | None = None + reasoner: ReasoningParser | None = None ctx: xgr.CompiledGrammar | None = None tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment] diff --git a/vllm/entrypoints/openai/reasoning_parsers/__init__.py b/vllm/reasoning/__init__.py similarity index 100% rename from vllm/entrypoints/openai/reasoning_parsers/__init__.py rename to vllm/reasoning/__init__.py diff --git a/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py similarity index 82% rename from vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py rename to vllm/reasoning/abs_reasoning_parsers.py index c95ff191e4d2..454b8074ae1d 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -17,7 +17,7 @@ class ReasoningParser: """ - Abstract reasoning parser class that should not be used directly. + Abstract reasoning parser class that should not be used directly. Provided and methods should be used in derived classes. It is used to extract reasoning content from the model output. @@ -32,6 +32,38 @@ def vocab(self) -> dict[str, int]: # whereas all tokenizers have .get_vocab() return self.model_tokenizer.get_vocab() + @abstractmethod + def is_reasoning_end(self, input_ids: list[int]) -> bool: + """ + Check if the reasoning content ends in the input_ids. + + It is used in structured engines like `xgrammar` to check if the + reasoning content ends in the model output. + + Parameters: + input_ids: list[int] + The input_ids of the model output. + + Returns: + bool + True if the reasoning content ends in the input_ids. + """ + pass + + @abstractmethod + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract content token ids from the input_ids. + Parameters: + input_ids: list[int] + The input_ids of the model output. + Returns: + list[int] + The extracted content from the input_ids. + """ + pass + + @abstractmethod def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: @@ -52,11 +84,9 @@ def extract_reasoning_content( tuple[Optional[str], Optional[str]] A tuple containing the reasoning content and the content. """ + pass - raise NotImplementedError( - "AbstractReasoningParser.extract_reasoning_calls " - "has not been implemented!") - + @abstractmethod def extract_reasoning_content_streaming( self, previous_text: str, @@ -73,43 +103,7 @@ def extract_reasoning_content_streaming( the current tokens/diffs, but also the information about what has previously been parsed and extracted (see constructor) """ - raise NotImplementedError( - "AbstractReasoningParser.extract_reasoning_content_streaming " - "has not been implemented!") - - # TODO: need to rebase by PR #14428 - @abstractmethod - def is_reasoning_end(self, input_ids: list[int]) -> bool: - """ - Check if the reasoning content ends in the input_ids. - Parameters: - input_ids: list[int] - The input_ids of the model output. - Returns: - bool - True if the reasoning content ends in the input_ids. - """ - - raise NotImplementedError( - "AbstractReasoningParser.is_reasoning_end has" - "not been implemented!") - - # TODO: need to rebase by PR #14428 - @abstractmethod - def extract_content_ids(self, input_ids: list[int]) -> list[int]: - """ - Extract content token ids from the input_ids. - Parameters: - input_ids: list[int] - The input_ids of the model output. - Returns: - list[int] - The extracted content from the input_ids. - """ - - raise NotImplementedError( - "AbstractReasoningParser.extract_content_ids has" - " not been implemented!") + pass class ReasoningParserManager: @@ -125,14 +119,16 @@ def get_reasoning_parser(cls, name) -> type: if name in cls.reasoning_parsers: return cls.reasoning_parsers[name] - raise KeyError(f"reasoning helper: '{name}' not found in " - "reasoning_parsers") + raise KeyError( + f"reasoning helper: '{name}' not found in reasoning_parsers") @classmethod - def _register_module(cls, - module: type, - module_name: Optional[Union[str, list[str]]] = None, - force: bool = True) -> None: + def _register_module( + cls, + module: type, + module_name: Optional[Union[str, list[str]]] = None, + force: bool = True, + ) -> None: if not issubclass(module, ReasoningParser): raise TypeError("module must be subclass of ReasoningParser, " f"but got {type(module)}") @@ -149,13 +145,14 @@ def _register_module(cls, @classmethod def register_module( - cls, - name: Optional[Union[str, list[str]]] = None, - force: bool = True, - module: Union[type, None] = None) -> Union[type, Callable]: + cls, + name: Optional[Union[str, list[str]]] = None, + force: bool = True, + module: Union[type, None] = None, + ) -> Union[type, Callable]: """ Register module with the given name or name list. it can be used as a - decoder(with module as None) or normal function(with module as not + decoder(with module as None) or normal function(with module as not None). """ if not isinstance(force, bool): @@ -183,7 +180,7 @@ def _register(module): @classmethod def import_reasoning_parser(cls, plugin_path: str) -> None: """ - Import a user-defined reasoning parser by the path + Import a user-defined reasoning parser by the path of the reasoning parser define file. """ module_name = os.path.splitext(os.path.basename(plugin_path))[0] diff --git a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py similarity index 64% rename from vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py rename to vllm/reasoning/deepseek_r1_reasoning_parser.py index 54e960168cf4..73be6d4d1ab1 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -8,9 +8,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) -from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import ( - ReasoningParser, ReasoningParserManager) from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager logger = init_logger(__name__) @@ -20,43 +19,45 @@ class DeepSeekR1ReasoningParser(ReasoningParser): """ Reasoning parser for DeepSeek R1 model. - The DeepSeek R1 model uses ... tokens to denote reasoning + The DeepSeek R1 model uses ... tokens to denote reasoning text. This parser extracts the reasoning content from the model output. """ + start_token_id: int + end_token_id: int + + start_token: str = "" + end_token: str = "" + def __init__(self, tokenizer: PreTrainedTokenizerBase): super().__init__(tokenizer) - self.think_start_token = "" - self.think_end_token = "" self.reasoning_regex = re.compile( - rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL) + rf"{self.start_token}(.*?){self.end_token}", re.DOTALL) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ReasoningParser " "constructor during construction.") - self.think_start_token_id = self.vocab.get(self.think_start_token) - self.think_end_token_id = self.vocab.get(self.think_end_token) - if (self.think_start_token_id is None - or self.think_end_token_id is None): + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + if self.start_token_id is None or self.end_token_id is None: raise RuntimeError( "DeepSeek R1 reasoning parser could not locate think start/end " "tokens in the tokenizer!") - # TODO: need to rebase by PR #14428 def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.think_end_token_id in input_ids + return self.end_token_id in input_ids def extract_content_ids(self, input_ids: list[int]) -> list[int]: """ Extract the content after the end tokens """ - if self.think_end_token_id not in input_ids[:-1]: + if self.end_token_id not in input_ids[:-1]: return [] else: - return input_ids[input_ids.index(self.think_end_token_id) + 1:] + return input_ids[input_ids.index(self.end_token_id) + 1:] def extract_reasoning_content_streaming( self, @@ -77,22 +78,24 @@ def extract_reasoning_content_streaming( """ # Skip single special tokens if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.think_start_token_id, self.think_end_token_id + self.start_token_id, self.end_token_id ]): return None # Check if is present in previous or delta. # Keep compatibility with models that don't generate tokens. - if self.think_start_token_id in previous_token_ids: - if self.think_end_token_id in delta_token_ids: + if self.start_token_id in previous_token_ids: + if self.end_token_id in delta_token_ids: # in previous, in delta, # extract reasoning content - end_index = delta_text.find(self.think_end_token) + end_index = delta_text.find(self.end_token) reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) - elif self.think_end_token_id in previous_token_ids: + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: # in previous, in previous, # reasoning content continues return DeltaMessage(content=delta_text) @@ -100,17 +103,18 @@ def extract_reasoning_content_streaming( # in previous, no in previous or delta, # reasoning content continues return DeltaMessage(reasoning_content=delta_text) - elif self.think_start_token_id in delta_token_ids: - if self.think_end_token_id in delta_token_ids: + elif self.start_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: # in delta, in delta, extract reasoning content - start_index = delta_text.find(self.think_start_token) - end_index = delta_text.find(self.think_end_token) + start_index = delta_text.find(self.start_token) + end_index = delta_text.find(self.end_token) reasoning_content = delta_text[start_index + - len(self.think_start_token - ):end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) + len(self.start_token):end_index] + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) else: # in delta, no in delta, # reasoning content continues @@ -119,15 +123,17 @@ def extract_reasoning_content_streaming( # No in previous or delta, also need to check for . # Because the model may have generated without # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.think_end_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: # in delta with more tokens, # extract reasoning content and content - end_index = delta_text.find(self.think_end_token) + end_index = delta_text.find(self.end_token) reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) - elif self.think_end_token_id in previous_token_ids: + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: # in previous, thinking content ends return DeltaMessage(content=delta_text) else: @@ -137,22 +143,20 @@ def extract_reasoning_content_streaming( def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: - # DeepSeek R1 doesn't generate now. # Thus we assume the reasoning content is always at the start. # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.think_end_token not in model_output: + if self.end_token not in model_output: return model_output, None else: # Add a start token if it's missing to keep compatibility. - if self.think_start_token not in model_output: - model_output = f"{self.think_start_token}{model_output}" + if self.start_token not in model_output: + model_output = f"{self.start_token}{model_output}" # Use a regex to find the reasoning content reasoning_content = self.reasoning_regex.findall(model_output)[0] end_index = len( - f"{self.think_start_token}{reasoning_content}{self.think_end_token}" - ) + f"{self.start_token}{reasoning_content}{self.end_token}") final_output = model_output[end_index:] if len(final_output) == 0: diff --git a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py b/vllm/reasoning/granite_reasoning_parser.py similarity index 92% rename from vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py rename to vllm/reasoning/granite_reasoning_parser.py index 117d051a7378..4ae38f59c8c9 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py +++ b/vllm/reasoning/granite_reasoning_parser.py @@ -8,9 +8,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) -from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import ( - ReasoningParser, ReasoningParserManager) from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager logger = init_logger(__name__) @@ -35,13 +34,16 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): self.reasoning_regex = re.compile( rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", - re.DOTALL) + re.DOTALL, + ) self.valid_think_starts = [ - "Here's my thought process:", "Here is my thought process:" + "Here's my thought process:", + "Here is my thought process:", ] self.valid_response_starts = [ - "Here's my response:", "Here is my response:" + "Here's my response:", + "Here is my response:", ] # Substrings to match for sequence boundaries on raw text @@ -127,8 +129,12 @@ def extract_reasoning_content_streaming( # This should never happen since we matched on the response assert resp_seq_len is not None delta_message = self._get_delta_message_with_both_bounds( - delta_text, reasoning_content, content, current_text, - resp_seq_len) + delta_text, + reasoning_content, + content, + current_text, + resp_seq_len, + ) if not delta_message.content and not delta_message.reasoning_content: return None return delta_message @@ -139,7 +145,7 @@ def _is_reasoning_start_substr(self, text: str) -> bool: Args: text (str): Text to check for leading substr. - + Returns: bool: True if any of the possible reasoning start seqs match. """ @@ -152,7 +158,7 @@ def _is_response_start_substr(self, text: str) -> bool: Args: text (str): Text to check for leading substr. - + Returns: bool: True if any of the possible response start seqs match. """ @@ -234,12 +240,12 @@ def _get_delta_message_with_no_response_bounds( delta_idx = delta_text.rfind(self.seq_boundary_start) # Check the state of potential start of response substring matches. - prev_was_substr = self._is_response_start_substr( - previous_text[prev_idx:]) if prev_idx >= 0 else False - delta_continues_substr = self._is_response_start_substr( - current_text[prev_idx:]) if prev_idx >= 0 else False - delta_new_substr = self._is_response_start_substr( - delta_text[delta_idx:]) if delta_idx >= 0 else False + prev_was_substr = (self._is_response_start_substr( + previous_text[prev_idx:]) if prev_idx >= 0 else False) + delta_continues_substr = (self._is_response_start_substr( + current_text[prev_idx:]) if prev_idx >= 0 else False) + delta_new_substr = (self._is_response_start_substr( + delta_text[delta_idx:]) if delta_idx >= 0 else False) # Delta only contains potential continued response sequence text. if delta_continues_substr: @@ -256,8 +262,8 @@ def _get_delta_message_with_no_response_bounds( # seq wasn't one; we need to add the content to the delta message, # and also slice off the potential response sequence elif delta_new_substr: - reasoning_content = previous_text[ - prev_idx:] + delta_text[:delta_idx] + reasoning_content = (previous_text[prev_idx:] + + delta_text[:delta_idx]) return DeltaMessage(reasoning_content=reasoning_content, content=None) # No new substring yet, and we broke our old one; take the whole delta @@ -296,9 +302,9 @@ def _get_delta_message_with_both_bounds( delta_reasoning_content = None else: # Get the starting offset - start_reasoning_content_idx = len( - reasoning_content) + response_seq_len + len( - response_content) - 1 + start_reasoning_content_idx = (len(reasoning_content) + + response_seq_len + + len(response_content) - 1) delta_offset = len(current_text) - len(delta_text) start_offset = start_reasoning_content_idx - delta_offset if start_offset < 0: @@ -346,8 +352,8 @@ def _get_content_sections( # Check to see if the start of response seq if complete elif not parsed_content: for response_start in self.valid_response_starts: - if current_chunk[-len(response_start) + - 1:] == response_start[:-1]: + if (current_chunk[-len(response_start) + + 1:] == response_start[:-1]): # Mark end of reasoning and start response content # after the start of response sequence. end_reasoning_content = current_chunk_end - len( @@ -355,8 +361,11 @@ def _get_content_sections( reasoning_content = current_text[ start_reasoning_content:end_reasoning_content] response_content = current_text[current_chunk_end + 1:] - return reasoning_content, len( - response_start), response_content + return ( + reasoning_content, + len(response_start), + response_content, + ) if start_reasoning_content and not parsed_content: return current_text[start_reasoning_content:], None, None From b2d8761007fb486fa18ddf7618959ad889d1ff60 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Thu, 27 Mar 2025 12:34:51 +0800 Subject: [PATCH 2/3] revert Signed-off-by: Ce Gao --- vllm/reasoning/granite_reasoning_parser.py | 54 +++++++++------------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/vllm/reasoning/granite_reasoning_parser.py b/vllm/reasoning/granite_reasoning_parser.py index 4ae38f59c8c9..249ace1f167f 100644 --- a/vllm/reasoning/granite_reasoning_parser.py +++ b/vllm/reasoning/granite_reasoning_parser.py @@ -34,16 +34,13 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): self.reasoning_regex = re.compile( rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", - re.DOTALL, - ) + re.DOTALL) self.valid_think_starts = [ - "Here's my thought process:", - "Here is my thought process:", + "Here's my thought process:", "Here is my thought process:" ] self.valid_response_starts = [ - "Here's my response:", - "Here is my response:", + "Here's my response:", "Here is my response:" ] # Substrings to match for sequence boundaries on raw text @@ -129,12 +126,8 @@ def extract_reasoning_content_streaming( # This should never happen since we matched on the response assert resp_seq_len is not None delta_message = self._get_delta_message_with_both_bounds( - delta_text, - reasoning_content, - content, - current_text, - resp_seq_len, - ) + delta_text, reasoning_content, content, current_text, + resp_seq_len) if not delta_message.content and not delta_message.reasoning_content: return None return delta_message @@ -145,7 +138,7 @@ def _is_reasoning_start_substr(self, text: str) -> bool: Args: text (str): Text to check for leading substr. - + Returns: bool: True if any of the possible reasoning start seqs match. """ @@ -158,7 +151,7 @@ def _is_response_start_substr(self, text: str) -> bool: Args: text (str): Text to check for leading substr. - + Returns: bool: True if any of the possible response start seqs match. """ @@ -240,12 +233,12 @@ def _get_delta_message_with_no_response_bounds( delta_idx = delta_text.rfind(self.seq_boundary_start) # Check the state of potential start of response substring matches. - prev_was_substr = (self._is_response_start_substr( - previous_text[prev_idx:]) if prev_idx >= 0 else False) - delta_continues_substr = (self._is_response_start_substr( - current_text[prev_idx:]) if prev_idx >= 0 else False) - delta_new_substr = (self._is_response_start_substr( - delta_text[delta_idx:]) if delta_idx >= 0 else False) + prev_was_substr = self._is_response_start_substr( + previous_text[prev_idx:]) if prev_idx >= 0 else False + delta_continues_substr = self._is_response_start_substr( + current_text[prev_idx:]) if prev_idx >= 0 else False + delta_new_substr = self._is_response_start_substr( + delta_text[delta_idx:]) if delta_idx >= 0 else False # Delta only contains potential continued response sequence text. if delta_continues_substr: @@ -262,8 +255,8 @@ def _get_delta_message_with_no_response_bounds( # seq wasn't one; we need to add the content to the delta message, # and also slice off the potential response sequence elif delta_new_substr: - reasoning_content = (previous_text[prev_idx:] + - delta_text[:delta_idx]) + reasoning_content = previous_text[ + prev_idx:] + delta_text[:delta_idx] return DeltaMessage(reasoning_content=reasoning_content, content=None) # No new substring yet, and we broke our old one; take the whole delta @@ -302,9 +295,9 @@ def _get_delta_message_with_both_bounds( delta_reasoning_content = None else: # Get the starting offset - start_reasoning_content_idx = (len(reasoning_content) + - response_seq_len + - len(response_content) - 1) + start_reasoning_content_idx = len( + reasoning_content) + response_seq_len + len( + response_content) - 1 delta_offset = len(current_text) - len(delta_text) start_offset = start_reasoning_content_idx - delta_offset if start_offset < 0: @@ -352,8 +345,8 @@ def _get_content_sections( # Check to see if the start of response seq if complete elif not parsed_content: for response_start in self.valid_response_starts: - if (current_chunk[-len(response_start) + - 1:] == response_start[:-1]): + if current_chunk[-len(response_start) + + 1:] == response_start[:-1]: # Mark end of reasoning and start response content # after the start of response sequence. end_reasoning_content = current_chunk_end - len( @@ -361,11 +354,8 @@ def _get_content_sections( reasoning_content = current_text[ start_reasoning_content:end_reasoning_content] response_content = current_text[current_chunk_end + 1:] - return ( - reasoning_content, - len(response_start), - response_content, - ) + return reasoning_content, len( + response_start), response_content if start_reasoning_content and not parsed_content: return current_text[start_reasoning_content:], None, None From 18024cfe97f9d3f31d160017673d6af37a7e20c5 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Thu, 27 Mar 2025 14:34:28 +0800 Subject: [PATCH 3/3] Remove `pass` Signed-off-by: Ce Gao --- vllm/reasoning/abs_reasoning_parsers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 454b8074ae1d..454167a0dc95 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -48,7 +48,6 @@ def is_reasoning_end(self, input_ids: list[int]) -> bool: bool True if the reasoning content ends in the input_ids. """ - pass @abstractmethod def extract_content_ids(self, input_ids: list[int]) -> list[int]: @@ -61,7 +60,6 @@ def extract_content_ids(self, input_ids: list[int]) -> list[int]: list[int] The extracted content from the input_ids. """ - pass @abstractmethod def extract_reasoning_content( @@ -84,7 +82,6 @@ def extract_reasoning_content( tuple[Optional[str], Optional[str]] A tuple containing the reasoning content and the content. """ - pass @abstractmethod def extract_reasoning_content_streaming( @@ -103,7 +100,6 @@ def extract_reasoning_content_streaming( the current tokens/diffs, but also the information about what has previously been parsed and extracted (see constructor) """ - pass class ReasoningParserManager: