Skip to content

Commit 913dca6

Browse files
fix: use same code in both example and test case
Signed-off-by: nachiketb <nachiketb@nvidia.com>
1 parent a63f758 commit 913dca6

File tree

1 file changed

+33
-51
lines changed

1 file changed

+33
-51
lines changed

lib/bindings/python/examples/basic_reasoning_parser/basic_parser.py

Lines changed: 33 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,18 @@ def detect_and_parse_reasoning(
2828
One-time parsing: Detects and parses reasoning sections in the provided text.
2929
Returns both reasoning content and normal text separately.
3030
"""
31-
in_reasoning = self._in_reasoning or self.think_start_token in text
32-
33-
if not in_reasoning:
31+
start_idx = text.find(self.think_start_token)
32+
if start_idx == -1:
3433
return (text, "")
35-
36-
# The text is considered to be in a reasoning block.
37-
processed_text = text.replace(self.think_start_token, "").strip()
38-
39-
if self.think_end_token not in processed_text:
40-
# Assume reasoning was truncated before `</think>` token
41-
return ("", processed_text)
42-
43-
# Extract reasoning content
44-
splits = processed_text.split(self.think_end_token, maxsplit=1)
45-
reasoning_text = splits[0]
46-
normal_text = splits[1].strip()
47-
48-
return (normal_text, reasoning_text)
34+
normal_prefix = text[:start_idx]
35+
after_start = text[start_idx + len(self.think_start_token) :]
36+
end_idx = after_start.find(self.think_end_token)
37+
if end_idx == -1:
38+
# Reasoning started but not closed yet
39+
return (normal_prefix, after_start)
40+
reasoning_text = after_start[:end_idx]
41+
normal_suffix = after_start[end_idx + len(self.think_end_token) :]
42+
return (normal_prefix + normal_suffix, reasoning_text)
4943

5044
def parse_reasoning_streaming_incremental(
5145
self, new_text: str, _token_ids: Sequence[int]
@@ -60,44 +54,32 @@ def parse_reasoning_streaming_incremental(
6054
Streams reasoning content as it arrives
6155
"""
6256
self._buffer += new_text
63-
current_text = self._buffer
57+
current = self._buffer
58+
normal_out = ""
6459

65-
# If the current text is a prefix of the think token, keep buffering
66-
if any(
67-
token.startswith(current_text) and token != current_text
68-
for token in [self.think_start_token, self.think_end_token]
69-
):
70-
return ("", "")
71-
72-
# Strip `<think>` token if present
73-
if not self.stripped_think_start and self.think_start_token in current_text:
74-
current_text = current_text.replace(self.think_start_token, "")
75-
self.stripped_think_start = True
60+
# If not in reasoning, emit normal prefix up to `<think>`
61+
if not self._in_reasoning:
62+
start_idx = current.find(self.think_start_token)
63+
if start_idx == -1:
64+
self._buffer = ""
65+
return (current, "")
66+
normal_out = current[:start_idx]
67+
current = current[start_idx + len(self.think_start_token) :]
7668
self._in_reasoning = True
69+
self.stripped_think_start = True
7770

78-
# Handle end of reasoning block
79-
if self._in_reasoning and self.think_end_token in current_text:
80-
end_idx = current_text.find(self.think_end_token)
81-
82-
reasoning_text = current_text[:end_idx]
83-
71+
# In reasoning: check for `</think>`
72+
end_idx = current.find(self.think_end_token)
73+
if end_idx != -1:
74+
reasoning_delta = current[:end_idx]
75+
normal_suffix = current[end_idx + len(self.think_end_token) :]
8476
self._buffer = ""
8577
self._in_reasoning = False
86-
normal_text = current_text[end_idx + len(self.think_end_token) :]
87-
88-
return (normal_text, reasoning_text.rstrip())
78+
self.stripped_think_start = False
79+
return (normal_out + normal_suffix, reasoning_delta.rstrip())
8980

90-
# Continue with reasoning content
91-
if self._in_reasoning:
92-
if self.stream_reasoning:
93-
self._buffer = ""
94-
return ("", current_text)
95-
else:
96-
return ("", "")
97-
98-
# If we're not in a reasoning block return as normal text
99-
if not self._in_reasoning:
81+
# No end yet
82+
if self.stream_reasoning:
10083
self._buffer = ""
101-
return (current_text, "")
102-
103-
return ("", "")
84+
return (normal_out, current)
85+
return (normal_out, "")

0 commit comments

Comments
 (0)