diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs index a9d3ac3e3ee..c77e7dffb5b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs @@ -289,24 +289,49 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js objSchema.InsertAtStart(TypePropertyName, "string"); } - // Include the type keyword in nullable enum types - if (Nullable.GetUnderlyingType(ctx.TypeInfo.Type)?.IsEnum is true && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName)) - { - objSchema.InsertAtStart(TypePropertyName, new JsonArray { (JsonNode)"string", (JsonNode)"null" }); - } - // Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand // schemas with "type": [...], and only understand "type" being a single value. // In certain configurations STJ represents .NET numeric types as ["string", "number"], which will then lead to an error. - if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema, out string? numericType)) + if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema, out string? numericType, out bool isNullable)) { // We don't want to emit any array for "type". In this case we know it contains "integer" or "number", // so reduce the type to that alone, assuming it's the most specific type. // This makes schemas for Int32 (etc) work with Ollama. JsonObject obj = ConvertSchemaToObject(ref schema); - obj[TypePropertyName] = numericType; + if (isNullable) + { + // If the type is nullable, we still need use a type array + obj[TypePropertyName] = new JsonArray { (JsonNode)numericType, (JsonNode)"null" }; + } + else + { + obj[TypePropertyName] = (JsonNode)numericType; + } + _ = obj.Remove(PatternPropertyName); } + + if (Nullable.GetUnderlyingType(ctx.TypeInfo.Type) is Type nullableElement) + { + // Account for bug https://github.com/dotnet/runtime/issues/117493 + // To be removed once System.Text.Json v10 becomes the lowest supported version. + // null not inserted in the type keyword for root-level Nullable types. + if (objSchema.TryGetPropertyValue(TypePropertyName, out JsonNode? typeKeyWord) && + typeKeyWord?.GetValueKind() is JsonValueKind.String) + { + string typeValue = typeKeyWord.GetValue()!; + if (typeValue is not "null") + { + objSchema[TypePropertyName] = new JsonArray { (JsonNode)typeValue, (JsonNode)"null" }; + } + } + + // Include the type keyword in nullable enum types + if (nullableElement.IsEnum && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName)) + { + objSchema.InsertAtStart(TypePropertyName, new JsonArray { (JsonNode)"string", (JsonNode)"null" }); + } + } } if (ctx.Path.IsEmpty && hasDefaultValue) @@ -601,11 +626,12 @@ static JsonArray CreateJsonArray(object?[] values, JsonSerializerOptions seriali } } - private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateContext ctx, JsonObject schema, [NotNullWhen(true)] out string? numericType) + private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateContext ctx, JsonObject schema, [NotNullWhen(true)] out string? numericType, out bool isNullable) { numericType = null; + isNullable = false; - if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray { Count: 2 } typeArray) + if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray typeArray) { bool allowString = false; @@ -617,11 +643,23 @@ private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateCont switch (type) { case "integer" or "number": + if (numericType is not null) + { + // Conflicting numeric type + return false; + } + numericType = type; break; case "string": allowString = true; break; + case "null": + isNullable = true; + break; + default: + // keyword is not valid in the context of numeric types. + return false; } } } @@ -665,7 +703,7 @@ private static JsonElement ParseJsonElement(ReadOnlySpan utf8Json) if (defaultValue is null || (defaultValue == DBNull.Value && parameterType != typeof(DBNull))) { - return parameterType.IsValueType + return parameterType.IsValueType && Nullable.GetUnderlyingType(parameterType) is null #if NET ? RuntimeHelpers.GetUninitializedObject(parameterType) #else diff --git a/src/Shared/JsonSchemaExporter/JsonSchemaExporter.ReflectionHelpers.cs b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.ReflectionHelpers.cs index 481e5f75753..6d350dab026 100644 --- a/src/Shared/JsonSchemaExporter/JsonSchemaExporter.ReflectionHelpers.cs +++ b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.ReflectionHelpers.cs @@ -31,8 +31,6 @@ private static class ReflectionHelpers public static bool IsBuiltInConverter(JsonConverter converter) => converter.GetType().Assembly == typeof(JsonConverter).Assembly; - public static bool CanBeNull(Type type) => !type.IsValueType || Nullable.GetUnderlyingType(type) is not null; - public static Type GetElementType(JsonTypeInfo typeInfo) { Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.Enumerable or JsonTypeInfoKind.Dictionary, "TypeInfo must be of collection type"); diff --git a/src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs index 2d8ffc5497c..d651ce6a727 100644 --- a/src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs +++ b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs @@ -452,20 +452,24 @@ JsonSchema CompleteSchema(ref GenerationState state, JsonSchema schema) bool IsNullableSchema(ref GenerationState state) { - // A schema is marked as nullable if either + // A schema is marked as nullable if either: // 1. We have a schema for a property where either the getter or setter are marked as nullable. - // 2. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable + // 2. We have a schema for a Nullable type. + // 3. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable. if (propertyInfo != null || parameterInfo != null) { return !isNonNullableType; } - else + + if (Nullable.GetUnderlyingType(typeInfo.Type) is not null) { - return ReflectionHelpers.CanBeNull(typeInfo.Type) && - !parentPolymorphicTypeIsNonNullable && - !state.ExporterOptions.TreatNullObliviousAsNonNullable; + return true; } + + return !typeInfo.Type.IsValueType && + !parentPolymorphicTypeIsNonNullable && + !state.ExporterOptions.TreatNullObliviousAsNonNullable; } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AssertExtensions.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AssertExtensions.cs index 72985108c6e..6361fe7817e 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AssertExtensions.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AssertExtensions.cs @@ -53,21 +53,29 @@ public static void EqualFunctionCallParameters( public static void EqualFunctionCallResults(object? expected, object? actual, JsonSerializerOptions? options = null) => AreJsonEquivalentValues(expected, actual, options); - private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null) + /// + /// Asserts that the two JSON values are equal. + /// + public static void EqualJsonValues(JsonElement expectedJson, JsonElement actualJson, string? propertyName = null) { - options ??= AIJsonUtilities.DefaultOptions; - JsonElement expectedElement = NormalizeToElement(expected, options); - JsonElement actualElement = NormalizeToElement(actual, options); if (!JsonNode.DeepEquals( - JsonSerializer.SerializeToNode(expectedElement, AIJsonUtilities.DefaultOptions), - JsonSerializer.SerializeToNode(actualElement, AIJsonUtilities.DefaultOptions))) + JsonSerializer.SerializeToNode(expectedJson, AIJsonUtilities.DefaultOptions), + JsonSerializer.SerializeToNode(actualJson, AIJsonUtilities.DefaultOptions))) { string message = propertyName is null - ? $"Function result does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}" - : $"Parameter '{propertyName}' does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}"; + ? $"JSON result does not match expected JSON.\r\nExpected: {expectedJson.GetRawText()}\r\nActual: {actualJson.GetRawText()}" + : $"Parameter '{propertyName}' does not match expected JSON.\r\nExpected: {expectedJson.GetRawText()}\r\nActual: {actualJson.GetRawText()}"; throw new XunitException(message); } + } + + private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null) + { + options ??= AIJsonUtilities.DefaultOptions; + JsonElement expectedElement = NormalizeToElement(expected, options); + JsonElement actualElement = NormalizeToElement(actual, options); + EqualJsonValues(expectedElement, actualElement, propertyName); static JsonElement NormalizeToElement(object? value, JsonSerializerOptions options) => value is JsonElement e ? e : JsonSerializer.SerializeToElement(value, options); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs index 19b2fc8bb48..c2177486fea 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs @@ -354,13 +354,21 @@ public static void CreateFunctionJsonSchema_TreatsIntegralTypesAsInteger_EvenWit int i = 0; foreach (JsonProperty property in schemaParameters.EnumerateObject()) { - string numericType = Type.GetTypeCode(parameters[i].ParameterType) is TypeCode.Double or TypeCode.Single or TypeCode.Decimal - ? "number" - : "integer"; + bool isNullable = false; + Type type = parameters[i].ParameterType; + if (Nullable.GetUnderlyingType(type) is { } elementType) + { + type = elementType; + isNullable = true; + } + + string numericType = Type.GetTypeCode(type) is TypeCode.Double or TypeCode.Single or TypeCode.Decimal + ? "\"number\"" + : "\"integer\""; JsonElement expected = JsonDocument.Parse($$""" { - "type": "{{numericType}}" + "type": {{(isNullable ? $"[{numericType}, \"null\"]" : numericType)}} } """).RootElement; diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs index 8c0c7d057a6..69787dc868b 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.ComponentModel; +using System.Linq; using System.Reflection; using System.Text.Json; using System.Text.Json.Nodes; @@ -854,6 +855,71 @@ public async Task AIFunctionFactory_DefaultDefaultParameter() Assert.Contains("00000000-0000-0000-0000-000000000000,0", result?.ToString()); } + [Fact] + public async Task AIFunctionFactory_NullableParameters() + { + Assert.NotEqual(new StructWithDefaultCtor().Value, default(StructWithDefaultCtor).Value); + + AIFunction f = AIFunctionFactory.Create( + (int? limit = null, DateTime? from = null) => Enumerable.Repeat(from ?? default, limit ?? 4).Select(d => d.Year).ToArray(), + serializerOptions: JsonContext.Default.Options); + + JsonElement expectedSchema = JsonDocument.Parse(""" + { + "type": "object", + "properties": { + "limit": { + "type": ["integer", "null"], + "default": null + }, + "from": { + "type": ["string", "null"], + "format": "date-time", + "default": null + } + } + } + """).RootElement; + + AssertExtensions.EqualJsonValues(expectedSchema, f.JsonSchema); + + object? result = await f.InvokeAsync(); + Assert.Contains("[1,1,1,1]", result?.ToString()); + } + + [Fact] + public async Task AIFunctionFactory_NullableParameters_AllowReadingFromString() + { + JsonSerializerOptions options = new(JsonContext.Default.Options) { NumberHandling = JsonNumberHandling.AllowReadingFromString }; + Assert.NotEqual(new StructWithDefaultCtor().Value, default(StructWithDefaultCtor).Value); + + AIFunction f = AIFunctionFactory.Create( + (int? limit = null, DateTime? from = null) => Enumerable.Repeat(from ?? default, limit ?? 4).Select(d => d.Year).ToArray(), + serializerOptions: options); + + JsonElement expectedSchema = JsonDocument.Parse(""" + { + "type": "object", + "properties": { + "limit": { + "type": ["integer", "null"], + "default": null + }, + "from": { + "type": ["string", "null"], + "format": "date-time", + "default": null + } + } + } + """).RootElement; + + AssertExtensions.EqualJsonValues(expectedSchema, f.JsonSchema); + + object? result = await f.InvokeAsync(); + Assert.Contains("[1,1,1,1]", result?.ToString()); + } + [Fact] public void AIFunctionFactory_ReturnTypeWithDescriptionAttribute() { @@ -959,5 +1025,7 @@ private static AIFunctionFactoryOptions CreateKeyedServicesSupportOptions() => [JsonSerializable(typeof(Guid))] [JsonSerializable(typeof(StructWithDefaultCtor))] [JsonSerializable(typeof(B))] + [JsonSerializable(typeof(int?))] + [JsonSerializable(typeof(DateTime?))] private partial class JsonContext : JsonSerializerContext; } diff --git a/test/Shared/JsonSchemaExporter/TestData.cs b/test/Shared/JsonSchemaExporter/TestData.cs index 26902bfe0db..7c7cc7fc9a7 100644 --- a/test/Shared/JsonSchemaExporter/TestData.cs +++ b/test/Shared/JsonSchemaExporter/TestData.cs @@ -13,7 +13,9 @@ internal sealed record TestData( T? Value, [StringSyntax(StringSyntaxAttribute.Json)] string ExpectedJsonSchema, IEnumerable? AdditionalValues = null, - object? ExporterOptions = null, +#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL + System.Text.Json.Schema.JsonSchemaExporterOptions? ExporterOptions = null, +#endif JsonSerializerOptions? Options = null, bool WritesNumbersAsStrings = false) : ITestData @@ -22,7 +24,9 @@ internal sealed record TestData( public Type Type => typeof(T); object? ITestData.Value => Value; +#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL object? ITestData.ExporterOptions => ExporterOptions; +#endif JsonNode ITestData.ExpectedJsonSchema { get; } = JsonNode.Parse(ExpectedJsonSchema, documentOptions: _schemaParseOptions) ?? throw new ArgumentNullException("schema must not be null"); @@ -32,7 +36,7 @@ IEnumerable ITestData.GetTestDataForAllValues() yield return this; if (default(T) is null && -#if NET9_0_OR_GREATER +#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL ExporterOptions is System.Text.Json.Schema.JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable: false } && #endif Value is not null) @@ -58,7 +62,9 @@ public interface ITestData JsonNode ExpectedJsonSchema { get; } +#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL object? ExporterOptions { get; } +#endif JsonSerializerOptions? Options { get; } diff --git a/test/Shared/JsonSchemaExporter/TestTypes.cs b/test/Shared/JsonSchemaExporter/TestTypes.cs index 7cfd0ce45be..794e58fa2b8 100644 --- a/test/Shared/JsonSchemaExporter/TestTypes.cs +++ b/test/Shared/JsonSchemaExporter/TestTypes.cs @@ -9,12 +9,9 @@ using System.ComponentModel.DataAnnotations; using System.Diagnostics.CodeAnalysis; using System.Linq; -#if NET9_0_OR_GREATER -using System.Reflection; -#endif using System.Text.Json; using System.Text.Json.Nodes; -#if NET9_0_OR_GREATER +#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL using System.Text.Json.Schema; #endif using System.Text.Json.Serialization; @@ -135,6 +132,21 @@ public static IEnumerable GetTestDataCore() } """); +#if !NET9_0 && TESTS_JSON_SCHEMA_EXPORTER_POLYFILL + // Regression test for https://github.com/dotnet/runtime/issues/117493 + yield return new TestData( + Value: 42, + AdditionalValues: [null], + ExpectedJsonSchema: """{"type":["integer","null"]}""", + ExporterOptions: new() { TreatNullObliviousAsNonNullable = true }); + + yield return new TestData( + Value: DateTimeOffset.MinValue, + AdditionalValues: [null], + ExpectedJsonSchema: """{"type":["string","null"],"format":"date-time"}""", + ExporterOptions: new() { TreatNullObliviousAsNonNullable = true }); +#endif + // User-defined POCOs yield return new TestData( Value: new() { String = "string", StringNullable = "string", Int = 42, Double = 3.14, Boolean = true }, @@ -152,7 +164,7 @@ public static IEnumerable GetTestDataCore() } """); -#if NET9_0_OR_GREATER +#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL // Same as above but with nullable types set to non-nullable yield return new TestData( Value: new() { String = "string", StringNullable = "string", Int = 42, Double = 3.14, Boolean = true }, @@ -311,7 +323,7 @@ public static IEnumerable GetTestDataCore() } """); -#if NET9_0_OR_GREATER +#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL // Same as above but with non-nullable reference types by default. yield return new TestData( Value: new() { Value = 1, Next = new() { Value = 2, Next = new() { Value = 3 } } }, @@ -761,7 +773,7 @@ of the type which points to the first occurrence. */ } """); -#if NET9_0_OR_GREATER +#if TEST yield return new TestData( Value: new("string", -1), ExpectedJsonSchema: """ @@ -1164,7 +1176,7 @@ public readonly struct StructDictionary(IEnumerable _dictionary.Count; public bool ContainsKey(TKey key) => _dictionary.ContainsKey(key); public IEnumerator> GetEnumerator() => _dictionary.GetEnumerator(); -#if NETCOREAPP +#if NET public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value) => _dictionary.TryGetValue(key, out value); #else public bool TryGetValue(TKey key, out TValue value) => _dictionary.TryGetValue(key, out value); @@ -1249,6 +1261,7 @@ public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions [JsonSerializable(typeof(IntEnum?))] [JsonSerializable(typeof(StringEnum?))] [JsonSerializable(typeof(SimpleRecordStruct?))] + [JsonSerializable(typeof(DateTimeOffset?))] // User-defined POCOs [JsonSerializable(typeof(SimplePoco))] [JsonSerializable(typeof(SimpleRecord))] @@ -1299,22 +1312,4 @@ public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions [JsonSerializable(typeof(StructDictionary))] [JsonSerializable(typeof(XElement))] public partial class TestTypesContext : JsonSerializerContext; - -#if NET9_0_OR_GREATER - private static TAttribute? ResolveAttribute(this JsonSchemaExporterContext ctx) - where TAttribute : Attribute - { - // Resolve attributes from locations in the following order: - // 1. Property-level attributes - // 2. Parameter-level attributes and - // 3. Type-level attributes. - return - GetAttrs(ctx.PropertyInfo?.AttributeProvider) ?? - GetAttrs(ctx.PropertyInfo?.AssociatedParameter?.AttributeProvider) ?? - GetAttrs(ctx.TypeInfo.Type); - - static TAttribute? GetAttrs(ICustomAttributeProvider? provider) => - (TAttribute?)provider?.GetCustomAttributes(typeof(TAttribute), inherit: false).FirstOrDefault(); - } -#endif }