Skip to content

Middleware for tool reduction #6670

@stephentoub

Description

@stephentoub

Especially with MCP, but not limited to it, it's becoming more and more common for clients to have many tools. Too many tools can cause problems, both due to hitting explicit limits on the number of tools a service allows in a request but also due to implicit limits that stem from degraded capability by an LLM to choose the right tool when there are lots. Clients end up needing to implement custom schemes for tool reduction, such as embedding tool descriptions and choosing which subset of tools to incorporate based on an embedding of the supplied messages, or using virtual tools, where tools are grouped (possibly by an LLM) and that grouping is advertized as a single tool, which the LLM can choose to invoke in order to activate all of the tools in that group.

Such schemes can be implemented as middleware, such that a developer can just plug in a scheme into their IChatClient pipeline, e.g.

IChatClient client = ...;
client = client
    .UseToolReduction(new EmbeddingToolReductionStrategy(embeddingGenerator, 15))
    .UseFunctionInvocation()
    .UseOpenTelemetry()
    .Build();

Different strategies could be handled via parameterization of that IChatClient implementation, e.g.

public interface IToolReductionStrategy
{
    ValueTask<IEnumerable<AITool>> SelectToolsForRequest(IEnumerable<ChatMessage> messages, ChatOptions? options);
}

and implementations of that strategy, e.g.

public sealed class EmbeddingToolReductionStrategy(
    IEmbeddingGenerator<string, Embedding<float>> generator, int toolLimit) : IToolReductionStrategy
{
     public async ValueTask<IEnumerable<AITool>> SelectToolsForRequest(IEnumerable<ChatMessage> messages, ChatOptions? options)
    {
        IList<AITool>? tools = options?.Tools;
        if (tools is null) return [];

        var toolEmbeddings = await generator.GenerateAsync(tools.Select(t => $"{f.Name}\n{f.Description}"));

        Embedding<float> queryEmbedding = await generator.GenerateAsync(string.Concat(messages));
        return tools.Zip(toolEmbeddings)
            .OrderByDescending(f => TensorPrimitives.CosineSimilarity(queryEmbedding.Vector.Span, f.Item1.Vector.Span))
            .Take(toolLimit)
            .Select(f => f.Item1)
            .ToArray();
}

That's just a rough sketch, but something along those lines.

Metadata

Metadata

Assignees

Labels

area-aiMicrosoft.Extensions.AI libraries

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions