Skip to content

Commit dd706bd

Browse files
wuliang229copybara-github
authored andcommitted
feat: Update conformance test CLI to handle long-running tool calls
Capture the function call ID from events and use it to populate the `function_response.id` in subsequent user messages, enabling recording and replaying of test cases involving long-running tools. Co-authored-by: Liang Wu <wuliang@google.com> PiperOrigin-RevId: 829042356
1 parent e511eb1 commit dd706bd

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

src/google/adk/cli/conformance/cli_record.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,35 @@ async def _create_conformance_test_files(
5252
)
5353

5454
# Run the agent with the user messages
55+
function_call_name_to_id_map = {}
5556
for user_message_index, user_message in enumerate(
5657
test_case.test_spec.user_messages
5758
):
5859
# Create content from UserMessage object
5960
if user_message.content is not None:
6061
content = user_message.content
62+
63+
# If the user provides a function response, it means this is for
64+
# long-running tool. Replace the function call ID with the actual
65+
# function call ID. This is needed because the function call ID is not
66+
# known when writing the test case.
67+
if (
68+
user_message.content.parts
69+
and user_message.content.parts[0].function_response
70+
and user_message.content.parts[0].function_response.name
71+
):
72+
if (
73+
user_message.content.parts[0].function_response.name
74+
not in function_call_name_to_id_map
75+
):
76+
raise ValueError(
77+
"Function response for"
78+
f" {user_message.content.parts[0].function_response.name} does"
79+
" not match any pending function call."
80+
)
81+
content.parts[0].function_response.id = function_call_name_to_id_map[
82+
user_message.content.parts[0].function_response.name
83+
]
6184
elif user_message.text is not None:
6285
content = types.UserContent(parts=[types.Part(text=user_message.text)])
6386
else:
@@ -66,7 +89,7 @@ async def _create_conformance_test_files(
6689
" content"
6790
)
6891

69-
async for _ in client.run_agent(
92+
async for event in client.run_agent(
7093
RunAgentRequest(
7194
app_name=test_case.test_spec.agent,
7295
user_id=user_id,
@@ -78,7 +101,12 @@ async def _create_conformance_test_files(
78101
test_case_dir=str(test_case_dir),
79102
user_message_index=user_message_index,
80103
):
81-
pass
104+
if event.content and event.content.parts:
105+
for part in event.content.parts:
106+
if part.function_call:
107+
function_call_name_to_id_map[part.function_call.name] = (
108+
part.function_call.id
109+
)
82110

83111
# Retrieve the updated session
84112
updated_session = await client.get_session(

src/google/adk/cli/conformance/cli_test.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,35 @@ async def _run_user_messages(
117117
self, session_id: str, test_case: TestCase
118118
) -> None:
119119
"""Run all user messages for a test case."""
120+
function_call_name_to_id_map = {}
120121
for user_message_index, user_message in enumerate(
121122
test_case.test_spec.user_messages
122123
):
123124
# Create content from UserMessage object
124125
if user_message.content is not None:
125126
content = user_message.content
127+
128+
# If the user provides a function response, it means this is for
129+
# long-running tool. Replace the function call ID with the actual
130+
# function call ID. This is needed because the function call ID is not
131+
# known when writing the test case.
132+
if (
133+
user_message.content.parts
134+
and user_message.content.parts[0].function_response
135+
and user_message.content.parts[0].function_response.name
136+
):
137+
if (
138+
user_message.content.parts[0].function_response.name
139+
not in function_call_name_to_id_map
140+
):
141+
raise ValueError(
142+
"Function response for"
143+
f" {user_message.content.parts[0].function_response.name} does"
144+
" not match any pending function call."
145+
)
146+
content.parts[0].function_response.id = function_call_name_to_id_map[
147+
user_message.content.parts[0].function_response.name
148+
]
126149
elif user_message.text is not None:
127150
content = types.UserContent(parts=[types.Part(text=user_message.text)])
128151
else:
@@ -141,13 +164,18 @@ async def _run_user_messages(
141164
)
142165

143166
# Run the agent but don't collect events here
144-
async for _ in self.client.run_agent(
167+
async for event in self.client.run_agent(
145168
request,
146169
mode="replay",
147170
test_case_dir=str(test_case.dir),
148171
user_message_index=user_message_index,
149172
):
150-
pass
173+
if event.content and event.content.parts:
174+
for part in event.content.parts:
175+
if part.function_call:
176+
function_call_name_to_id_map[part.function_call.name] = (
177+
part.function_call.id
178+
)
151179

152180
async def _validate_test_results(
153181
self, session_id: str, test_case: TestCase

0 commit comments

Comments
 (0)