From ded60733d180385d2c8d9f17f648b53d0d914311 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 9 Oct 2025 12:21:05 +0100 Subject: [PATCH] Fix bug to yield remaining buffered FCC --- .../FunctionInvokingChatClient.cs | 10 +- ...unctionInvokingChatClientApprovalsTests.cs | 115 ++++++++++++++++-- 2 files changed, 116 insertions(+), 9 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index fdc5ef7a204..febf0a1336e 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -529,7 +529,7 @@ public override async IAsyncEnumerable GetStreamingResponseA .ToArray(); } - if (approvalRequiredFunctions is not { Length: > 0 }) + if (approvalRequiredFunctions is not { Length: > 0 } || functionCallContents is not { Count: > 0 }) { // If there are no function calls to make yet, or if none of the functions require approval at all, // we can yield the update as-is. @@ -574,6 +574,14 @@ public override async IAsyncEnumerable GetStreamingResponseA // or when we reach the end of the updates stream. } + // We need to yield any remaining updates that were not yielded while looping through the streamed updates. + for (; lastYieldedUpdateIndex < updates.Count; lastYieldedUpdateIndex++) + { + var updateToYield = updates[lastYieldedUpdateIndex]; + yield return updateToYield; + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + } + // If there's nothing more to do, break out of the loop and allow the handling at the // end to configure the response with aggregated data from previous requests. if (iteration >= MaximumIterationsPerRequest || diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientApprovalsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientApprovalsTests.cs index a0836a864c4..7c42c0edaf9 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientApprovalsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientApprovalsTests.cs @@ -485,6 +485,69 @@ public async Task AlreadyExecutedApprovalsAreIgnoredAsync() await InvokeAndAssertStreamingAsync(options, input, downstreamClientOutput, output, expectedDownstreamClientInput); } + /// + /// This verifies the following scenario: + /// 1. We are streaming (also including non-streaming in the test for completeness). + /// 2. There is one function that requires approval and one that does not. + /// 3. We only get back FCC for the function that does not require approval. + /// 4. This means that once we receive this FCC, we need to buffer all updates until the end, because we might receive more FCCs and some may require approval. + /// 5. We then need to verify that we will still stream all updates once we reach the end, including the buffered FCC. + /// + [Fact] + public async Task MixedApprovalRequiredToolsWithNonApprovalRequiringFunctionCallAsync() + { + var options = new ChatOptions + { + Tools = + [ + new ApprovalRequiredAIFunction(AIFunctionFactory.Create(() => "Result 1", "Func1")), + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + ] + }; + + List input = + [ + new ChatMessage(ChatRole.User, "hello"), + ]; + + Func>> expectedDownstreamClientInput = () => new Queue>( + [ + new List + { + new ChatMessage(ChatRole.User, "hello"), + }, + new List + { + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]) + } + ]); + + Func>> downstreamClientOutput = () => new Queue>( + [ + new List + { + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + }, + new List + { + new ChatMessage(ChatRole.Assistant, "World again"), + } + ]); + + List output = + [ + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, "World again"), + ]; + + await InvokeAndAssertMultiRoundAsync(options, input, downstreamClientOutput(), output, expectedDownstreamClientInput()); + + await InvokeAndAssertStreamingMultiRoundAsync(options, input, downstreamClientOutput(), output, expectedDownstreamClientInput()); + } + [Fact] public async Task ApprovalRequestWithoutApprovalResponseThrowsAsync() { @@ -781,7 +844,7 @@ async IAsyncEnumerable YieldInnerClientUpdates( } } - private static async Task> InvokeAndAssertAsync( + private static Task> InvokeAndAssertAsync( ChatOptions? options, List input, List downstreamClientOutput, @@ -789,6 +852,23 @@ private static async Task> InvokeAndAssertAsync( List? expectedDownstreamClientInput = null, Func? configurePipeline = null, AITool[]? additionalTools = null) + => InvokeAndAssertMultiRoundAsync( + options, + input, + new Queue>(new[] { downstreamClientOutput }), + expectedOutput, + expectedDownstreamClientInput is null ? null : new Queue>(new[] { expectedDownstreamClientInput }), + configurePipeline, + additionalTools); + + private static async Task> InvokeAndAssertMultiRoundAsync( + ChatOptions? options, + List input, + Queue> downstreamClientOutput, + List expectedOutput, + Queue>? expectedDownstreamClientInput = null, + Func? configurePipeline = null, + AITool[]? additionalTools = null) { Assert.NotEmpty(input); @@ -804,7 +884,7 @@ private static async Task> InvokeAndAssertAsync( Assert.Equal(cts.Token, actualCancellationToken); if (expectedDownstreamClientInput is not null) { - AssertExtensions.EqualMessageLists(expectedDownstreamClientInput, contents.ToList()); + AssertExtensions.EqualMessageLists(expectedDownstreamClientInput.Dequeue(), contents.ToList()); } await Task.Yield(); @@ -812,8 +892,9 @@ private static async Task> InvokeAndAssertAsync( var usage = CreateRandomUsage(); expectedTotalTokenCounts += usage.InputTokenCount!.Value; - downstreamClientOutput.ForEach(m => m.MessageId = Guid.NewGuid().ToString("N")); - return new ChatResponse(downstreamClientOutput) { Usage = usage }; + var output = downstreamClientOutput.Dequeue(); + output.ForEach(m => m.MessageId = Guid.NewGuid().ToString("N")); + return new ChatResponse(output) { Usage = usage }; } }; @@ -851,7 +932,7 @@ private static UsageDetails CreateRandomUsage() }; } - private static async Task> InvokeAndAssertStreamingAsync( + private static Task> InvokeAndAssertStreamingAsync( ChatOptions? options, List input, List downstreamClientOutput, @@ -859,6 +940,23 @@ private static async Task> InvokeAndAssertStreamingAsync( List? expectedDownstreamClientInput = null, Func? configurePipeline = null, AITool[]? additionalTools = null) + => InvokeAndAssertStreamingMultiRoundAsync( + options, + input, + new Queue>(new[] { downstreamClientOutput }), + expectedOutput, + expectedDownstreamClientInput is null ? null : new Queue>(new[] { expectedDownstreamClientInput }), + configurePipeline, + additionalTools); + + private static async Task> InvokeAndAssertStreamingMultiRoundAsync( + ChatOptions? options, + List input, + Queue> downstreamClientOutput, + List expectedOutput, + Queue>? expectedDownstreamClientInput = null, + Func? configurePipeline = null, + AITool[]? additionalTools = null) { Assert.NotEmpty(input); @@ -873,11 +971,12 @@ private static async Task> InvokeAndAssertStreamingAsync( Assert.Equal(cts.Token, actualCancellationToken); if (expectedDownstreamClientInput is not null) { - AssertExtensions.EqualMessageLists(expectedDownstreamClientInput, contents.ToList()); + AssertExtensions.EqualMessageLists(expectedDownstreamClientInput.Dequeue(), contents.ToList()); } - downstreamClientOutput.ForEach(m => m.MessageId = Guid.NewGuid().ToString("N")); - return YieldAsync(new ChatResponse(downstreamClientOutput).ToChatResponseUpdates()); + var output = downstreamClientOutput.Dequeue(); + output.ForEach(m => m.MessageId = Guid.NewGuid().ToString("N")); + return YieldAsync(new ChatResponse(output).ToChatResponseUpdates()); } };