Skip to content

Commit 36eebfe

Browse files
committed
Parse '<think>' tags in streamed text as thinking parts
1 parent 903f11e commit 36eebfe

File tree

6 files changed

+2227
-7
lines changed

6 files changed

+2227
-7
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
from collections.abc import Hashable
1717
from dataclasses import dataclass, field, replace
18-
from typing import Any, Union
18+
from typing import Any, Literal, Union, overload
1919

20+
from pydantic_ai._thinking_part import END_THINK_TAG, START_THINK_TAG
2021
from pydantic_ai.exceptions import UnexpectedModelBehavior
2122
from pydantic_ai.messages import (
2223
ModelResponsePart,
@@ -66,12 +67,30 @@ def get_parts(self) -> list[ModelResponsePart]:
6667
"""
6768
return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]
6869

70+
@overload
6971
def handle_text_delta(
7072
self,
7173
*,
72-
vendor_part_id: Hashable | None,
74+
vendor_part_id: VendorId | None,
7375
content: str,
74-
) -> ModelResponseStreamEvent:
76+
) -> ModelResponseStreamEvent: ...
77+
78+
@overload
79+
def handle_text_delta(
80+
self,
81+
*,
82+
vendor_part_id: VendorId,
83+
content: str,
84+
extract_think_tags: Literal[True],
85+
) -> ModelResponseStreamEvent | None: ...
86+
87+
def handle_text_delta(
88+
self,
89+
*,
90+
vendor_part_id: VendorId | None,
91+
content: str,
92+
extract_think_tags: bool = False,
93+
) -> ModelResponseStreamEvent | None:
7594
"""Handle incoming text content, creating or updating a TextPart in the manager as appropriate.
7695
7796
When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
@@ -83,6 +102,7 @@ def handle_text_delta(
83102
of text. If None, a new part will be created unless the latest part is already
84103
a TextPart.
85104
content: The text content to append to the appropriate TextPart.
105+
extract_think_tags: Whether to extract `<think>` tags from the text content and handle them as thinking parts.
86106
87107
Returns:
88108
A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.
@@ -104,9 +124,24 @@ def handle_text_delta(
104124
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
105125
if part_index is not None:
106126
existing_part = self._parts[part_index]
107-
if not isinstance(existing_part, TextPart):
127+
128+
if extract_think_tags and isinstance(existing_part, ThinkingPart):
129+
# We may be building a thinking part instead of a text part if we had previously seen a `<think>` tag
130+
if content == END_THINK_TAG:
131+
# When we see `</think>`, we're done with the thinking part and the next text delta will need a new part
132+
self._vendor_id_to_part_index.pop(vendor_part_id)
133+
return None
134+
else:
135+
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
136+
elif isinstance(existing_part, TextPart):
137+
existing_text_part_and_index = existing_part, part_index
138+
else:
108139
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
109-
existing_text_part_and_index = existing_part, part_index
140+
141+
if extract_think_tags and content == START_THINK_TAG:
142+
# When we see a `<think>` tag (which is a single token), we'll build a new thinking part instead
143+
self._vendor_id_to_part_index.pop(vendor_part_id, None)
144+
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
110145

111146
if existing_text_part_and_index is None:
112147
# There is no existing text part that should be updated, so create a new one

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
415415
# Handle the text part of the response
416416
content = choice.delta.content
417417
if content is not None:
418-
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
418+
maybe_event = self._parts_manager.handle_text_delta(
419+
vendor_part_id='content', content=content, extract_think_tags=True
420+
)
421+
if maybe_event is not None: # pragma: no branch
422+
yield maybe_event
419423

420424
# Handle the tool calls
421425
for dtc in choice.delta.tool_calls or []:

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
10041004
# Handle the text part of the response
10051005
content = choice.delta.content
10061006
if content:
1007-
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
1007+
maybe_event = self._parts_manager.handle_text_delta(
1008+
vendor_part_id='content', content=content, extract_think_tags=True
1009+
)
1010+
if maybe_event is not None: # pragma: no branch
1011+
yield maybe_event
10081012

10091013
# Handle reasoning part of the response, present in DeepSeek models
10101014
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):

0 commit comments

Comments
 (0)