diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 9bb730ba62..4e36cb764a 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -1405,13 +1405,17 @@ async def _stream_generator() -> Any: call_middleware = kwargs.pop("middleware", None) instance_middleware = getattr(self, "middleware", None) - # Merge middleware from both sources, filtering for chat middleware only - all_middleware: list[ChatMiddleware | ChatMiddlewareCallable] = _merge_and_filter_chat_middleware( - instance_middleware, call_middleware - ) + # Merge all middleware and separate by type + middleware = categorize_middleware(instance_middleware, call_middleware) + chat_middleware_list = middleware["chat"] + function_middleware_list = middleware["function"] + + # Pass function middleware to function invocation system if present + if function_middleware_list: + kwargs["_function_middleware_pipeline"] = FunctionMiddlewarePipeline(function_middleware_list) - # If no middleware, use original method - if not all_middleware: + # If no chat middleware, use original method + if not chat_middleware_list: async for update in original_get_streaming_response(self, messages, **kwargs): yield update return @@ -1422,7 +1426,7 @@ async def _stream_generator() -> Any: # Extract chat_options or create default chat_options = kwargs.pop("chat_options", ChatOptions()) - pipeline = ChatMiddlewarePipeline(all_middleware) # type: ignore[arg-type] + pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type] context = ChatContext( chat_client=self, messages=prepare_messages(messages), @@ -1536,27 +1540,40 @@ def _merge_and_filter_chat_middleware( return middleware["chat"] # type: ignore[return-value] -def extract_and_merge_function_middleware(chat_client: Any, **kwargs: Any) -> None: +def extract_and_merge_function_middleware( + chat_client: Any, kwargs: dict[str, Any] +) -> "FunctionMiddlewarePipeline | None": """Extract function middleware from chat client and merge with existing pipeline in kwargs. Args: chat_client: The chat client instance to extract middleware from. + kwargs: Dictionary containing middleware and pipeline information. - Keyword Args: - **kwargs: Dictionary containing middleware and pipeline information. + Returns: + A FunctionMiddlewarePipeline if function middleware is found, None otherwise. """ + # Check if a pipeline was already created by use_chat_middleware + existing_pipeline: FunctionMiddlewarePipeline | None = kwargs.get("_function_middleware_pipeline") + # Get middleware sources client_middleware = getattr(chat_client, "middleware", None) if hasattr(chat_client, "middleware") else None run_level_middleware = kwargs.get("middleware") - existing_pipeline = kwargs.get("_function_middleware_pipeline") - # Extract existing pipeline middlewares if present - existing_middlewares = existing_pipeline._middlewares if existing_pipeline else None + # If we have an existing pipeline but no additional middleware sources, return it directly + if existing_pipeline and not client_middleware and not run_level_middleware: + return existing_pipeline + + # If we have an existing pipeline with additional middleware, we need to merge + # Extract existing pipeline middlewares if present - cast to list[Middleware] for type compatibility + existing_middlewares: list[Middleware] | None = list(existing_pipeline._middlewares) if existing_pipeline else None # Create combined pipeline from all sources using existing helper combined_pipeline = create_function_middleware_pipeline( client_middleware, run_level_middleware, existing_middlewares ) - if combined_pipeline: - kwargs["_function_middleware_pipeline"] = combined_pipeline + # If we have an existing pipeline but combined is None (no new middlewares), return existing + if existing_pipeline and combined_pipeline is None: + return existing_pipeline + + return combined_pipeline diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index bc16d9edb9..953553142d 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -1348,6 +1348,35 @@ def __init__( self.include_detailed_errors = include_detailed_errors +class FunctionExecutionResult: + """Internal wrapper pairing function output with loop control signals. + + Function execution produces two distinct concerns: the semantic result (returned to + the LLM as FunctionResultContent) and control flow decisions (whether middleware + requested early termination). This wrapper keeps control signals out of user-facing + content types while allowing _try_execute_function_calls to communicate both. + + Not exposed to users. + + Attributes: + content: The FunctionResultContent or other content from the function execution. + terminate: If True, the function invocation loop should exit immediately without + another LLM call. Set when middleware sets context.terminate=True. + """ + + __slots__ = ("content", "terminate") + + def __init__(self, content: "Contents", terminate: bool = False) -> None: + """Initialize FunctionExecutionResult. + + Args: + content: The content from the function execution. + terminate: Whether to terminate the function calling loop. + """ + self.content = content + self.terminate = terminate + + async def _auto_invoke_function( function_call_content: "FunctionCallContent | FunctionApprovalResponseContent", custom_args: dict[str, Any] | None = None, @@ -1357,7 +1386,7 @@ async def _auto_invoke_function( sequence_index: int | None = None, request_index: int | None = None, middleware_pipeline: Any = None, # Optional MiddlewarePipeline -) -> "Contents": +) -> "FunctionExecutionResult | Contents": """Invoke a function call requested by the agent, applying middleware that is defined. Args: @@ -1372,7 +1401,8 @@ async def _auto_invoke_function( middleware_pipeline: Optional middleware pipeline to apply during execution. Returns: - A FunctionResultContent containing the result or exception. + A FunctionExecutionResult wrapping the content and terminate signal, + or a Contents object for approval/hosted tool scenarios. Raises: KeyError: If the requested function is not found in the tool map. @@ -1392,10 +1422,12 @@ async def _auto_invoke_function( # Tool should exist because _try_execute_function_calls validates this if tool is None: exc = KeyError(f'Function "{function_call_content.name}" not found.') - return FunctionResultContent( - call_id=function_call_content.call_id, - result=f'Error: Requested function "{function_call_content.name}" not found.', - exception=exc, + return FunctionExecutionResult( + content=FunctionResultContent( + call_id=function_call_content.call_id, + result=f'Error: Requested function "{function_call_content.name}" not found.', + exception=exc, + ) ) else: # Note: Unapproved tools (approved=False) are handled in _replace_approval_contents_with_results @@ -1420,7 +1452,9 @@ async def _auto_invoke_function( message = "Error: Argument parsing failed." if config.include_detailed_errors: message = f"{message} Exception: {exc}" - return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + return FunctionExecutionResult( + content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + ) if not middleware_pipeline or ( not hasattr(middleware_pipeline, "has_middlewares") and not middleware_pipeline.has_middlewares @@ -1432,15 +1466,19 @@ async def _auto_invoke_function( tool_call_id=function_call_content.call_id, **runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {}, ) - return FunctionResultContent( - call_id=function_call_content.call_id, - result=function_result, + return FunctionExecutionResult( + content=FunctionResultContent( + call_id=function_call_content.call_id, + result=function_result, + ) ) except Exception as exc: message = "Error: Function failed." if config.include_detailed_errors: message = f"{message} Exception: {exc}" - return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + return FunctionExecutionResult( + content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + ) # Execute through middleware pipeline if available from ._middleware import FunctionInvocationContext @@ -1464,15 +1502,20 @@ async def final_function_handler(context_obj: Any) -> Any: context=middleware_context, final_handler=final_function_handler, ) - return FunctionResultContent( - call_id=function_call_content.call_id, - result=function_result, + return FunctionExecutionResult( + content=FunctionResultContent( + call_id=function_call_content.call_id, + result=function_result, + ), + terminate=middleware_context.terminate, ) except Exception as exc: message = "Error: Function failed." if config.include_detailed_errors: message = f"{message} Exception: {exc}" - return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + return FunctionExecutionResult( + content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + ) def _get_tool_map( @@ -1503,7 +1546,7 @@ async def _try_execute_function_calls( | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]", config: FunctionInvocationConfiguration, middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports -) -> Sequence["Contents"]: +) -> tuple[Sequence["Contents"], bool]: """Execute multiple function calls concurrently. Args: @@ -1515,9 +1558,11 @@ async def _try_execute_function_calls( middleware_pipeline: Optional middleware pipeline to apply during execution. Returns: - A list of Contents containing the results of each function call, - or the approval requests if any function requires approval, - or the original function calls if any are declaration only. + A tuple of: + - A list of Contents containing the results of each function call, + or the approval requests if any function requires approval, + or the original function calls if any are declaration only. + - A boolean indicating whether to terminate the function calling loop. """ from ._types import FunctionApprovalRequestContent, FunctionCallContent @@ -1540,17 +1585,20 @@ async def _try_execute_function_calls( raise KeyError(f'Error: Requested function "{fcc.name}" not found.') if approval_needed: # approval can only be needed for Function Call Contents, not Approval Responses. - return [ - FunctionApprovalRequestContent(id=fcc.call_id, function_call=fcc) - for fcc in function_calls - if isinstance(fcc, FunctionCallContent) - ] + return ( + [ + FunctionApprovalRequestContent(id=fcc.call_id, function_call=fcc) + for fcc in function_calls + if isinstance(fcc, FunctionCallContent) + ], + False, + ) if declaration_only_flag: # return the declaration only tools to the user, since we cannot execute them. - return [fcc for fcc in function_calls if isinstance(fcc, FunctionCallContent)] + return ([fcc for fcc in function_calls if isinstance(fcc, FunctionCallContent)], False) # Run all function calls concurrently - return await asyncio.gather(*[ + execution_results = await asyncio.gather(*[ _auto_invoke_function( function_call_content=function_call, # type: ignore[arg-type] custom_args=custom_args, @@ -1563,6 +1611,20 @@ async def _try_execute_function_calls( for seq_idx, function_call in enumerate(function_calls) ]) + # Unpack FunctionExecutionResult wrappers and check for terminate signal + contents: list[Contents] = [] + should_terminate = False + for result in execution_results: + if isinstance(result, FunctionExecutionResult): + contents.append(result.content) + if result.terminate: + should_terminate = True + else: + # Direct Contents (e.g., from hosted tools) + contents.append(result) + + return (contents, should_terminate) + def _update_conversation_id(kwargs: dict[str, Any], conversation_id: str | None) -> None: """Update kwargs with conversation id. @@ -1695,12 +1757,8 @@ async def function_invocation_wrapper( prepare_messages, ) - # Extract and merge function middleware from chat client with kwargs pipeline - extract_and_merge_function_middleware(self, **kwargs) - - # Extract the middleware pipeline before calling the underlying function - # because the underlying function may not preserve it in kwargs - stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline") + # Extract and merge function middleware from chat client with kwargs + stored_middleware_pipeline = extract_and_merge_function_middleware(self, kwargs) # Get the config for function invocation (not part of ChatClientProtocol, hence getattr) config: FunctionInvocationConfiguration | None = getattr(self, "function_invocation_configuration", None) @@ -1726,7 +1784,7 @@ async def function_invocation_wrapper( approved_responses = [resp for resp in fcc_todo.values() if resp.approved] approved_function_results: list[Contents] = [] if approved_responses: - approved_function_results = await _try_execute_function_calls( + results, _ = await _try_execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=approved_responses, @@ -1734,6 +1792,7 @@ async def function_invocation_wrapper( middleware_pipeline=stored_middleware_pipeline, config=config, ) + approved_function_results = list(results) if any( fcr.exception is not None for fcr in approved_function_results @@ -1773,7 +1832,7 @@ async def function_invocation_wrapper( if function_calls and tools: # Use the stored middleware pipeline instead of extracting from kwargs # because kwargs may have been modified by the underlying function - function_call_results: list[Contents] = await _try_execute_function_calls( + function_call_results, should_terminate = await _try_execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=function_calls, @@ -1798,6 +1857,17 @@ async def function_invocation_wrapper( # the function calls are already in the response, so we just continue return response + # Check if middleware signaled to terminate the loop (context.terminate=True) + # This allows middleware to short-circuit the tool loop without another LLM call + if should_terminate: + # Add tool results to response and return immediately without calling LLM again + result_message = ChatMessage(role="tool", contents=function_call_results) + response.messages.append(result_message) + if fcc_messages: + for msg in reversed(fcc_messages): + response.messages.insert(0, msg) + return response + if any( fcr.exception is not None for fcr in function_call_results @@ -1890,12 +1960,8 @@ async def streaming_function_invocation_wrapper( prepare_messages, ) - # Extract and merge function middleware from chat client with kwargs pipeline - extract_and_merge_function_middleware(self, **kwargs) - - # Extract the middleware pipeline before calling the underlying function - # because the underlying function may not preserve it in kwargs - stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline") + # Extract and merge function middleware from chat client with kwargs + stored_middleware_pipeline = extract_and_merge_function_middleware(self, kwargs) # Get the config for function invocation (not part of ChatClientProtocol, hence getattr) config: FunctionInvocationConfiguration | None = getattr(self, "function_invocation_configuration", None) @@ -1914,7 +1980,7 @@ async def streaming_function_invocation_wrapper( approved_responses = [resp for resp in fcc_todo.values() if resp.approved] approved_function_results: list[Contents] = [] if approved_responses: - approved_function_results = await _try_execute_function_calls( + results, _ = await _try_execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=approved_responses, @@ -1922,6 +1988,7 @@ async def streaming_function_invocation_wrapper( middleware_pipeline=stored_middleware_pipeline, config=config, ) + approved_function_results = list(results) if any( fcr.exception is not None for fcr in approved_function_results @@ -1976,7 +2043,7 @@ async def streaming_function_invocation_wrapper( if function_calls and tools: # Use the stored middleware pipeline instead of extracting from kwargs # because kwargs may have been modified by the underlying function - function_call_results: list[Contents] = await _try_execute_function_calls( + function_call_results, should_terminate = await _try_execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=function_calls, @@ -2005,6 +2072,13 @@ async def streaming_function_invocation_wrapper( # the function calls were already yielded. return + # Check if middleware signaled to terminate the loop (context.terminate=True) + # This allows middleware to short-circuit the tool loop without another LLM call + if should_terminate: + # Yield tool results and return immediately without calling LLM again + yield ChatResponseUpdate(contents=function_call_results, role="tool") + return + if any( fcr.exception is not None for fcr in function_call_results diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 5a0ec5a773..bc96ddcc35 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +from collections.abc import Awaitable, Callable + import pytest from agent_framework import ( @@ -16,6 +18,7 @@ TextContent, ai_function, ) +from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware async def test_base_client_with_function_calling(chat_client_base: ChatClientProtocol): @@ -2206,3 +2209,175 @@ def sometimes_fails(arg1: str) -> str: assert len(error_results) >= 1 assert len(success_results) >= 1 assert call_count == 2 # Both calls executed + + +class TerminateLoopMiddleware(FunctionMiddleware): + """Middleware that sets terminate=True to exit the function calling loop.""" + + async def process( + self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + # Set result to a simple value - the framework will wrap it in FunctionResultContent + context.result = "terminated by middleware" + context.terminate = True + + +async def test_terminate_loop_single_function_call(chat_client_base: ChatClientProtocol): + """Test that terminate_loop=True exits the function calling loop after single function call.""" + exec_counter = 0 + + @ai_function(name="test_function") + def ai_func(arg1: str) -> str: + nonlocal exec_counter + exec_counter += 1 + return f"Processed {arg1}" + + # Queue up two responses: function call, then final text + # If terminate_loop works, only the first response should be consumed + chat_client_base.run_responses = [ + ChatResponse( + messages=ChatMessage( + role="assistant", + contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + ) + ), + ChatResponse(messages=ChatMessage(role="assistant", text="done")), + ] + + response = await chat_client_base.get_response( + "hello", + tool_choice="auto", + tools=[ai_func], + middleware=[TerminateLoopMiddleware()], + ) + + # Function should NOT have been executed - middleware intercepted it + assert exec_counter == 0 + + # There should be 2 messages: assistant with function call, tool result from middleware + # The loop should NOT have continued to call the LLM again + assert len(response.messages) == 2 + assert response.messages[0].role == Role.ASSISTANT + assert isinstance(response.messages[0].contents[0], FunctionCallContent) + assert response.messages[1].role == Role.TOOL + assert isinstance(response.messages[1].contents[0], FunctionResultContent) + assert response.messages[1].contents[0].result == "terminated by middleware" + + # Verify the second response is still in the queue (wasn't consumed) + assert len(chat_client_base.run_responses) == 1 + + +class SelectiveTerminateMiddleware(FunctionMiddleware): + """Only terminates for terminating_function.""" + + async def process( + self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + if context.function.name == "terminating_function": + # Set result to a simple value - the framework will wrap it in FunctionResultContent + context.result = "terminated by middleware" + context.terminate = True + else: + await next_handler(context) + + +async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client_base: ChatClientProtocol): + """Test that any(terminate_loop=True) exits loop even with multiple function calls.""" + normal_call_count = 0 + terminating_call_count = 0 + + @ai_function(name="normal_function") + def normal_func(arg1: str) -> str: + nonlocal normal_call_count + normal_call_count += 1 + return f"Normal {arg1}" + + @ai_function(name="terminating_function") + def terminating_func(arg1: str) -> str: + nonlocal terminating_call_count + terminating_call_count += 1 + return f"Terminating {arg1}" + + # Queue up two responses: parallel function calls, then final text + chat_client_base.run_responses = [ + ChatResponse( + messages=ChatMessage( + role="assistant", + contents=[ + FunctionCallContent(call_id="1", name="normal_function", arguments='{"arg1": "value1"}'), + FunctionCallContent(call_id="2", name="terminating_function", arguments='{"arg1": "value2"}'), + ], + ) + ), + ChatResponse(messages=ChatMessage(role="assistant", text="done")), + ] + + response = await chat_client_base.get_response( + "hello", + tool_choice="auto", + tools=[normal_func, terminating_func], + middleware=[SelectiveTerminateMiddleware()], + ) + + # normal_function should have executed (middleware calls next_handler) + # terminating_function should NOT have executed (middleware intercepts it) + assert normal_call_count == 1 + assert terminating_call_count == 0 + + # There should be 2 messages: assistant with function calls, tool results + # The loop should NOT have continued to call the LLM again + assert len(response.messages) == 2 + assert response.messages[0].role == Role.ASSISTANT + assert len(response.messages[0].contents) == 2 + assert response.messages[1].role == Role.TOOL + # Both function results should be present + assert len(response.messages[1].contents) == 2 + + # Verify the second response is still in the queue (wasn't consumed) + assert len(chat_client_base.run_responses) == 1 + + +async def test_terminate_loop_streaming_single_function_call(chat_client_base: ChatClientProtocol): + """Test that terminate_loop=True exits the streaming function calling loop.""" + exec_counter = 0 + + @ai_function(name="test_function") + def ai_func(arg1: str) -> str: + nonlocal exec_counter + exec_counter += 1 + return f"Processed {arg1}" + + # Queue up two streaming responses + chat_client_base.streaming_responses = [ + [ + ChatResponseUpdate( + contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + role="assistant", + ), + ], + [ + ChatResponseUpdate( + contents=[TextContent(text="done")], + role="assistant", + ) + ], + ] + + updates = [] + async for update in chat_client_base.get_streaming_response( + "hello", + tool_choice="auto", + tools=[ai_func], + middleware=[TerminateLoopMiddleware()], + ): + updates.append(update) + + # Function should NOT have been executed - middleware intercepted it + assert exec_counter == 0 + + # Should have function call update and function result update + # The loop should NOT have continued to call the LLM again + assert len(updates) == 2 + + # Verify the second streaming response is still in the queue (wasn't consumed) + assert len(chat_client_base.streaming_responses) == 1 diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 7b280da123..6cb41f674b 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -193,7 +193,8 @@ async def process( # Create a message to start the conversation messages = [ChatMessage(role=Role.USER, text="test message")] - # Set up chat client to return a function call + # Set up chat client to return a function call, then a final response + # If terminate works correctly, only the first response should be consumed chat_client.responses = [ ChatResponse( messages=[ @@ -204,7 +205,8 @@ async def process( ], ) ] - ) + ), + ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="this should not be consumed")]), ] # Create the test function with the expected signature @@ -222,7 +224,11 @@ def test_function(text: str) -> str: # Verify that function was not called and only middleware executed assert execution_order == ["middleware_before", "middleware_after"] assert "function_called" not in execution_order - assert execution_order == ["middleware_before", "middleware_after"] + + # Verify the chat client was only called once (no extra LLM call after termination) + assert chat_client.call_count == 1 + # Verify the second response is still in the queue (wasn't consumed) + assert len(chat_client.responses) == 1 async def test_function_middleware_with_post_termination(self, chat_client: "MockChatClient") -> None: """Test that function middleware can terminate execution after calling next().""" @@ -242,7 +248,8 @@ async def process( # Create a message to start the conversation messages = [ChatMessage(role=Role.USER, text="test message")] - # Set up chat client to return a function call + # Set up chat client to return a function call, then a final response + # If terminate works correctly, only the first response should be consumed chat_client.responses = [ ChatResponse( messages=[ @@ -253,7 +260,8 @@ async def process( ], ) ] - ) + ), + ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="this should not be consumed")]), ] # Create the test function with the expected signature @@ -273,6 +281,11 @@ def test_function(text: str) -> str: assert "function_called" in execution_order assert execution_order == ["middleware_before", "function_called", "middleware_after"] + # Verify the chat client was only called once (no extra LLM call after termination) + assert chat_client.call_count == 1 + # Verify the second response is still in the queue (wasn't consumed) + assert len(chat_client.responses) == 1 + async def test_function_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None: """Test function-based agent middleware with ChatAgent.""" execution_order: list[str] = []