Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,54 @@ var openai = new OpenAIClient(
OpenAIClientOptions.Observable(requests.Add, responses.Add));
```

## Tool Results

Given the following tool:

```csharp
MyResult RunTool(string name, string description, string content) { ... }
```

You can use the `ToolFactory` and `FindCall<MyResult>` extension method to
locate the function invocation, its outcome and the typed result for inspection:

```csharp
AIFunction tool = ToolFactory.Create(RunTool);
var options = new ChatOptions
{
ToolMode = ChatToolMode.RequireSpecific(tool.Name), // 👈 forces the tool to be used
Tools = [tool]
};

var response = await client.GetResponseAsync(chat, options);
var result = response.FindCalls<MyResult>(tool).FirstOrDefault();

if (result != null)
{
// Successful tool call
Console.WriteLine($"Args: '{result.Call.Arguments.Count}'");
MyResult typed = result.Result;
}
else
{
Console.WriteLine("Tool call not found in response.");
}
```

If the typed result is not found, you can also inspect the raw outcomes by finding
untyped calls to the tool and checking their `Outcome.Exception` property:

```csharp
var result = response.FindCalls(tool).FirstOrDefault();
if (result.Outcome.Exception is not null)
{
Console.WriteLine($"Tool call failed: {result.Outcome.Exception.Message}");
}
else
{
Console.WriteLine($"Tool call succeeded: {result.Outcome.Result}");
}
```

## Console Logging

Expand Down
137 changes: 137 additions & 0 deletions src/AI.Tests/ToolsTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
using System.ComponentModel;
using Microsoft.Extensions.AI;
using static ConfigurationExtensions;

namespace Devlooped.Extensions.AI;

public class ToolsTests(ITestOutputHelper output)
{
public record ToolResult(string Name, string Description, string Content);

[SecretsFact("OPENAI_API_KEY")]
public async Task RunToolResult()
{
var chat = new Chat()
{
{ "system", "You make up a tool run by making up a name, description and content based on whatever the user says." },
{ "user", "I want to create an order for a dozen eggs" },
};

var client = new OpenAIChatClient(Configuration["OPENAI_API_KEY"]!, "gpt-4.1",
OpenAI.OpenAIClientOptions.WriteTo(output))
.AsBuilder()
.UseFunctionInvocation()
.Build();

var tool = ToolFactory.Create(RunTool);
var options = new ChatOptions
{
ToolMode = ChatToolMode.RequireSpecific(tool.Name),
Tools = [tool]
};

var response = await client.GetResponseAsync(chat, options);
var result = response.FindCalls<ToolResult>(tool).FirstOrDefault();

Assert.NotNull(result);
Assert.NotNull(result.Call);
Assert.Equal(tool.Name, result.Call.Name);
Assert.NotNull(result.Outcome);
Assert.Null(result.Outcome.Exception);
}

[SecretsFact("OPENAI_API_KEY")]
public async Task RunToolTerminateResult()
{
var chat = new Chat()
{
{ "system", "You make up a tool run by making up a name, description and content based on whatever the user says." },
{ "user", "I want to create an order for a dozen eggs" },
};

var client = new OpenAIChatClient(Configuration["OPENAI_API_KEY"]!, "gpt-4.1",
OpenAI.OpenAIClientOptions.WriteTo(output))
.AsBuilder()
.UseFunctionInvocation()
.Build();

var tool = ToolFactory.Create(RunToolTerminate);
var options = new ChatOptions
{
ToolMode = ChatToolMode.RequireSpecific(tool.Name),
Tools = [tool]
};

var response = await client.GetResponseAsync(chat, options);
var result = response.FindCalls<ToolResult>(tool).FirstOrDefault();

Assert.NotNull(result);
Assert.NotNull(result.Call);
Assert.Equal(tool.Name, result.Call.Name);
Assert.NotNull(result.Outcome);
Assert.Null(result.Outcome.Exception);
}

[SecretsFact("OPENAI_API_KEY")]
public async Task RunToolExceptionOutcome()
{
var chat = new Chat()
{
{ "system", "You make up a tool run by making up a name, description and content based on whatever the user says." },
{ "user", "I want to create an order for a dozen eggs" },
};

var client = new OpenAIChatClient(Configuration["OPENAI_API_KEY"]!, "gpt-4.1",
OpenAI.OpenAIClientOptions.WriteTo(output))
.AsBuilder()
.UseFunctionInvocation()
.Build();

var tool = ToolFactory.Create(RunToolThrows);
var options = new ChatOptions
{
ToolMode = ChatToolMode.RequireSpecific(tool.Name),
Tools = [tool]
};

var response = await client.GetResponseAsync(chat, options);
var result = response.FindCalls(tool).FirstOrDefault();

Assert.NotNull(result);
Assert.NotNull(result.Call);
Assert.Equal(tool.Name, result.Call.Name);
Assert.NotNull(result.Outcome);
Assert.NotNull(result.Outcome.Exception);
}

[Description("Runs a tool to provide a result based on user input.")]
ToolResult RunTool(
[Description("The name")] string name,
[Description("The description")] string description,
[Description("The content")] string content)
{
// Simulate running a tool and returning a result
return new ToolResult(name, description, content);
}

[Description("Runs a tool to provide a result based on user input.")]
ToolResult RunToolTerminate(
[Description("The name")] string name,
[Description("The description")] string description,
[Description("The content")] string content)
{
FunctionInvokingChatClient.CurrentContext?.Terminate = true;
// Simulate running a tool and returning a result
return new ToolResult(name, description, content);
}

[Description("Runs a tool to provide a result based on user input.")]
ToolResult RunToolThrows(
[Description("The name")] string name,
[Description("The description")] string description,
[Description("The content")] string content)
{
FunctionInvokingChatClient.CurrentContext?.Terminate = true;
throw new ArgumentException("BOOM");
}
}
105 changes: 105 additions & 0 deletions src/AI/ToolExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
using System.Text.Json;
using Microsoft.Extensions.AI;

namespace Devlooped.Extensions.AI;

/// <summary>
/// Represents a tool call made by the AI, including the function call content and the result of the function execution.
/// </summary>
public record ToolCall(FunctionCallContent Call, FunctionResultContent Outcome);

/// <summary>
/// Represents a tool call made by the AI, including the function call content, the result of the function execution,
/// and the deserialized result of type <typeparamref name="TResult"/>.
/// </summary>
public record ToolCall<TResult>(FunctionCallContent Call, FunctionResultContent Outcome, TResult Result);

/// <summary>
/// Extensions for inspecting chat messages and responses for tool
/// usage and processing responses.
/// </summary>
public static class ToolExtensions
{
/// <summary>
/// Looks for calls to a tool and their outcome.
/// </summary>
public static IEnumerable<ToolCall> FindCalls(this ChatResponse response, AIFunction tool)
=> FindCalls(response.Messages, tool.Name);

/// <summary>
/// Looks for calls to a tool and their outcome.
/// </summary>
public static IEnumerable<ToolCall> FindCalls(this IEnumerable<ChatMessage> messages, AIFunction tool)
=> FindCalls(messages, tool.Name);

/// <summary>
/// Looks for calls to a tool and their outcome.
/// </summary>
public static IEnumerable<ToolCall> FindCalls(this IEnumerable<ChatMessage> messages, string tool)
{
var calls = messages
.Where(x => x.Role == ChatRole.Assistant)
.SelectMany(x => x.Contents)
.OfType<FunctionCallContent>()
.Where(x => x.Name == tool)
.ToDictionary(x => x.CallId);

var results = messages
.Where(x => x.Role == ChatRole.Tool)
.SelectMany(x => x.Contents)
.OfType<FunctionResultContent>()
.Where(x => calls.TryGetValue(x.CallId, out var call) && call.Name == tool)
.Select(x => new ToolCall(calls[x.CallId], x));

return results;
}

/// <summary>
/// Looks for a user prompt in the chat response messages.
/// </summary>
/// <remarks>
/// In order for this to work, the <see cref="AIFunctionFactory"/> must have been invoked using
/// the <see cref="ToolJsonOptions.Default"/> or with a <see cref="JsonSerializerOptions"/> configured
/// with <see cref="TypeInjectingResolverExtensions.WithTypeInjection(JsonSerializerOptions)"/> so
/// that the tool result type can be properly inspected.
/// </remarks>
public static IEnumerable<ToolCall<TResult>> FindCalls<TResult>(this ChatResponse response, AIFunction tool)
=> FindCalls<TResult>(response.Messages, tool.Name);

/// <summary>
/// Looks for a user prompt in the chat response messages.
/// </summary>
/// <remarks>
/// In order for this to work, the <see cref="AIFunctionFactory"/> must have been invoked using
/// the <see cref="ToolJsonOptions.Default"/> or with a <see cref="JsonSerializerOptions"/> configured
/// with <see cref="TypeInjectingResolverExtensions.WithTypeInjection(JsonSerializerOptions)"/> so
/// that the tool result type can be properly inspected.
/// </remarks>
public static IEnumerable<ToolCall<TResult>> FindCalls<TResult>(this IEnumerable<ChatMessage> messages, AIFunction tool)
=> FindCalls<TResult>(messages, tool.Name);

/// <summary>
/// Looks for a user prompt in the chat response messages.
/// </summary>
/// <remarks>
/// In order for this to work, the <see cref="AIFunctionFactory"/> must have been invoked using
/// the <see cref="ToolJsonOptions.Default"/> or with a <see cref="JsonSerializerOptions"/> configured
/// with <see cref="TypeInjectingResolverExtensions.WithTypeInjection(JsonSerializerOptions)"/> so
/// that the tool result type can be properly inspected.
/// </remarks>
public static IEnumerable<ToolCall<TResult>> FindCalls<TResult>(this IEnumerable<ChatMessage> messages, string tool)
{
var calls = FindCalls(messages, tool)
.Where(x => x.Outcome.Result is JsonElement element &&
element.ValueKind == JsonValueKind.Object &&
element.TryGetProperty("$type", out var type) &&
type.GetString() == typeof(TResult).FullName)
.Select(x => new ToolCall<TResult>(
Call: x.Call,
Outcome: x.Outcome,
Result: JsonSerializer.Deserialize<TResult>((JsonElement)x.Outcome.Result!, ToolJsonOptions.Default) ??
throw new InvalidOperationException($"Failed to deserialize result for tool '{tool}' to {typeof(TResult).FullName}.")));

return calls;
}
}
19 changes: 19 additions & 0 deletions src/AI/ToolFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Microsoft.Extensions.AI;

namespace Devlooped.Extensions.AI;

/// <summary>
/// Creates tools for function calling that can leverage the <see cref="ToolExtensions"/>
/// extension methods for locating invocations and their results.
/// </summary>
public static class ToolFactory
{
/// <summary>
/// Invokes <see cref="AIFunctionFactory.Create(Delegate, string?, string?, System.Text.Json.JsonSerializerOptions?)"/>
/// using the method name following the naming convention and serialization options from <see cref="ToolJsonOptions.Default"/>.
/// </summary>
public static AIFunction Create(Delegate method)
=> AIFunctionFactory.Create(method,
ToolJsonOptions.Default.PropertyNamingPolicy!.ConvertName(method.Method.Name),
serializerOptions: ToolJsonOptions.Default);
}
34 changes: 34 additions & 0 deletions src/AI/ToolJsonOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using System.Diagnostics;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;

namespace Devlooped.Extensions.AI;

/// <summary>
/// Provides a <see cref="JsonSerializerOptions"/> optimized for use with
/// function calling and tools.
/// </summary>
public static class ToolJsonOptions
{
static ToolJsonOptions() => Default.MakeReadOnly();

/// <summary>
/// Default <see cref="JsonSerializerOptions"/> for function calling and tools.
/// </summary>
public static JsonSerializerOptions Default { get; } = new(JsonSerializerDefaults.Web)
{
Converters =
{
new AdditionalPropertiesDictionaryConverter(),
new JsonStringEnumConverter(),
},
DefaultIgnoreCondition =
JsonIgnoreCondition.WhenWritingDefault |
JsonIgnoreCondition.WhenWritingNull,
Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
WriteIndented = Debugger.IsAttached,
TypeInfoResolver = new TypeInjectingResolver(new DefaultJsonTypeInfoResolver())
};
}
Loading
Loading