|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
3 | | -import re |
4 | 3 | from collections.abc import Sequence |
5 | 4 | from typing import Optional, Union |
6 | 5 |
|
@@ -32,9 +31,6 @@ class DeepSeekR1ReasoningParser(ReasoningParser): |
32 | 31 | def __init__(self, tokenizer: PreTrainedTokenizerBase): |
33 | 32 | super().__init__(tokenizer) |
34 | 33 |
|
35 | | - self.reasoning_regex = re.compile( |
36 | | - rf"{self.start_token}(.*?){self.end_token}", re.DOTALL) |
37 | | - |
38 | 34 | if not self.model_tokenizer: |
39 | 35 | raise ValueError( |
40 | 36 | "The model tokenizer must be passed to the ReasoningParser " |
@@ -143,23 +139,34 @@ def extract_reasoning_content_streaming( |
143 | 139 | def extract_reasoning_content( |
144 | 140 | self, model_output: str, request: ChatCompletionRequest |
145 | 141 | ) -> tuple[Optional[str], Optional[str]]: |
| 142 | + """ |
| 143 | + Extract reasoning content from the model output. |
| 144 | +
|
| 145 | + For text <think>abc</think>xyz: |
| 146 | + - 'abc' goes to reasoning_content |
| 147 | + - 'xyz' goes to content |
| 148 | +
|
| 149 | + Returns: |
| 150 | + tuple[Optional[str], Optional[str]]: reasoning content and content |
| 151 | + """ |
| 152 | + |
| 153 | + # Check if the start token is present in the model output, remove it |
| 154 | + # if it is present. |
| 155 | + model_output_parts = model_output.partition(self.start_token) |
| 156 | + model_output = model_output_parts[2] if model_output_parts[ |
| 157 | + 1] else model_output_parts[0] |
| 158 | + |
146 | 159 | # DeepSeek R1 doesn't generate <think> now. |
147 | 160 | # Thus we assume the reasoning content is always at the start. |
148 | 161 | # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f |
149 | 162 | if self.end_token not in model_output: |
150 | 163 | return model_output, None |
151 | 164 | else: |
152 | | - # Add a start token if it's missing to keep compatibility. |
153 | | - if self.start_token not in model_output: |
154 | | - model_output = f"{self.start_token}{model_output}" |
155 | | - # Use a regex to find the reasoning content |
156 | | - reasoning_content = self.reasoning_regex.findall(model_output)[0] |
157 | | - |
158 | | - end_index = len( |
159 | | - f"{self.start_token}{reasoning_content}{self.end_token}") |
160 | | - final_output = model_output[end_index:] |
161 | | - |
162 | | - if len(final_output) == 0: |
163 | | - return reasoning_content, None |
164 | | - |
165 | | - return reasoning_content, final_output |
| 165 | + reasoning_content, _, content = model_output.partition( |
| 166 | + self.end_token) |
| 167 | + # If the end token is not found, return the model output as is. |
| 168 | + # It should not happen since we already checked for the presence |
| 169 | + # of the end token. |
| 170 | + # If generation stops right after end-of-think, return null content |
| 171 | + final_content = content or None |
| 172 | + return reasoning_content, final_content |
0 commit comments