From 784c170919956ae7bbad089cde474fa31c9e82a8 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 2 Mar 2025 07:02:58 +0000 Subject: [PATCH 01/20] Implement granite reasoning parser for non streaming Signed-off-by: Alex-Brooks --- .../granite_reasoning_parser.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py diff --git a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py new file mode 100644 index 000000000000..b556f78bd326 --- /dev/null +++ b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import Optional, Sequence, Tuple, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import ( + ReasoningParser, ReasoningParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("granite") +class GraniteReasoningParser(ReasoningParser): + """ + Reasoning parser for IBM Granite. + + IBM granite models currently use "Here is my start process:" + and "Here is my response:" to separate its thinking / response outputs. + NOTE: There have been some observed occurrences of quantized instances of + the current models using "Here's" instead of "Here is", so to be safe, we + match on both + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.think_start_expr = r"(?:Here's|Here is) my thought process:" + self.think_end_expr = r"(?:Here's|Here is) my response:" + + self.reasoning_regex = re.compile( + rf"{self.think_start_expr}(.*?){self.think_end_expr}(.*)", + re.DOTALL) + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + For text: Here is my thinking:abcHere is my response:xyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + """ + raise NotImplementedError("Streaming not implemented") + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> Tuple[Optional[str], Optional[str]]: + re_match = self.reasoning_regex.findall(model_output) + if not re_match: + return model_output, None + reasoning_content, response_content = re_match[0] + if not response_content: + return reasoning_content, None + return reasoning_content, response_content From 4438a38fba738b1c06de857f04739f24f58e333c Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 2 Mar 2025 07:04:24 +0000 Subject: [PATCH 02/20] Add granite reasoning parser to init pkg Signed-off-by: Alex-Brooks --- vllm/entrypoints/openai/reasoning_parsers/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/reasoning_parsers/__init__.py b/vllm/entrypoints/openai/reasoning_parsers/__init__.py index 80354d69b50a..45132a780e5b 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/__init__.py +++ b/vllm/entrypoints/openai/reasoning_parsers/__init__.py @@ -2,7 +2,11 @@ from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from .granite_reasoning_parser import GraniteReasoningParser __all__ = [ - "ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser" + "ReasoningParser", + "ReasoningParserManager", + "DeepSeekR1ReasoningParser", + "GraniteReasoningParser", ] From 3278ca7f72db9e90db92746cfe7238f8377ba061 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 2 Mar 2025 07:10:03 +0000 Subject: [PATCH 03/20] Add preliminary test for non streaming granite rparser Signed-off-by: Alex-Brooks --- .../test_granite_reasoning_parser.py | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py diff --git a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py b/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py new file mode 100644 index 000000000000..156fd9de8bb1 --- /dev/null +++ b/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List + +import pytest +from transformers import AutoTokenizer + +from tests.entrypoints.openai.reasoning_parsers.utils import ( + run_reasoning_extraction) +from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, + ReasoningParserManager) + +parser_name = "granite" +START_REASONING = "Here is my thought process:" +START_RESPONSE = "Here is my response:" + +SIMPLE_REASONING = { + "output": + f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501 + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +COMPLETE_REASONING = { + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}", + "reasoning_content": "This is a reasoning section", + "content": None, +} +NO_CONTENT = { + "output": "This is content", + "reasoning_content": "This is content", + "content": None, +} +MULTIPLE_LINES = { + "output": + f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", +} +REASONING_WITH_THINK = { + "output": + f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501 + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +COMPLETE_REASONING_WITH_THINK = { + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}", + "reasoning_content": "This is a reasoning section", + "content": None, +} +MULTIPLE_LINES_WITH_THINK = { + "output": + f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", +} + +TEST_CASES = [ + pytest.param(False, SIMPLE_REASONING, id="simple_reasoning"), + pytest.param(False, COMPLETE_REASONING, id="complete_reasoning"), + pytest.param(False, NO_CONTENT, id="no_content"), + pytest.param(False, MULTIPLE_LINES, id="multiple_lines"), + pytest.param(False, REASONING_WITH_THINK, id="reasoning_with_think"), + pytest.param(False, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think"), + pytest.param(False, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think"), +] + +# Global tokenizer initialization to avoid repeated loading +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, +): + output = tokenizer.tokenize(param_dict["output"]) + # decode everything to tokens + output_tokens: List[str] = [ + tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(tokenizer) + + reasoning, content = run_reasoning_extraction(parser, + output_tokens, + streaming=streaming) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] From 07e58a8392d243290bc8a42194fd4a7ad3f728af Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 4 Mar 2025 00:44:12 +0000 Subject: [PATCH 04/20] Implement granite reasoning parser streaming Signed-off-by: Alex-Brooks --- .../granite_reasoning_parser.py | 209 ++++++++++++++++-- 1 file changed, 196 insertions(+), 13 deletions(-) diff --git a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py index b556f78bd326..388548e93366 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py +++ b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py @@ -19,7 +19,7 @@ class GraniteReasoningParser(ReasoningParser): """ Reasoning parser for IBM Granite. - IBM granite models currently use "Here is my start process:" + IBM granite models currently use "Here is my thought process:" and "Here is my response:" to separate its thinking / response outputs. NOTE: There have been some observed occurrences of quantized instances of the current models using "Here's" instead of "Here is", so to be safe, we @@ -29,12 +29,189 @@ class GraniteReasoningParser(ReasoningParser): def __init__(self, tokenizer: PreTrainedTokenizerBase): super().__init__(tokenizer) self.think_start_expr = r"(?:Here's|Here is) my thought process:" - self.think_end_expr = r"(?:Here's|Here is) my response:" + self.response_start_expr = r"(?:Here's|Here is) my response:" self.reasoning_regex = re.compile( - rf"{self.think_start_expr}(.*?){self.think_end_expr}(.*)", + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL) + self.valid_think_starts = [ + "Here's my thought process:", "Here is my thought process:" + ] + self.valid_response_starts = [ + "Here's my response:", "Here is my response:" + ] + self.seq_boundary_end = ":" + self.seq_boundary_start = "Here" + + # The longest any thinking / start of response message can be + self.longest_think_start = max( + len(think_start) for think_start in self.valid_think_starts) + + def _get_delta_message_with_no_reasoning_bounds(self, current_text, + delta_text): + # Even before this, we already had a longer text str than any expected + # message; The output is not parsable; assume we rectified this + # previously and just add the delta text to the content. + prev_longest_length = len(current_text) - len(delta_text) + if prev_longest_length > self.longest_think_start: + return DeltaMessage(reasoning_content=None, content=delta_text) + + is_substr = any(current_text in think_start + for think_start in self.valid_think_starts) + was_substr = any(current_text[:prev_longest_length] in think_start + for think_start in self.valid_think_starts) + # Check if we just generated something NOT in the special token seq; + # if so, we add everything that we previously skipped with this delta + # message and append everything to content in the future. + if was_substr and not is_substr: + return DeltaMessage( + reasoning_content=None, + content=current_text, + ) + if is_substr: + # Might still be in the special token sequence; return nothing + return DeltaMessage(reasoning_content=None, content=None) + # Otherwise the sequence has already been broken and we already + # corrected; just return the delta text as normal content. + return DeltaMessage(reasoning_content=None, content=delta_text) + + def _get_delta_message_with_no_response_bounds(self, current_text, + reasoning_content, + delta_text): + # If we have no reasoning content or explicitly end with the start of + # response sequence, we are in transition to the response; need to be + # careful here, since the final token (:) will match the reasoning + # content and fully parse it out; we should not pass the : back. + ends_with_start_response_seq = any( + current_text.endswith(response_start) + for response_start in self.valid_response_starts) + if reasoning_content is None or ends_with_start_response_seq: + return DeltaMessage(reasoning_content=None, content=None) + + # Consider previous / current text only within context of the reasoning + previous_text = reasoning_content[:-len(delta_text)] + current_text = reasoning_content + + # We need to be careful about adding unfinished response sequences; + # Find the place at which we MIGHT be starting a response sequence + prev_idx = previous_text.rfind(self.seq_boundary_start) + delta_idx = delta_text.rfind(self.seq_boundary_start) + + # And check if either we had one in the previous text, or have on + prev_was_substr = any( + previous_text[prev_idx:] in response_start for response_start in + self.valid_response_starts) if prev_idx >= 0 else False + delta_continues_substr = any( + current_text[prev_idx:] in response_start for response_start in + self.valid_response_starts) if prev_idx >= 0 else False + delta_new_substr = any( + delta_text[delta_idx:] in response_start for response_start in + self.valid_response_starts) if delta_idx >= 0 else False + # It was a substring before, and all delta tokens are + # part of the potential start of response sequence + + if prev_was_substr and delta_continues_substr: + return DeltaMessage(reasoning_content=None, content=None) + + if not prev_was_substr: + # Don't add the potential unfinished response sequence + if delta_new_substr: + return DeltaMessage(reasoning_content=delta_text[:delta_idx], + content=None) + # No possible places to start the response seq; continue normally + return DeltaMessage(reasoning_content=delta_text, content=None) + # The substring being continued is the same one as before; + # the whole delta message is part of the potential response seq + if delta_continues_substr: + return DeltaMessage(reasoning_content=None, content=None) + # The substring that previously seemed to be a potential response + # seq wasn't one; we need to add the content to the delta message, + # and also slice off the potential response sequence + elif delta_new_substr: + return DeltaMessage(reasoning_content=previous_text[prev_idx:] + + delta_text[:delta_idx], + content=None) + # No new substring yet, and we broke our old one; take the whole delta + return DeltaMessage(reasoning_content=previous_text[prev_idx:] + + delta_text, + content=None) + + def _get_delta_message_with_both_bounds(self, delta_text, + reasoning_content, + response_content, current_text): + # We have reasoning and response content, but it may not all be in the + # delta text; we need to consider the length of the start of response + # sequence and divide just the delta text part. + ##### HACK pass this through + for rs in self.valid_response_starts: + if rs in current_text: + response_seq_len = len(rs) + break + ##### + + # Always have content; take length to the end + delta_content = delta_text[-len(response_content):] + reasoning_end_idx = len(delta_text) - (len(response_content) + + response_seq_len) + if reasoning_end_idx < 0: + delta_reasoning_content = None + else: + # Get the starting offset + start_reasoning_content_idx = len( + reasoning_content) + response_seq_len + len( + response_content) - 1 + delta_offset = len(current_text) - len(delta_text) + start_offset = start_reasoning_content_idx - delta_offset + if start_offset < 0: + start_offset = 0 + delta_reasoning_content = delta_text[ + start_offset:reasoning_end_idx] + + return DeltaMessage( + reasoning_content=delta_reasoning_content, + content=delta_content, + ) + + def get_content_sections(self, current_text: str): + + delimiter_idxs = [ + idx for idx, char in enumerate(current_text) + if char == self.seq_boundary_end + ] + current_chunk_start = 0 + start_reasoning_content = None + start_response_content = None + + for current_chunk_end in delimiter_idxs: + current_chunk = current_text[current_chunk_start:current_chunk_end] + # Check to see if this is start of reasoning + if start_reasoning_content is None: + for think_start in self.valid_think_starts: + if current_chunk == think_start[:-1]: + start_reasoning_content = current_chunk_end + 1 + current_chunk_start = current_chunk_end + 1 + break + + # Check to see if this is start of response + elif start_response_content is None: + for response_start in self.valid_response_starts: + if current_chunk[-len(response_start) + + 1:] == response_start[:-1]: + end_reasoning_content = current_chunk_end - len( + response_start) + reasoning_content = current_text[ + start_reasoning_content:end_reasoning_content] + response_content = current_text[current_chunk_end + 1:] + # Ensure we handle empty strings / None consistently + if not response_content: + response_content = None + return reasoning_content, response_content + # Set the actual response content + if start_reasoning_content and start_response_content is None: + return current_text[start_reasoning_content:], None + return None, None + def extract_reasoning_content_streaming( self, previous_text: str, @@ -44,22 +221,28 @@ def extract_reasoning_content_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> Union[DeltaMessage, None]: - """ - Extract reasoning content from a delta message. - Handles streaming output where previous + delta = current. - Uses token IDs for faster processing. - For text: Here is my thinking:abcHere is my response:xyz: - - 'abc' goes to reasoning_content - - 'xyz' goes to content - """ - raise NotImplementedError("Streaming not implemented") + reasoning_content, response_content = self.get_content_sections( + current_text) + if not reasoning_content: + delta_message = self._get_delta_message_with_no_reasoning_bounds( + current_text, delta_text) + # We have a start of reasoning message, but not the response message yet + elif not response_content: + delta_message = self._get_delta_message_with_no_response_bounds( + current_text, reasoning_content, delta_text) + else: + delta_message = self._get_delta_message_with_both_bounds( + delta_text, reasoning_content, response_content, current_text) + if not delta_message.content and not delta_message.reasoning_content: + return None + return delta_message def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest ) -> Tuple[Optional[str], Optional[str]]: re_match = self.reasoning_regex.findall(model_output) if not re_match: - return model_output, None + return None, model_output reasoning_content, response_content = re_match[0] if not response_content: return reasoning_content, None From 6980ea8ef56adb9788231cf75b9015df2c40bbad Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 4 Mar 2025 00:46:00 +0000 Subject: [PATCH 05/20] Add additional granite reasoning parser tests Signed-off-by: Alex-Brooks --- .../test_granite_reasoning_parser.py | 180 +++++++++++++++++- 1 file changed, 175 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py b/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py index 156fd9de8bb1..18a537f6bb26 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py +++ b/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer from tests.entrypoints.openai.reasoning_parsers.utils import ( - run_reasoning_extraction) + DeltaMessage, run_reasoning_extraction) from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, ReasoningParserManager) @@ -24,10 +24,10 @@ "reasoning_content": "This is a reasoning section", "content": None, } -NO_CONTENT = { +NO_REASONING = { "output": "This is content", - "reasoning_content": "This is content", - "content": None, + "reasoning_content": None, + "content": "This is content", } MULTIPLE_LINES = { "output": @@ -56,7 +56,7 @@ TEST_CASES = [ pytest.param(False, SIMPLE_REASONING, id="simple_reasoning"), pytest.param(False, COMPLETE_REASONING, id="complete_reasoning"), - pytest.param(False, NO_CONTENT, id="no_content"), + pytest.param(False, NO_REASONING, id="no_reasoning"), pytest.param(False, MULTIPLE_LINES, id="multiple_lines"), pytest.param(False, REASONING_WITH_THINK, id="reasoning_with_think"), pytest.param(False, @@ -65,6 +65,19 @@ pytest.param(False, MULTIPLE_LINES_WITH_THINK, id="multiple_lines_with_think"), + pytest.param(True, SIMPLE_REASONING, id="simple_reasoning_streaming"), + pytest.param(True, COMPLETE_REASONING, id="complete_reasoning_streaming"), + pytest.param(True, NO_REASONING, id="no_reasoning_streaming"), + pytest.param(True, MULTIPLE_LINES, id="multiple_lines_streaming"), + pytest.param(True, + REASONING_WITH_THINK, + id="reasoning_with_think_streaming"), + pytest.param(True, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think_streaming"), + pytest.param(True, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think_streaming"), ] # Global tokenizer initialization to avoid repeated loading @@ -90,3 +103,160 @@ def test_reasoning( assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] + + +# Additional tests for verifying the correctness of granite streaming; this +# is complicated because granite uses multiple tokens to indicate when thinking +# is starting / when it's starting its response, so skipping special tokens +# is awkward. + +### Handling the start of reasoning +STREAMING_1 = { + "previous_text": None, + "current_text": "Here", + "delta_text": "Here", + "reasoning_content": None, + "content": None, +} +# When we fail, we should give what was previously being silenced first +STREAMING_2 = { + "previous_text": "Here is my thought", + "current_text": "Here is my thought failure", + "delta_text": " failure", + "reasoning_content": None, + "content": "Here is my thought failure", +} +# But then after the first one, we should only add the delta text to content +STREAMING_3 = { + "previous_text": "Here wrong", + "current_text": " words", + "delta_text": " Here wrong words", + "reasoning_content": None, + "content": " words", +} +# But then after the first one, we should only add the delta text to content +STREAMING_4 = { + "previous_text": "Here is my thought", + "current_text": "Here is my thought process:", + "delta_text": " process:", + "reasoning_content": None, + "content": None, +} +# Reasoning started successfully; parse reasoning content +STREAMING_5 = { + "previous_text": "Here is my thought process:", + "current_text": "Here is my thought process: foo", + "delta_text": " foo", + "reasoning_content": " foo", + "content": None, +} +# Response special sequence has started, but not finished. +STREAMING_6 = { + "previous_text": "Here is my thought process: foo", + "current_text": "Here is my thought process: foo Here is", + "delta_text": " Here is", + "reasoning_content": " ", + "content": None, +} +# Response special sequence started, but was broken; the reasoning +# content should be the content that was previously unused. +STREAMING_7 = { + "previous_text": "Here is my thought process: foo Here is", + "current_text": "Here is my thought process: foo Here is Here", + "delta_text": " Here", + "reasoning_content": "Here is ", + "content": None, +} +# Response special sequence is ongoing +STREAMING_8 = { + "previous_text": "Here is my thought process: foo Here is my response:", + "current_text": "Here is my thought process: foo Here is my response: bar", + "delta_text": " bar", + "reasoning_content": None, + "content": " bar", +} +# The delta text has everything; we should be able to correctly parse both +STREAMING_9 = { + "previous_text": None, + "current_text": "Here is my thought process: foo Here is my response: bar", + "delta_text": "Here is my thought process: foo Here is my response: bar", + "reasoning_content": " foo ", + "content": " bar", +} +## The Response is ongoing, and the delta mixes reasoning content / content +STREAMING_10 = { + "previous_text": "Here is my thought process: foo", + "current_text": + "Here is my thought process: foo bar Here is my response: baz", + "delta_text": " bar Here is my response: baz", + "reasoning_content": " bar ", + "content": " baz", +} +# The delta text starts a new substring that might be a response special seq +STREAMING_11 = { + "previous_text": + "Here is my thought process: This is a reasoning section ", + "current_text": + "Here is my thought process: This is a reasoning section Here", + "delta_text": "Here", + "reasoning_content": None, + "content": None, +} +# The delta text is finishing the response special seq +STREAMING_12 = { + "previous_text": "Here is my thought process: foo Here is my response", + "current_text": "Here is my thought process: foo Here is my response:", + "delta_text": ":", + "reasoning_content": None, + "content": None, +} +STREAMING_SUBCASES = [ + pytest.param(STREAMING_1, id="Starting reasoning special sequence"), + pytest.param(STREAMING_2, id="Unexpected start reasoning sequence"), + pytest.param(STREAMING_3, + id="Continuing unexpected start reasoning sequence"), + pytest.param(STREAMING_4, + id="Only start reasoning sequence and nothing else"), + pytest.param(STREAMING_5, id="Reasoning content has started"), + pytest.param(STREAMING_6, id="Response special sequence has started"), + pytest.param(STREAMING_7, id="Response special sequence reset"), + pytest.param(STREAMING_8, id="Response text has started"), + pytest.param(STREAMING_9, id="Delta contains everything"), + pytest.param(STREAMING_10, + id="Delta contains some reasoning and response"), + pytest.param(STREAMING_11, id="Delta starts response sequence"), + pytest.param(STREAMING_12, id="Delta finishes response sequence"), +] + + +@pytest.mark.parametrize("param_dict", STREAMING_SUBCASES) +def test_streaming_subcases(param_dict): + # Get all of the token IDs + previous_token_ids = tokenizer.encode( + param_dict["previous_text"] + ) if param_dict["previous_text"] is not None else [] + current_token_ids = tokenizer.encode( + param_dict["current_text"] + ) if param_dict["current_text"] is not None else [] + delta_token_ids = tokenizer.encode( + param_dict["delta_text"] + ) if param_dict["delta_text"] is not None else [] + + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(tokenizer) + + response = parser.extract_reasoning_content_streaming( + previous_text=param_dict["previous_text"], + current_text=param_dict["current_text"], + delta_text=param_dict["delta_text"], + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + ) + if param_dict["reasoning_content"] is None and param_dict[ + "content"] is None: + assert response is None + else: + assert isinstance(response, DeltaMessage) + assert param_dict["reasoning_content"] == response.reasoning_content + assert param_dict["content"] == response.content From f6ff0bccb550af578f0aaf67289809c104f12ddc Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 4 Mar 2025 02:05:30 +0000 Subject: [PATCH 06/20] Add docstrings for granite reasoning parser Signed-off-by: Alex-Brooks --- .../granite_reasoning_parser.py | 181 +++++++++++++----- 1 file changed, 136 insertions(+), 45 deletions(-) diff --git a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py index 388548e93366..7a4394f38324 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py +++ b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py @@ -48,8 +48,96 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): self.longest_think_start = max( len(think_start) for think_start in self.valid_think_starts) - def _get_delta_message_with_no_reasoning_bounds(self, current_text, - delta_text): + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> Tuple[Optional[str], Optional[str]]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionReqest): Request being processed. + + Returns: + Tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + re_match = self.reasoning_regex.findall(model_output) + if not re_match: + return None, model_output + reasoning_content, response_content = re_match[0] + if not response_content: + return reasoning_content, None + return reasoning_content, response_content + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """Extract the reasoning content / content emitted by granite models; + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + NOTE: Granite models do not use a special token to start their reasoning + and response sections; instead they have token sequences, e.g., + + Here is my thought process: Foo Here is my response: Bar + + This increases the complexity of correctly handling streams by a lot, + since we need to watch for specific sequences and correctly handle + parsing them without dropping content that is potentially overlapping + & spanning across multiple delta messages. + + Args: + previous_text (str): Previous text outside of this delta message. + current_text (str): Previous text + delta text. + delta_text (str): Text to consider and parse content from. + previous_token_ids (Sequence[int]): Token IDs of previous_text. + current_token_ids (Sequence[int]): Token IDs of current_text. + delta_token_ids (Sequence[int]): Token IDs of delta_text. + + Returns: + Union[DeltaMessage, None] + DeltaMessage with either reasoning content or content, or None. + """ + reasoning_content, response_content = self._get_content_sections( + current_text) + if not reasoning_content: + delta_message = self._get_delta_message_with_no_reasoning_bounds( + current_text, delta_text) + # We have a start of reasoning message, but not the response message yet + elif not response_content: + delta_message = self._get_delta_message_with_no_response_bounds( + current_text, reasoning_content, delta_text) + else: + delta_message = self._get_delta_message_with_both_bounds( + delta_text, reasoning_content, response_content, current_text) + if not delta_message.content and not delta_message.reasoning_content: + return None + return delta_message + + #### Implementation details of stream parsing for granite models + def _get_delta_message_with_no_reasoning_bounds( + self, + current_text: str, + delta_text: str, + ) -> DeltaMessage: + """Parse the delta message when the current text has not yet completed + its start of reasoning sequence. + + Args: + current_text (str): The full previous + delta text. + delta_text (str): Text to consider and parse content from. + + Returns: + DeltaMessage: Message containing the parsed content. + """ # Even before this, we already had a longer text str than any expected # message; The output is not parsable; assume we rectified this # previously and just add the delta text to the content. @@ -76,9 +164,22 @@ def _get_delta_message_with_no_reasoning_bounds(self, current_text, # corrected; just return the delta text as normal content. return DeltaMessage(reasoning_content=None, content=delta_text) - def _get_delta_message_with_no_response_bounds(self, current_text, - reasoning_content, - delta_text): + def _get_delta_message_with_no_response_bounds( + self, current_text: str, reasoning_content: str, + delta_text: str) -> DeltaMessage: + """Parse the delta message when the current text has both reasoning + content with no (response) content. NOTE that we may have overlapping + tokens with the start of reasoning / start of response sequences on + either side of the delta text. + + Args: + current_text (str): The full previous + delta text. + reasoning_content (str): reasoning content parsed from current_text. + delta_text (str): Text to consider and parse content from. + + Returns: + DeltaMessage: Message containing the parsed content. + """ # If we have no reasoning content or explicitly end with the start of # response sequence, we are in transition to the response; need to be # careful here, since the final token (:) will match the reasoning @@ -137,9 +238,25 @@ def _get_delta_message_with_no_response_bounds(self, current_text, delta_text, content=None) - def _get_delta_message_with_both_bounds(self, delta_text, - reasoning_content, - response_content, current_text): + def _get_delta_message_with_both_bounds( + self, + delta_text: str, + reasoning_content: str, + response_content: str, + current_text: str, + ) -> DeltaMessage: + """Parse the delta message when the current text has both reasoning + content and normal (response) content. + + Args: + delta_text (str): Text to consider and parse content from. + reasoning_content (str): reasoning content parsed from current_text. + response_content (str): response content parsed from current_text. + current_text (str): The full previous + delta text. + + Returns: + DeltaMessage: Message containing the parsed content. + """ # We have reasoning and response content, but it may not all be in the # delta text; we need to consider the length of the start of response # sequence and divide just the delta text part. @@ -173,8 +290,18 @@ def _get_delta_message_with_both_bounds(self, delta_text, content=delta_content, ) - def get_content_sections(self, current_text: str): + def _get_content_sections( + self, current_text: str) -> Tuple[Optional[str], Optional[str]]: + """Parse the text to extract the reasoning content / content + if we have them. + Args: + current_text (str): The full previous + delta text. + + Returns: + Tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ delimiter_idxs = [ idx for idx, char in enumerate(current_text) if char == self.seq_boundary_end @@ -211,39 +338,3 @@ def get_content_sections(self, current_text: str): if start_reasoning_content and start_response_content is None: return current_text[start_reasoning_content:], None return None, None - - def extract_reasoning_content_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: - reasoning_content, response_content = self.get_content_sections( - current_text) - if not reasoning_content: - delta_message = self._get_delta_message_with_no_reasoning_bounds( - current_text, delta_text) - # We have a start of reasoning message, but not the response message yet - elif not response_content: - delta_message = self._get_delta_message_with_no_response_bounds( - current_text, reasoning_content, delta_text) - else: - delta_message = self._get_delta_message_with_both_bounds( - delta_text, reasoning_content, response_content, current_text) - if not delta_message.content and not delta_message.reasoning_content: - return None - return delta_message - - def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest - ) -> Tuple[Optional[str], Optional[str]]: - re_match = self.reasoning_regex.findall(model_output) - if not re_match: - return None, model_output - reasoning_content, response_content = re_match[0] - if not response_content: - return reasoning_content, None - return reasoning_content, response_content From 1f2f6902e3e2008869abf12953e5aab0bf696fb4 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 4 Mar 2025 09:27:39 +0000 Subject: [PATCH 07/20] Add more streaming tests & cleanup Signed-off-by: Alex-Brooks --- .../test_granite_reasoning_parser.py | 179 +++++++++++++----- 1 file changed, 134 insertions(+), 45 deletions(-) diff --git a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py b/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py index 18a537f6bb26..4495edc13c0f 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py +++ b/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py @@ -54,30 +54,76 @@ } TEST_CASES = [ - pytest.param(False, SIMPLE_REASONING, id="simple_reasoning"), - pytest.param(False, COMPLETE_REASONING, id="complete_reasoning"), - pytest.param(False, NO_REASONING, id="no_reasoning"), - pytest.param(False, MULTIPLE_LINES, id="multiple_lines"), - pytest.param(False, REASONING_WITH_THINK, id="reasoning_with_think"), - pytest.param(False, - COMPLETE_REASONING_WITH_THINK, - id="complete_reasoning_with_think"), - pytest.param(False, - MULTIPLE_LINES_WITH_THINK, - id="multiple_lines_with_think"), - pytest.param(True, SIMPLE_REASONING, id="simple_reasoning_streaming"), - pytest.param(True, COMPLETE_REASONING, id="complete_reasoning_streaming"), - pytest.param(True, NO_REASONING, id="no_reasoning_streaming"), - pytest.param(True, MULTIPLE_LINES, id="multiple_lines_streaming"), - pytest.param(True, - REASONING_WITH_THINK, - id="reasoning_with_think_streaming"), - pytest.param(True, - COMPLETE_REASONING_WITH_THINK, - id="complete_reasoning_with_think_streaming"), - pytest.param(True, - MULTIPLE_LINES_WITH_THINK, - id="multiple_lines_with_think_streaming"), + pytest.param( + False, + SIMPLE_REASONING, + id="simple_reasoning", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + False, + NO_REASONING, + id="no_reasoning", + ), + pytest.param( + False, + MULTIPLE_LINES, + id="multiple_lines", + ), + pytest.param( + False, + REASONING_WITH_THINK, + id="reasoning_with_think", + ), + pytest.param( + False, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think", + ), + pytest.param( + False, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think", + ), + pytest.param( + True, + SIMPLE_REASONING, + id="simple_reasoning_streaming", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_streaming", + ), + pytest.param( + True, + NO_REASONING, + id="no_reasoning_streaming", + ), + pytest.param( + True, + MULTIPLE_LINES, + id="multiple_lines_streaming", + ), + pytest.param( + True, + REASONING_WITH_THINK, + id="reasoning_with_think_streaming", + ), + pytest.param( + True, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think_streaming", + ), + pytest.param( + True, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think_streaming", + ), ] # Global tokenizer initialization to avoid repeated loading @@ -210,22 +256,67 @@ def test_reasoning( "reasoning_content": None, "content": None, } +STREAMING_13 = { + "previous_text": "Here is my thought process: foo Here", + "current_text": "Here is my thought process: foo Here was", + "delta_text": " was", + "reasoning_content": "Here was", + "content": None, +} + STREAMING_SUBCASES = [ - pytest.param(STREAMING_1, id="Starting reasoning special sequence"), - pytest.param(STREAMING_2, id="Unexpected start reasoning sequence"), - pytest.param(STREAMING_3, - id="Continuing unexpected start reasoning sequence"), - pytest.param(STREAMING_4, - id="Only start reasoning sequence and nothing else"), - pytest.param(STREAMING_5, id="Reasoning content has started"), - pytest.param(STREAMING_6, id="Response special sequence has started"), - pytest.param(STREAMING_7, id="Response special sequence reset"), - pytest.param(STREAMING_8, id="Response text has started"), - pytest.param(STREAMING_9, id="Delta contains everything"), - pytest.param(STREAMING_10, - id="Delta contains some reasoning and response"), - pytest.param(STREAMING_11, id="Delta starts response sequence"), - pytest.param(STREAMING_12, id="Delta finishes response sequence"), + pytest.param( + STREAMING_1, + id="Starting reasoning special sequence", + ), + pytest.param( + STREAMING_2, + id="Unexpected start reasoning sequence", + ), + pytest.param( + STREAMING_3, + id="Continuing unexpected start reasoning sequence", + ), + pytest.param( + STREAMING_4, + id="Only start reasoning sequence and nothing else", + ), + pytest.param( + STREAMING_5, + id="Reasoning content has started", + ), + pytest.param( + STREAMING_6, + id="Response special sequence has started", + ), + pytest.param( + STREAMING_7, + id="Response special sequence reset", + ), + pytest.param( + STREAMING_8, + id="Response text has started", + ), + pytest.param( + STREAMING_9, + id="Delta contains everything", + ), + pytest.param( + STREAMING_10, + id="Delta contains some reasoning and response", + ), + pytest.param( + STREAMING_11, + id="Delta starts response sequence", + ), + pytest.param( + STREAMING_12, + id="Delta finishes response sequence", + ), + pytest.param( + STREAMING_13, + id="Delta breaks potential responise sequence", + ), ] @@ -235,12 +326,8 @@ def test_streaming_subcases(param_dict): previous_token_ids = tokenizer.encode( param_dict["previous_text"] ) if param_dict["previous_text"] is not None else [] - current_token_ids = tokenizer.encode( - param_dict["current_text"] - ) if param_dict["current_text"] is not None else [] - delta_token_ids = tokenizer.encode( - param_dict["delta_text"] - ) if param_dict["delta_text"] is not None else [] + current_token_ids = tokenizer.encode(param_dict["current_text"]) + delta_token_ids = tokenizer.encode(param_dict["delta_text"]) parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( parser_name)(tokenizer) @@ -253,6 +340,8 @@ def test_streaming_subcases(param_dict): current_token_ids=current_token_ids, delta_token_ids=delta_token_ids, ) + # Streaming currently expects at least one of reasoning content / content, + # so the response should return None in that case. if param_dict["reasoning_content"] is None and param_dict[ "content"] is None: assert response is None From 2c9251c37b8cef164d26cda0ea1f70ed08b91166 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 4 Mar 2025 09:29:35 +0000 Subject: [PATCH 08/20] Refactoring and code formatting Signed-off-by: Alex-Brooks --- .../granite_reasoning_parser.py | 114 +++++++++++------- 1 file changed, 69 insertions(+), 45 deletions(-) diff --git a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py index 7a4394f38324..601bbbc4548b 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py +++ b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py @@ -21,13 +21,14 @@ class GraniteReasoningParser(ReasoningParser): IBM granite models currently use "Here is my thought process:" and "Here is my response:" to separate its thinking / response outputs. - NOTE: There have been some observed occurrences of quantized instances of - the current models using "Here's" instead of "Here is", so to be safe, we - match on both """ def __init__(self, tokenizer: PreTrainedTokenizerBase): super().__init__(tokenizer) + + # NOTE: There have been some observed occurrences of quantized + # instances of the current models using "Here's" instead of "Here is", + # so to be safe, we match on both. self.think_start_expr = r"(?:Here's|Here is) my thought process:" self.response_start_expr = r"(?:Here's|Here is) my response:" @@ -41,6 +42,8 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): self.valid_response_starts = [ "Here's my response:", "Here is my response:" ] + + # Substrings to match for sequence boundaries on raw text self.seq_boundary_end = ":" self.seq_boundary_start = "Here" @@ -89,10 +92,10 @@ def extract_reasoning_content_streaming( Here is my thought process: Foo Here is my response: Bar - This increases the complexity of correctly handling streams by a lot, - since we need to watch for specific sequences and correctly handle - parsing them without dropping content that is potentially overlapping - & spanning across multiple delta messages. + This increases the complexity of correctly handling streams, since we + need to watch for specific sequences and correctly parse them without + dropping content that is potentially overlapping & spanning multiple + delta messages. Args: previous_text (str): Previous text outside of this delta message. @@ -108,13 +111,17 @@ def extract_reasoning_content_streaming( """ reasoning_content, response_content = self._get_content_sections( current_text) + # Either we haven't finished the start of the reasoning sequence, + # or the model is generating something unexpected. if not reasoning_content: delta_message = self._get_delta_message_with_no_reasoning_bounds( current_text, delta_text) - # We have a start of reasoning message, but not the response message yet + # We have a start of reasoning message, but have not yet finished + # the start of response sequence. elif not response_content: delta_message = self._get_delta_message_with_no_response_bounds( current_text, reasoning_content, delta_text) + # We've finished both the start of reasoning and start of response seq. else: delta_message = self._get_delta_message_with_both_bounds( delta_text, reasoning_content, response_content, current_text) @@ -123,6 +130,32 @@ def extract_reasoning_content_streaming( return delta_message #### Implementation details of stream parsing for granite models + def _is_reasoning_start_substr(self, text: str) -> bool: + """Check if a text matches one of the possible start reasoning seqs. + + Args: + text (str): Text to check for leading substr. + + Returns: + bool: True if any of the possible reasoning start seqs match. + """ + return any( + think_start.startswith(text) + for think_start in self.valid_think_starts) + + def _is_response_start_substr(self, text: str) -> bool: + """Check if a text matches one of the possible start response seqs. + + Args: + text (str): Text to check for leading substr. + + Returns: + bool: True if any of the possible response start seqs match. + """ + return any( + response_start.startswith(text) + for response_start in self.valid_response_starts) + def _get_delta_message_with_no_reasoning_bounds( self, current_text: str, @@ -138,19 +171,13 @@ def _get_delta_message_with_no_reasoning_bounds( Returns: DeltaMessage: Message containing the parsed content. """ - # Even before this, we already had a longer text str than any expected - # message; The output is not parsable; assume we rectified this - # previously and just add the delta text to the content. prev_longest_length = len(current_text) - len(delta_text) - if prev_longest_length > self.longest_think_start: - return DeltaMessage(reasoning_content=None, content=delta_text) + is_substr = self._is_reasoning_start_substr(current_text) + was_substr = self._is_reasoning_start_substr( + current_text[:prev_longest_length]) - is_substr = any(current_text in think_start - for think_start in self.valid_think_starts) - was_substr = any(current_text[:prev_longest_length] in think_start - for think_start in self.valid_think_starts) # Check if we just generated something NOT in the special token seq; - # if so, we add everything that we previously skipped with this delta + # if so, add everything that we previously skipped with this delta # message and append everything to content in the future. if was_substr and not is_substr: return DeltaMessage( @@ -165,8 +192,11 @@ def _get_delta_message_with_no_reasoning_bounds( return DeltaMessage(reasoning_content=None, content=delta_text) def _get_delta_message_with_no_response_bounds( - self, current_text: str, reasoning_content: str, - delta_text: str) -> DeltaMessage: + self, + current_text: str, + reasoning_content: str, + delta_text: str, + ) -> DeltaMessage: """Parse the delta message when the current text has both reasoning content with no (response) content. NOTE that we may have overlapping tokens with the start of reasoning / start of response sequences on @@ -199,44 +229,38 @@ def _get_delta_message_with_no_response_bounds( prev_idx = previous_text.rfind(self.seq_boundary_start) delta_idx = delta_text.rfind(self.seq_boundary_start) - # And check if either we had one in the previous text, or have on - prev_was_substr = any( - previous_text[prev_idx:] in response_start for response_start in - self.valid_response_starts) if prev_idx >= 0 else False - delta_continues_substr = any( - current_text[prev_idx:] in response_start for response_start in - self.valid_response_starts) if prev_idx >= 0 else False - delta_new_substr = any( - delta_text[delta_idx:] in response_start for response_start in - self.valid_response_starts) if delta_idx >= 0 else False - # It was a substring before, and all delta tokens are - # part of the potential start of response sequence - - if prev_was_substr and delta_continues_substr: + # Check the state of potential start of response substring matches. + prev_was_substr = self._is_response_start_substr( + previous_text[prev_idx:]) if prev_idx >= 0 else False + delta_continues_substr = self._is_response_start_substr( + current_text[prev_idx:]) if prev_idx >= 0 else False + delta_new_substr = self._is_response_start_substr( + delta_text[delta_idx:]) if delta_idx >= 0 else False + + # Delta only contains potential continued response sequence text. + if delta_continues_substr: return DeltaMessage(reasoning_content=None, content=None) if not prev_was_substr: - # Don't add the potential unfinished response sequence + # Delta may be starting a new response seq but has other text too. if delta_new_substr: return DeltaMessage(reasoning_content=delta_text[:delta_idx], content=None) - # No possible places to start the response seq; continue normally + # Normal case for most reasoning text (no potential special seqs). return DeltaMessage(reasoning_content=delta_text, content=None) - # The substring being continued is the same one as before; - # the whole delta message is part of the potential response seq - if delta_continues_substr: - return DeltaMessage(reasoning_content=None, content=None) # The substring that previously seemed to be a potential response # seq wasn't one; we need to add the content to the delta message, # and also slice off the potential response sequence elif delta_new_substr: - return DeltaMessage(reasoning_content=previous_text[prev_idx:] + - delta_text[:delta_idx], + reasoning_content = previous_text[ + prev_idx:] + delta_text[:delta_idx] + return DeltaMessage(reasoning_content=reasoning_content, content=None) # No new substring yet, and we broke our old one; take the whole delta - return DeltaMessage(reasoning_content=previous_text[prev_idx:] + - delta_text, - content=None) + return DeltaMessage( + reasoning_content=previous_text[prev_idx:] + delta_text, + content=None, + ) def _get_delta_message_with_both_bounds( self, From 1604ca96371083a991ffe79c184d4f683925028c Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 4 Mar 2025 09:44:37 +0000 Subject: [PATCH 09/20] Pass response seq length through message parsing Signed-off-by: Alex-Brooks --- .../granite_reasoning_parser.py | 55 +++++++++---------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py index 601bbbc4548b..d702c74e5c80 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py +++ b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py @@ -109,7 +109,7 @@ def extract_reasoning_content_streaming( Union[DeltaMessage, None] DeltaMessage with either reasoning content or content, or None. """ - reasoning_content, response_content = self._get_content_sections( + reasoning_content, resp_seq_len, content = self._get_content_sections( current_text) # Either we haven't finished the start of the reasoning sequence, # or the model is generating something unexpected. @@ -118,13 +118,14 @@ def extract_reasoning_content_streaming( current_text, delta_text) # We have a start of reasoning message, but have not yet finished # the start of response sequence. - elif not response_content: + elif not content: delta_message = self._get_delta_message_with_no_response_bounds( current_text, reasoning_content, delta_text) # We've finished both the start of reasoning and start of response seq. else: delta_message = self._get_delta_message_with_both_bounds( - delta_text, reasoning_content, response_content, current_text) + delta_text, reasoning_content, content, current_text, + resp_seq_len) if not delta_message.content and not delta_message.reasoning_content: return None return delta_message @@ -204,7 +205,7 @@ def _get_delta_message_with_no_response_bounds( Args: current_text (str): The full previous + delta text. - reasoning_content (str): reasoning content parsed from current_text. + reasoning_content (str): reasoning content from current_text. delta_text (str): Text to consider and parse content from. Returns: @@ -268,33 +269,26 @@ def _get_delta_message_with_both_bounds( reasoning_content: str, response_content: str, current_text: str, + response_seq_len: int, ) -> DeltaMessage: """Parse the delta message when the current text has both reasoning content and normal (response) content. Args: delta_text (str): Text to consider and parse content from. - reasoning_content (str): reasoning content parsed from current_text. - response_content (str): response content parsed from current_text. + reasoning_content (str): reasoning content from current_text. + response_content (str): response content from current_text. current_text (str): The full previous + delta text. + response_seq_len(str): Len of the complete response sequence used. Returns: DeltaMessage: Message containing the parsed content. """ - # We have reasoning and response content, but it may not all be in the - # delta text; we need to consider the length of the start of response - # sequence and divide just the delta text part. - ##### HACK pass this through - for rs in self.valid_response_starts: - if rs in current_text: - response_seq_len = len(rs) - break - ##### - # Always have content; take length to the end delta_content = delta_text[-len(response_content):] reasoning_end_idx = len(delta_text) - (len(response_content) + response_seq_len) + if reasoning_end_idx < 0: delta_reasoning_content = None else: @@ -323,20 +317,21 @@ def _get_content_sections( current_text (str): The full previous + delta text. Returns: - Tuple[Optional[str], Optional[str]]: Tuple pair containing the - reasoning content and non-reasoning content. + Tuple[Optional[str], Optional[str], Optional[str]]: Tuple of len 3 + containing the reasoning content, the length of the response seq + (if there is one) and the non-reasoning content. """ + current_chunk_start = 0 + start_reasoning_content = None + start_response_content = None delimiter_idxs = [ idx for idx, char in enumerate(current_text) if char == self.seq_boundary_end ] - current_chunk_start = 0 - start_reasoning_content = None - start_response_content = None for current_chunk_end in delimiter_idxs: current_chunk = current_text[current_chunk_start:current_chunk_end] - # Check to see if this is start of reasoning + # Check to see if the start of reasoning seq if complete if start_reasoning_content is None: for think_start in self.valid_think_starts: if current_chunk == think_start[:-1]: @@ -344,21 +339,21 @@ def _get_content_sections( current_chunk_start = current_chunk_end + 1 break - # Check to see if this is start of response + # Check to see if the start of response seq if complete elif start_response_content is None: for response_start in self.valid_response_starts: if current_chunk[-len(response_start) + 1:] == response_start[:-1]: + # Mark end of reasoning and start response content + # after the start of response sequence. end_reasoning_content = current_chunk_end - len( response_start) reasoning_content = current_text[ start_reasoning_content:end_reasoning_content] response_content = current_text[current_chunk_end + 1:] - # Ensure we handle empty strings / None consistently - if not response_content: - response_content = None - return reasoning_content, response_content - # Set the actual response content + return reasoning_content, len( + response_start), response_content + if start_reasoning_content and start_response_content is None: - return current_text[start_reasoning_content:], None - return None, None + return current_text[start_reasoning_content:], None, None + return None, None, None From 118a051966155930542ed0b1f8c3c3cd021103cc Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 4 Mar 2025 09:47:02 +0000 Subject: [PATCH 10/20] Track parsed content as a bool Signed-off-by: Alex-Brooks --- .../openai/reasoning_parsers/granite_reasoning_parser.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py index d702c74e5c80..ac832e626956 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py +++ b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py @@ -323,7 +323,7 @@ def _get_content_sections( """ current_chunk_start = 0 start_reasoning_content = None - start_response_content = None + parsed_content = False delimiter_idxs = [ idx for idx, char in enumerate(current_text) if char == self.seq_boundary_end @@ -340,7 +340,7 @@ def _get_content_sections( break # Check to see if the start of response seq if complete - elif start_response_content is None: + elif not parsed_content: for response_start in self.valid_response_starts: if current_chunk[-len(response_start) + 1:] == response_start[:-1]: @@ -354,6 +354,6 @@ def _get_content_sections( return reasoning_content, len( response_start), response_content - if start_reasoning_content and start_response_content is None: + if start_reasoning_content and not parsed_content: return current_text[start_reasoning_content:], None, None return None, None, None From 5ac1c11bf80d6a7a28e02fa97ac5931605612ba8 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 4 Mar 2025 09:55:15 +0000 Subject: [PATCH 11/20] Add IBM 3.2 lang models to reasoning models Signed-off-by: Alex-Brooks --- docs/source/features/reasoning_outputs.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index b5fad26368bd..3169ea03c7a4 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -15,6 +15,8 @@ vLLM currently supports the following reasoning models: | [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | +- [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a), which look for `Here is my thought process: ... Here is my response: ...` + ## Quickstart To use reasoning models, you need to specify the `--enable-reasoning` and `--reasoning-parser` flags when making a request to the chat completion endpoint. The `--reasoning-parser` flag specifies the reasoning parser to use for extracting reasoning content from the model output. From 2b871f18157d42f64ac3fba7f9e5cdde84ece466 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 4 Mar 2025 11:07:28 +0000 Subject: [PATCH 12/20] Add note on thinking kwarg for granite reasoning Signed-off-by: Alex-Brooks --- docs/source/features/reasoning_outputs.md | 5 +++-- .../online_serving/openai_chat_completion_with_reasoning.py | 1 + .../openai_chat_completion_with_reasoning_streaming.py | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index 3169ea03c7a4..fd620c5f4097 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -4,7 +4,7 @@ vLLM offers support for reasoning models like [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), which are designed to generate outputs containing both reasoning steps and final conclusions. -Reasoning models return a additional `reasoning_content` field in their outputs, which contains the reasoning steps that led to the final conclusion. This field is not present in the outputs of other models. +Reasoning models return an additional `reasoning_content` field in their outputs, which contains the reasoning steps that led to the final conclusion. This field is not present in the outputs of other models. ## Supported Models @@ -15,7 +15,7 @@ vLLM currently supports the following reasoning models: | [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | -- [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a), which look for `Here is my thought process: ... Here is my response: ...` +- [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a), which look for `Here is my thought process: ... Here is my response: ...`. Note that for granite models, you must also pass `thinking=True` in your `chat_template_kwargs`. ## Quickstart @@ -45,6 +45,7 @@ model = models.data[0].id # Round 1 messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] +# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` response = client.chat.completions.create(model=model, messages=messages) reasoning_content = response.choices[0].message.reasoning_content diff --git a/examples/online_serving/openai_chat_completion_with_reasoning.py b/examples/online_serving/openai_chat_completion_with_reasoning.py index b5dbed1205d3..e753cedcdc08 100644 --- a/examples/online_serving/openai_chat_completion_with_reasoning.py +++ b/examples/online_serving/openai_chat_completion_with_reasoning.py @@ -31,6 +31,7 @@ # Round 1 messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] +# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` response = client.chat.completions.create(model=model, messages=messages) reasoning_content = response.choices[0].message.reasoning_content diff --git a/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py b/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py index fe4332576d43..cb13b0c614aa 100644 --- a/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py +++ b/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py @@ -38,6 +38,7 @@ model = models.data[0].id messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] +# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` stream = client.chat.completions.create(model=model, messages=messages, stream=True) From 721ab9fa4500fea7e84a5cf0f4f72e8cb33f2770 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 4 Mar 2025 11:57:53 +0000 Subject: [PATCH 13/20] Fix formatting Signed-off-by: Alex-Brooks --- .../test_granite_reasoning_parser.py | 4 +--- .../reasoning_parsers/granite_reasoning_parser.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py b/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py index 4495edc13c0f..84ac6600498b 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py +++ b/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py @@ -1,6 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List - import pytest from transformers import AutoTokenizer @@ -137,7 +135,7 @@ def test_reasoning( ): output = tokenizer.tokenize(param_dict["output"]) # decode everything to tokens - output_tokens: List[str] = [ + output_tokens: list[str] = [ tokenizer.convert_tokens_to_string([token]) for token in output ] parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( diff --git a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py index ac832e626956..117d051a7378 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py +++ b/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import re -from typing import Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Optional, Union from transformers import PreTrainedTokenizerBase @@ -53,7 +54,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest - ) -> Tuple[Optional[str], Optional[str]]: + ) -> tuple[Optional[str], Optional[str]]: """Extract the reasoning content & content sections, respectively. If the sequence doesn't match what we expect, i.e., the model generates something else, all content is considered non-reasoning content. @@ -63,7 +64,7 @@ def extract_reasoning_content( request (ChatCompletionReqest): Request being processed. Returns: - Tuple[Optional[str], Optional[str]]: Tuple pair containing the + tuple[Optional[str], Optional[str]]: Tuple pair containing the reasoning content and non-reasoning content. """ re_match = self.reasoning_regex.findall(model_output) @@ -123,6 +124,8 @@ def extract_reasoning_content_streaming( current_text, reasoning_content, delta_text) # We've finished both the start of reasoning and start of response seq. else: + # This should never happen since we matched on the response + assert resp_seq_len is not None delta_message = self._get_delta_message_with_both_bounds( delta_text, reasoning_content, content, current_text, resp_seq_len) @@ -309,7 +312,8 @@ def _get_delta_message_with_both_bounds( ) def _get_content_sections( - self, current_text: str) -> Tuple[Optional[str], Optional[str]]: + self, current_text: str + ) -> tuple[Optional[str], Optional[int], Optional[str]]: """Parse the text to extract the reasoning content / content if we have them. @@ -317,7 +321,7 @@ def _get_content_sections( current_text (str): The full previous + delta text. Returns: - Tuple[Optional[str], Optional[str], Optional[str]]: Tuple of len 3 + tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3 containing the reasoning content, the length of the response seq (if there is one) and the non-reasoning content. """ From 6b795861c27d7f0544d13e6ed6b1baf985adc5f3 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 6 Mar 2025 07:48:07 +0000 Subject: [PATCH 14/20] Add reasoning parser arg for granite Signed-off-by: Alex-Brooks --- vllm/engine/arg_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0d285acd15f3..983a1062c545 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1098,7 +1098,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--reasoning-parser", type=str, - choices=["deepseek_r1"], + choices=["deepseek_r1", "granite"], default=None, help= "Select the reasoning parser depending on the model that you're " From e460b5ec8529b6983cda7c3fd2d544f4f8d716b0 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 6 Mar 2025 07:51:52 +0000 Subject: [PATCH 15/20] Fix granite reasoning parser doc formatting Signed-off-by: Alex-Brooks --- docs/source/features/reasoning_outputs.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index fd620c5f4097..901b9ff6389c 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -13,9 +13,13 @@ vLLM currently supports the following reasoning models: | Model Series | Parser Name | Structured Output Support | |--------------|-------------|------------------| | [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | +<<<<<<< HEAD | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | +======= +| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a)* | `granite` | | +>>>>>>> 4ecb1695 (Fix granite reasoning parser doc formatting) -- [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a), which look for `Here is my thought process: ... Here is my response: ...`. Note that for granite models, you must also pass `thinking=True` in your `chat_template_kwargs`. +* IBM Granite 3.2 reasoning is disabled by response; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. ## Quickstart @@ -175,8 +179,8 @@ print("content: ", completion.choices[0].message.content) ## Limitations -- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). -- It is not compatible with [`tool_calling`](#tool_calling). +* The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). +* It is not compatible with [`tool_calling`](#tool_calling). ## How to support a new reasoning model From 658bf0ae99baf9c17725cc630ec890a25e04cbce Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 6 Mar 2025 08:24:13 +0000 Subject: [PATCH 16/20] Warn for unimplemented structured outputs reasoner Signed-off-by: Alex-Brooks --- vllm/model_executor/guided_decoding/reasoner/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/guided_decoding/reasoner/__init__.py b/vllm/model_executor/guided_decoding/reasoner/__init__.py index d930d3dbe94c..ab6e47c007d2 100644 --- a/vllm/model_executor/guided_decoding/reasoner/__init__.py +++ b/vllm/model_executor/guided_decoding/reasoner/__init__.py @@ -19,6 +19,10 @@ def get_reasoner(tokenizer: PreTrainedTokenizer, return None elif reasoning_backend == "deepseek_r1": return DeepSeekReasoner.from_tokenizer(tokenizer) + elif reasoning_backend == "granite": + logger.warning( + "Granite reasoner not yet implemented for structured outputs") + return None else: # Raise a warning for unknown reasoning backend and return None # We cannot raise an error here because some reasoning models From 75dcdd276a17e06b905f566c608517afb2f41367 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 10 Mar 2025 16:54:16 +0000 Subject: [PATCH 17/20] Add granite thinking note to reasoning docs Signed-off-by: Alex-Brooks --- docs/source/features/reasoning_outputs.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index 901b9ff6389c..aa43df61c475 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -13,11 +13,8 @@ vLLM currently supports the following reasoning models: | Model Series | Parser Name | Structured Output Support | |--------------|-------------|------------------| | [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | -<<<<<<< HEAD | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | -======= | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a)* | `granite` | | ->>>>>>> 4ecb1695 (Fix granite reasoning parser doc formatting) * IBM Granite 3.2 reasoning is disabled by response; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. @@ -104,6 +101,7 @@ models = client.models.list() model = models.data[0].id messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] +# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` stream = client.chat.completions.create(model=model, messages=messages, stream=True) From ab83ec170616175cafcf1ded076c9117ff0cebda Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 10 Mar 2025 14:06:32 -0600 Subject: [PATCH 18/20] Update docs/source/features/reasoning_outputs.md Co-authored-by: Joe Runde Signed-off-by: Alex-Brooks --- docs/source/features/reasoning_outputs.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index aa43df61c475..049958125c21 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -16,7 +16,7 @@ vLLM currently supports the following reasoning models: | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a)* | `granite` | | -* IBM Granite 3.2 reasoning is disabled by response; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. +* IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. ## Quickstart From a73c7895d8b35bf0d65657af6c491aa75d3bd74c Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 12 Mar 2025 22:19:19 +0000 Subject: [PATCH 19/20] Revert precommit bullet formatting Signed-off-by: Alex-Brooks --- docs/source/features/reasoning_outputs.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index 049958125c21..01e25f087cbf 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -14,9 +14,9 @@ vLLM currently supports the following reasoning models: |--------------|-------------|------------------| | [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | -| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a)* | `granite` | | +| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | | -* IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. +- IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. ## Quickstart @@ -177,8 +177,8 @@ print("content: ", completion.choices[0].message.content) ## Limitations -* The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). -* It is not compatible with [`tool_calling`](#tool_calling). +- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). +- It is not compatible with [`tool_calling`](#tool_calling). ## How to support a new reasoning model From da09717b4111a27c9fc96fc509f53010baec41b3 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 26 Mar 2025 11:45:26 +0000 Subject: [PATCH 20/20] Fix precommit Signed-off-by: Alex-Brooks --- docs/source/features/reasoning_outputs.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index 01f665d0c0cf..879b16d4f7b5 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -10,7 +10,6 @@ Reasoning models return an additional `reasoning_content` field in their outputs vLLM currently supports the following reasoning models: - | Model Series | Parser Name | Structured Output Support | Tool Calling | |--------------|-------------|------------------|-------------| | [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ | @@ -19,7 +18,6 @@ vLLM currently supports the following reasoning models: - IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. - ## Quickstart To use reasoning models, you need to specify the `--enable-reasoning` and `--reasoning-parser` flags when making a request to the chat completion endpoint. The `--reasoning-parser` flag specifies the reasoning parser to use for extracting reasoning content from the model output.