diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index 0a8673dc91d..9c0506a2307 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -205,6 +205,16 @@ public int MaximumConsecutiveErrorsPerRequest set => _maximumConsecutiveErrorsPerRequest = Throw.IfLessThan(value, 0); } + /// Gets or sets a collection of additional tools the client is able to invoke. + /// + /// These will not impact the requests sent by the , which will pass through the + /// unmodified. However, if the inner client requests the invocation of a tool + /// that was not in , this collection will also be consulted + /// to look for a corresponding tool to invoke. This is useful when the service may have been pre-configured to be aware + /// of certain tools that aren't also sent on each individual request. + /// + public IList? AdditionalTools { get; set; } + /// Gets or sets a delegate used to invoke instances. /// /// By default, the protected method is called for each to be invoked, @@ -250,7 +260,7 @@ public override async Task GetResponseAsync( // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents. bool requiresFunctionInvocation = - options?.Tools is { Count: > 0 } && + (options?.Tools is { Count: > 0 } || AdditionalTools is { Count: > 0 }) && iteration < MaximumIterationsPerRequest && CopyFunctionCalls(response.Messages, ref functionCallContents); @@ -288,7 +298,7 @@ public override async Task GetResponseAsync( // Add the responses from the function calls into the augmented history and also into the tracked // list of response messages. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken); responseMessages.AddRange(modeAndMessages.MessagesAdded); consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; @@ -297,7 +307,7 @@ public override async Task GetResponseAsync( break; } - UpdateOptionsForNextIteration(ref options!, response.ConversationId); + UpdateOptionsForNextIteration(ref options, response.ConversationId); } Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages."); @@ -367,7 +377,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // If there are no tools to call, or for any other reason we should stop, return the response. if (functionCallContents is not { Count: > 0 } || - options?.Tools is not { Count: > 0 } || + (options?.Tools is not { Count: > 0 } && AdditionalTools is not { Count: > 0 }) || iteration >= _maximumIterationsPerRequest) { break; @@ -535,9 +545,16 @@ private static bool CopyFunctionCalls( return any; } - private static void UpdateOptionsForNextIteration(ref ChatOptions options, string? conversationId) + private static void UpdateOptionsForNextIteration(ref ChatOptions? options, string? conversationId) { - if (options.ToolMode is RequiredChatToolMode) + if (options is null) + { + if (conversationId is not null) + { + options = new() { ConversationId = conversationId }; + } + } + else if (options.ToolMode is RequiredChatToolMode) { // We have to reset the tool mode to be non-required after the first iteration, // as otherwise we'll be in an infinite loop. @@ -566,7 +583,7 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList MessagesAdded)> ProcessFunctionCallsAsync( - List messages, ChatOptions options, List functionCallContents, int iteration, int consecutiveErrorCount, + List messages, ChatOptions? options, List functionCallContents, int iteration, int consecutiveErrorCount, bool isStreaming, CancellationToken cancellationToken) { // We must add a response for every tool call, regardless of whether we successfully executed it or not. @@ -695,13 +712,13 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages) /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task ProcessFunctionCallAsync( - List messages, ChatOptions options, List callContents, + List messages, ChatOptions? options, List callContents, int iteration, int functionCallIndex, bool captureExceptions, bool isStreaming, CancellationToken cancellationToken) { var callContent = callContents[functionCallIndex]; // Look up the AIFunction for the function call. If the requested function isn't available, send back an error. - AIFunction? aiFunction = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name); + AIFunction? aiFunction = FindAIFunction(options?.Tools, callContent.Name) ?? FindAIFunction(AdditionalTools, callContent.Name); if (aiFunction is null) { return new(terminate: false, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null); @@ -746,6 +763,23 @@ private async Task ProcessFunctionCallAsync( callContent, result, exception: null); + + static AIFunction? FindAIFunction(IList? tools, string functionName) + { + if (tools is not null) + { + int count = tools.Count; + for (int i = 0; i < count; i++) + { + if (tools[i] is AIFunction function && function.Name == functionName) + { + return function; + } + } + } + + return null; + } } /// Creates one or more response messages for function invocation results. diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json index 59ed3d32fab..3e3f0426dd1 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json @@ -515,6 +515,10 @@ } ], "Properties": [ + { + "Member": "System.Collections.Generic.IList? Microsoft.Extensions.AI.FunctionInvokingChatClient.AdditionalTools { get; set; }", + "Stage": "Stable" + }, { "Member": "bool Microsoft.Extensions.AI.FunctionInvokingChatClient.AllowConcurrentInvocation { get; set; }", "Stage": "Stable" diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index b4ce2f1546c..08cb5ee5760 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -39,6 +39,7 @@ public void Ctor_HasExpectedDefaults() Assert.Equal(40, client.MaximumIterationsPerRequest); Assert.Equal(3, client.MaximumConsecutiveErrorsPerRequest); Assert.Null(client.FunctionInvoker); + Assert.Null(client.AdditionalTools); } [Fact] @@ -67,6 +68,11 @@ public void Properties_Roundtrip() Func> invoker = (ctx, ct) => new ValueTask("test"); client.FunctionInvoker = invoker; Assert.Same(invoker, client.FunctionInvoker); + + Assert.Null(client.AdditionalTools); + IList additionalTools = [AIFunctionFactory.Create(() => "Additional Tool")]; + client.AdditionalTools = additionalTools; + Assert.Same(additionalTools, client.AdditionalTools); } [Fact] @@ -99,6 +105,73 @@ public async Task SupportsSingleFunctionCallPerRequestAsync() await InvokeAndAssertStreamingAsync(options, plan); } + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task SupportsToolsProvidedByAdditionalTools(bool provideOptions) + { + ChatOptions? options = provideOptions ? + new() { Tools = [AIFunctionFactory.Create(() => "Shouldn't be invoked", "ChatOptionsFunc")] } : + null; + + Func configure = builder => + builder.UseFunctionInvocation(configure: c => c.AdditionalTools = + [ + AIFunctionFactory.Create(() => "Result 1", "Func1"), + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + AIFunctionFactory.Create((int i) => { }, "VoidReturn"), + ]); + + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; + + await InvokeAndAssertAsync(options, plan, configurePipeline: configure); + + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure); + } + + [Fact] + public async Task PrefersToolsProvidedByChatOptions() + { + ChatOptions options = new() + { + Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")] + }; + + Func configure = builder => + builder.UseFunctionInvocation(configure: c => c.AdditionalTools = + [ + AIFunctionFactory.Create(() => "Should never be invoked", "Func1"), + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + AIFunctionFactory.Create((int i) => { }, "VoidReturn"), + ]); + + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; + + await InvokeAndAssertAsync(options, plan, configurePipeline: configure); + + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure); + } + [Theory] [InlineData(false)] [InlineData(true)] @@ -1002,7 +1075,7 @@ public override void Post(SendOrPostCallback d, object? state) } private static async Task> InvokeAndAssertAsync( - ChatOptions options, + ChatOptions? options, List plan, List? expected = null, Func? configurePipeline = null, @@ -1102,7 +1175,7 @@ private static UsageDetails CreateRandomUsage() } private static async Task> InvokeAndAssertStreamingAsync( - ChatOptions options, + ChatOptions? options, List plan, List? expected = null, Func? configurePipeline = null, diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Microsoft.Extensions.AI.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Tests/Microsoft.Extensions.AI.Tests.csproj index e4f17abb179..c07f3056054 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Microsoft.Extensions.AI.Tests.csproj +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Microsoft.Extensions.AI.Tests.csproj @@ -5,7 +5,7 @@ - $(NoWarn);CA1063;CA1861;SA1130;VSTHRD003 + $(NoWarn);CA1063;CA1861;S104;SA1130;VSTHRD003 $(NoWarn);MEAI001 true