Skip to content

Commit a6be1b4

Browse files
committed
fix: bring coverage back up, addressing edge cases
1 parent 5ce726a commit a6be1b4

File tree

9 files changed

+907
-63
lines changed

9 files changed

+907
-63
lines changed

src/agents/handoffs/history.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@ def _build_summary_message(transcript: list[TResponseInputItem]) -> TResponseInp
126126
end_marker,
127127
]
128128
content = "\n".join(content_lines)
129-
assistant_message: dict[str, Any] = {
130-
"role": "assistant",
129+
summary_message: dict[str, Any] = {
130+
"role": "system",
131131
"content": content,
132132
}
133-
return cast(TResponseInputItem, assistant_message)
133+
return cast(TResponseInputItem, summary_message)
134134

135135

136136
def _format_transcript_item(item: TResponseInputItem) -> str:

src/agents/items.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import abc
4+
import json
45
import weakref
56
from dataclasses import dataclass, field
67
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union, cast
@@ -56,6 +57,44 @@
5657
)
5758
from .usage import Usage
5859

60+
61+
def normalize_function_call_output_payload(payload: dict[str, Any]) -> dict[str, Any]:
62+
"""Ensure function_call_output payloads conform to Responses API expectations."""
63+
64+
payload_type = payload.get("type")
65+
if payload_type not in {"function_call_output", "function_call_result"}:
66+
return payload
67+
68+
output_value = payload.get("output")
69+
70+
if output_value is None:
71+
payload["output"] = ""
72+
return payload
73+
74+
if isinstance(output_value, list):
75+
if all(
76+
isinstance(entry, dict) and entry.get("type") in _ALLOWED_FUNCTION_CALL_OUTPUT_TYPES
77+
for entry in output_value
78+
):
79+
return payload
80+
payload["output"] = json.dumps(output_value)
81+
return payload
82+
83+
if isinstance(output_value, dict):
84+
entry_type = output_value.get("type")
85+
if entry_type in _ALLOWED_FUNCTION_CALL_OUTPUT_TYPES:
86+
payload["output"] = [output_value]
87+
else:
88+
payload["output"] = json.dumps(output_value)
89+
return payload
90+
91+
if isinstance(output_value, str):
92+
return payload
93+
94+
payload["output"] = json.dumps(output_value)
95+
return payload
96+
97+
5998
if TYPE_CHECKING:
6099
from .agent import Agent
61100

@@ -75,6 +114,15 @@
75114

76115
# Distinguish a missing dict entry from an explicit None value.
77116
_MISSING_ATTR_SENTINEL = object()
117+
_ALLOWED_FUNCTION_CALL_OUTPUT_TYPES: set[str] = {
118+
"input_text",
119+
"input_image",
120+
"output_text",
121+
"refusal",
122+
"input_file",
123+
"computer_screenshot",
124+
"summary_text",
125+
}
78126

79127

80128
@dataclass
@@ -220,6 +268,21 @@ def release_agent(self) -> None:
220268
# Preserve dataclass fields for repr/asdict while dropping strong refs.
221269
self.__dict__["target_agent"] = None
222270

271+
def to_input_item(self) -> TResponseInputItem:
272+
"""Convert handoff output into the API format expected by the model."""
273+
274+
if isinstance(self.raw_item, dict):
275+
payload = dict(self.raw_item)
276+
if payload.get("type") == "function_call_result":
277+
payload["type"] = "function_call_output"
278+
payload.pop("name", None)
279+
payload.pop("status", None)
280+
281+
payload = normalize_function_call_output_payload(payload)
282+
return cast(TResponseInputItem, payload)
283+
284+
return super().to_input_item()
285+
223286

224287
ToolCallItemTypes: TypeAlias = Union[
225288
ResponseFunctionToolCall,
@@ -273,15 +336,25 @@ def to_input_item(self) -> TResponseInputItem:
273336
Hosted tool outputs (e.g. shell/apply_patch) carry a `status` field for the SDK's
274337
book-keeping, but the Responses API does not yet accept that parameter. Strip it from the
275338
payload we send back to the model while keeping the original raw item intact.
339+
340+
Also converts protocol format (function_call_result) to API format (function_call_output).
276341
"""
277342

278343
if isinstance(self.raw_item, dict):
279344
payload = dict(self.raw_item)
280345
payload_type = payload.get("type")
281-
if payload_type == "shell_call_output":
346+
# Convert protocol format to API format
347+
# Protocol uses function_call_result, API expects function_call_output
348+
if payload_type == "function_call_result":
349+
payload["type"] = "function_call_output"
350+
# Remove fields that are in protocol format but not in API format
351+
payload.pop("name", None)
352+
payload.pop("status", None)
353+
elif payload_type == "shell_call_output":
282354
payload.pop("status", None)
283355
payload.pop("shell_output", None)
284356
payload.pop("provider_data", None)
357+
payload = normalize_function_call_output_payload(payload)
285358
return cast(TResponseInputItem, payload)
286359

287360
return super().to_input_item()
@@ -392,6 +465,17 @@ def arguments(self) -> str | None:
392465
return self.raw_item.arguments
393466
return None
394467

468+
def to_input_item(self) -> TResponseInputItem:
469+
"""ToolApprovalItem should never be converted to input items.
470+
471+
These items represent pending approvals and should be filtered out before
472+
preparing input for the API. This method raises an error to prevent accidental usage.
473+
"""
474+
raise AgentsException(
475+
"ToolApprovalItem cannot be converted to an input item. "
476+
"These items should be filtered out before preparing input for the API."
477+
)
478+
395479

396480
RunItem: TypeAlias = Union[
397481
MessageOutputItem,

src/agents/run.py

Lines changed: 131 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
ToolCallItem,
6060
ToolCallItemTypes,
6161
TResponseInputItem,
62+
normalize_function_call_output_payload,
6263
)
6364
from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase
6465
from .logger import logger
@@ -742,10 +743,15 @@ async def run(
742743
# Resuming from a saved state
743744
run_state = cast(RunState[TContext], input)
744745
original_user_input = run_state._original_input
745-
# Normalize items to remove top-level providerData (API doesn't accept it there)
746+
# Normalize items to remove top-level providerData and convert protocol to API format
747+
# Then filter incomplete function calls to ensure API compatibility
746748
if isinstance(original_user_input, list):
747-
prepared_input: str | list[TResponseInputItem] = AgentRunner._normalize_input_items(
748-
original_user_input
749+
# Normalize first (converts protocol format to API format, normalizes field names)
750+
normalized = AgentRunner._normalize_input_items(original_user_input)
751+
# Filter incomplete function calls after normalization
752+
# This ensures consistent field names (call_id vs callId) for matching
753+
prepared_input: str | list[TResponseInputItem] = (
754+
AgentRunner._filter_incomplete_function_calls(normalized)
749755
)
750756
else:
751757
prepared_input = original_user_input
@@ -787,12 +793,16 @@ async def run(
787793
if is_resumed_state and run_state is not None:
788794
# Restore state from RunState
789795
current_turn = run_state._current_turn
790-
# Normalize original_input to remove top-level providerData
791-
# (API doesn't accept it there)
796+
# Normalize original_input: remove top-level providerData,
797+
# convert protocol to API format, then filter incomplete function calls
792798
raw_original_input = run_state._original_input
793799
if isinstance(raw_original_input, list):
800+
# Normalize first (converts protocol to API format, normalizes field names)
801+
normalized = AgentRunner._normalize_input_items(raw_original_input)
802+
# Filter incomplete function calls after normalization
803+
# This ensures consistent field names (call_id vs callId) for matching
794804
original_input: str | list[TResponseInputItem] = (
795-
AgentRunner._normalize_input_items(raw_original_input)
805+
AgentRunner._filter_incomplete_function_calls(normalized)
796806
)
797807
else:
798808
original_input = raw_original_input
@@ -861,8 +871,40 @@ async def run(
861871
)
862872
in output_call_ids
863873
]
864-
# Save both function_call and function_call_output together
865-
items_to_save = tool_call_items + tool_output_items
874+
# Check which items are already in the session to avoid duplicates
875+
# Get existing items from session and extract their call_ids
876+
existing_items = await session.get_items()
877+
existing_call_ids: set[str] = set()
878+
for existing_item in existing_items:
879+
if isinstance(existing_item, dict):
880+
item_type = existing_item.get("type")
881+
if item_type in ("function_call", "function_call_output"):
882+
existing_call_id = existing_item.get(
883+
"call_id"
884+
) or existing_item.get("callId")
885+
if existing_call_id and isinstance(existing_call_id, str):
886+
existing_call_ids.add(existing_call_id)
887+
888+
# Filter out items that are already in the session
889+
items_to_save: list[RunItem] = []
890+
for item in tool_call_items + tool_output_items:
891+
item_call_id: str | None = None
892+
if isinstance(item.raw_item, dict):
893+
raw_call_id = item.raw_item.get("call_id") or item.raw_item.get(
894+
"callId"
895+
)
896+
item_call_id = (
897+
cast(str | None, raw_call_id) if raw_call_id else None
898+
)
899+
elif hasattr(item.raw_item, "call_id"):
900+
item_call_id = cast(
901+
str | None, getattr(item.raw_item, "call_id", None)
902+
)
903+
904+
# Only save if not already in session
905+
if item_call_id is None or item_call_id not in existing_call_ids:
906+
items_to_save.append(item)
907+
866908
if items_to_save:
867909
await self._save_result_to_session(session, [], items_to_save)
868910
# Clear the current step since we've handled it
@@ -1419,11 +1461,12 @@ async def _start_streaming(
14191461
# Resuming from state - normalize items to remove top-level providerData
14201462
# and filter incomplete function_call pairs
14211463
if isinstance(starting_input, list):
1422-
# Filter incomplete function_call pairs before normalizing
1423-
filtered = AgentRunner._filter_incomplete_function_calls(starting_input)
1424-
prepared_input: str | list[TResponseInputItem] = (
1425-
AgentRunner._normalize_input_items(filtered)
1426-
)
1464+
# Normalize field names first (camelCase -> snake_case) to ensure
1465+
# consistent field names for filtering
1466+
normalized_input = AgentRunner._normalize_input_items(starting_input)
1467+
# Filter incomplete function_call pairs after normalizing
1468+
filtered = AgentRunner._filter_incomplete_function_calls(normalized_input)
1469+
prepared_input: str | list[TResponseInputItem] = filtered
14271470
else:
14281471
prepared_input = starting_input
14291472
else:
@@ -2600,33 +2643,67 @@ def _normalize_input_items(items: list[TResponseInputItem]) -> list[TResponseInp
26002643
"""
26012644
from .run_state import _normalize_field_names
26022645

2646+
def _coerce_to_dict(value: TResponseInputItem) -> dict[str, Any] | None:
2647+
if isinstance(value, dict):
2648+
return dict(value)
2649+
if hasattr(value, "model_dump"):
2650+
try:
2651+
return cast(dict[str, Any], value.model_dump(exclude_unset=True))
2652+
except Exception:
2653+
return None
2654+
return None
2655+
26032656
normalized: list[TResponseInputItem] = []
26042657
for item in items:
2605-
if isinstance(item, dict):
2606-
# Create a copy to avoid modifying the original
2607-
normalized_item = dict(item)
2608-
# Remove top-level providerData/provider_data - these should only be in content
2609-
# The API doesn't accept providerData at the top level of input items
2610-
normalized_item.pop("providerData", None)
2611-
normalized_item.pop("provider_data", None)
2612-
# Normalize item type: API expects 'function_call_output',
2613-
# not 'function_call_result'
2614-
item_type = normalized_item.get("type")
2615-
if item_type == "function_call_result":
2616-
normalized_item["type"] = "function_call_output"
2617-
item_type = "function_call_output"
2618-
# Remove invalid fields based on item type
2619-
# function_call_output items should not have 'name' field
2620-
if item_type == "function_call_output":
2621-
normalized_item.pop("name", None)
2622-
# Normalize field names (callId -> call_id, responseId -> response_id)
2623-
normalized_item = _normalize_field_names(normalized_item)
2624-
normalized.append(cast(TResponseInputItem, normalized_item))
2625-
else:
2626-
# For non-dict items, keep as-is (they should already be in correct format)
2658+
coerced = _coerce_to_dict(item)
2659+
if coerced is None:
26272660
normalized.append(item)
2661+
continue
2662+
2663+
normalized_item = dict(coerced)
2664+
normalized_item.pop("providerData", None)
2665+
normalized_item.pop("provider_data", None)
2666+
item_type = normalized_item.get("type")
2667+
if item_type == "function_call_result":
2668+
normalized_item["type"] = "function_call_output"
2669+
item_type = "function_call_output"
2670+
if item_type == "function_call_output":
2671+
normalized_item.pop("name", None)
2672+
normalized_item.pop("status", None)
2673+
normalized_item = normalize_function_call_output_payload(normalized_item)
2674+
normalized_item = _normalize_field_names(normalized_item)
2675+
normalized.append(cast(TResponseInputItem, normalized_item))
26282676
return normalized
26292677

2678+
@staticmethod
2679+
def _ensure_api_input_item(item: TResponseInputItem) -> TResponseInputItem:
2680+
"""Ensure item is in API format (function_call_output, snake_case fields)."""
2681+
2682+
def _coerce_dict(value: TResponseInputItem) -> dict[str, Any] | None:
2683+
if isinstance(value, dict):
2684+
return dict(value)
2685+
if hasattr(value, "model_dump"):
2686+
try:
2687+
return cast(dict[str, Any], value.model_dump(exclude_unset=True))
2688+
except Exception:
2689+
return None
2690+
return None
2691+
2692+
coerced = _coerce_dict(item)
2693+
if coerced is None:
2694+
return item
2695+
2696+
normalized = dict(coerced)
2697+
item_type = normalized.get("type")
2698+
if item_type == "function_call_result":
2699+
normalized["type"] = "function_call_output"
2700+
normalized.pop("name", None)
2701+
normalized.pop("status", None)
2702+
2703+
if normalized.get("type") == "function_call_output":
2704+
normalized = normalize_function_call_output_payload(normalized)
2705+
return cast(TResponseInputItem, normalized)
2706+
26302707
@classmethod
26312708
async def _prepare_input_with_session(
26322709
cls,
@@ -2651,13 +2728,19 @@ async def _prepare_input_with_session(
26512728
# Get previous conversation history
26522729
history = await session.get_items()
26532730

2731+
# Convert protocol format items from session to API format.
2732+
# TypeScript may save protocol format (function_call_result) to sessions,
2733+
# but the API expects API format (function_call_output).
2734+
converted_history = [cls._ensure_api_input_item(item) for item in history]
2735+
26542736
# Convert input to list format
26552737
new_input_list = ItemHelpers.input_to_new_input_list(input)
2738+
new_input_list = [cls._ensure_api_input_item(item) for item in new_input_list]
26562739

26572740
if session_input_callback is None:
2658-
merged = history + new_input_list
2741+
merged = converted_history + new_input_list
26592742
elif callable(session_input_callback):
2660-
res = session_input_callback(history, new_input_list)
2743+
res = session_input_callback(converted_history, new_input_list)
26612744
if inspect.isawaitable(res):
26622745
merged = await res
26632746
else:
@@ -2711,10 +2794,19 @@ async def _save_result_to_session(
27112794
return
27122795

27132796
# Convert original input to list format if needed
2714-
input_list = ItemHelpers.input_to_new_input_list(original_input)
2797+
input_list = [
2798+
cls._ensure_api_input_item(item)
2799+
for item in ItemHelpers.input_to_new_input_list(original_input)
2800+
]
2801+
2802+
# Filter out tool_approval_item items before converting to input format
2803+
# These items represent pending approvals and shouldn't be sent to the API
2804+
items_to_convert = [item for item in new_items if item.type != "tool_approval_item"]
27152805

27162806
# Convert new items to input format
2717-
new_items_as_input = [item.to_input_item() for item in new_items]
2807+
new_items_as_input = [
2808+
cls._ensure_api_input_item(item.to_input_item()) for item in items_to_convert
2809+
]
27182810

27192811
# Save all items from this turn
27202812
items_to_save = input_list + new_items_as_input

0 commit comments

Comments
 (0)