Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openai.types.responses.response_prompt_param import (
ResponsePromptParam,
)
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from typing_extensions import NotRequired, TypedDict, Unpack

from ._run_impl import (
Expand Down Expand Up @@ -48,6 +49,7 @@
HandoffCallItem,
ItemHelpers,
ModelResponse,
ReasoningItem,
RunItem,
ToolCallItem,
ToolCallItemTypes,
Expand Down Expand Up @@ -1097,6 +1099,7 @@ async def _run_single_turn_streamed(
server_conversation_tracker: _ServerConversationTracker | None = None,
) -> SingleStepResult:
emitted_tool_call_ids: set[str] = set()
emitted_reasoning_item_ids: set[str] = set()

if should_run_agent_start_hooks:
await asyncio.gather(
Expand Down Expand Up @@ -1178,6 +1181,9 @@ async def _run_single_turn_streamed(
conversation_id=conversation_id,
prompt=prompt_config,
):
# Emit the raw event ASAP
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))

if isinstance(event, ResponseCompletedEvent):
usage = (
Usage(
Expand Down Expand Up @@ -1217,7 +1223,16 @@ async def _run_single_turn_streamed(
RunItemStreamEvent(item=tool_item, name="tool_called")
)

streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
elif isinstance(output_item, ResponseReasoningItem):
reasoning_id: str | None = getattr(output_item, "id", None)

if reasoning_id and reasoning_id not in emitted_reasoning_item_ids:
emitted_reasoning_item_ids.add(reasoning_id)

reasoning_item = ReasoningItem(raw_item=output_item, agent=agent)
streamed_result._event_queue.put_nowait(
RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created")
)

# Call hook just after the model response is finalized.
if final_response is not None:
Expand Down Expand Up @@ -1271,6 +1286,18 @@ async def _run_single_turn_streamed(
)
]

if emitted_reasoning_item_ids:
# Filter out reasoning items that were already emitted during streaming
items_to_filter = [
item
for item in items_to_filter
if not (
isinstance(item, ReasoningItem)
and (reasoning_id := getattr(item.raw_item, "id", None))
and reasoning_id in emitted_reasoning_item_ids
)
]

# Filter out HandoffCallItem to avoid duplicates (already sent earlier)
items_to_filter = [
item for item in items_to_filter if not isinstance(item, HandoffCallItem)
Expand Down
173 changes: 170 additions & 3 deletions tests/fake_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,33 @@
from collections.abc import AsyncIterator
from typing import Any

from openai.types.responses import Response, ResponseCompletedEvent, ResponseUsage
from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningSummaryPartAddedEvent,
ResponseReasoningSummaryPartDoneEvent,
ResponseReasoningSummaryTextDeltaEvent,
ResponseReasoningSummaryTextDoneEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.response_reasoning_summary_part_added_event import (
Part as AddedEventPart,
)
from openai.types.responses.response_reasoning_summary_part_done_event import Part as DoneEventPart
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails

from agents.agent_output import AgentOutputSchemaBase
Expand Down Expand Up @@ -143,10 +169,151 @@ async def stream_response(
)
raise output

response = get_response_obj(output, usage=self.hardcoded_usage)
sequence_number = 0

yield ResponseCreatedEvent(
type="response.created",
response=response,
sequence_number=sequence_number,
)
sequence_number += 1

yield ResponseInProgressEvent(
type="response.in_progress",
response=response,
sequence_number=sequence_number,
)
sequence_number += 1

for output_index, output_item in enumerate(output):
yield ResponseOutputItemAddedEvent(
type="response.output_item.added",
item=output_item,
output_index=output_index,
sequence_number=sequence_number,
)
sequence_number += 1

if isinstance(output_item, ResponseReasoningItem):
if output_item.summary:
for summary_index, summary in enumerate(output_item.summary):
yield ResponseReasoningSummaryPartAddedEvent(
type="response.reasoning_summary_part.added",
item_id=output_item.id,
output_index=output_index,
summary_index=summary_index,
part=AddedEventPart(text=summary.text, type=summary.type),
sequence_number=sequence_number,
)
sequence_number += 1

yield ResponseReasoningSummaryTextDeltaEvent(
type="response.reasoning_summary_text.delta",
item_id=output_item.id,
output_index=output_index,
summary_index=summary_index,
delta=summary.text,
sequence_number=sequence_number,
)
sequence_number += 1

yield ResponseReasoningSummaryTextDoneEvent(
type="response.reasoning_summary_text.done",
item_id=output_item.id,
output_index=output_index,
summary_index=summary_index,
text=summary.text,
sequence_number=sequence_number,
)
sequence_number += 1

yield ResponseReasoningSummaryPartDoneEvent(
type="response.reasoning_summary_part.done",
item_id=output_item.id,
output_index=output_index,
summary_index=summary_index,
part=DoneEventPart(text=summary.text, type=summary.type),
sequence_number=sequence_number,
)
sequence_number += 1

elif isinstance(output_item, ResponseFunctionToolCall):
yield ResponseFunctionCallArgumentsDeltaEvent(
type="response.function_call_arguments.delta",
item_id=output_item.call_id,
output_index=output_index,
delta=output_item.arguments,
sequence_number=sequence_number,
)
sequence_number += 1

yield ResponseFunctionCallArgumentsDoneEvent(
type="response.function_call_arguments.done",
item_id=output_item.call_id,
output_index=output_index,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seratch Insert name=output_item.name here to fix the mypy error that appeared after upgrading to openai-python v2.

arguments=output_item.arguments,
sequence_number=sequence_number,
)
sequence_number += 1

elif isinstance(output_item, ResponseOutputMessage):
for content_index, content_part in enumerate(output_item.content):
if isinstance(content_part, ResponseOutputText):
yield ResponseContentPartAddedEvent(
type="response.content_part.added",
item_id=output_item.id,
output_index=output_index,
content_index=content_index,
part=content_part,
sequence_number=sequence_number,
)
sequence_number += 1

yield ResponseTextDeltaEvent(
type="response.output_text.delta",
item_id=output_item.id,
output_index=output_index,
content_index=content_index,
delta=content_part.text,
logprobs=[],
sequence_number=sequence_number,
)
sequence_number += 1

yield ResponseTextDoneEvent(
type="response.output_text.done",
item_id=output_item.id,
output_index=output_index,
content_index=content_index,
text=content_part.text,
logprobs=[],
sequence_number=sequence_number,
)
sequence_number += 1

yield ResponseContentPartDoneEvent(
type="response.content_part.done",
item_id=output_item.id,
output_index=output_index,
content_index=content_index,
part=content_part,
sequence_number=sequence_number,
)
sequence_number += 1

yield ResponseOutputItemDoneEvent(
type="response.output_item.done",
item=output_item,
output_index=output_index,
sequence_number=sequence_number,
)
sequence_number += 1

yield ResponseCompletedEvent(
type="response.completed",
response=get_response_obj(output, usage=self.hardcoded_usage),
sequence_number=0,
response=response,
sequence_number=sequence_number,
)


Expand Down
14 changes: 13 additions & 1 deletion tests/fastapi/test_streaming_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,17 @@ async def test_streaming_context():
body = (await r.aread()).decode("utf-8")
lines = [line for line in body.splitlines() if line]
assert lines == snapshot(
["agent_updated_stream_event", "raw_response_event", "run_item_stream_event"]
[
"agent_updated_stream_event",
"raw_response_event", # ResponseCreatedEvent
"raw_response_event", # ResponseInProgressEvent
"raw_response_event", # ResponseOutputItemAddedEvent
"raw_response_event", # ResponseContentPartAddedEvent
"raw_response_event", # ResponseTextDeltaEvent
"raw_response_event", # ResponseTextDoneEvent
"raw_response_event", # ResponseContentPartDoneEvent
"raw_response_event", # ResponseOutputItemDoneEvent
"raw_response_event", # ResponseCompletedEvent
"run_item_stream_event", # MessageOutputItem
]
)
13 changes: 9 additions & 4 deletions tests/test_agent_runner_streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,11 +695,16 @@ async def test_streaming_events():
# Now lets check the events

expected_item_type_map = {
"tool_call": 2,
# 3 tool_call_item events:
# 1. get_function_tool_call("foo", ...)
# 2. get_handoff_tool_call(agent_1) because handoffs are implemented via tool calls too
# 3. get_function_tool_call("bar", ...)
"tool_call": 3,
# Only 2 outputs, handoff tool call doesn't have corresponding tool_call_output event
"tool_call_output": 2,
"message": 2,
"handoff": 1,
"handoff_output": 1,
"message": 2, # get_text_message("a_message") + get_final_output_message(...)
"handoff": 1, # get_handoff_tool_call(agent_1)
"handoff_output": 1, # handoff_output_item
}

total_expected_item_count = sum(expected_item_type_map.values())
Expand Down
Loading