Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,18 @@ public static partial class AIJsonUtilities
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
private static JsonSerializerOptions CreateDefaultOptions()
{
// If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
// and we want to be flexible in terms of what can be put into the various collections in the object model.
// Otherwise, use the source-generated options to enable trimming and Native AOT.
JsonSerializerOptions options;
// Copy configuration from the source generated context.
JsonSerializerOptions options = new(JsonContext.Default.Options)
{
Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
};

if (JsonSerializer.IsReflectionEnabledByDefault)
{
// Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext below.
options = new(JsonSerializerDefaults.Web)
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
Converters = { new JsonStringEnumConverter() },
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
WriteIndented = true,
};
}
else
{
options = new(JsonContext.Default.Options)
{
// Compile-time encoder setting not yet available
Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
};
// If reflection-based serialization is enabled by default, use it as a fallback for all other types.
// Also turn on string-based enum serialization for all unknown enums.
options.TypeInfoResolverChain.Add(new DefaultJsonTypeInfoResolver());
options.Converters.Add(new JsonStringEnumConverter());
}

options.MakeReadOnly();
Expand All @@ -83,6 +71,8 @@ private static JsonSerializerOptions CreateDefaultOptions()
[JsonSerializable(typeof(SpeechToTextResponseUpdate))]
[JsonSerializable(typeof(IReadOnlyList<SpeechToTextResponseUpdate>))]
[JsonSerializable(typeof(IList<ChatMessage>))]
[JsonSerializable(typeof(IEnumerable<ChatMessage>))]
[JsonSerializable(typeof(ChatMessage[]))]
[JsonSerializable(typeof(ChatOptions))]
[JsonSerializable(typeof(EmbeddingGenerationOptions))]
[JsonSerializable(typeof(ChatClientMetadata))]
Expand All @@ -95,14 +85,24 @@ private static JsonSerializerOptions CreateDefaultOptions()
[JsonSerializable(typeof(JsonDocument))]
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(JsonNode))]
[JsonSerializable(typeof(JsonObject))]
[JsonSerializable(typeof(JsonValue))]
[JsonSerializable(typeof(JsonArray))]
[JsonSerializable(typeof(IEnumerable<string>))]
[JsonSerializable(typeof(char))]
[JsonSerializable(typeof(string))]
[JsonSerializable(typeof(int))]
[JsonSerializable(typeof(short))]
[JsonSerializable(typeof(long))]
[JsonSerializable(typeof(uint))]
[JsonSerializable(typeof(ushort))]
[JsonSerializable(typeof(ulong))]
[JsonSerializable(typeof(float))]
[JsonSerializable(typeof(double))]
[JsonSerializable(typeof(decimal))]
[JsonSerializable(typeof(bool))]
[JsonSerializable(typeof(TimeSpan))]
[JsonSerializable(typeof(DateTime))]
[JsonSerializable(typeof(DateTimeOffset))]
[JsonSerializable(typeof(Embedding))]
[JsonSerializable(typeof(Embedding<byte>))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,7 @@ static bool IsAsyncMethod(MethodInfo method)
Throw.ArgumentException(nameof(parameter), "Parameter is missing a name.");
}

// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
Type parameterType = parameter.ParameterType;
JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);

// For CancellationToken parameters, we always bind to the token passed directly to InvokeAsync.
if (parameterType == typeof(CancellationToken))
Expand Down Expand Up @@ -530,6 +528,8 @@ static bool IsAsyncMethod(MethodInfo method)
}

// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);
return (arguments, _) =>
{
// If the parameter has an argument specified in the dictionary, return that argument.
Expand Down Expand Up @@ -636,14 +636,22 @@ static bool IsAsyncMethod(MethodInfo method)
if (returnType.GetGenericTypeDefinition() == typeof(Task<>))
{
MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult);
if (marshalResult is not null)
{
return async (taskObj, cancellationToken) =>
{
await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false);
object? result = ReflectionInvoke(taskResultGetter, taskObj, null);
return await marshalResult(result, taskResultGetter.ReturnType, cancellationToken).ConfigureAwait(false);
};
}

returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType);
return async (taskObj, cancellationToken) =>
{
await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false);
object? result = ReflectionInvoke(taskResultGetter, taskObj, null);
return marshalResult is not null ?
await marshalResult(result, returnTypeInfo.Type, cancellationToken).ConfigureAwait(false) :
await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false);
return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false);
};
}

Expand All @@ -652,24 +660,37 @@ await marshalResult(result, returnTypeInfo.Type, cancellationToken).ConfigureAwa
{
MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask);
MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult);

if (marshalResult is not null)
{
return async (taskObj, cancellationToken) =>
{
var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!;
await task.ConfigureAwait(false);
object? result = ReflectionInvoke(asTaskResultGetter, task, null);
return await marshalResult(result, asTaskResultGetter.ReturnType, cancellationToken).ConfigureAwait(false);
};
}

returnTypeInfo = serializerOptions.GetTypeInfo(asTaskResultGetter.ReturnType);
return async (taskObj, cancellationToken) =>
{
var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!;
await task.ConfigureAwait(false);
object? result = ReflectionInvoke(asTaskResultGetter, task, null);
return marshalResult is not null ?
await marshalResult(result, returnTypeInfo.Type, cancellationToken).ConfigureAwait(false) :
await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false);
return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false);
};
}
}

// For everything else, just serialize the result as-is.
if (marshalResult is not null)
{
return (result, cancellationToken) => marshalResult(result, returnType, cancellationToken);
}

returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
return marshalResult is not null ?
(result, cancellationToken) => marshalResult(result, returnTypeInfo.Type, cancellationToken) :
(result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken);
return (result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken);

static async ValueTask<object?> SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ public static void EqualFunctionCallResults(object? expected, object? actual, Js

private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null)
{
options ??= JsonSerializerOptions.Default;
options ??= AIJsonUtilities.DefaultOptions;
JsonElement expectedElement = NormalizeToElement(expected, options);
JsonElement actualElement = NormalizeToElement(actual, options);
if (!JsonNode.DeepEquals(
JsonSerializer.SerializeToNode(expectedElement),
JsonSerializer.SerializeToNode(actualElement)))
JsonSerializer.SerializeToNode(expectedElement, AIJsonUtilities.DefaultOptions),
JsonSerializer.SerializeToNode(actualElement, AIJsonUtilities.DefaultOptions)))
{
string message = propertyName is null
? $"Function result does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public void Serialization_JsonRoundtrips()
public void Serialization_ForJsonSchemaRoundtrips()
{
string json = JsonSerializer.Serialize(
ChatResponseFormat.ForJsonSchema(JsonSerializer.Deserialize<JsonElement>("[1,2,3]"), "name", "description"),
ChatResponseFormat.ForJsonSchema(JsonSerializer.Deserialize<JsonElement>("[1,2,3]", AIJsonUtilities.DefaultOptions), "name", "description"),
TestJsonSerializerContext.Default.ChatResponseFormat);
Assert.Equal("""{"$type":"json","schema":[1,2,3],"schemaName":"name","schemaDescription":"description"}""", json);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public void JsonSerialization_ShouldSerializeAndDeserializeCorrectly()
ErrorCode = "ERR001",
Details = "Something went wrong"
};
var options = new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase };
JsonSerializerOptions options = new(AIJsonUtilities.DefaultOptions) { PropertyNamingPolicy = JsonNamingPolicy.CamelCase };

// Act
var json = JsonSerializer.Serialize(errorContent, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ public static void CreateFromParsedArguments_ObjectJsonInput_ReturnsElementArgum
"""{"Key1":{}, "Key2":null, "Key3" : [], "Key4" : 42, "Key5" : true }""",
"callId",
"functionName",
argumentParser: static json => JsonSerializer.Deserialize<Dictionary<string, object?>>(json));
argumentParser: static json => JsonSerializer.Deserialize<Dictionary<string, object?>>(json, AIJsonUtilities.DefaultOptions));

Assert.NotNull(content);
Assert.Null(content.Exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
</PropertyGroup>

<PropertyGroup Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
<JsonSerializerIsReflectionEnabledByDefault>false</JsonSerializerIsReflectionEnabledByDefault>
</PropertyGroup>

<PropertyGroup>
<InjectDiagnosticAttributesOnLegacy>true</InjectDiagnosticAttributesOnLegacy>
<InjectCompilerFeatureRequiredOnLegacy>true</InjectCompilerFeatureRequiredOnLegacy>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

namespace Microsoft.Extensions.AI;

public static class AIJsonUtilitiesTests
public static partial class AIJsonUtilitiesTests
{
[Fact]
public static void DefaultOptions_HasExpectedConfiguration()
Expand Down Expand Up @@ -53,6 +53,18 @@ public static void DefaultOptions_UsesExpectedEscaping(string input, string expe
Assert.Equal($@"""{expectedJsonString}""", json);
}

[Fact]
public static void DefaultOptions_UsesReflectionWhenDefault()
{
// Reflection is only turned off in .NET Core test environments.
bool isDotnetCore = Type.GetType("System.Half") is not null;
var options = AIJsonUtilities.DefaultOptions;
Type anonType = new { Name = 42 }.GetType();

Assert.Equal(!isDotnetCore, JsonSerializer.IsReflectionEnabledByDefault);
Assert.Equal(JsonSerializer.IsReflectionEnabledByDefault, AIJsonUtilities.DefaultOptions.TryGetTypeInfo(anonType, out _));
}

[Theory]
[InlineData(false)]
[InlineData(true)]
Expand Down Expand Up @@ -145,7 +157,7 @@ public static void CreateJsonSchema_DefaultParameters_GeneratesExpectedJsonSchem
}
""").RootElement;

JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(MyPoco), serializerOptions: JsonSerializerOptions.Default);
JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(MyPoco), serializerOptions: JsonContext.Default.Options);

Assert.True(DeepEquals(expected, actual));
}
Expand Down Expand Up @@ -189,7 +201,7 @@ public static void CreateJsonSchema_OverriddenParameters_GeneratesExpectedJsonSc
description: "alternative description",
hasDefaultValue: true,
defaultValue: null,
serializerOptions: JsonSerializerOptions.Default,
serializerOptions: JsonContext.Default.Options,
inferenceOptions: inferenceOptions);

Assert.True(DeepEquals(expected, actual));
Expand Down Expand Up @@ -235,7 +247,7 @@ public static void CreateJsonSchema_UserDefinedTransformer()
}
};

JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(MyPoco), serializerOptions: JsonSerializerOptions.Default, inferenceOptions: inferenceOptions);
JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(MyPoco), serializerOptions: JsonContext.Default.Options, inferenceOptions: inferenceOptions);

Assert.True(DeepEquals(expected, actual));
}
Expand Down Expand Up @@ -263,7 +275,7 @@ public static void CreateJsonSchema_FiltersDisallowedKeywords()
}
""").RootElement;

JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(PocoWithTypesWithOpenAIUnsupportedKeywords), serializerOptions: JsonSerializerOptions.Default);
JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(PocoWithTypesWithOpenAIUnsupportedKeywords), serializerOptions: JsonContext.Default.Options);

Assert.True(DeepEquals(expected, actual));
}
Expand All @@ -283,7 +295,7 @@ public class PocoWithTypesWithOpenAIUnsupportedKeywords
[Fact]
public static void CreateFunctionJsonSchema_ReturnsExpectedValue()
{
JsonSerializerOptions options = new(JsonSerializerOptions.Default);
JsonSerializerOptions options = new(AIJsonUtilities.DefaultOptions);
AIFunction func = AIFunctionFactory.Create((int x, int y) => x + y, serializerOptions: options);

Assert.NotNull(func.UnderlyingMethod);
Expand All @@ -295,7 +307,7 @@ public static void CreateFunctionJsonSchema_ReturnsExpectedValue()
[Fact]
public static void CreateFunctionJsonSchema_TreatsIntegralTypesAsInteger_EvenWithAllowReadingFromString()
{
JsonSerializerOptions options = new(JsonSerializerOptions.Default) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
JsonSerializerOptions options = new(AIJsonUtilities.DefaultOptions) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
AIFunction func = AIFunctionFactory.Create((int a, int? b, long c, short d, float e, double f, decimal g) => { }, serializerOptions: options);

JsonElement schemaParameters = func.JsonSchema.GetProperty("properties");
Expand Down Expand Up @@ -376,7 +388,11 @@ public static void CreateJsonSchema_ValidateWithTestData(ITestData testData)
[Fact]
public static void AddAIContentType_DerivedAIContent()
{
JsonSerializerOptions options = new();
JsonSerializerOptions options = new()
{
TypeInfoResolver = JsonTypeInfoResolver.Combine(AIJsonUtilities.DefaultOptions.TypeInfoResolver, JsonContext.Default),
};

options.AddAIContentType<DerivedAIContent>("derivativeContent");

AIContent c = new DerivedAIContent { DerivedValue = 42 };
Expand Down Expand Up @@ -465,7 +481,7 @@ public static void CreateFunctionJsonSchema_InvokesIncludeParameterCallbackForEv
{
names.Add(p.Name);
return p.Name is "first" or "fifth";
}
},
});

Assert.Equal(["first", "second", "third", "fifth"], names);
Expand All @@ -483,14 +499,19 @@ private class DerivedAIContent : AIContent
public int DerivedValue { get; set; }
}

[JsonSerializable(typeof(DerivedAIContent))]
[JsonSerializable(typeof(MyPoco))]
[JsonSerializable(typeof(PocoWithTypesWithOpenAIUnsupportedKeywords))]
private partial class JsonContext : JsonSerializerContext;

private static bool DeepEquals(JsonElement element1, JsonElement element2)
{
#if NET9_0_OR_GREATER
return JsonElement.DeepEquals(element1, element2);
#else
return JsonNode.DeepEquals(
JsonSerializer.SerializeToNode(element1),
JsonSerializer.SerializeToNode(element2));
JsonSerializer.SerializeToNode(element1, AIJsonUtilities.DefaultOptions),
JsonSerializer.SerializeToNode(element2, AIJsonUtilities.DefaultOptions));
#endif
}
}
Loading
Loading