diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index 0b170aadc344..879b16d4f7b5 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -4,7 +4,7 @@ vLLM offers support for reasoning models like [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), which are designed to generate outputs containing both reasoning steps and final conclusions. -Reasoning models return a additional `reasoning_content` field in their outputs, which contains the reasoning steps that led to the final conclusion. This field is not present in the outputs of other models. +Reasoning models return an additional `reasoning_content` field in their outputs, which contains the reasoning steps that led to the final conclusion. This field is not present in the outputs of other models. ## Supported Models @@ -14,6 +14,9 @@ vLLM currently supports the following reasoning models: |--------------|-------------|------------------|-------------| | [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ | | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | +| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | + +- IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. ## Quickstart @@ -43,6 +46,7 @@ model = models.data[0].id # Round 1 messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] +# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` response = client.chat.completions.create(model=model, messages=messages) reasoning_content = response.choices[0].message.reasoning_content @@ -97,6 +101,7 @@ models = client.models.list() model = models.data[0].id messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] +# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` stream = client.chat.completions.create(model=model, messages=messages, stream=True) diff --git a/examples/online_serving/openai_chat_completion_with_reasoning.py b/examples/online_serving/openai_chat_completion_with_reasoning.py index b5dbed1205d3..e753cedcdc08 100644 --- a/examples/online_serving/openai_chat_completion_with_reasoning.py +++ b/examples/online_serving/openai_chat_completion_with_reasoning.py @@ -31,6 +31,7 @@ # Round 1 messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] +# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` response = client.chat.completions.create(model=model, messages=messages) reasoning_content = response.choices[0].message.reasoning_content diff --git a/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py b/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py index fe4332576d43..cb13b0c614aa 100644 --- a/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py +++ b/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py @@ -38,6 +38,7 @@ model = models.data[0].id messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] +# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` stream = client.chat.completions.create(model=model, messages=messages, stream=True) diff --git a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py b/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py new file mode 100644 index 000000000000..84ac6600498b --- /dev/null +++ b/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py @@ -0,0 +1,349 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +from transformers import AutoTokenizer + +from tests.entrypoints.openai.reasoning_parsers.utils import ( + DeltaMessage, run_reasoning_extraction) +from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, + ReasoningParserManager) + +parser_name = "granite" +START_REASONING = "Here is my thought process:" +START_RESPONSE = "Here is my response:" + +SIMPLE_REASONING = { + "output": + f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501 + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +COMPLETE_REASONING = { + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}", + "reasoning_content": "This is a reasoning section", + "content": None, +} +NO_REASONING = { + "output": "This is content", + "reasoning_content": None, + "content": "This is content", +} +MULTIPLE_LINES = { + "output": + f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", +} +REASONING_WITH_THINK = { + "output": + f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501 + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +COMPLETE_REASONING_WITH_THINK = { + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}", + "reasoning_content": "This is a reasoning section", + "content": None, +} +MULTIPLE_LINES_WITH_THINK = { + "output": + f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", +} + +TEST_CASES = [ + pytest.param( + False, + SIMPLE_REASONING, + id="simple_reasoning", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + False, + NO_REASONING, + id="no_reasoning", + ), + pytest.param( + False, + MULTIPLE_LINES, + id="multiple_lines", + ), + pytest.param( + False, + REASONING_WITH_THINK, + id="reasoning_with_think", + ), + pytest.param( + False, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think", + ), + pytest.param( + False, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think", + ), + pytest.param( + True, + SIMPLE_REASONING, + id="simple_reasoning_streaming", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_streaming", + ), + pytest.param( + True, + NO_REASONING, + id="no_reasoning_streaming", + ), + pytest.param( + True, + MULTIPLE_LINES, + id="multiple_lines_streaming", + ), + pytest.param( + True, + REASONING_WITH_THINK, + id="reasoning_with_think_streaming", + ), + pytest.param( + True, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think_streaming", + ), + pytest.param( + True, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think_streaming", + ), +] + +# Global tokenizer initialization to avoid repeated loading +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, +): + output = tokenizer.tokenize(param_dict["output"]) + # decode everything to tokens + output_tokens: list[str] = [ + tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(tokenizer) + + reasoning, content = run_reasoning_extraction(parser, + output_tokens, + streaming=streaming) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] + + +# Additional tests for verifying the correctness of granite streaming; this +# is complicated because granite uses multiple tokens to indicate when thinking +# is starting / when it's starting its response, so skipping special tokens +# is awkward. + +### Handling the start of reasoning +STREAMING_1 = { + "previous_text": None, + "current_text": "Here", + "delta_text": "Here", + "reasoning_content": None, + "content": None, +} +# When we fail, we should give what was previously being silenced first +STREAMING_2 = { + "previous_text": "Here is my thought", + "current_text": "Here is my thought failure", + "delta_text": " failure", + "reasoning_content": None, + "content": "Here is my thought failure", +} +# But then after the first one, we should only add the delta text to content +STREAMING_3 = { + "previous_text": "Here wrong", + "current_text": " words", + "delta_text": " Here wrong words", + "reasoning_content": None, + "content": " words", +} +# But then after the first one, we should only add the delta text to content +STREAMING_4 = { + "previous_text": "Here is my thought", + "current_text": "Here is my thought process:", + "delta_text": " process:", + "reasoning_content": None, + "content": None, +} +# Reasoning started successfully; parse reasoning content +STREAMING_5 = { + "previous_text": "Here is my thought process:", + "current_text": "Here is my thought process: foo", + "delta_text": " foo", + "reasoning_content": " foo", + "content": None, +} +# Response special sequence has started, but not finished. +STREAMING_6 = { + "previous_text": "Here is my thought process: foo", + "current_text": "Here is my thought process: foo Here is", + "delta_text": " Here is", + "reasoning_content": " ", + "content": None, +} +# Response special sequence started, but was broken; the reasoning +# content should be the content that was previously unused. +STREAMING_7 = { + "previous_text": "Here is my thought process: foo Here is", + "current_text": "Here is my thought process: foo Here is Here", + "delta_text": " Here", + "reasoning_content": "Here is ", + "content": None, +} +# Response special sequence is ongoing +STREAMING_8 = { + "previous_text": "Here is my thought process: foo Here is my response:", + "current_text": "Here is my thought process: foo Here is my response: bar", + "delta_text": " bar", + "reasoning_content": None, + "content": " bar", +} +# The delta text has everything; we should be able to correctly parse both +STREAMING_9 = { + "previous_text": None, + "current_text": "Here is my thought process: foo Here is my response: bar", + "delta_text": "Here is my thought process: foo Here is my response: bar", + "reasoning_content": " foo ", + "content": " bar", +} +## The Response is ongoing, and the delta mixes reasoning content / content +STREAMING_10 = { + "previous_text": "Here is my thought process: foo", + "current_text": + "Here is my thought process: foo bar Here is my response: baz", + "delta_text": " bar Here is my response: baz", + "reasoning_content": " bar ", + "content": " baz", +} +# The delta text starts a new substring that might be a response special seq +STREAMING_11 = { + "previous_text": + "Here is my thought process: This is a reasoning section ", + "current_text": + "Here is my thought process: This is a reasoning section Here", + "delta_text": "Here", + "reasoning_content": None, + "content": None, +} +# The delta text is finishing the response special seq +STREAMING_12 = { + "previous_text": "Here is my thought process: foo Here is my response", + "current_text": "Here is my thought process: foo Here is my response:", + "delta_text": ":", + "reasoning_content": None, + "content": None, +} +STREAMING_13 = { + "previous_text": "Here is my thought process: foo Here", + "current_text": "Here is my thought process: foo Here was", + "delta_text": " was", + "reasoning_content": "Here was", + "content": None, +} + +STREAMING_SUBCASES = [ + pytest.param( + STREAMING_1, + id="Starting reasoning special sequence", + ), + pytest.param( + STREAMING_2, + id="Unexpected start reasoning sequence", + ), + pytest.param( + STREAMING_3, + id="Continuing unexpected start reasoning sequence", + ), + pytest.param( + STREAMING_4, + id="Only start reasoning sequence and nothing else", + ), + pytest.param( + STREAMING_5, + id="Reasoning content has started", + ), + pytest.param( + STREAMING_6, + id="Response special sequence has started", + ), + pytest.param( + STREAMING_7, + id="Response special sequence reset", + ), + pytest.param( + STREAMING_8, + id="Response text has started", + ), + pytest.param( + STREAMING_9, + id="Delta contains everything", + ), + pytest.param( + STREAMING_10, + id="Delta contains some reasoning and response", + ), + pytest.param( + STREAMING_11, + id="Delta starts response sequence", + ), + pytest.param( + STREAMING_12, + id="Delta finishes response sequence", + ), + pytest.param( + STREAMING_13, + id="Delta breaks potential responise sequence", + ), +] + + +@pytest.mark.parametrize("param_dict", STREAMING_SUBCASES) +def test_streaming_subcases(param_dict): + # Get all of the token IDs + previous_token_ids = tokenizer.encode( + param_dict["previous_text"] + ) if param_dict["previous_text"] is not None else [] + current_token_ids = tokenizer.encode(param_dict["current_text"]) + delta_token_ids = tokenizer.encode(param_dict["delta_text"]) + + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(tokenizer) + + response = parser.extract_reasoning_content_streaming( + previous_text=param_dict["previous_text"], + current_text=param_dict["current_text"], + delta_text=param_dict["delta_text"], + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + ) + # Streaming currently expects at least one of reasoning content / content, + # so the response should return None in that case. + if param_dict["reasoning_content"] is None and param_dict[ + "content"] is None: + assert response is None + else: + assert isinstance(response, DeltaMessage) + assert param_dict["reasoning_content"] == response.reasoning_content + assert param_dict["content"] == response.content diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 75ac326aaa3d..be00689f2b55 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1099,7 +1099,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--reasoning-parser", type=str, - choices=["deepseek_r1"], + choices=["deepseek_r1", "granite"], default=None, help= "Select the reasoning parser depending on the model that you're " diff --git a/vllm/entrypoints/openai/reasoning_parsers/__init__.py b/vllm/entrypoints/openai/reasoning_parsers/__init__.py index 80354d69b50a..45132a780e5b 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/__init__.py +++ b/vllm/entrypoints/openai/reasoning_parsers/__init__.py @@ -2,7 +2,11 @@ from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from .granite_reasoning_parser import GraniteReasoningParser __all__ = [ - "ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser" + "ReasoningParser", + "ReasoningParserManager", + "DeepSeekR1ReasoningParser", + "GraniteReasoningParser", ] diff --git a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py new file mode 100644 index 000000000000..117d051a7378 --- /dev/null +++ b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py @@ -0,0 +1,363 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from collections.abc import Sequence +from typing import Optional, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import ( + ReasoningParser, ReasoningParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("granite") +class GraniteReasoningParser(ReasoningParser): + """ + Reasoning parser for IBM Granite. + + IBM granite models currently use "Here is my thought process:" + and "Here is my response:" to separate its thinking / response outputs. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # NOTE: There have been some observed occurrences of quantized + # instances of the current models using "Here's" instead of "Here is", + # so to be safe, we match on both. + self.think_start_expr = r"(?:Here's|Here is) my thought process:" + self.response_start_expr = r"(?:Here's|Here is) my response:" + + self.reasoning_regex = re.compile( + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", + re.DOTALL) + + self.valid_think_starts = [ + "Here's my thought process:", "Here is my thought process:" + ] + self.valid_response_starts = [ + "Here's my response:", "Here is my response:" + ] + + # Substrings to match for sequence boundaries on raw text + self.seq_boundary_end = ":" + self.seq_boundary_start = "Here" + + # The longest any thinking / start of response message can be + self.longest_think_start = max( + len(think_start) for think_start in self.valid_think_starts) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionReqest): Request being processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + re_match = self.reasoning_regex.findall(model_output) + if not re_match: + return None, model_output + reasoning_content, response_content = re_match[0] + if not response_content: + return reasoning_content, None + return reasoning_content, response_content + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """Extract the reasoning content / content emitted by granite models; + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + NOTE: Granite models do not use a special token to start their reasoning + and response sections; instead they have token sequences, e.g., + + Here is my thought process: Foo Here is my response: Bar + + This increases the complexity of correctly handling streams, since we + need to watch for specific sequences and correctly parse them without + dropping content that is potentially overlapping & spanning multiple + delta messages. + + Args: + previous_text (str): Previous text outside of this delta message. + current_text (str): Previous text + delta text. + delta_text (str): Text to consider and parse content from. + previous_token_ids (Sequence[int]): Token IDs of previous_text. + current_token_ids (Sequence[int]): Token IDs of current_text. + delta_token_ids (Sequence[int]): Token IDs of delta_text. + + Returns: + Union[DeltaMessage, None] + DeltaMessage with either reasoning content or content, or None. + """ + reasoning_content, resp_seq_len, content = self._get_content_sections( + current_text) + # Either we haven't finished the start of the reasoning sequence, + # or the model is generating something unexpected. + if not reasoning_content: + delta_message = self._get_delta_message_with_no_reasoning_bounds( + current_text, delta_text) + # We have a start of reasoning message, but have not yet finished + # the start of response sequence. + elif not content: + delta_message = self._get_delta_message_with_no_response_bounds( + current_text, reasoning_content, delta_text) + # We've finished both the start of reasoning and start of response seq. + else: + # This should never happen since we matched on the response + assert resp_seq_len is not None + delta_message = self._get_delta_message_with_both_bounds( + delta_text, reasoning_content, content, current_text, + resp_seq_len) + if not delta_message.content and not delta_message.reasoning_content: + return None + return delta_message + + #### Implementation details of stream parsing for granite models + def _is_reasoning_start_substr(self, text: str) -> bool: + """Check if a text matches one of the possible start reasoning seqs. + + Args: + text (str): Text to check for leading substr. + + Returns: + bool: True if any of the possible reasoning start seqs match. + """ + return any( + think_start.startswith(text) + for think_start in self.valid_think_starts) + + def _is_response_start_substr(self, text: str) -> bool: + """Check if a text matches one of the possible start response seqs. + + Args: + text (str): Text to check for leading substr. + + Returns: + bool: True if any of the possible response start seqs match. + """ + return any( + response_start.startswith(text) + for response_start in self.valid_response_starts) + + def _get_delta_message_with_no_reasoning_bounds( + self, + current_text: str, + delta_text: str, + ) -> DeltaMessage: + """Parse the delta message when the current text has not yet completed + its start of reasoning sequence. + + Args: + current_text (str): The full previous + delta text. + delta_text (str): Text to consider and parse content from. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + prev_longest_length = len(current_text) - len(delta_text) + is_substr = self._is_reasoning_start_substr(current_text) + was_substr = self._is_reasoning_start_substr( + current_text[:prev_longest_length]) + + # Check if we just generated something NOT in the special token seq; + # if so, add everything that we previously skipped with this delta + # message and append everything to content in the future. + if was_substr and not is_substr: + return DeltaMessage( + reasoning_content=None, + content=current_text, + ) + if is_substr: + # Might still be in the special token sequence; return nothing + return DeltaMessage(reasoning_content=None, content=None) + # Otherwise the sequence has already been broken and we already + # corrected; just return the delta text as normal content. + return DeltaMessage(reasoning_content=None, content=delta_text) + + def _get_delta_message_with_no_response_bounds( + self, + current_text: str, + reasoning_content: str, + delta_text: str, + ) -> DeltaMessage: + """Parse the delta message when the current text has both reasoning + content with no (response) content. NOTE that we may have overlapping + tokens with the start of reasoning / start of response sequences on + either side of the delta text. + + Args: + current_text (str): The full previous + delta text. + reasoning_content (str): reasoning content from current_text. + delta_text (str): Text to consider and parse content from. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + # If we have no reasoning content or explicitly end with the start of + # response sequence, we are in transition to the response; need to be + # careful here, since the final token (:) will match the reasoning + # content and fully parse it out; we should not pass the : back. + ends_with_start_response_seq = any( + current_text.endswith(response_start) + for response_start in self.valid_response_starts) + if reasoning_content is None or ends_with_start_response_seq: + return DeltaMessage(reasoning_content=None, content=None) + + # Consider previous / current text only within context of the reasoning + previous_text = reasoning_content[:-len(delta_text)] + current_text = reasoning_content + + # We need to be careful about adding unfinished response sequences; + # Find the place at which we MIGHT be starting a response sequence + prev_idx = previous_text.rfind(self.seq_boundary_start) + delta_idx = delta_text.rfind(self.seq_boundary_start) + + # Check the state of potential start of response substring matches. + prev_was_substr = self._is_response_start_substr( + previous_text[prev_idx:]) if prev_idx >= 0 else False + delta_continues_substr = self._is_response_start_substr( + current_text[prev_idx:]) if prev_idx >= 0 else False + delta_new_substr = self._is_response_start_substr( + delta_text[delta_idx:]) if delta_idx >= 0 else False + + # Delta only contains potential continued response sequence text. + if delta_continues_substr: + return DeltaMessage(reasoning_content=None, content=None) + + if not prev_was_substr: + # Delta may be starting a new response seq but has other text too. + if delta_new_substr: + return DeltaMessage(reasoning_content=delta_text[:delta_idx], + content=None) + # Normal case for most reasoning text (no potential special seqs). + return DeltaMessage(reasoning_content=delta_text, content=None) + # The substring that previously seemed to be a potential response + # seq wasn't one; we need to add the content to the delta message, + # and also slice off the potential response sequence + elif delta_new_substr: + reasoning_content = previous_text[ + prev_idx:] + delta_text[:delta_idx] + return DeltaMessage(reasoning_content=reasoning_content, + content=None) + # No new substring yet, and we broke our old one; take the whole delta + return DeltaMessage( + reasoning_content=previous_text[prev_idx:] + delta_text, + content=None, + ) + + def _get_delta_message_with_both_bounds( + self, + delta_text: str, + reasoning_content: str, + response_content: str, + current_text: str, + response_seq_len: int, + ) -> DeltaMessage: + """Parse the delta message when the current text has both reasoning + content and normal (response) content. + + Args: + delta_text (str): Text to consider and parse content from. + reasoning_content (str): reasoning content from current_text. + response_content (str): response content from current_text. + current_text (str): The full previous + delta text. + response_seq_len(str): Len of the complete response sequence used. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + # Always have content; take length to the end + delta_content = delta_text[-len(response_content):] + reasoning_end_idx = len(delta_text) - (len(response_content) + + response_seq_len) + + if reasoning_end_idx < 0: + delta_reasoning_content = None + else: + # Get the starting offset + start_reasoning_content_idx = len( + reasoning_content) + response_seq_len + len( + response_content) - 1 + delta_offset = len(current_text) - len(delta_text) + start_offset = start_reasoning_content_idx - delta_offset + if start_offset < 0: + start_offset = 0 + delta_reasoning_content = delta_text[ + start_offset:reasoning_end_idx] + + return DeltaMessage( + reasoning_content=delta_reasoning_content, + content=delta_content, + ) + + def _get_content_sections( + self, current_text: str + ) -> tuple[Optional[str], Optional[int], Optional[str]]: + """Parse the text to extract the reasoning content / content + if we have them. + + Args: + current_text (str): The full previous + delta text. + + Returns: + tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3 + containing the reasoning content, the length of the response seq + (if there is one) and the non-reasoning content. + """ + current_chunk_start = 0 + start_reasoning_content = None + parsed_content = False + delimiter_idxs = [ + idx for idx, char in enumerate(current_text) + if char == self.seq_boundary_end + ] + + for current_chunk_end in delimiter_idxs: + current_chunk = current_text[current_chunk_start:current_chunk_end] + # Check to see if the start of reasoning seq if complete + if start_reasoning_content is None: + for think_start in self.valid_think_starts: + if current_chunk == think_start[:-1]: + start_reasoning_content = current_chunk_end + 1 + current_chunk_start = current_chunk_end + 1 + break + + # Check to see if the start of response seq if complete + elif not parsed_content: + for response_start in self.valid_response_starts: + if current_chunk[-len(response_start) + + 1:] == response_start[:-1]: + # Mark end of reasoning and start response content + # after the start of response sequence. + end_reasoning_content = current_chunk_end - len( + response_start) + reasoning_content = current_text[ + start_reasoning_content:end_reasoning_content] + response_content = current_text[current_chunk_end + 1:] + return reasoning_content, len( + response_start), response_content + + if start_reasoning_content and not parsed_content: + return current_text[start_reasoning_content:], None, None + return None, None, None diff --git a/vllm/model_executor/guided_decoding/reasoner/__init__.py b/vllm/model_executor/guided_decoding/reasoner/__init__.py index d930d3dbe94c..ab6e47c007d2 100644 --- a/vllm/model_executor/guided_decoding/reasoner/__init__.py +++ b/vllm/model_executor/guided_decoding/reasoner/__init__.py @@ -19,6 +19,10 @@ def get_reasoner(tokenizer: PreTrainedTokenizer, return None elif reasoning_backend == "deepseek_r1": return DeepSeekReasoner.from_tokenizer(tokenizer) + elif reasoning_backend == "granite": + logger.warning( + "Granite reasoner not yet implemented for structured outputs") + return None else: # Raise a warning for unknown reasoning backend and return None # We cannot raise an error here because some reasoning models