diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs index ea1f393f7e5..3a9c99c2e72 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs @@ -4,12 +4,14 @@ using System; using System.Text.Json.Nodes; +#pragma warning disable S1067 // Expressions should not be too complex + namespace Microsoft.Extensions.AI; /// /// Provides options for configuring the behavior of JSON schema creation functionality. /// -public sealed class AIJsonSchemaCreateOptions +public sealed class AIJsonSchemaCreateOptions : IEquatable { /// /// Gets the default options instance. @@ -40,4 +42,21 @@ public sealed class AIJsonSchemaCreateOptions /// Gets a value indicating whether to mark all properties as required in the schema. /// public bool RequireAllProperties { get; init; } = true; + + /// + public bool Equals(AIJsonSchemaCreateOptions? other) + { + return other is not null && + TransformSchemaNode == other.TransformSchemaNode && + IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas && + DisallowAdditionalProperties == other.DisallowAdditionalProperties && + IncludeSchemaKeyword == other.IncludeSchemaKeyword && + RequireAllProperties == other.RequireAllProperties; + } + + /// + public override bool Equals(object? obj) => obj is AIJsonSchemaCreateOptions other && Equals(other); + + /// + public override int GetHashCode() => (TransformSchemaNode, IncludeTypeInEnumSchemas, DisallowAdditionalProperties, IncludeSchemaKeyword, RequireAllProperties).GetHashCode(); } diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs index 251059035db..cbafe78e5d3 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs @@ -1,6 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; +using System.Buffers; +using System.IO; +using System.Reflection; using System.Text.RegularExpressions; using Microsoft.Shared.Diagnostics; @@ -30,4 +34,104 @@ internal static string SanitizeMemberName(string memberName) private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex; private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled); #endif + + /// Invokes the MethodInfo with the specified target object and arguments. + private static object? ReflectionInvoke(MethodInfo method, object? target, object?[]? arguments) + { +#if NET + return method.Invoke(target, BindingFlags.DoNotWrapExceptions, binder: null, arguments, culture: null); +#else + try + { + return method.Invoke(target, BindingFlags.Default, binder: null, arguments, culture: null); + } + catch (TargetInvocationException e) when (e.InnerException is not null) + { + // If we're targeting .NET Framework, such that BindingFlags.DoNotWrapExceptions + // is ignored, the original exception will be wrapped in a TargetInvocationException. + // Unwrap it and throw that original exception, maintaining its stack information. + System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw(); + throw; + } +#endif + } + + /// + /// Implements a simple write-only memory stream that uses pooled buffers. + /// + private sealed class PooledMemoryStream : Stream + { + private const int DefaultBufferSize = 4096; + private byte[] _buffer; + private int _position; + + public PooledMemoryStream(int initialCapacity = DefaultBufferSize) + { + _buffer = ArrayPool.Shared.Rent(initialCapacity); + _position = 0; + } + + public ReadOnlySpan GetBuffer() => _buffer.AsSpan(0, _position); + public override bool CanWrite => true; + public override bool CanRead => false; + public override bool CanSeek => false; + public override long Length => _position; + public override long Position + { + get => _position; + set => throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + EnsureNotDisposed(); + EnsureCapacity(_position + count); + + Buffer.BlockCopy(buffer, offset, _buffer, _position, count); + _position += count; + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + protected override void Dispose(bool disposing) + { + if (_buffer is not null) + { + ArrayPool.Shared.Return(_buffer); + _buffer = null!; + } + + base.Dispose(disposing); + } + + private void EnsureCapacity(int requiredCapacity) + { + if (requiredCapacity <= _buffer.Length) + { + return; + } + + int newCapacity = Math.Max(requiredCapacity, _buffer.Length * 2); + byte[] newBuffer = ArrayPool.Shared.Rent(newCapacity); + Buffer.BlockCopy(_buffer, 0, newBuffer, 0, _position); + + ArrayPool.Shared.Return(_buffer); + _buffer = newBuffer; + } + + private void EnsureNotDisposed() + { + if (_buffer is null) + { + Throw(); + static void Throw() => throw new ObjectDisposedException(nameof(PooledMemoryStream)); + } + } + } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index 50a5afd14e7..0aff0901c7a 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -2,12 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; -using System.IO; using System.Linq; using System.Reflection; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization.Metadata; @@ -42,7 +43,7 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? optio { _ = Throw.IfNull(method); - return new ReflectionAIFunction(method.Method, method.Target, options ?? _defaultOptions); + return ReflectionAIFunction.Build(method.Method, method.Target, options ?? _defaultOptions); } /// Creates an instance for a method, specified via a delegate. @@ -68,12 +69,12 @@ public static AIFunction Create(Delegate method, string? name = null, string? de ? _defaultOptions : new() { - SerializerOptions = serializerOptions ?? _defaultOptions.SerializerOptions, Name = name, - Description = description + Description = description, + SerializerOptions = serializerOptions, }; - return new ReflectionAIFunction(method.Method, method.Target, createOptions); + return ReflectionAIFunction.Build(method.Method, method.Target, createOptions); } /// @@ -100,7 +101,7 @@ public static AIFunction Create(Delegate method, string? name = null, string? de public static AIFunction Create(MethodInfo method, object? target, AIFunctionFactoryOptions? options) { _ = Throw.IfNull(method); - return new ReflectionAIFunction(method, target, options ?? _defaultOptions); + return ReflectionAIFunction.Build(method, target, options ?? _defaultOptions); } /// @@ -129,44 +130,23 @@ public static AIFunction Create(MethodInfo method, object? target, string? name { _ = Throw.IfNull(method); - AIFunctionFactoryOptions? createOptions = serializerOptions is null && name is null && description is null + AIFunctionFactoryOptions createOptions = serializerOptions is null && name is null && description is null ? _defaultOptions : new() { - SerializerOptions = serializerOptions ?? _defaultOptions.SerializerOptions, Name = name, - Description = description + Description = description, + SerializerOptions = serializerOptions, }; - return new ReflectionAIFunction(method, target, createOptions); + return ReflectionAIFunction.Build(method, target, createOptions); } private sealed class ReflectionAIFunction : AIFunction { - private readonly MethodInfo _method; - private readonly object? _target; - private readonly Func, AIFunctionContext?, object?>[] _parameterMarshallers; - private readonly Func> _returnMarshaller; - private readonly JsonTypeInfo? _returnTypeInfo; - private readonly bool _needsAIFunctionContext; - - /// - /// Initializes a new instance of the class for a method, specified via an instance - /// and an optional target object if the method is an instance method. - /// - /// The method to be represented via the created . - /// - /// The target object for the if it represents an instance method. - /// This should be if and only if is a static method. - /// - /// Function creation options. - public ReflectionAIFunction(MethodInfo method, object? target, AIFunctionFactoryOptions options) + public static ReflectionAIFunction Build(MethodInfo method, object? target, AIFunctionFactoryOptions options) { _ = Throw.IfNull(method); - _ = Throw.IfNull(options); - - JsonSerializerOptions serializerOptions = options.SerializerOptions ?? AIJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); if (method.ContainsGenericParameters) { @@ -178,86 +158,37 @@ public ReflectionAIFunction(MethodInfo method, object? target, AIFunctionFactory Throw.ArgumentNullException(nameof(target), "Target must not be null for an instance method."); } - _method = method; - _target = target; - - // Get the function name to use. - string? functionName = options.Name; - if (functionName is null) - { - functionName = SanitizeMemberName(method.Name!); - - const string AsyncSuffix = "Async"; - if (IsAsyncMethod(method) && - functionName.EndsWith(AsyncSuffix, StringComparison.Ordinal) && - functionName.Length > AsyncSuffix.Length) - { - functionName = functionName.Substring(0, functionName.Length - AsyncSuffix.Length); - } - - static bool IsAsyncMethod(MethodInfo method) - { - Type t = method.ReturnType; - - if (t == typeof(Task) || t == typeof(ValueTask)) - { - return true; - } - - if (t.IsGenericType) - { - t = t.GetGenericTypeDefinition(); - if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) - { - return true; - } - } - - return false; - } - } + ReflectionAIFunctionDescriptor functionDescriptor = ReflectionAIFunctionDescriptor.GetOrCreate(method, options); - // Get marshaling delegates for parameters. - ParameterInfo[] parameters = method.GetParameters(); - _parameterMarshallers = new Func, AIFunctionContext?, object?>[parameters.Length]; - bool sawAIContextParameter = false; - for (int i = 0; i < parameters.Length; i++) + if (target is null && options.AdditionalProperties is null) { - _parameterMarshallers[i] = GetParameterMarshaller(serializerOptions, parameters[i], ref sawAIContextParameter); + // We can use a cached value for static methods not specifying additional properties. + return functionDescriptor.CachedDefaultInstance ??= new(functionDescriptor, target, options); } - _needsAIFunctionContext = sawAIContextParameter; - - // Get the return type and a marshaling func for the return value. - _returnMarshaller = GetReturnMarshaller(method, out Type returnType); - _returnTypeInfo = returnType != typeof(void) ? serializerOptions.GetTypeInfo(returnType) : null; + return new(functionDescriptor, target, options); + } - Name = functionName; - Description = options.Description ?? method.GetCustomAttribute(inherit: true)?.Description ?? string.Empty; - UnderlyingMethod = method; + private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, object? target, AIFunctionFactoryOptions options) + { + FunctionDescriptor = functionDescriptor; + Target = target; AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary.Instance; - JsonSerializerOptions = serializerOptions; - JsonSchema = AIJsonUtilities.CreateFunctionJsonSchema( - method, - title: Name, - description: Description, - options.SerializerOptions, - options.JsonSchemaCreateOptions); } - public override string Name { get; } - public override string Description { get; } - public override MethodInfo? UnderlyingMethod { get; } + public ReflectionAIFunctionDescriptor FunctionDescriptor { get; } + public object? Target { get; } public override IReadOnlyDictionary AdditionalProperties { get; } - public override JsonSerializerOptions JsonSerializerOptions { get; } - public override JsonElement JsonSchema { get; } - - /// - protected override async Task InvokeCoreAsync( + public override string Name => FunctionDescriptor.Name; + public override string Description => FunctionDescriptor.Description; + public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method; + public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema; + public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions; + protected override Task InvokeCoreAsync( IEnumerable>? arguments, CancellationToken cancellationToken) { - var paramMarshallers = _parameterMarshallers; + var paramMarshallers = FunctionDescriptor.ParameterMarshallers; object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : []; IReadOnlyDictionary argDict = @@ -269,7 +200,7 @@ static bool IsAsyncMethod(MethodInfo method) #else ToDictionary(kvp => kvp.Key, kvp => kvp.Value); #endif - AIFunctionContext? context = _needsAIFunctionContext ? + AIFunctionContext? context = FunctionDescriptor.RequiresAIFunctionContext ? new() { CancellationToken = cancellationToken } : null; @@ -278,30 +209,111 @@ static bool IsAsyncMethod(MethodInfo method) args[i] = paramMarshallers[i](argDict, context); } - object? result = await _returnMarshaller(ReflectionInvoke(_method, _target, args)).ConfigureAwait(false); + return FunctionDescriptor.ReturnParameterMarshaller(ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken); + } + } - switch (_returnTypeInfo) + /// + /// A descriptor for a .NET method-backed AIFunction that precomputes its marshalling delegates and JSON schema. + /// + private sealed class ReflectionAIFunctionDescriptor + { + private const int InnerCacheSoftLimit = 512; + private static readonly ConditionalWeakTable> _descriptorCache = new(); + + /// + /// Gets or creates a descriptors using the specified method and options. + /// + public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFunctionFactoryOptions options) + { + JsonSerializerOptions serializerOptions = options.SerializerOptions ?? AIJsonUtilities.DefaultOptions; + AIJsonSchemaCreateOptions schemaOptions = options.JsonSchemaCreateOptions ?? AIJsonSchemaCreateOptions.Default; + serializerOptions.MakeReadOnly(); + ConcurrentDictionary innerCache = _descriptorCache.GetOrCreateValue(serializerOptions); + + DescriptorKey key = new(method, options.Name, options.Description, schemaOptions); + if (innerCache.TryGetValue(key, out ReflectionAIFunctionDescriptor? descriptor)) { - case null: - Debug.Assert( - UnderlyingMethod?.ReturnType == typeof(void) || - UnderlyingMethod?.ReturnType == typeof(Task) || - UnderlyingMethod?.ReturnType == typeof(ValueTask), "The return parameter should be void or non-generic task."); + return descriptor; + } - return null; + descriptor = new(key, serializerOptions); + return innerCache.Count < InnerCacheSoftLimit + ? innerCache.GetOrAdd(key, descriptor) + : descriptor; + } - case { Kind: JsonTypeInfoKind.None }: - // Special-case trivial contracts to avoid the more expensive general-purpose serialization path. - return JsonSerializer.SerializeToElement(result, _returnTypeInfo); + private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions) + { + // Get marshaling delegates for parameters. + ParameterInfo[] parameters = key.Method.GetParameters(); + ParameterMarshallers = new Func, AIFunctionContext?, object?>[parameters.Length]; + bool foundAIFunctionContextParameter = false; + for (int i = 0; i < parameters.Length; i++) + { + ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, parameters[i], ref foundAIFunctionContextParameter); + } + + // Get a marshaling delegate for the return value. + ReturnParameterMarshaller = GetReturnParameterMarshaller(key.Method, serializerOptions); - default: + Method = key.Method; + Name = key.Name ?? GetFunctionName(key.Method); + Description = key.Description ?? key.Method.GetCustomAttribute(inherit: true)?.Description ?? string.Empty; + RequiresAIFunctionContext = foundAIFunctionContextParameter; + JsonSerializerOptions = serializerOptions; + JsonSchema = AIJsonUtilities.CreateFunctionJsonSchema( + key.Method, + Name, + Description, + serializerOptions, + key.SchemaOptions); + } + + public string Name { get; } + public string Description { get; } + public MethodInfo Method { get; } + public JsonSerializerOptions JsonSerializerOptions { get; } + public JsonElement JsonSchema { get; } + public Func, AIFunctionContext?, object?>[] ParameterMarshallers { get; } + public Func> ReturnParameterMarshaller { get; } + public bool RequiresAIFunctionContext { get; } + public ReflectionAIFunction? CachedDefaultInstance { get; set; } + + private static string GetFunctionName(MethodInfo method) + { + // Get the function name to use. + string name = SanitizeMemberName(method.Name); + + const string AsyncSuffix = "Async"; + if (IsAsyncMethod(method) && + name.EndsWith(AsyncSuffix, StringComparison.Ordinal) && + name.Length > AsyncSuffix.Length) + { + name = name.Substring(0, name.Length - AsyncSuffix.Length); + } + + return name; + + static bool IsAsyncMethod(MethodInfo method) + { + Type t = method.ReturnType; + + if (t == typeof(Task) || t == typeof(ValueTask)) { - // Serialize asynchronously to support potential IAsyncEnumerable responses. - using MemoryStream stream = new(); - await JsonSerializer.SerializeAsync(stream, result, _returnTypeInfo, cancellationToken).ConfigureAwait(false); - Utf8JsonReader reader = new(stream.GetBuffer().AsSpan(0, (int)stream.Length)); - return JsonElement.ParseValue(ref reader); + return true; } + + if (t.IsGenericType) + { + t = t.GetGenericTypeDefinition(); + if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) + { + return true; + } + } + + return false; } } @@ -311,7 +323,7 @@ static bool IsAsyncMethod(MethodInfo method) private static Func, AIFunctionContext?, object?> GetParameterMarshaller( JsonSerializerOptions serializerOptions, ParameterInfo parameter, - ref bool sawAIFunctionContext) + ref bool foundAIFunctionContextParameter) { if (string.IsNullOrWhiteSpace(parameter.Name)) { @@ -321,12 +333,12 @@ static bool IsAsyncMethod(MethodInfo method) // Special-case an AIFunctionContext parameter. if (parameter.ParameterType == typeof(AIFunctionContext)) { - if (sawAIFunctionContext) + if (foundAIFunctionContextParameter) { Throw.ArgumentException(nameof(parameter), $"Only one {nameof(AIFunctionContext)} parameter is permitted."); } - sawAIFunctionContext = true; + foundAIFunctionContextParameter = true; return static (_, ctx) => { @@ -386,16 +398,21 @@ static bool IsAsyncMethod(MethodInfo method) /// /// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation. /// - private static Func> GetReturnMarshaller(MethodInfo method, out Type returnType) + private static Func> GetReturnParameterMarshaller(MethodInfo method, JsonSerializerOptions serializerOptions) { - // Handle each known return type for the method - returnType = method.ReturnType; + Type returnType = method.ReturnType; + JsonTypeInfo returnTypeInfo; + + // Void + if (returnType == typeof(void)) + { + return static (_, _) => Task.FromResult(null); + } // Task if (returnType == typeof(Task)) { - returnType = typeof(void); - return async static result => + return async static (result, _) => { await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); return null; @@ -405,8 +422,7 @@ static bool IsAsyncMethod(MethodInfo method) // ValueTask if (returnType == typeof(ValueTask)) { - returnType = typeof(void); - return async static result => + return async static (result, _) => { await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false); return null; @@ -419,11 +435,12 @@ static bool IsAsyncMethod(MethodInfo method) if (returnType.GetGenericTypeDefinition() == typeof(Task<>)) { MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult); - returnType = taskResultGetter.ReturnType; - return async result => + returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType); + return async (taskObj, cancellationToken) => { - await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); - return ReflectionInvoke(taskResultGetter, result, null); + await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false); + object? result = ReflectionInvoke(taskResultGetter, taskObj, null); + return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); }; } @@ -432,42 +449,38 @@ static bool IsAsyncMethod(MethodInfo method) { MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask); MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult); - returnType = asTaskResultGetter.ReturnType; - return async result => + returnTypeInfo = serializerOptions.GetTypeInfo(asTaskResultGetter.ReturnType); + return async (taskObj, cancellationToken) => { - var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(result), null)!; + var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; await task.ConfigureAwait(false); - return ReflectionInvoke(asTaskResultGetter, task, null); + object? result = ReflectionInvoke(asTaskResultGetter, task, null); + return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); }; } } - // For everything else, just use the result as-is. - return result => new ValueTask(result); - - // Throws an exception if a result is found to be null unexpectedly - static object ThrowIfNullResult(object? result) => result ?? throw new InvalidOperationException("Function returned null unexpectedly."); - } + // For everything else, just serialize the result as-is. + returnTypeInfo = serializerOptions.GetTypeInfo(returnType); + return (result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken); - /// Invokes the MethodInfo with the specified target object and arguments. - private static object? ReflectionInvoke(MethodInfo method, object? target, object?[]? arguments) - { -#if NET - return method.Invoke(target, BindingFlags.DoNotWrapExceptions, binder: null, arguments, culture: null); -#else - try + static async Task SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken) { - return method.Invoke(target, BindingFlags.Default, binder: null, arguments, culture: null); - } - catch (TargetInvocationException e) when (e.InnerException is not null) - { - // If we're targeting .NET Framework, such that BindingFlags.DoNotWrapExceptions - // is ignored, the original exception will be wrapped in a TargetInvocationException. - // Unwrap it and throw that original exception, maintaining its stack information. - System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw(); - return null; + if (returnTypeInfo.Kind is JsonTypeInfoKind.None) + { + // Special-case trivial contracts to avoid the more expensive general-purpose serialization path. + return JsonSerializer.SerializeToElement(result, returnTypeInfo); + } + + // Serialize asynchronously to support potential IAsyncEnumerable responses. + using PooledMemoryStream stream = new(); + await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + Utf8JsonReader reader = new(stream.GetBuffer()); + return JsonElement.ParseValue(ref reader); } -#endif + + // Throws an exception if a result is found to be null unexpectedly + static object ThrowIfNullResult(object? result) => result ?? throw new InvalidOperationException("Function returned null unexpectedly."); } private static readonly MethodInfo _taskGetResult = typeof(Task<>).GetProperty(nameof(Task.Result), BindingFlags.Instance | BindingFlags.Public)!.GetMethod!; @@ -485,5 +498,7 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken); #endif } + + private record struct DescriptorKey(MethodInfo Method, string? Name, string? Description, AIJsonSchemaCreateOptions SchemaOptions); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs index a0804a0451f..05084c102ab 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs @@ -64,6 +64,53 @@ public static void AIJsonSchemaCreateOptions_DefaultInstance_ReturnsExpectedValu Assert.Null(options.TransformSchemaNode); } + [Fact] + public static void AIJsonSchemaCreateOptions_UsesStructuralEquality() + { + AssertEqual(new AIJsonSchemaCreateOptions(), new AIJsonSchemaCreateOptions()); + + foreach (PropertyInfo property in typeof(AIJsonSchemaCreateOptions).GetProperties(BindingFlags.Instance | BindingFlags.Public)) + { + AIJsonSchemaCreateOptions options1 = new AIJsonSchemaCreateOptions(); + AIJsonSchemaCreateOptions options2 = new AIJsonSchemaCreateOptions(); + switch (property.GetValue(AIJsonSchemaCreateOptions.Default)) + { + case bool booleanFlag: + property.SetValue(options1, !booleanFlag); + property.SetValue(options2, !booleanFlag); + break; + + case null when property.PropertyType == typeof(Func): + Func transformer = static (context, schema) => (JsonNode)true; + property.SetValue(options1, transformer); + property.SetValue(options2, transformer); + break; + + default: + Assert.Fail($"Unexpected property type: {property.PropertyType}"); + break; + } + + AssertEqual(options1, options2); + AssertNotEqual(AIJsonSchemaCreateOptions.Default, options1); + } + + static void AssertEqual(AIJsonSchemaCreateOptions x, AIJsonSchemaCreateOptions y) + { + Assert.Equal(x.GetHashCode(), y.GetHashCode()); + Assert.Equal(x, x); + Assert.Equal(y, y); + Assert.Equal(x, y); + Assert.Equal(y, x); + } + + static void AssertNotEqual(AIJsonSchemaCreateOptions x, AIJsonSchemaCreateOptions y) + { + Assert.NotEqual(x, y); + Assert.NotEqual(y, x); + } + } + [Fact] public static void CreateJsonSchema_DefaultParameters_GeneratesExpectedJsonSchema() {