Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the same JsonSerializerOptions default in all locations. #5507

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Schema;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
Expand All @@ -21,13 +17,8 @@ namespace Microsoft.Extensions.AI;
/// <summary>
/// Provides extension methods on <see cref="IChatClient"/> that simplify working with structured output.
/// </summary>
public static partial class ChatClientStructuredOutputExtensions
public static class ChatClientStructuredOutputExtensions
{
private const string UsesReflectionJsonSerializerMessage =
"This method uses the reflection-based JsonSerializer which can break in trimmed or AOT applications.";

private static JsonSerializerOptions? _defaultJsonSerializerOptions;

/// <summary>Sends chat messages to the model, requesting a response matching the type <typeparamref name="T"/>.</summary>
/// <param name="chatClient">The <see cref="IChatClient"/>.</param>
/// <param name="chatMessages">The chat content to send.</param>
Expand All @@ -44,16 +35,14 @@ public static partial class ChatClientStructuredOutputExtensions
/// by the client, including any messages for roundtrips to the model as part of the implementation of this request, will be included.
/// </remarks>
/// <typeparam name="T">The type of structured output to request.</typeparam>
[RequiresUnreferencedCode(UsesReflectionJsonSerializerMessage)]
[RequiresDynamicCode(UsesReflectionJsonSerializerMessage)]
public static Task<ChatCompletion<T>> CompleteAsync<T>(
this IChatClient chatClient,
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class =>
CompleteAsync<T>(chatClient, chatMessages, DefaultJsonSerializerOptions, options, useNativeJsonSchema, cancellationToken);
CompleteAsync<T>(chatClient, chatMessages, JsonDefaults.Options, options, useNativeJsonSchema, cancellationToken);
stephentoub marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>Sends a user chat text message to the model, requesting a response matching the type <typeparamref name="T"/>.</summary>
/// <param name="chatClient">The <see cref="IChatClient"/>.</param>
Expand All @@ -67,10 +56,6 @@ public static Task<ChatCompletion<T>> CompleteAsync<T>(
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The response messages generated by the client.</returns>
/// <typeparam name="T">The type of structured output to request.</typeparam>
[RequiresDynamicCode("JSON serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. "
+ "Use System.Text.Json source generation for native AOT applications.")]
[RequiresUnreferencedCode("JSON serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. "
+ "Use System.Text.Json source generation for native AOT applications.")]
public static Task<ChatCompletion<T>> CompleteAsync<T>(
this IChatClient chatClient,
string chatMessage,
Expand Down Expand Up @@ -154,7 +139,7 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
});
schemaNode.Insert(0, "$schema", "https://json-schema.org/draft/2020-12/schema");
schemaNode.Add("additionalProperties", false);
var schema = JsonSerializer.Serialize(schemaNode, JsonNodeContext.Default.JsonNode);
var schema = JsonSerializer.Serialize(schemaNode, JsonDefaults.Options.GetTypeInfo(typeof(JsonNode)));

ChatMessage? promptAugmentation = null;
options = (options ?? new()).Clone();
Expand Down Expand Up @@ -201,28 +186,4 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
}
}
}

private static JsonSerializerOptions DefaultJsonSerializerOptions
{
[RequiresUnreferencedCode(UsesReflectionJsonSerializerMessage)]
[RequiresDynamicCode(UsesReflectionJsonSerializerMessage)]
get => _defaultJsonSerializerOptions ?? GetOrCreateDefaultJsonSerializerOptions();
}

[RequiresUnreferencedCode(UsesReflectionJsonSerializerMessage)]
[RequiresDynamicCode(UsesReflectionJsonSerializerMessage)]
private static JsonSerializerOptions GetOrCreateDefaultJsonSerializerOptions()
{
var options = new JsonSerializerOptions(JsonSerializerDefaults.General)
{
Converters = { new JsonStringEnumConverter() },
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
WriteIndented = true,
};
return Interlocked.CompareExchange(ref _defaultJsonSerializerOptions, options, null) ?? options;
}

[JsonSerializable(typeof(JsonNode))]
[JsonSourceGenerationOptions(WriteIndented = true)]
private sealed partial class JsonNodeContext : JsonSerializerContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,12 @@ namespace Microsoft.Extensions.AI;
/// <summary>Provides factory methods for creating commonly-used implementations of <see cref="AIFunction"/>.</summary>
public static class AIFunctionFactory
{
internal const string UsesReflectionJsonSerializerMessage =
"This method uses the reflection-based JsonSerializer which can break in trimmed or AOT applications.";

/// <summary>Lazily-initialized default options instance.</summary>
private static AIFunctionFactoryCreateOptions? _defaultOptions;

/// <summary>Creates an <see cref="AIFunction"/> instance for a method, specified via a delegate.</summary>
/// <param name="method">The method to be represented via the created <see cref="AIFunction"/>.</param>
/// <returns>The created <see cref="AIFunction"/> for invoking <paramref name="method"/>.</returns>
[RequiresUnreferencedCode(UsesReflectionJsonSerializerMessage)]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A number of these overloads are now redundant, but I defer to @stephentoub when and when they should be removed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which are redundant / would you want to remove? Do you mean making the AIFunctionFactoryCreateOptions optional on the other overload?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be ok if you wanted to consolidate them here. I'm more concerned about batching up breaking binary changes to the object model, as those are the things we'd expect nuget packages to be impacted by. AIFunctionFactory is less relevant there.

[RequiresDynamicCode(UsesReflectionJsonSerializerMessage)]
public static AIFunction Create(Delegate method) => Create(method, _defaultOptions ??= new());

/// <summary>Creates an <see cref="AIFunction"/> instance for a method, specified via a delegate.</summary>
Expand All @@ -52,8 +47,6 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryCreateOptions
/// <param name="name">The name to use for the <see cref="AIFunction"/>.</param>
/// <param name="description">The description to use for the <see cref="AIFunction"/>.</param>
/// <returns>The created <see cref="AIFunction"/> for invoking <paramref name="method"/>.</returns>
[RequiresUnreferencedCode("Reflection is used to access types from the supplied Delegate.")]
[RequiresDynamicCode("Reflection is used to access types from the supplied Delegate.")]
public static AIFunction Create(Delegate method, string? name, string? description = null)
=> Create(method, (_defaultOptions ??= new()).SerializerOptions, name, description);

Expand All @@ -80,8 +73,6 @@ public static AIFunction Create(Delegate method, JsonSerializerOptions options,
/// This should be <see langword="null"/> if and only if <paramref name="method"/> is a static method.
/// </param>
/// <returns>The created <see cref="AIFunction"/> for invoking <paramref name="method"/>.</returns>
[RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")]
[RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")]
public static AIFunction Create(MethodInfo method, object? target = null)
=> Create(method, target, _defaultOptions ??= new());

Expand All @@ -107,8 +98,8 @@ private sealed class ReflectionAIFunction : AIFunction
{
private readonly MethodInfo _method;
private readonly object? _target;
private readonly Func<IReadOnlyDictionary<string, object?>, AIFunctionContext?, object?>[] _parameterMarshalers;
private readonly Func<object?, ValueTask<object?>> _returnMarshaler;
private readonly Func<IReadOnlyDictionary<string, object?>, AIFunctionContext?, object?>[] _parameterMarshallers;
private readonly Func<object?, ValueTask<object?>> _returnMarshaller;
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
private readonly JsonTypeInfo? _returnTypeInfo;
private readonly bool _needsAIFunctionContext;

Expand Down Expand Up @@ -185,11 +176,11 @@ static bool IsAsyncMethod(MethodInfo method)

// Get marshaling delegates for parameters and build up the parameter metadata.
var parameters = method.GetParameters();
_parameterMarshalers = new Func<IReadOnlyDictionary<string, object?>, AIFunctionContext?, object?>[parameters.Length];
_parameterMarshallers = new Func<IReadOnlyDictionary<string, object?>, AIFunctionContext?, object?>[parameters.Length];
bool sawAIContextParameter = false;
for (int i = 0; i < parameters.Length; i++)
{
if (GetParameterMarshaler(options.SerializerOptions, parameters[i], ref sawAIContextParameter, out _parameterMarshalers[i]) is AIFunctionParameterMetadata parameterView)
if (GetParameterMarshaller(options.SerializerOptions, parameters[i], ref sawAIContextParameter, out _parameterMarshallers[i]) is AIFunctionParameterMetadata parameterView)
{
parameterMetadata?.Add(parameterView);
}
Expand All @@ -198,7 +189,7 @@ static bool IsAsyncMethod(MethodInfo method)
_needsAIFunctionContext = sawAIContextParameter;

// Get the return type and a marshaling func for the return value.
Type returnType = GetReturnMarshaler(method, out _returnMarshaler);
Type returnType = GetReturnMarshaller(method, out _returnMarshaller);
_returnTypeInfo = returnType != typeof(void) ? options.SerializerOptions.GetTypeInfo(returnType) : null;

Metadata = new AIFunctionMetadata(functionName)
Expand All @@ -224,8 +215,8 @@ static bool IsAsyncMethod(MethodInfo method)
IEnumerable<KeyValuePair<string, object?>>? arguments,
CancellationToken cancellationToken)
{
var paramMarshalers = _parameterMarshalers;
object?[] args = paramMarshalers.Length != 0 ? new object?[paramMarshalers.Length] : [];
var paramMarshallers = _parameterMarshallers;
object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : [];

IReadOnlyDictionary<string, object?> argDict =
arguments is null || args.Length == 0 ? EmptyReadOnlyDictionary<string, object?>.Instance :
Expand All @@ -242,10 +233,10 @@ static bool IsAsyncMethod(MethodInfo method)

for (int i = 0; i < args.Length; i++)
{
args[i] = paramMarshalers[i](argDict, context);
args[i] = paramMarshallers[i](argDict, context);
}

object? result = await _returnMarshaler(ReflectionInvoke(_method, _target, args)).ConfigureAwait(false);
object? result = await _returnMarshaller(ReflectionInvoke(_method, _target, args)).ConfigureAwait(false);

switch (_returnTypeInfo)
{
Expand All @@ -271,11 +262,11 @@ static bool IsAsyncMethod(MethodInfo method)
/// <summary>
/// Gets a delegate for handling the marshaling of a parameter.
/// </summary>
private static AIFunctionParameterMetadata? GetParameterMarshaler(
private static AIFunctionParameterMetadata? GetParameterMarshaller(
JsonSerializerOptions options,
ParameterInfo parameter,
ref bool sawAIFunctionContext,
out Func<IReadOnlyDictionary<string, object?>, AIFunctionContext?, object?> marshaler)
out Func<IReadOnlyDictionary<string, object?>, AIFunctionContext?, object?> marshaller)
{
if (string.IsNullOrWhiteSpace(parameter.Name))
{
Expand All @@ -292,20 +283,20 @@ static bool IsAsyncMethod(MethodInfo method)

sawAIFunctionContext = true;

marshaler = static (_, ctx) =>
marshaller = static (_, ctx) =>
{
Debug.Assert(ctx is not null, "Expected a non-null context object.");
return ctx;
};
return null;
}

// Resolve the contract used to marshall the value from JSON -- can throw if not supported or not found.
// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
Type parameterType = parameter.ParameterType;
JsonTypeInfo typeInfo = options.GetTypeInfo(parameterType);

// Create a marshaler that simply looks up the parameter by name in the arguments dictionary.
marshaler = (IReadOnlyDictionary<string, object?> arguments, AIFunctionContext? _) =>
// Create a marshaller that simply looks up the parameter by name in the arguments dictionary.
marshaller = (IReadOnlyDictionary<string, object?> arguments, AIFunctionContext? _) =>
{
// If the parameter has an argument specified in the dictionary, return that argument.
if (arguments.TryGetValue(parameter.Name, out object? value))
Expand Down Expand Up @@ -368,15 +359,15 @@ static bool IsAsyncMethod(MethodInfo method)
/// <summary>
/// Gets a delegate for handling the result value of a method, converting it into the <see cref="Task{FunctionResult}"/> to return from the invocation.
/// </summary>
private static Type GetReturnMarshaler(MethodInfo method, out Func<object?, ValueTask<object?>> marshaler)
private static Type GetReturnMarshaller(MethodInfo method, out Func<object?, ValueTask<object?>> marshaller)
{
// Handle each known return type for the method
Type returnType = method.ReturnType;

// Task
if (returnType == typeof(Task))
{
marshaler = async static result =>
marshaller = async static result =>
{
await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false);
return null;
Expand All @@ -387,7 +378,7 @@ private static Type GetReturnMarshaler(MethodInfo method, out Func<object?, Valu
// ValueTask
if (returnType == typeof(ValueTask))
{
marshaler = async static result =>
marshaller = async static result =>
{
await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false);
return null;
Expand All @@ -401,7 +392,7 @@ private static Type GetReturnMarshaler(MethodInfo method, out Func<object?, Valu
if (returnType.GetGenericTypeDefinition() == typeof(Task<>))
{
MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult);
marshaler = async result =>
marshaller = async result =>
{
await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false);
return ReflectionInvoke(taskResultGetter, result, null);
Expand All @@ -414,7 +405,7 @@ private static Type GetReturnMarshaler(MethodInfo method, out Func<object?, Valu
{
MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask);
MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult);
marshaler = async result =>
marshaller = async result =>
{
var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(result), null)!;
await task.ConfigureAwait(false);
Expand All @@ -425,7 +416,7 @@ private static Type GetReturnMarshaler(MethodInfo method, out Func<object?, Valu
}

// For everything else, just use the result as-is.
marshaler = result => new ValueTask<object?>(result);
marshaller = result => new ValueTask<object?>(result);
return returnType;

// Throws an exception if a result is found to be null unexpectedly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Text.Json;
using Microsoft.Shared.Diagnostics;
Expand All @@ -19,10 +18,8 @@ public sealed class AIFunctionFactoryCreateOptions
/// <summary>
/// Initializes a new instance of the <see cref="AIFunctionFactoryCreateOptions"/> class with default serializer options.
/// </summary>
[RequiresUnreferencedCode(AIFunctionFactory.UsesReflectionJsonSerializerMessage)]
[RequiresDynamicCode(AIFunctionFactory.UsesReflectionJsonSerializerMessage)]
public AIFunctionFactoryCreateOptions()
: this(JsonSerializerOptions.Default)
: this(JsonDefaults.Options)
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
}

Expand Down
13 changes: 10 additions & 3 deletions src/Libraries/Microsoft.Extensions.AI/JsonDefaults.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;

Expand All @@ -30,9 +31,10 @@ private static JsonSerializerOptions CreateDefaultOptions()
// Keep in sync with the JsonSourceGenerationOptions on JsonContext below.
var options = new JsonSerializerOptions(JsonSerializerDefaults.Web)
{
WriteIndented = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
Converters = { new JsonStringEnumConverter() },
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true,
};

options.MakeReadOnly();
Expand All @@ -45,7 +47,10 @@ private static JsonSerializerOptions CreateDefaultOptions()
}

// Keep in sync with CreateDefaultOptions above.
[JsonSourceGenerationOptions(JsonSerializerDefaults.Web, WriteIndented = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)]
[JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
UseStringEnumConverter = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true)]
[JsonSerializable(typeof(IList<ChatMessage>))]
[JsonSerializable(typeof(ChatOptions))]
[JsonSerializable(typeof(EmbeddingGenerationOptions))]
Expand All @@ -57,7 +62,9 @@ private static JsonSerializerOptions CreateDefaultOptions()
[JsonSerializable(typeof(Dictionary<string, object>))]
[JsonSerializable(typeof(IDictionary<int, int>))]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonDocument))]
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(JsonNode))]
[JsonSerializable(typeof(IEnumerable<string>))]
[JsonSerializable(typeof(string))]
[JsonSerializable(typeof(int))]
Expand Down
Loading