diff --git a/eng/packages/General.props b/eng/packages/General.props
index b66f1a4ffa8..441a30afa73 100644
--- a/eng/packages/General.props
+++ b/eng/packages/General.props
@@ -29,6 +29,7 @@
+
diff --git a/eng/packages/TestOnly.props b/eng/packages/TestOnly.props
index 5875491f919..603e31c0e5b 100644
--- a/eng/packages/TestOnly.props
+++ b/eng/packages/TestOnly.props
@@ -21,7 +21,6 @@
-
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs
new file mode 100644
index 00000000000..029eeae47a1
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs
@@ -0,0 +1,41 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.Extensions.AI;
+
+///
+/// Represents a strategy capable of selecting a reduced set of tools for a chat request.
+///
+///
+/// A tool reduction strategy is invoked prior to sending a request to an underlying ,
+/// enabling scenarios where a large tool catalog must be trimmed to fit provider limits or to improve model
+/// tool selection quality.
+///
+/// The implementation should return a non- enumerable. Returning the original
+/// instance indicates no change. Returning a different enumerable indicates
+/// the caller may replace the existing tool list.
+///
+///
+[Experimental("MEAI001")]
+public interface IToolReductionStrategy
+{
+ ///
+ /// Selects the tools that should be included for a specific request.
+ ///
+ /// The chat messages for the request. This is an to avoid premature materialization.
+ /// The chat options for the request (may be ).
+ /// A token to observe cancellation.
+ ///
+ /// A (possibly reduced) enumerable of instances. Must never be .
+ /// Returning the same instance referenced by . signals no change.
+ ///
+ Task> SelectToolsForRequestAsync(
+ IEnumerable messages,
+ ChatOptions? options,
+ CancellationToken cancellationToken = default);
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj
index 36e6bb00562..54cbcc99754 100644
--- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj
+++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj
@@ -44,6 +44,7 @@
+
diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs
new file mode 100644
index 00000000000..5a644267328
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs
@@ -0,0 +1,32 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Diagnostics.CodeAnalysis;
+using Microsoft.Shared.Diagnostics;
+
+namespace Microsoft.Extensions.AI;
+
+/// Extension methods for adding tool reduction middleware to a chat client pipeline.
+[Experimental("MEAI001")]
+public static class ChatClientBuilderToolReductionExtensions
+{
+ ///
+ /// Adds tool reduction to the chat client pipeline using the specified .
+ ///
+ /// The chat client builder.
+ /// The reduction strategy.
+ /// The original builder for chaining.
+ /// If or is .
+ ///
+ /// This should typically appear in the pipeline before function invocation middleware so that only the reduced tools
+ /// are exposed to the underlying provider.
+ ///
+ public static ChatClientBuilder UseToolReduction(this ChatClientBuilder builder, IToolReductionStrategy strategy)
+ {
+ _ = Throw.IfNull(builder);
+ _ = Throw.IfNull(strategy);
+
+ return builder.Use(inner => new ToolReducingChatClient(inner, strategy));
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs
new file mode 100644
index 00000000000..f9e4c60995a
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs
@@ -0,0 +1,330 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Numerics.Tensors;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Shared.Diagnostics;
+
+namespace Microsoft.Extensions.AI;
+
+#pragma warning disable IDE0032 // Use auto property, suppressed until repo updates to C# 14
+
+///
+/// A tool reduction strategy that ranks tools by embedding similarity to the current conversation context.
+///
+///
+/// The strategy embeds each tool (name + description by default) once (cached) and embeds the current
+/// conversation content each request. It then selects the top toolLimit tools by similarity.
+///
+[Experimental("MEAI001")]
+public sealed class EmbeddingToolReductionStrategy : IToolReductionStrategy
+{
+ private readonly ConditionalWeakTable> _toolEmbeddingsCache = new();
+ private readonly IEmbeddingGenerator> _embeddingGenerator;
+ private readonly int _toolLimit;
+
+ private Func _toolEmbeddingTextSelector = static t =>
+ {
+ if (string.IsNullOrWhiteSpace(t.Name))
+ {
+ return t.Description;
+ }
+
+ if (string.IsNullOrWhiteSpace(t.Description))
+ {
+ return t.Name;
+ }
+
+ return t.Name + Environment.NewLine + t.Description;
+ };
+
+ private Func, ValueTask> _messagesEmbeddingTextSelector = static messages =>
+ {
+ var sb = new StringBuilder();
+ foreach (var message in messages)
+ {
+ var contents = message.Contents;
+ for (var i = 0; i < contents.Count; i++)
+ {
+ string text;
+ switch (contents[i])
+ {
+ case TextContent content:
+ text = content.Text;
+ break;
+ case TextReasoningContent content:
+ text = content.Text;
+ break;
+ default:
+ continue;
+ }
+
+ _ = sb.AppendLine(text);
+ }
+ }
+
+ return new ValueTask(sb.ToString());
+ };
+
+ private Func, ReadOnlyMemory, float> _similarity = static (a, b) => TensorPrimitives.CosineSimilarity(a.Span, b.Span);
+
+ private Func _isRequiredTool = static _ => false;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// Embedding generator used to produce embeddings.
+ /// Maximum number of tools to return, excluding required tools. Must be greater than zero.
+ public EmbeddingToolReductionStrategy(
+ IEmbeddingGenerator> embeddingGenerator,
+ int toolLimit)
+ {
+ _embeddingGenerator = Throw.IfNull(embeddingGenerator);
+ _toolLimit = Throw.IfLessThanOrEqual(toolLimit, min: 0);
+ }
+
+ ///
+ /// Gets or sets the selector used to generate a single text string from a tool.
+ ///
+ ///
+ /// Defaults to: Name + "\n" + Description (omitting empty parts).
+ ///
+ public Func ToolEmbeddingTextSelector
+ {
+ get => _toolEmbeddingTextSelector;
+ set => _toolEmbeddingTextSelector = Throw.IfNull(value);
+ }
+
+ ///
+ /// Gets or sets the selector used to generate a single text string from a collection of chat messages for
+ /// embedding purposes.
+ ///
+ public Func, ValueTask> MessagesEmbeddingTextSelector
+ {
+ get => _messagesEmbeddingTextSelector;
+ set => _messagesEmbeddingTextSelector = Throw.IfNull(value);
+ }
+
+ ///
+ /// Gets or sets a similarity function applied to (query, tool) embedding vectors.
+ ///
+ ///
+ /// Defaults to cosine similarity.
+ ///
+ public Func, ReadOnlyMemory, float> Similarity
+ {
+ get => _similarity;
+ set => _similarity = Throw.IfNull(value);
+ }
+
+ ///
+ /// Gets or sets a function that determines whether a tool is required (always included).
+ ///
+ ///
+ /// If this returns , the tool is included regardless of ranking and does not count against
+ /// the configured non-required tool limit. A tool explicitly named by (when
+ /// is non-null) is also treated as required, independent
+ /// of this delegate's result.
+ ///
+ public Func IsRequiredTool
+ {
+ get => _isRequiredTool;
+ set => _isRequiredTool = Throw.IfNull(value);
+ }
+
+ ///
+ /// Gets or sets a value indicating whether to preserve original ordering of selected tools.
+ /// If (default), tools are ordered by descending similarity.
+ /// If , the top-N tools by similarity are re-emitted in their original order.
+ ///
+ public bool PreserveOriginalOrdering { get; set; }
+
+ ///
+ public async Task> SelectToolsForRequestAsync(
+ IEnumerable messages,
+ ChatOptions? options,
+ CancellationToken cancellationToken = default)
+ {
+ _ = Throw.IfNull(messages);
+
+ if (options?.Tools is not { Count: > 0 } tools)
+ {
+ // Prefer the original tools list reference if possible.
+ // This allows ToolReducingChatClient to avoid unnecessarily copying ChatOptions.
+ // When no reduction is performed.
+ return options?.Tools ?? [];
+ }
+
+ Debug.Assert(_toolLimit > 0, "Expected the tool count limit to be greater than zero.");
+
+ if (tools.Count <= _toolLimit)
+ {
+ // Since the total number of tools doesn't exceed the configured tool limit,
+ // there's no need to determine which tools are optional, i.e., subject to reduction.
+ // We can return the original tools list early.
+ return tools;
+ }
+
+ var toolRankingInfoArray = ArrayPool.Shared.Rent(tools.Count);
+ try
+ {
+ var toolRankingInfoMemory = toolRankingInfoArray.AsMemory(start: 0, length: tools.Count);
+
+ // We allocate tool rankings in a contiguous chunk of memory, but partition them such that
+ // required tools come first and are immediately followed by optional tools.
+ // This allows us to separately rank optional tools by similarity score, but then later re-order
+ // the top N tools (including required tools) to preserve their original relative order.
+ var (requiredTools, optionalTools) = PartitionToolRankings(toolRankingInfoMemory, tools, options.ToolMode);
+
+ if (optionalTools.Length <= _toolLimit)
+ {
+ // There aren't enough optional tools to require reduction, so we'll return the original
+ // tools list.
+ return tools;
+ }
+
+ // Build query text from recent messages.
+ var queryText = await MessagesEmbeddingTextSelector(messages).ConfigureAwait(false);
+ if (string.IsNullOrWhiteSpace(queryText))
+ {
+ // We couldn't build a meaningful query, likely because the message list was empty.
+ // We'll just return the original tools list.
+ return tools;
+ }
+
+ var queryEmbedding = await _embeddingGenerator.GenerateAsync(queryText, cancellationToken: cancellationToken).ConfigureAwait(false);
+
+ // Compute and populate similarity scores in the tool ranking info.
+ await ComputeSimilarityScoresAsync(optionalTools, queryEmbedding, cancellationToken);
+
+ var topTools = toolRankingInfoMemory.Slice(start: 0, length: requiredTools.Length + _toolLimit);
+#if NET
+ optionalTools.Span.Sort(AIToolRankingInfo.CompareByDescendingSimilarityScore);
+ if (PreserveOriginalOrdering)
+ {
+ topTools.Span.Sort(AIToolRankingInfo.CompareByOriginalIndex);
+ }
+#else
+ Array.Sort(toolRankingInfoArray, index: requiredTools.Length, length: optionalTools.Length, AIToolRankingInfo.CompareByDescendingSimilarityScore);
+ if (PreserveOriginalOrdering)
+ {
+ Array.Sort(toolRankingInfoArray, index: 0, length: topTools.Length, AIToolRankingInfo.CompareByOriginalIndex);
+ }
+#endif
+ return ToToolList(topTools.Span);
+
+ static List ToToolList(ReadOnlySpan toolInfo)
+ {
+ var result = new List(capacity: toolInfo.Length);
+ foreach (var info in toolInfo)
+ {
+ result.Add(info.Tool);
+ }
+
+ return result;
+ }
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(toolRankingInfoArray);
+ }
+ }
+
+ private (Memory RequiredTools, Memory OptionalTools) PartitionToolRankings(
+ Memory toolRankingInfo, IList tools, ChatToolMode? toolMode)
+ {
+ // Always include a tool if its name matches the required function name.
+ var requiredFunctionName = (toolMode as RequiredChatToolMode)?.RequiredFunctionName;
+ var nextRequiredToolIndex = 0;
+ var nextOptionalToolIndex = tools.Count - 1;
+ for (var i = 0; i < toolRankingInfo.Length; i++)
+ {
+ var tool = tools[i];
+ var isRequiredByToolMode = requiredFunctionName is not null && string.Equals(requiredFunctionName, tool.Name, StringComparison.Ordinal);
+ var toolIndex = isRequiredByToolMode || IsRequiredTool(tool)
+ ? nextRequiredToolIndex++
+ : nextOptionalToolIndex--;
+ toolRankingInfo.Span[toolIndex] = new AIToolRankingInfo(tool, originalIndex: i);
+ }
+
+ return (
+ RequiredTools: toolRankingInfo.Slice(0, nextRequiredToolIndex),
+ OptionalTools: toolRankingInfo.Slice(nextRequiredToolIndex));
+ }
+
+ private async Task ComputeSimilarityScoresAsync(Memory toolInfo, Embedding queryEmbedding, CancellationToken cancellationToken)
+ {
+ var anyCacheMisses = false;
+ List cacheMissToolEmbeddingTexts = null!;
+ List cacheMissToolInfoIndexes = null!;
+ for (var i = 0; i < toolInfo.Length; i++)
+ {
+ ref var info = ref toolInfo.Span[i];
+ if (_toolEmbeddingsCache.TryGetValue(info.Tool, out var toolEmbedding))
+ {
+ info.SimilarityScore = Similarity(queryEmbedding.Vector, toolEmbedding.Vector);
+ }
+ else
+ {
+ if (!anyCacheMisses)
+ {
+ anyCacheMisses = true;
+ cacheMissToolEmbeddingTexts = [];
+ cacheMissToolInfoIndexes = [];
+ }
+
+ var text = ToolEmbeddingTextSelector(info.Tool);
+ cacheMissToolEmbeddingTexts.Add(text);
+ cacheMissToolInfoIndexes.Add(i);
+ }
+ }
+
+ if (!anyCacheMisses)
+ {
+ // There were no cache misses; no more work to do.
+ return;
+ }
+
+ var uncachedEmbeddings = await _embeddingGenerator.GenerateAsync(cacheMissToolEmbeddingTexts, cancellationToken: cancellationToken).ConfigureAwait(false);
+ if (uncachedEmbeddings.Count != cacheMissToolEmbeddingTexts.Count)
+ {
+ throw new InvalidOperationException($"Expected {cacheMissToolEmbeddingTexts.Count} embeddings, got {uncachedEmbeddings.Count}.");
+ }
+
+ for (var i = 0; i < uncachedEmbeddings.Count; i++)
+ {
+ var toolInfoIndex = cacheMissToolInfoIndexes[i];
+ var toolEmbedding = uncachedEmbeddings[i];
+ ref var info = ref toolInfo.Span[toolInfoIndex];
+ info.SimilarityScore = Similarity(queryEmbedding.Vector, toolEmbedding.Vector);
+ _toolEmbeddingsCache.Add(info.Tool, toolEmbedding);
+ }
+ }
+
+ private struct AIToolRankingInfo(AITool tool, int originalIndex)
+ {
+ public static readonly Comparer CompareByDescendingSimilarityScore
+ = Comparer.Create(static (a, b) =>
+ {
+ var result = b.SimilarityScore.CompareTo(a.SimilarityScore);
+ return result != 0
+ ? result
+ : a.OriginalIndex.CompareTo(b.OriginalIndex); // Stabilize ties.
+ });
+
+ public static readonly Comparer CompareByOriginalIndex
+ = Comparer.Create(static (a, b) => a.OriginalIndex.CompareTo(b.OriginalIndex));
+
+ public AITool Tool { get; } = tool;
+ public int OriginalIndex { get; } = originalIndex;
+ public float SimilarityScore { get; set; }
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs
new file mode 100644
index 00000000000..6a5d6d925fc
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs
@@ -0,0 +1,89 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Shared.Diagnostics;
+
+namespace Microsoft.Extensions.AI;
+
+///
+/// A delegating chat client that applies a tool reduction strategy before invoking the inner client.
+///
+///
+/// Insert this into a pipeline (typically before function invocation middleware) to automatically
+/// reduce the tool list carried on for each request.
+///
+[Experimental("MEAI001")]
+public sealed class ToolReducingChatClient : DelegatingChatClient
+{
+ private readonly IToolReductionStrategy _strategy;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The inner client.
+ /// The tool reduction strategy to apply.
+ /// Thrown if any argument is .
+ public ToolReducingChatClient(IChatClient innerClient, IToolReductionStrategy strategy)
+ : base(innerClient)
+ {
+ _strategy = Throw.IfNull(strategy);
+ }
+
+ ///
+ public override async Task GetResponseAsync(
+ IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ options = await ApplyReductionAsync(messages, options, cancellationToken).ConfigureAwait(false);
+ return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
+ }
+
+ ///
+ public override async IAsyncEnumerable GetStreamingResponseAsync(
+ IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ options = await ApplyReductionAsync(messages, options, cancellationToken).ConfigureAwait(false);
+
+ await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
+ {
+ yield return update;
+ }
+ }
+
+ private async Task ApplyReductionAsync(
+ IEnumerable messages,
+ ChatOptions? options,
+ CancellationToken cancellationToken)
+ {
+ // If there are no options or no tools, skip.
+ if (options?.Tools is not { Count: > 0 })
+ {
+ return options;
+ }
+
+ var reduced = await _strategy.SelectToolsForRequestAsync(messages, options, cancellationToken).ConfigureAwait(false);
+
+ // If strategy returned the same list instance (or reference equality), assume no change.
+ if (ReferenceEquals(reduced, options.Tools))
+ {
+ return options;
+ }
+
+ // Materialize and compare counts; if unchanged and tools have identical ordering and references, keep original.
+ if (reduced is not IList reducedList)
+ {
+ reducedList = reduced.ToList();
+ }
+
+ // Clone options to avoid mutating a possibly shared instance.
+ var cloned = options.Clone();
+ cloned.Tools = reducedList;
+ return cloned;
+ }
+}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs
index 448de8d11df..992e86a1184 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs
@@ -41,6 +41,8 @@ protected ChatClientIntegrationTests()
protected IChatClient? ChatClient { get; }
+ protected IEmbeddingGenerator>? EmbeddingGenerator { get; private set; }
+
public void Dispose()
{
ChatClient?.Dispose();
@@ -49,6 +51,13 @@ public void Dispose()
protected abstract IChatClient? CreateChatClient();
+ ///
+ /// Optionally supplies an embedding generator for integration tests that exercise
+ /// embedding-based components (e.g., tool reduction). Default returns null and
+ /// tests depending on embeddings will skip if not overridden.
+ ///
+ protected virtual IEmbeddingGenerator>? CreateEmbeddingGenerator() => null;
+
[ConditionalFact]
public virtual async Task GetResponseAsync_SingleRequestMessage()
{
@@ -1395,6 +1404,343 @@ public void Dispose()
}
}
+ [ConditionalFact]
+ public virtual async Task ToolReduction_DynamicSelection_RespectsConversationHistory()
+ {
+ SkipIfNotEnabled();
+ EnsureEmbeddingGenerator();
+
+ // Limit to 2 so that, once the conversation references both weather and translation,
+ // both tools can be included even if the latest user turn only mentions one of them.
+ var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 2);
+
+ var weatherTool = AIFunctionFactory.Create(
+ () => "Weather data",
+ new AIFunctionFactoryOptions
+ {
+ Name = "GetWeatherForecast",
+ Description = "Returns weather forecast and temperature for a given city."
+ });
+
+ var translateTool = AIFunctionFactory.Create(
+ () => "Translated text",
+ new AIFunctionFactoryOptions
+ {
+ Name = "TranslateText",
+ Description = "Translates text between human languages."
+ });
+
+ var mathTool = AIFunctionFactory.Create(
+ () => 42,
+ new AIFunctionFactoryOptions
+ {
+ Name = "SolveMath",
+ Description = "Solves basic math problems."
+ });
+
+ var allTools = new List { weatherTool, translateTool, mathTool };
+
+ IList? firstTurnTools = null;
+ IList? secondTurnTools = null;
+
+ using var client = ChatClient!
+ .AsBuilder()
+ .UseToolReduction(strategy)
+ .Use(async (messages, options, next, ct) =>
+ {
+ // Capture the (possibly reduced) tool list for each turn.
+ if (firstTurnTools is null)
+ {
+ firstTurnTools = options?.Tools;
+ }
+ else
+ {
+ secondTurnTools ??= options?.Tools;
+ }
+
+ await next(messages, options, ct);
+ })
+ .UseFunctionInvocation()
+ .Build();
+
+ // Maintain chat history across turns.
+ List history = [];
+
+ // Turn 1: Ask a weather question.
+ history.Add(new ChatMessage(ChatRole.User, "What will the weather be in Seattle tomorrow?"));
+ var firstResponse = await client.GetResponseAsync(history, new ChatOptions { Tools = allTools });
+ history.AddMessages(firstResponse); // Append assistant reply.
+
+ Assert.NotNull(firstTurnTools);
+ Assert.Contains(firstTurnTools, t => t.Name == "GetWeatherForecast");
+
+ // Turn 2: Ask a translation question. Even though only translation is mentioned now,
+ // conversation history still contains a weather request. Expect BOTH weather + translation tools.
+ history.Add(new ChatMessage(ChatRole.User, "Please translate 'good evening' into French."));
+ var secondResponse = await client.GetResponseAsync(history, new ChatOptions { Tools = allTools });
+ history.AddMessages(secondResponse);
+
+ Assert.NotNull(secondTurnTools);
+ Assert.Equal(2, secondTurnTools.Count); // Should have filled both slots with the two relevant domains.
+ Assert.Contains(secondTurnTools, t => t.Name == "GetWeatherForecast");
+ Assert.Contains(secondTurnTools, t => t.Name == "TranslateText");
+
+ // Ensure unrelated tool was excluded.
+ Assert.DoesNotContain(secondTurnTools, t => t.Name == "SolveMath");
+ }
+
+ [ConditionalFact]
+ public virtual async Task ToolReduction_RequireSpecificToolPreservedAndOrdered()
+ {
+ SkipIfNotEnabled();
+ EnsureEmbeddingGenerator();
+
+ // Limit would normally reduce to 1, but required tool plus another should remain.
+ var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 1);
+
+ var translateTool = AIFunctionFactory.Create(
+ () => "Translated text",
+ new AIFunctionFactoryOptions
+ {
+ Name = "TranslateText",
+ Description = "Translates phrases between languages."
+ });
+
+ var weatherTool = AIFunctionFactory.Create(
+ () => "Weather data",
+ new AIFunctionFactoryOptions
+ {
+ Name = "GetWeatherForecast",
+ Description = "Returns forecast data for a city."
+ });
+
+ var tools = new List { translateTool, weatherTool };
+
+ IList? captured = null;
+
+ using var client = ChatClient!
+ .AsBuilder()
+ .UseToolReduction(strategy)
+ .UseFunctionInvocation()
+ .Use((messages, options, next, ct) =>
+ {
+ captured = options?.Tools;
+ return next(messages, options, ct);
+ })
+ .Build();
+
+ var history = new List
+ {
+ new(ChatRole.User, "What will the weather be like in Redmond next week?")
+ };
+
+ var response = await client.GetResponseAsync(history, new ChatOptions
+ {
+ Tools = tools,
+ ToolMode = ChatToolMode.RequireSpecific(translateTool.Name)
+ });
+ history.AddMessages(response);
+
+ Assert.NotNull(captured);
+ Assert.Equal(2, captured!.Count);
+ Assert.Equal("TranslateText", captured[0].Name); // Required should appear first.
+ Assert.Equal("GetWeatherForecast", captured[1].Name);
+ }
+
+ [ConditionalFact]
+ public virtual async Task ToolReduction_ToolRemovedAfterFirstUse_NotInvokedAgain()
+ {
+ SkipIfNotEnabled();
+ EnsureEmbeddingGenerator();
+
+ int weatherInvocationCount = 0;
+
+ var weatherTool = AIFunctionFactory.Create(
+ () =>
+ {
+ weatherInvocationCount++;
+ return "Sunny and dry.";
+ },
+ new AIFunctionFactoryOptions
+ {
+ Name = "GetWeather",
+ Description = "Gets the weather forecast for a given location."
+ });
+
+ // Strategy exposes tools only on the first request, then removes them.
+ var removalStrategy = new RemoveToolAfterFirstUseStrategy();
+
+ IList? firstTurnTools = null;
+ IList? secondTurnTools = null;
+
+ using var client = ChatClient!
+ .AsBuilder()
+ // Place capture immediately after reduction so it's invoked exactly once per user request.
+ .UseToolReduction(removalStrategy)
+ .Use((messages, options, next, ct) =>
+ {
+ if (firstTurnTools is null)
+ {
+ firstTurnTools = options?.Tools;
+ }
+ else
+ {
+ secondTurnTools ??= options?.Tools;
+ }
+
+ return next(messages, options, ct);
+ })
+ .UseFunctionInvocation()
+ .Build();
+
+ List history = [];
+
+ // Turn 1
+ history.Add(new ChatMessage(ChatRole.User, "What's the weather like tomorrow in Seattle?"));
+ var firstResponse = await client.GetResponseAsync(history, new ChatOptions
+ {
+ Tools = [weatherTool],
+ ToolMode = ChatToolMode.RequireAny
+ });
+ history.AddMessages(firstResponse);
+
+ Assert.Equal(1, weatherInvocationCount);
+ Assert.NotNull(firstTurnTools);
+ Assert.Contains(firstTurnTools!, t => t.Name == "GetWeather");
+
+ // Turn 2 (tool removed by strategy even though caller supplies it again)
+ history.Add(new ChatMessage(ChatRole.User, "And what about next week?"));
+ var secondResponse = await client.GetResponseAsync(history, new ChatOptions
+ {
+ Tools = [weatherTool]
+ });
+ history.AddMessages(secondResponse);
+
+ Assert.Equal(1, weatherInvocationCount); // Not invoked again.
+ Assert.NotNull(secondTurnTools);
+ Assert.Empty(secondTurnTools!); // Strategy removed the tool set.
+
+ // Response text shouldn't just echo the tool's stub output.
+ Assert.DoesNotContain("Sunny and dry.", secondResponse.Text, StringComparison.OrdinalIgnoreCase);
+ }
+
+ [ConditionalFact]
+ public virtual async Task ToolReduction_MessagesEmbeddingTextSelector_UsesChatClientToAnalyzeConversation()
+ {
+ SkipIfNotEnabled();
+ EnsureEmbeddingGenerator();
+
+ // Create tools for different domains.
+ var weatherTool = AIFunctionFactory.Create(
+ () => "Weather data",
+ new AIFunctionFactoryOptions
+ {
+ Name = "GetWeatherForecast",
+ Description = "Returns weather forecast and temperature for a given city."
+ });
+
+ var translateTool = AIFunctionFactory.Create(
+ () => "Translated text",
+ new AIFunctionFactoryOptions
+ {
+ Name = "TranslateText",
+ Description = "Translates text between human languages."
+ });
+
+ var mathTool = AIFunctionFactory.Create(
+ () => 42,
+ new AIFunctionFactoryOptions
+ {
+ Name = "SolveMath",
+ Description = "Solves basic math problems."
+ });
+
+ var allTools = new List { weatherTool, translateTool, mathTool };
+
+ // Track the analysis result from the chat client used in the selector.
+ string? capturedAnalysis = null;
+
+ var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 2)
+ {
+ // Use a chat client to analyze the conversation and extract relevant tool categories.
+ MessagesEmbeddingTextSelector = async messages =>
+ {
+ var conversationText = string.Join("\n", messages.Select(m => $"{m.Role}: {m.Text}"));
+
+ var analysisPrompt = $"""
+ Analyze the following conversation and identify what kinds of tools would be most helpful.
+ Focus on the key topics and tasks being discussed.
+ Respond with a brief summary of the relevant tool categories (e.g., "weather", "translation", "math").
+
+ Conversation:
+ {conversationText}
+
+ Relevant tool categories:
+ """;
+
+ var response = await ChatClient!.GetResponseAsync(analysisPrompt);
+ capturedAnalysis = response.Text;
+
+ // Return the analysis as the query text for embedding-based tool selection.
+ return capturedAnalysis;
+ }
+ };
+
+ IList? selectedTools = null;
+
+ using var client = ChatClient!
+ .AsBuilder()
+ .UseToolReduction(strategy)
+ .Use(async (messages, options, next, ct) =>
+ {
+ selectedTools = options?.Tools;
+ await next(messages, options, ct);
+ })
+ .UseFunctionInvocation()
+ .Build();
+
+ // Conversation that clearly indicates weather-related needs.
+ List history = [];
+ history.Add(new ChatMessage(ChatRole.User, "What will the weather be like in London tomorrow?"));
+
+ var response = await client.GetResponseAsync(history, new ChatOptions { Tools = allTools });
+ history.AddMessages(response);
+
+ // Verify that the chat client was used to analyze the conversation.
+ Assert.NotNull(capturedAnalysis);
+ Assert.True(
+ capturedAnalysis.IndexOf("weather", StringComparison.OrdinalIgnoreCase) >= 0 ||
+ capturedAnalysis.IndexOf("forecast", StringComparison.OrdinalIgnoreCase) >= 0,
+ $"Expected analysis to mention weather or forecast: {capturedAnalysis}");
+
+ // Verify that the tool selection was influenced by the analysis.
+ Assert.NotNull(selectedTools);
+ Assert.True(selectedTools.Count <= 2, $"Expected at most 2 tools, got {selectedTools.Count}");
+ Assert.Contains(selectedTools, t => t.Name == "GetWeatherForecast");
+ }
+
+ // Test-only custom strategy: include tools on first request, then remove them afterward.
+ private sealed class RemoveToolAfterFirstUseStrategy : IToolReductionStrategy
+ {
+ private bool _used;
+
+ public Task> SelectToolsForRequestAsync(
+ IEnumerable messages,
+ ChatOptions? options,
+ CancellationToken cancellationToken = default)
+ {
+ if (!_used && options?.Tools is { Count: > 0 })
+ {
+ _used = true;
+ // Returning the same instance signals no change.
+ return Task.FromResult>(options.Tools);
+ }
+
+ // After first use, remove all tools.
+ return Task.FromResult>(Array.Empty());
+ }
+ }
+
[MemberNotNull(nameof(ChatClient))]
protected void SkipIfNotEnabled()
{
@@ -1405,4 +1751,15 @@ protected void SkipIfNotEnabled()
throw new SkipTestException("Client is not enabled.");
}
}
+
+ [MemberNotNull(nameof(EmbeddingGenerator))]
+ protected void EnsureEmbeddingGenerator()
+ {
+ EmbeddingGenerator ??= CreateEmbeddingGenerator();
+
+ if (EmbeddingGenerator is null)
+ {
+ throw new SkipTestException("Embedding generator is not enabled.");
+ }
+ }
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs
new file mode 100644
index 00000000000..96c9adc6311
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs
@@ -0,0 +1,663 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace Microsoft.Extensions.AI;
+
+public class ToolReductionTests
+{
+ [Fact]
+ public void EmbeddingToolReductionStrategy_Constructor_ThrowsWhenToolLimitIsLessThanOrEqualToZero()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ Assert.Throws(() => new EmbeddingToolReductionStrategy(gen, toolLimit: 0));
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_NoReduction_WhenToolsBelowLimit()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 5);
+
+ var tools = CreateTools("Weather", "Math");
+ var options = new ChatOptions { Tools = tools };
+
+ var result = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "Tell me about weather") },
+ options);
+
+ Assert.Same(tools, result);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_NoReduction_WhenOptionalToolsBelowLimit()
+ {
+ // 1 required + 2 optional, limit = 2 (optional count == limit) => original list returned
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2)
+ {
+ IsRequiredTool = t => t.Name == "Req"
+ };
+
+ var tools = CreateTools("Req", "Opt1", "Opt2");
+ var result = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "anything") },
+ new ChatOptions { Tools = tools });
+
+ Assert.Same(tools, result);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_Reduces_ToLimit_BySimilarity()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2);
+
+ var tools = CreateTools("Weather", "Translate", "Math", "Jokes");
+ var options = new ChatOptions { Tools = tools };
+
+ var messages = new[]
+ {
+ new ChatMessage(ChatRole.User, "Can you do some weather math for forecasting?")
+ };
+
+ var reduced = (await strategy.SelectToolsForRequestAsync(messages, options)).ToList();
+
+ Assert.Equal(2, reduced.Count);
+ Assert.Contains(reduced, t => t.Name == "Weather");
+ Assert.Contains(reduced, t => t.Name == "Math");
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_PreserveOriginalOrdering_ReordersAfterSelection()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2)
+ {
+ PreserveOriginalOrdering = true
+ };
+
+ var tools = CreateTools("Math", "Translate", "Weather");
+ var reduced = (await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "Explain weather math please") },
+ new ChatOptions { Tools = tools })).ToList();
+
+ Assert.Equal(2, reduced.Count);
+ Assert.Equal("Math", reduced[0].Name);
+ Assert.Equal("Weather", reduced[1].Name);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_Caching_AvoidsReEmbeddingTools()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1);
+
+ var tools = CreateTools("Weather", "Math", "Jokes");
+ var messages = new[] { new ChatMessage(ChatRole.User, "weather") };
+
+ _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools });
+ int afterFirst = gen.TotalValueInputs;
+
+ _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools });
+ int afterSecond = gen.TotalValueInputs;
+
+ // +1 for second query embedding only
+ Assert.Equal(afterFirst + 1, afterSecond);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_OptionsNullOrNoTools_ReturnsEmptyOrOriginal()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2);
+
+ var empty = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "anything") }, null);
+ Assert.Empty(empty);
+
+ var options = new ChatOptions { Tools = [] };
+ var result = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "weather") }, options);
+ Assert.Same(options.Tools, result);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_CustomSimilarity_InvertsOrdering()
+ {
+ using var gen = new VectorBasedTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1)
+ {
+ Similarity = (q, t) => -t.Span[0]
+ };
+
+ var highTool = new SimpleTool("HighScore", "alpha");
+ var lowTool = new SimpleTool("LowScore", "beta");
+ gen.VectorSelector = text => text.Contains("alpha") ? 10f : 1f;
+
+ var reduced = (await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "Pick something") },
+ new ChatOptions { Tools = [highTool, lowTool] })).ToList();
+
+ Assert.Single(reduced);
+ Assert.Equal("LowScore", reduced[0].Name);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_TieDeterminism_PrefersLowerOriginalIndex()
+ {
+ // Generator returns identical vectors so similarity ties; we expect original order preserved
+ using var gen = new ConstantEmbeddingGenerator(3);
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2);
+
+ var tools = CreateTools("T1", "T2", "T3", "T4");
+ var reduced = (await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "any") },
+ new ChatOptions { Tools = tools })).ToList();
+
+ Assert.Equal(2, reduced.Count);
+ Assert.Equal("T1", reduced[0].Name);
+ Assert.Equal("T2", reduced[1].Name);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_DefaultEmbeddingTextSelector_EmptyDescription_UsesNameOnly()
+ {
+ using var recorder = new RecordingEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1);
+
+ var target = new SimpleTool("ComputeSum", description: "");
+ var filler = new SimpleTool("Other", "Unrelated");
+ _ = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "math") },
+ new ChatOptions { Tools = [target, filler] });
+
+ Assert.Contains("ComputeSum", recorder.Inputs);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_DefaultEmbeddingTextSelector_EmptyName_UsesDescriptionOnly()
+ {
+ using var recorder = new RecordingEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1);
+
+ var target = new SimpleTool("", description: "Translates between languages.");
+ var filler = new SimpleTool("Other", "Unrelated");
+ _ = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "translate") },
+ new ChatOptions { Tools = [target, filler] });
+
+ Assert.Contains("Translates between languages.", recorder.Inputs);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_CustomEmbeddingTextSelector_Applied()
+ {
+ using var recorder = new RecordingEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1)
+ {
+ ToolEmbeddingTextSelector = t => $"NAME:{t.Name}|DESC:{t.Description}"
+ };
+
+ var target = new SimpleTool("WeatherTool", "Gets forecast.");
+ var filler = new SimpleTool("Other", "Irrelevant");
+ _ = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "weather") },
+ new ChatOptions { Tools = [target, filler] });
+
+ Assert.Contains("NAME:WeatherTool|DESC:Gets forecast.", recorder.Inputs);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextSelector_CustomFiltersMessages()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1);
+
+ var tools = CreateTools("Weather", "Math", "Translate");
+
+ var messages = new[]
+ {
+ new ChatMessage(ChatRole.User, "Please tell me the weather tomorrow."),
+ new ChatMessage(ChatRole.Assistant, "Sure, I can help."),
+ new ChatMessage(ChatRole.User, "Now instead solve a math problem.")
+ };
+
+ strategy.MessagesEmbeddingTextSelector = msgs => new ValueTask(msgs.LastOrDefault()?.Text ?? string.Empty);
+
+ var reduced = (await strategy.SelectToolsForRequestAsync(
+ messages,
+ new ChatOptions { Tools = tools })).ToList();
+
+ Assert.Single(reduced);
+ Assert.Equal("Math", reduced[0].Name);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextSelector_InvokedOnce()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1);
+
+ var tools = CreateTools("Weather", "Math");
+ int invocationCount = 0;
+
+ strategy.MessagesEmbeddingTextSelector = msgs =>
+ {
+ invocationCount++;
+ return new ValueTask(string.Join("\n", msgs.Select(m => m.Text)));
+ };
+
+ _ = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "weather and math") },
+ new ChatOptions { Tools = tools });
+
+ Assert.Equal(1, invocationCount);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_DefaultMessagesEmbeddingTextSelector_IncludesReasoningContent()
+ {
+ using var recorder = new RecordingEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1);
+ var tools = CreateTools("Weather", "Math");
+
+ var reasoningLine = "Thinking about the best way to get tomorrow's forecast...";
+ var answerLine = "Tomorrow will be sunny.";
+ var userLine = "What's the weather tomorrow?";
+
+ var messages = new[]
+ {
+ new ChatMessage(ChatRole.User, userLine),
+ new ChatMessage(ChatRole.Assistant,
+ [
+ new TextReasoningContent(reasoningLine),
+ new TextContent(answerLine)
+ ])
+ };
+
+ _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools });
+
+ string queryInput = recorder.Inputs[0];
+
+ Assert.Contains(userLine, queryInput);
+ Assert.Contains(reasoningLine, queryInput);
+ Assert.Contains(answerLine, queryInput);
+
+ var userIndex = queryInput.IndexOf(userLine, StringComparison.Ordinal);
+ var reasoningIndex = queryInput.IndexOf(reasoningLine, StringComparison.Ordinal);
+ var answerIndex = queryInput.IndexOf(answerLine, StringComparison.Ordinal);
+ Assert.True(userIndex >= 0 && reasoningIndex > userIndex && answerIndex > reasoningIndex);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_DefaultMessagesEmbeddingTextSelector_SkipsNonTextContent()
+ {
+ using var recorder = new RecordingEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1);
+ var tools = CreateTools("Alpha", "Beta");
+
+ var textOnly = "Provide translation.";
+ var messages = new[]
+ {
+ new ChatMessage(ChatRole.User,
+ [
+ new DataContent(new byte[] { 1, 2, 3 }, "application/octet-stream"),
+ new TextContent(textOnly)
+ ])
+ };
+
+ _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools });
+
+ var queryInput = recorder.Inputs[0];
+ Assert.Contains(textOnly, queryInput);
+ Assert.DoesNotContain("application/octet-stream", queryInput, StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_RequiredToolAlwaysIncluded()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1)
+ {
+ IsRequiredTool = t => t.Name == "Core"
+ };
+
+ var tools = CreateTools("Core", "Weather", "Math");
+ var reduced = (await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "math") },
+ new ChatOptions { Tools = tools })).ToList();
+
+ Assert.Equal(2, reduced.Count); // required + one optional (limit=1)
+ Assert.Contains(reduced, t => t.Name == "Core");
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_MultipleRequiredTools_ExceedLimit_AllRequiredIncluded()
+ {
+ // 3 required, limit=1 => expect 3 required + 1 ranked optional = 4 total
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1)
+ {
+ IsRequiredTool = t => t.Name.StartsWith("R", StringComparison.Ordinal)
+ };
+
+ var tools = CreateTools("R1", "R2", "R3", "Weather", "Math");
+ var reduced = (await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "weather math") },
+ new ChatOptions { Tools = tools })).ToList();
+
+ Assert.Equal(4, reduced.Count);
+ Assert.Equal(3, reduced.Count(t => t.Name.StartsWith("R")));
+ }
+
+ [Fact]
+ public async Task ToolReducingChatClient_ReducesTools_ForGetResponseAsync()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2);
+ var tools = CreateTools("Weather", "Math", "Translate", "Jokes");
+
+ IList? observedTools = null;
+
+ using var inner = new TestChatClient
+ {
+ GetResponseAsyncCallback = (messages, options, ct) =>
+ {
+ observedTools = options?.Tools;
+ return Task.FromResult(new ChatResponse());
+ }
+ };
+
+ using var client = inner.AsBuilder().UseToolReduction(strategy).Build();
+
+ await client.GetResponseAsync(
+ new[] { new ChatMessage(ChatRole.User, "weather math please") },
+ new ChatOptions { Tools = tools });
+
+ Assert.NotNull(observedTools);
+ Assert.Equal(2, observedTools!.Count);
+ Assert.Contains(observedTools, t => t.Name == "Weather");
+ Assert.Contains(observedTools, t => t.Name == "Math");
+ }
+
+ [Fact]
+ public async Task ToolReducingChatClient_ReducesTools_ForStreaming()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1);
+ var tools = CreateTools("Weather", "Math");
+
+ IList? observedTools = null;
+
+ using var inner = new TestChatClient
+ {
+ GetStreamingResponseAsyncCallback = (messages, options, ct) =>
+ {
+ observedTools = options?.Tools;
+ return EmptyAsyncEnumerable();
+ }
+ };
+
+ using var client = inner.AsBuilder().UseToolReduction(strategy).Build();
+
+ await foreach (var _ in client.GetStreamingResponseAsync(
+ new[] { new ChatMessage(ChatRole.User, "math") },
+ new ChatOptions { Tools = tools }))
+ {
+ // Consume
+ }
+
+ Assert.NotNull(observedTools);
+ Assert.Single(observedTools!);
+ Assert.Equal("Math", observedTools![0].Name);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_EmptyQuery_NoReduction()
+ {
+ // Arrange: more tools than limit so we'd normally reduce, but query is empty -> return full list unchanged.
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1);
+
+ var tools = CreateTools("ToolA", "ToolB", "ToolC");
+ var options = new ChatOptions { Tools = tools };
+
+ // Empty / whitespace message text produces empty query.
+ var messages = new[] { new ChatMessage(ChatRole.User, " ") };
+
+ // Act
+ var result = await strategy.SelectToolsForRequestAsync(messages, options);
+
+ // Assert: same reference (no reduction), and generator not invoked at all.
+ Assert.Same(tools, result);
+ Assert.Equal(0, gen.TotalValueInputs);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_EmptyQuery_NoReduction_WithRequiredTool()
+ {
+ // Arrange: required tool + optional tools; still should return original set when query is empty.
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1)
+ {
+ IsRequiredTool = t => t.Name == "Req"
+ };
+
+ var tools = CreateTools("Req", "Optional1", "Optional2");
+ var options = new ChatOptions { Tools = tools };
+
+ var messages = new[] { new ChatMessage(ChatRole.User, " ") };
+
+ // Act
+ var result = await strategy.SelectToolsForRequestAsync(messages, options);
+
+ // Assert
+ Assert.Same(tools, result);
+ Assert.Equal(0, gen.TotalValueInputs);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_EmptyQuery_ViaCustomMessagesSelector_NoReduction()
+ {
+ // Arrange: force empty query through custom selector returning whitespace.
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1)
+ {
+ MessagesEmbeddingTextSelector = _ => new ValueTask(" ")
+ };
+
+ var tools = CreateTools("One", "Two");
+ var messages = new[]
+ {
+ new ChatMessage(ChatRole.User, "This content will be ignored by custom selector.")
+ };
+
+ // Act
+ var result = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools });
+
+ // Assert: no reduction and no embeddings generated.
+ Assert.Same(tools, result);
+ Assert.Equal(0, gen.TotalValueInputs);
+ }
+
+ private static List CreateTools(params string[] names) =>
+ names.Select(n => (AITool)new SimpleTool(n, $"Description about {n}")).ToList();
+
+#pragma warning disable CS1998
+ private static async IAsyncEnumerable EmptyAsyncEnumerable()
+ {
+ yield break;
+ }
+#pragma warning restore CS1998
+
+ private sealed class SimpleTool : AITool
+ {
+ private readonly string _name;
+ private readonly string _description;
+
+ public SimpleTool(string name, string description)
+ {
+ _name = name;
+ _description = description;
+ }
+
+ public override string Name => _name;
+ public override string Description => _description;
+ }
+
+ ///
+ /// Deterministic embedding generator producing sparse keyword indicator vectors.
+ /// Each dimension corresponds to a known keyword. Cosine similarity then reflects
+ /// pure keyword overlap (non-overlapping keywords contribute nothing), avoiding
+ /// false ties for tools unrelated to the query.
+ ///
+ private sealed class DeterministicTestEmbeddingGenerator : IEmbeddingGenerator>
+ {
+ private static readonly string[] _keywords =
+ [
+ "weather","forecast","temperature","math","calculate","sum","translate","language","joke"
+ ];
+
+ // +1 bias dimension (last) to avoid zero magnitude vectors when no keywords present.
+ private static int VectorLength => _keywords.Length + 1;
+
+ public int TotalValueInputs { get; private set; }
+
+ public Task>> GenerateAsync(
+ IEnumerable values,
+ EmbeddingGenerationOptions? options = null,
+ CancellationToken cancellationToken = default)
+ {
+ var list = new List>();
+
+ foreach (var v in values)
+ {
+ TotalValueInputs++;
+ var vec = new float[VectorLength];
+ if (!string.IsNullOrWhiteSpace(v))
+ {
+ var lower = v.ToLowerInvariant();
+ for (int i = 0; i < _keywords.Length; i++)
+ {
+ if (lower.Contains(_keywords[i]))
+ {
+ vec[i] = 1f;
+ }
+ }
+ }
+
+ vec[^1] = 1f; // bias
+ list.Add(new Embedding(vec));
+ }
+
+ return Task.FromResult(new GeneratedEmbeddings>(list));
+ }
+
+ public object? GetService(Type serviceType, object? serviceKey = null) => null;
+
+ public void Dispose()
+ {
+ // No-op
+ }
+ }
+
+ private sealed class RecordingEmbeddingGenerator : IEmbeddingGenerator>
+ {
+ public List Inputs { get; } = new();
+
+ public Task>> GenerateAsync(
+ IEnumerable values,
+ EmbeddingGenerationOptions? options = null,
+ CancellationToken cancellationToken = default)
+ {
+ var list = new List>();
+ foreach (var v in values)
+ {
+ Inputs.Add(v);
+
+ // Basic 2-dim vector (length encodes a bit of variability)
+ list.Add(new Embedding(new float[] { v.Length, 1f }));
+ }
+
+ return Task.FromResult(new GeneratedEmbeddings>(list));
+ }
+
+ public object? GetService(Type serviceType, object? serviceKey = null) => null;
+ public void Dispose()
+ {
+ // No-op
+ }
+ }
+
+ private sealed class VectorBasedTestEmbeddingGenerator : IEmbeddingGenerator>
+ {
+ public Func VectorSelector { get; set; } = _ => 1f;
+ public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ var list = new List>();
+ foreach (var v in values)
+ {
+ list.Add(new Embedding(new float[] { VectorSelector(v), 1f }));
+ }
+
+ return Task.FromResult(new GeneratedEmbeddings>(list));
+ }
+
+ public object? GetService(Type serviceType, object? serviceKey = null) => null;
+ public void Dispose()
+ {
+ // No-op
+ }
+ }
+
+ private sealed class ConstantEmbeddingGenerator : IEmbeddingGenerator>
+ {
+ private readonly float[] _vector;
+ public ConstantEmbeddingGenerator(int dims)
+ {
+ _vector = Enumerable.Repeat(1f, dims).ToArray();
+ }
+
+ public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ var list = new List>();
+ foreach (var _ in values)
+ {
+ list.Add(new Embedding(_vector));
+ }
+
+ return Task.FromResult(new GeneratedEmbeddings>(list));
+ }
+
+ public object? GetService(Type serviceType, object? serviceKey = null) => null;
+ public void Dispose()
+ {
+ // No-op
+ }
+ }
+
+ private sealed class TestChatClient : IChatClient
+ {
+ public Func, ChatOptions?, CancellationToken, Task>? GetResponseAsyncCallback { get; set; }
+ public Func, ChatOptions?, CancellationToken, IAsyncEnumerable>? GetStreamingResponseAsyncCallback { get; set; }
+
+ public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) =>
+ (GetResponseAsyncCallback ?? throw new InvalidOperationException())(messages, options, cancellationToken);
+
+ public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) =>
+ (GetStreamingResponseAsyncCallback ?? throw new InvalidOperationException())(messages, options, cancellationToken);
+
+ public object? GetService(Type serviceType, object? serviceKey = null) => null;
+ public void Dispose()
+ {
+ // No-op
+ }
+ }
+}
diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs
index 6322e3d6b64..a9e08a58e52 100644
--- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs
@@ -8,4 +8,8 @@ public class OpenAIChatClientIntegrationTests : ChatClientIntegrationTests
protected override IChatClient? CreateChatClient() =>
IntegrationTestHelpers.GetOpenAIClient()
?.GetChatClient(TestRunnerConfiguration.Instance["OpenAI:ChatModel"] ?? "gpt-4o-mini").AsIChatClient();
+
+ protected override IEmbeddingGenerator>? CreateEmbeddingGenerator() =>
+ IntegrationTestHelpers.GetOpenAIClient()
+ ?.GetEmbeddingClient(TestRunnerConfiguration.Instance["OpenAI:EmbeddingModel"] ?? "text-embedding-3-small").AsIEmbeddingGenerator();
}