Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions tests/reasoning/test_gptoss_reasoning_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
from transformers import AutoTokenizer

from vllm.reasoning import ReasoningParser
from vllm.reasoning.gptoss_reasoning_parser import GptOssReasoningParser

REASONING_MODEL_NAME = "openai/gpt-oss-120b"


@pytest.fixture(scope="module")
def gpt_oss_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)


USER_MESSAGE_START = "<|start|>user<|message|>"
REASONING_SECTION_START = "<|end|><|start|>assistant<|channel|>analysis<|message|>"
ASSISTANT_CONTENT_START_PREFIX = "<|end|><|start|>assistant<|channel|>final"
ASSISTANT_CONTENT_START_SUFFIX = "<|message|>"
ASSISTANT_CONTENT_START = (
ASSISTANT_CONTENT_START_PREFIX + ASSISTANT_CONTENT_START_SUFFIX
)

BASIC_CONTENT = {
"output": REASONING_SECTION_START
+ "This is reasoning"
+ ASSISTANT_CONTENT_START
+ "This is the rest",
"is_reasoning_end": True,
}

BASIC_REASONING_ONLY = {
"output": REASONING_SECTION_START + "This is reasoning" + "<|end|>",
"is_reasoning_end": False,
}
BASIC_NO_REASONING_NO_ASSISTANT = {
"output": USER_MESSAGE_START + "This is a user message",
"is_reasoning_end": False,
}

# Edge-case where the model omits the assistant tag entirely.
BASIC_NO_REASONING_ASSISTANT = {
"output": USER_MESSAGE_START + "This is a user message<|end|><|channel|>final",
"is_reasoning_end": True,
}

COMPLEX_CONTENT_INCOMPLETE_PREFIX_ONLY = {
"output": REASONING_SECTION_START
+ "This is reasoning"
+ ASSISTANT_CONTENT_START_PREFIX,
"is_reasoning_end": False,
}

COMPLEX_CONTENT_SUFFIX_ONLY = {
"output": REASONING_SECTION_START
+ "This is reasoning"
+ ASSISTANT_CONTENT_START_SUFFIX,
"is_reasoning_end": False,
}

COMPLEX_CONTENT_1_NO_SUFFIX = {
"output": REASONING_SECTION_START
+ "This is reasoning"
+ ASSISTANT_CONTENT_START_PREFIX
+ "<|constrain|> JSON ",
"is_reasoning_end": False,
}

COMPLEX_CONTENT_1 = {
"output": REASONING_SECTION_START
+ "This is reasoning"
+ ASSISTANT_CONTENT_START_PREFIX
+ "<|constrain|> JSON "
+ ASSISTANT_CONTENT_START_SUFFIX,
"is_reasoning_end": True,
}

COMPLEX_CONTENT_1_WITH_CONTENT = {
"output": REASONING_SECTION_START
+ "This is reasoning"
+ ASSISTANT_CONTENT_START_PREFIX
+ "<|constrain|> JSON "
+ ASSISTANT_CONTENT_START_SUFFIX
+ "This is the rest",
"is_reasoning_end": True,
}

COMPLEX_CONTENT_2 = {
"output": REASONING_SECTION_START
+ "This is reasoning"
+ ASSISTANT_CONTENT_START_PREFIX
+ "<|constrain|>ReplyAction "
+ ASSISTANT_CONTENT_START_SUFFIX
+ "This is the rest",
"is_reasoning_end": True,
}

TEST_CASES = [
BASIC_CONTENT,
BASIC_REASONING_ONLY,
COMPLEX_CONTENT_INCOMPLETE_PREFIX_ONLY,
COMPLEX_CONTENT_SUFFIX_ONLY,
COMPLEX_CONTENT_1_NO_SUFFIX,
COMPLEX_CONTENT_1,
COMPLEX_CONTENT_1_WITH_CONTENT,
COMPLEX_CONTENT_2,
]


@pytest.mark.parametrize(
"output, is_reasoning_end",
[(t["output"], t["is_reasoning_end"]) for t in TEST_CASES],
)
def test_gptoss_is_reasoning_end(
output,
is_reasoning_end,
gpt_oss_tokenizer,
):
output = gpt_oss_tokenizer.tokenize(output)
parser: ReasoningParser = GptOssReasoningParser(gpt_oss_tokenizer)

# Test is_reasoning_end
output_ids = gpt_oss_tokenizer.convert_tokens_to_ids(output)
actual_is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == actual_is_reasoning_end
31 changes: 24 additions & 7 deletions vllm/reasoning/gptoss_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,35 @@ class GptOssReasoningParser(ReasoningParser):

def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
self.reasoning_end_token_ids = self.model_tokenizer.encode(
"<|start|>assistant<|channel|>final<|message|>"
# The model can output some special tokens between "final" and "<|message|>"
# So we need to look for both sequences to determine the end of reasoning.
self.reasoning_end_token_ids_prefix = self.model_tokenizer.encode(
"<|channel|>final"
)
self.reasoning_end_token_ids_suffix = self.model_tokenizer.encode("<|message|>")
self.reasoning_max_num_between_tokens = 20

def is_reasoning_end(self, input_ids: list[int]) -> bool:
end_token_ids = self.reasoning_end_token_ids
assert len(end_token_ids) > 0, "reasoning_end_token_ids is empty"
end_token_ids_prefix = self.reasoning_end_token_ids_prefix
end_token_ids_suffix = self.reasoning_end_token_ids_suffix
assert len(end_token_ids_prefix) > 0, "reasoning_end_token_ids_prefix is empty"
assert len(end_token_ids_suffix) > 0, "reasoning_end_token_ids_suffix is empty"
# Check if the end sequence is present in the input_ids.
# We search from the end of input_ids to find the last match.
for i in range(len(input_ids) - len(end_token_ids), -1, -1):
if input_ids[i : i + len(end_token_ids)] == end_token_ids:
return True
for i in range(len(input_ids) - len(end_token_ids_prefix), -1, -1):
if input_ids[i : i + len(end_token_ids_prefix)] == end_token_ids_prefix:
# We have found the prefix, now we look for the suffix after the prefix.
suffix_start = i + len(end_token_ids_prefix)
for j in range(
suffix_start, len(input_ids) - len(end_token_ids_suffix) + 1
):
if j - suffix_start >= self.reasoning_max_num_between_tokens:
break
if (
input_ids[j : j + len(end_token_ids_suffix)]
== end_token_ids_suffix
):
return True
return False

def extract_content_ids(self, input_ids: list[int]) -> list[int]:
Expand Down