diff --git a/src/google/adk/flows/llm_flows/request_confirmation.py b/src/google/adk/flows/llm_flows/request_confirmation.py index f7b7f7f6dd..16cc80efea 100644 --- a/src/google/adk/flows/llm_flows/request_confirmation.py +++ b/src/google/adk/flows/llm_flows/request_confirmation.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING from google.genai import types +from pydantic import ValidationError from typing_extensions import override from . import functions @@ -81,13 +82,35 @@ async def run_async( # ADK client must send a resuming run request with a function response # that always encapsulate the confirmation result with a 'response' # key - tool_confirmation = ToolConfirmation.model_validate( - json.loads(function_response.response['response']) - ) + try: + tool_confirmation = ToolConfirmation.model_validate( + json.loads(function_response.response['response']) + ) + except ( + json.JSONDecodeError, + TypeError, + ValidationError, + ) as parse_err: + logger.warning( + 'Malformed tool confirmation payload for' + ' function_response_id=%s: %s', + function_response.id, + parse_err, + ) + tool_confirmation = ToolConfirmation(confirmed=False) else: - tool_confirmation = ToolConfirmation.model_validate( - function_response.response - ) + try: + tool_confirmation = ToolConfirmation.model_validate( + function_response.response + ) + except ValidationError as parse_err: + logger.warning( + 'Malformed tool confirmation payload for' + ' function_response_id=%s: %s', + function_response.id, + parse_err, + ) + tool_confirmation = ToolConfirmation(confirmed=False) request_confirmation_function_responses[function_response.id] = ( tool_confirmation ) diff --git a/tests/unittests/runners/test_run_tool_confirmation.py b/tests/unittests/runners/test_run_tool_confirmation.py index 08dfdd6fac..08da2719da 100644 --- a/tests/unittests/runners/test_run_tool_confirmation.py +++ b/tests/unittests/runners/test_run_tool_confirmation.py @@ -224,6 +224,37 @@ async def test_confirmation_flow( == expected_parts_final ) + @pytest.mark.asyncio + async def test_malformed_confirmation_payload_is_rejected_fail_closed( + self, + runner: testing_utils.InMemoryRunner, + agent: LlmAgent, + ): + """Malformed confirmation payloads must fail closed and reject tool calls.""" + initial_events = await runner.run_async(testing_utils.UserContent("test")) + ask_for_confirmation_function_call_id = ( + initial_events[1].content.parts[0].function_call.id + ) + + malformed_confirmation = testing_utils.UserContent( + Part( + function_response=FunctionResponse( + id=ask_for_confirmation_function_call_id, + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + response={"response": "{not-json"}, + ) + ) + ) + + events = await runner.run_async(malformed_confirmation) + simplified = testing_utils.simplify_events(copy.deepcopy(events)) + assert simplified[0][1] == Part( + function_response=FunctionResponse( + name=agent.tools[0].name, + response={"error": "This tool call is rejected."}, + ) + ) + class TestHITLConfirmationFlowWithCustomPayloadSchema(BaseHITLTest): """Tests the HITL confirmation flow with a single agent, for custom confirmation payload schema."""