From 4c8f24d16278aa6abe477414c6a2cb04db5b508a Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Sat, 15 Feb 2025 01:09:49 +0000 Subject: [PATCH 1/6] Reinstate caching in schema generation --- .../Utilities/AIJsonUtilities.Schema.cs | 173 +++++++++++++----- 1 file changed, 123 insertions(+), 50 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs index 805c121b326..9027fc1b151 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Concurrent; using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -19,6 +20,7 @@ #pragma warning disable S1075 // URIs should not be hardcoded #pragma warning disable SA1118 // Parameter should not span multiple lines #pragma warning disable S109 // Magic numbers should not be used +#pragma warning disable S1067 // Expressions should not be too complex namespace Microsoft.Extensions.AI; @@ -41,6 +43,12 @@ public static partial class AIJsonUtilities /// The uri used when populating the $schema keyword in inferred schemas. private const string SchemaKeywordUri = "https://json-schema.org/draft/2020-12/schema"; + /// The maximum number of schema entries to cache per JsonSerializerOptions instance. + private const int InnerCacheSoftLimit = 512; + + /// A global cache for generated schemas, weakly keyed on JsonSerializerOptions instances. + private static readonly ConditionalWeakTable> _schemaCache = new(); + // List of keywords used by JsonSchemaExporter but explicitly disallowed by some AI vendors. // cf. https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported private static readonly string[] _schemaKeywordsDisallowedByAIVendors = ["minLength", "maxLength", "pattern", "format"]; @@ -65,58 +73,64 @@ public static JsonElement CreateFunctionJsonSchema( serializerOptions ??= DefaultOptions; inferenceOptions ??= AIJsonSchemaCreateOptions.Default; title ??= method.Name; - description ??= method.GetCustomAttribute()?.Description; - JsonObject parameterSchemas = new(); - JsonArray? requiredProperties = null; - foreach (ParameterInfo parameter in method.GetParameters()) + JsonSchemaCacheKey cacheKey = new(member: method, title, description, hasDefaultValue: false, defaultValue: null, inferenceOptions); + return GetOrAddSchema(serializerOptions, cacheKey, CreateSchema); + + static JsonElement CreateSchema(JsonSchemaCacheKey key, JsonSerializerOptions serializerOptions) { - if (string.IsNullOrWhiteSpace(parameter.Name)) + JsonObject parameterSchemas = new(); + JsonArray? requiredProperties = null; + foreach (ParameterInfo parameter in ((MethodBase)key.Member!).GetParameters()) { - Throw.ArgumentException(nameof(parameter), "Parameter is missing a name."); + if (string.IsNullOrWhiteSpace(parameter.Name)) + { + Throw.ArgumentException(nameof(parameter), "Parameter is missing a name."); + } + + JsonNode parameterSchema = CreateJsonSchemaCore( + type: parameter.ParameterType, + parameterName: parameter.Name, + description: parameter.GetCustomAttribute(inherit: true)?.Description, + hasDefaultValue: parameter.HasDefaultValue, + defaultValue: parameter.HasDefaultValue ? parameter.DefaultValue : null, + serializerOptions, + key.Options); + + parameterSchemas.Add(parameter.Name, parameterSchema); + if (!parameter.IsOptional) + { + (requiredProperties ??= []).Add((JsonNode)parameter.Name); + } } - JsonNode parameterSchema = CreateJsonSchemaCore( - type: parameter.ParameterType, - parameterName: parameter.Name, - description: parameter.GetCustomAttribute(inherit: true)?.Description, - hasDefaultValue: parameter.HasDefaultValue, - defaultValue: parameter.HasDefaultValue ? parameter.DefaultValue : null, - serializerOptions, - inferenceOptions); - - parameterSchemas.Add(parameter.Name, parameterSchema); - if (!parameter.IsOptional) + JsonObject schema = new(); + if (key.Options.IncludeSchemaKeyword) { - (requiredProperties ??= []).Add((JsonNode)parameter.Name); + schema[SchemaPropertyName] = SchemaKeywordUri; } - } - JsonObject schema = new(); - if (inferenceOptions.IncludeSchemaKeyword) - { - schema[SchemaPropertyName] = SchemaKeywordUri; - } + if (!string.IsNullOrWhiteSpace(key.Title)) + { + schema[TitlePropertyName] = key.Title; + } - if (!string.IsNullOrWhiteSpace(title)) - { - schema[TitlePropertyName] = title; - } + string? description = key.Description ?? key.Member.GetCustomAttribute()?.Description; + if (!string.IsNullOrWhiteSpace(description)) + { + schema[DescriptionPropertyName] = description; + } - if (!string.IsNullOrWhiteSpace(description)) - { - schema[DescriptionPropertyName] = description; - } + schema[TypePropertyName] = "object"; // Method schemas always hardcode the type as "object". + schema[PropertiesPropertyName] = parameterSchemas; - schema[TypePropertyName] = "object"; // Method schemas always hardcode the type as "object". - schema[PropertiesPropertyName] = parameterSchemas; + if (requiredProperties is not null) + { + schema[RequiredPropertyName] = requiredProperties; + } - if (requiredProperties is not null) - { - schema[RequiredPropertyName] = requiredProperties; + return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode); } - - return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode); } /// Creates a JSON schema for the specified type. @@ -137,22 +151,19 @@ public static JsonElement CreateJsonSchema( { serializerOptions ??= DefaultOptions; inferenceOptions ??= AIJsonSchemaCreateOptions.Default; - JsonNode schema = CreateJsonSchemaCore(type, parameterName: null, description, hasDefaultValue, defaultValue, serializerOptions, inferenceOptions); - return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode); - } - - /// Gets the default JSON schema to be used by types or functions. - internal static JsonElement DefaultJsonSchema { get; } = ParseJsonElement("{}"u8); - /// Validates the provided JSON schema document. - internal static void ValidateSchemaDocument(JsonElement document, [CallerArgumentExpression("document")] string? paramName = null) - { - if (document.ValueKind is not JsonValueKind.Object or JsonValueKind.False or JsonValueKind.True) + JsonSchemaCacheKey cacheKey = new(member: type, title: null, description, hasDefaultValue, defaultValue, inferenceOptions); + return GetOrAddSchema(serializerOptions, cacheKey, CreateSchema); + static JsonElement CreateSchema(JsonSchemaCacheKey key, JsonSerializerOptions serializerOptions) { - Throw.ArgumentException(paramName ?? "schema", "The schema document must be an object or a boolean value."); + JsonNode schema = CreateJsonSchemaCore((Type?)key.Member, parameterName: null, key.Description, key.HasDefaultValue, key.DefaultValue, serializerOptions, key.Options); + return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode); } } + /// Gets the default JSON schema to be used by types or functions. + internal static JsonElement DefaultJsonSchema { get; } = ParseJsonElement("{}"u8); + #if !NET9_0_OR_GREATER [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access", Justification = "Pre STJ-9 schema extraction can fail with a runtime exception if certain reflection metadata have been trimmed. " + @@ -411,6 +422,68 @@ private static int IndexOf(this JsonObject jsonObject, string key) return -1; } #endif + + private static JsonElement GetOrAddSchema(JsonSerializerOptions serializerOptions, JsonSchemaCacheKey cacheKey, Func schemaFactory) + { + ConcurrentDictionary innerCache = _schemaCache.GetOrCreateValue(serializerOptions); + if (innerCache.TryGetValue(cacheKey, out JsonElement schema)) + { + return schema; + } + + if (innerCache.Count >= InnerCacheSoftLimit) + { + return schemaFactory(cacheKey, serializerOptions); + } + +#if NET + return innerCache.GetOrAdd(cacheKey, schemaFactory, serializerOptions); +#else + return innerCache.GetOrAdd(cacheKey, cacheKey => schemaFactory(cacheKey, serializerOptions)); +#endif + } + + private readonly struct JsonSchemaCacheKey : IEquatable + { + public JsonSchemaCacheKey(MemberInfo? member, string? title, string? description, bool hasDefaultValue, object? defaultValue, AIJsonSchemaCreateOptions options) + { + Debug.Assert(member is Type or MethodBase or null, "Must be type or method"); + Member = member; + Title = title; + Description = description; + HasDefaultValue = hasDefaultValue; + DefaultValue = defaultValue; + Options = options; + } + + public MemberInfo? Member { get; } + public string? Title { get; } + public string? Description { get; } + public bool HasDefaultValue { get; } + public object? DefaultValue { get; } + public AIJsonSchemaCreateOptions Options { get; } + + public override bool Equals(object? obj) => obj is JsonSchemaCacheKey key && Equals(key); + public bool Equals(JsonSchemaCacheKey other) => + Member == other.Member && + Title == other.Title && + Description == other.Description && + HasDefaultValue == other.HasDefaultValue && + Equals(DefaultValue, other.DefaultValue) && + Options.TransformSchemaNode == other.Options.TransformSchemaNode && + Options.IncludeTypeInEnumSchemas == other.Options.IncludeTypeInEnumSchemas && + Options.DisallowAdditionalProperties == other.Options.DisallowAdditionalProperties && + Options.IncludeSchemaKeyword == other.Options.IncludeSchemaKeyword && + Options.RequireAllProperties == other.Options.RequireAllProperties; + + public override int GetHashCode() => + (Member, Title, Description, HasDefaultValue, DefaultValue, + Options.TransformSchemaNode, Options.IncludeTypeInEnumSchemas, + Options.DisallowAdditionalProperties, Options.IncludeSchemaKeyword, + Options.RequireAllProperties) + .GetHashCode(); + } + private static JsonElement ParseJsonElement(ReadOnlySpan utf8Json) { Utf8JsonReader reader = new(utf8Json); From 568bf8fb396c330f99e98fb50d220cf2768b6b1c Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Sat, 15 Feb 2025 15:25:12 +0000 Subject: [PATCH 2/6] Address feedback. --- .../Utilities/AIJsonUtilities.Schema.cs | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs index 9027fc1b151..fea8d7148a9 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs @@ -426,21 +426,16 @@ private static int IndexOf(this JsonObject jsonObject, string key) private static JsonElement GetOrAddSchema(JsonSerializerOptions serializerOptions, JsonSchemaCacheKey cacheKey, Func schemaFactory) { ConcurrentDictionary innerCache = _schemaCache.GetOrCreateValue(serializerOptions); - if (innerCache.TryGetValue(cacheKey, out JsonElement schema)) + if (!innerCache.TryGetValue(cacheKey, out JsonElement schema)) { - return schema; - } - - if (innerCache.Count >= InnerCacheSoftLimit) - { - return schemaFactory(cacheKey, serializerOptions); + schema = schemaFactory(cacheKey, serializerOptions); + if (innerCache.Count < InnerCacheSoftLimit) + { + _ = innerCache.TryAdd(cacheKey, schema); + } } -#if NET - return innerCache.GetOrAdd(cacheKey, schemaFactory, serializerOptions); -#else - return innerCache.GetOrAdd(cacheKey, cacheKey => schemaFactory(cacheKey, serializerOptions)); -#endif + return schema; } private readonly struct JsonSchemaCacheKey : IEquatable From 84c9c441e5d1d93ecc8d47190f5d62cfc676442d Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Tue, 18 Feb 2025 19:29:08 +0000 Subject: [PATCH 3/6] Only cache on the AIFunctionFactory level. --- .../Utilities/AIJsonSchemaCreateOptions.cs | 21 +- .../Utilities/AIJsonUtilities.Schema.cs | 168 +++------ .../Functions/AIFunctionFactory.Utilities.cs | 22 ++ .../Functions/AIFunctionFactory.cs | 330 +++++++++--------- .../Utilities/AIJsonUtilitiesTests.cs | 47 +++ 5 files changed, 312 insertions(+), 276 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs index ea1f393f7e5..3a9c99c2e72 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs @@ -4,12 +4,14 @@ using System; using System.Text.Json.Nodes; +#pragma warning disable S1067 // Expressions should not be too complex + namespace Microsoft.Extensions.AI; /// /// Provides options for configuring the behavior of JSON schema creation functionality. /// -public sealed class AIJsonSchemaCreateOptions +public sealed class AIJsonSchemaCreateOptions : IEquatable { /// /// Gets the default options instance. @@ -40,4 +42,21 @@ public sealed class AIJsonSchemaCreateOptions /// Gets a value indicating whether to mark all properties as required in the schema. /// public bool RequireAllProperties { get; init; } = true; + + /// + public bool Equals(AIJsonSchemaCreateOptions? other) + { + return other is not null && + TransformSchemaNode == other.TransformSchemaNode && + IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas && + DisallowAdditionalProperties == other.DisallowAdditionalProperties && + IncludeSchemaKeyword == other.IncludeSchemaKeyword && + RequireAllProperties == other.RequireAllProperties; + } + + /// + public override bool Equals(object? obj) => obj is AIJsonSchemaCreateOptions other && Equals(other); + + /// + public override int GetHashCode() => (TransformSchemaNode, IncludeTypeInEnumSchemas, DisallowAdditionalProperties, IncludeSchemaKeyword, RequireAllProperties).GetHashCode(); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs index fea8d7148a9..805c121b326 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Concurrent; using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -20,7 +19,6 @@ #pragma warning disable S1075 // URIs should not be hardcoded #pragma warning disable SA1118 // Parameter should not span multiple lines #pragma warning disable S109 // Magic numbers should not be used -#pragma warning disable S1067 // Expressions should not be too complex namespace Microsoft.Extensions.AI; @@ -43,12 +41,6 @@ public static partial class AIJsonUtilities /// The uri used when populating the $schema keyword in inferred schemas. private const string SchemaKeywordUri = "https://json-schema.org/draft/2020-12/schema"; - /// The maximum number of schema entries to cache per JsonSerializerOptions instance. - private const int InnerCacheSoftLimit = 512; - - /// A global cache for generated schemas, weakly keyed on JsonSerializerOptions instances. - private static readonly ConditionalWeakTable> _schemaCache = new(); - // List of keywords used by JsonSchemaExporter but explicitly disallowed by some AI vendors. // cf. https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported private static readonly string[] _schemaKeywordsDisallowedByAIVendors = ["minLength", "maxLength", "pattern", "format"]; @@ -73,64 +65,58 @@ public static JsonElement CreateFunctionJsonSchema( serializerOptions ??= DefaultOptions; inferenceOptions ??= AIJsonSchemaCreateOptions.Default; title ??= method.Name; + description ??= method.GetCustomAttribute()?.Description; - JsonSchemaCacheKey cacheKey = new(member: method, title, description, hasDefaultValue: false, defaultValue: null, inferenceOptions); - return GetOrAddSchema(serializerOptions, cacheKey, CreateSchema); - - static JsonElement CreateSchema(JsonSchemaCacheKey key, JsonSerializerOptions serializerOptions) + JsonObject parameterSchemas = new(); + JsonArray? requiredProperties = null; + foreach (ParameterInfo parameter in method.GetParameters()) { - JsonObject parameterSchemas = new(); - JsonArray? requiredProperties = null; - foreach (ParameterInfo parameter in ((MethodBase)key.Member!).GetParameters()) + if (string.IsNullOrWhiteSpace(parameter.Name)) { - if (string.IsNullOrWhiteSpace(parameter.Name)) - { - Throw.ArgumentException(nameof(parameter), "Parameter is missing a name."); - } - - JsonNode parameterSchema = CreateJsonSchemaCore( - type: parameter.ParameterType, - parameterName: parameter.Name, - description: parameter.GetCustomAttribute(inherit: true)?.Description, - hasDefaultValue: parameter.HasDefaultValue, - defaultValue: parameter.HasDefaultValue ? parameter.DefaultValue : null, - serializerOptions, - key.Options); - - parameterSchemas.Add(parameter.Name, parameterSchema); - if (!parameter.IsOptional) - { - (requiredProperties ??= []).Add((JsonNode)parameter.Name); - } + Throw.ArgumentException(nameof(parameter), "Parameter is missing a name."); } - JsonObject schema = new(); - if (key.Options.IncludeSchemaKeyword) + JsonNode parameterSchema = CreateJsonSchemaCore( + type: parameter.ParameterType, + parameterName: parameter.Name, + description: parameter.GetCustomAttribute(inherit: true)?.Description, + hasDefaultValue: parameter.HasDefaultValue, + defaultValue: parameter.HasDefaultValue ? parameter.DefaultValue : null, + serializerOptions, + inferenceOptions); + + parameterSchemas.Add(parameter.Name, parameterSchema); + if (!parameter.IsOptional) { - schema[SchemaPropertyName] = SchemaKeywordUri; + (requiredProperties ??= []).Add((JsonNode)parameter.Name); } + } - if (!string.IsNullOrWhiteSpace(key.Title)) - { - schema[TitlePropertyName] = key.Title; - } + JsonObject schema = new(); + if (inferenceOptions.IncludeSchemaKeyword) + { + schema[SchemaPropertyName] = SchemaKeywordUri; + } - string? description = key.Description ?? key.Member.GetCustomAttribute()?.Description; - if (!string.IsNullOrWhiteSpace(description)) - { - schema[DescriptionPropertyName] = description; - } + if (!string.IsNullOrWhiteSpace(title)) + { + schema[TitlePropertyName] = title; + } - schema[TypePropertyName] = "object"; // Method schemas always hardcode the type as "object". - schema[PropertiesPropertyName] = parameterSchemas; + if (!string.IsNullOrWhiteSpace(description)) + { + schema[DescriptionPropertyName] = description; + } - if (requiredProperties is not null) - { - schema[RequiredPropertyName] = requiredProperties; - } + schema[TypePropertyName] = "object"; // Method schemas always hardcode the type as "object". + schema[PropertiesPropertyName] = parameterSchemas; - return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode); + if (requiredProperties is not null) + { + schema[RequiredPropertyName] = requiredProperties; } + + return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode); } /// Creates a JSON schema for the specified type. @@ -151,19 +137,22 @@ public static JsonElement CreateJsonSchema( { serializerOptions ??= DefaultOptions; inferenceOptions ??= AIJsonSchemaCreateOptions.Default; - - JsonSchemaCacheKey cacheKey = new(member: type, title: null, description, hasDefaultValue, defaultValue, inferenceOptions); - return GetOrAddSchema(serializerOptions, cacheKey, CreateSchema); - static JsonElement CreateSchema(JsonSchemaCacheKey key, JsonSerializerOptions serializerOptions) - { - JsonNode schema = CreateJsonSchemaCore((Type?)key.Member, parameterName: null, key.Description, key.HasDefaultValue, key.DefaultValue, serializerOptions, key.Options); - return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode); - } + JsonNode schema = CreateJsonSchemaCore(type, parameterName: null, description, hasDefaultValue, defaultValue, serializerOptions, inferenceOptions); + return JsonSerializer.SerializeToElement(schema, JsonContext.Default.JsonNode); } /// Gets the default JSON schema to be used by types or functions. internal static JsonElement DefaultJsonSchema { get; } = ParseJsonElement("{}"u8); + /// Validates the provided JSON schema document. + internal static void ValidateSchemaDocument(JsonElement document, [CallerArgumentExpression("document")] string? paramName = null) + { + if (document.ValueKind is not JsonValueKind.Object or JsonValueKind.False or JsonValueKind.True) + { + Throw.ArgumentException(paramName ?? "schema", "The schema document must be an object or a boolean value."); + } + } + #if !NET9_0_OR_GREATER [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access", Justification = "Pre STJ-9 schema extraction can fail with a runtime exception if certain reflection metadata have been trimmed. " + @@ -422,63 +411,6 @@ private static int IndexOf(this JsonObject jsonObject, string key) return -1; } #endif - - private static JsonElement GetOrAddSchema(JsonSerializerOptions serializerOptions, JsonSchemaCacheKey cacheKey, Func schemaFactory) - { - ConcurrentDictionary innerCache = _schemaCache.GetOrCreateValue(serializerOptions); - if (!innerCache.TryGetValue(cacheKey, out JsonElement schema)) - { - schema = schemaFactory(cacheKey, serializerOptions); - if (innerCache.Count < InnerCacheSoftLimit) - { - _ = innerCache.TryAdd(cacheKey, schema); - } - } - - return schema; - } - - private readonly struct JsonSchemaCacheKey : IEquatable - { - public JsonSchemaCacheKey(MemberInfo? member, string? title, string? description, bool hasDefaultValue, object? defaultValue, AIJsonSchemaCreateOptions options) - { - Debug.Assert(member is Type or MethodBase or null, "Must be type or method"); - Member = member; - Title = title; - Description = description; - HasDefaultValue = hasDefaultValue; - DefaultValue = defaultValue; - Options = options; - } - - public MemberInfo? Member { get; } - public string? Title { get; } - public string? Description { get; } - public bool HasDefaultValue { get; } - public object? DefaultValue { get; } - public AIJsonSchemaCreateOptions Options { get; } - - public override bool Equals(object? obj) => obj is JsonSchemaCacheKey key && Equals(key); - public bool Equals(JsonSchemaCacheKey other) => - Member == other.Member && - Title == other.Title && - Description == other.Description && - HasDefaultValue == other.HasDefaultValue && - Equals(DefaultValue, other.DefaultValue) && - Options.TransformSchemaNode == other.Options.TransformSchemaNode && - Options.IncludeTypeInEnumSchemas == other.Options.IncludeTypeInEnumSchemas && - Options.DisallowAdditionalProperties == other.Options.DisallowAdditionalProperties && - Options.IncludeSchemaKeyword == other.Options.IncludeSchemaKeyword && - Options.RequireAllProperties == other.Options.RequireAllProperties; - - public override int GetHashCode() => - (Member, Title, Description, HasDefaultValue, DefaultValue, - Options.TransformSchemaNode, Options.IncludeTypeInEnumSchemas, - Options.DisallowAdditionalProperties, Options.IncludeSchemaKeyword, - Options.RequireAllProperties) - .GetHashCode(); - } - private static JsonElement ParseJsonElement(ReadOnlySpan utf8Json) { Utf8JsonReader reader = new(utf8Json); diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs index 251059035db..2094f2f1886 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Reflection; using System.Text.RegularExpressions; using Microsoft.Shared.Diagnostics; @@ -30,4 +31,25 @@ internal static string SanitizeMemberName(string memberName) private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex; private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled); #endif + + /// Invokes the MethodInfo with the specified target object and arguments. + private static object? ReflectionInvoke(MethodInfo method, object? target, object?[]? arguments) + { +#if NET + return method.Invoke(target, BindingFlags.DoNotWrapExceptions, binder: null, arguments, culture: null); +#else + try + { + return method.Invoke(target, BindingFlags.Default, binder: null, arguments, culture: null); + } + catch (TargetInvocationException e) when (e.InnerException is not null) + { + // If we're targeting .NET Framework, such that BindingFlags.DoNotWrapExceptions + // is ignored, the original exception will be wrapped in a TargetInvocationException. + // Unwrap it and throw that original exception, maintaining its stack information. + System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw(); + return null; + } +#endif + } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index d0d3385749e..30c691d38e1 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -2,12 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; using System.IO; using System.Linq; using System.Reflection; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization.Metadata; @@ -42,7 +44,7 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? optio { _ = Throw.IfNull(method); - return new ReflectionAIFunction(method.Method, method.Target, options ?? _defaultOptions); + return ReflectionAIFunction.Build(method.Method, method.Target, options ?? _defaultOptions); } /// Creates an instance for a method, specified via a delegate. @@ -68,12 +70,12 @@ public static AIFunction Create(Delegate method, string? name = null, string? de ? _defaultOptions : new() { - SerializerOptions = serializerOptions ?? _defaultOptions.SerializerOptions, Name = name, - Description = description + Description = description, + SerializerOptions = serializerOptions, }; - return new ReflectionAIFunction(method.Method, method.Target, createOptions); + return ReflectionAIFunction.Build(method.Method, method.Target, createOptions); } /// @@ -100,7 +102,7 @@ public static AIFunction Create(Delegate method, string? name = null, string? de public static AIFunction Create(MethodInfo method, object? target, AIFunctionFactoryOptions? options) { _ = Throw.IfNull(method); - return new ReflectionAIFunction(method, target, options ?? _defaultOptions); + return ReflectionAIFunction.Build(method, target, options ?? _defaultOptions); } /// @@ -129,44 +131,23 @@ public static AIFunction Create(MethodInfo method, object? target, string? name { _ = Throw.IfNull(method); - AIFunctionFactoryOptions? createOptions = serializerOptions is null && name is null && description is null + AIFunctionFactoryOptions createOptions = serializerOptions is null && name is null && description is null ? _defaultOptions : new() { - SerializerOptions = serializerOptions ?? _defaultOptions.SerializerOptions, Name = name, - Description = description + Description = description, + SerializerOptions = serializerOptions, }; - return new ReflectionAIFunction(method, target, createOptions); + return ReflectionAIFunction.Build(method, target, createOptions); } private sealed class ReflectionAIFunction : AIFunction { - private readonly MethodInfo _method; - private readonly object? _target; - private readonly Func, AIFunctionContext?, object?>[] _parameterMarshallers; - private readonly Func> _returnMarshaller; - private readonly JsonTypeInfo? _returnTypeInfo; - private readonly bool _needsAIFunctionContext; - - /// - /// Initializes a new instance of the class for a method, specified via an instance - /// and an optional target object if the method is an instance method. - /// - /// The method to be represented via the created . - /// - /// The target object for the if it represents an instance method. - /// This should be if and only if is a static method. - /// - /// Function creation options. - public ReflectionAIFunction(MethodInfo method, object? target, AIFunctionFactoryOptions options) + public static ReflectionAIFunction Build(MethodInfo method, object? target, AIFunctionFactoryOptions options) { _ = Throw.IfNull(method); - _ = Throw.IfNull(options); - - JsonSerializerOptions serializerOptions = options.SerializerOptions ?? AIJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); if (method.ContainsGenericParameters) { @@ -178,86 +159,37 @@ public ReflectionAIFunction(MethodInfo method, object? target, AIFunctionFactory Throw.ArgumentNullException(nameof(target), "Target must not be null for an instance method."); } - _method = method; - _target = target; - - // Get the function name to use. - string? functionName = options.Name; - if (functionName is null) - { - functionName = SanitizeMemberName(method.Name!); - - const string AsyncSuffix = "Async"; - if (IsAsyncMethod(method) && - functionName.EndsWith(AsyncSuffix, StringComparison.Ordinal) && - functionName.Length > AsyncSuffix.Length) - { - functionName = functionName.Substring(0, functionName.Length - AsyncSuffix.Length); - } - - static bool IsAsyncMethod(MethodInfo method) - { - Type t = method.ReturnType; - - if (t == typeof(Task) || t == typeof(ValueTask)) - { - return true; - } - - if (t.IsGenericType) - { - t = t.GetGenericTypeDefinition(); - if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) - { - return true; - } - } - - return false; - } - } + ReflectionAIFunctionDescriptor functionDescriptor = ReflectionAIFunctionDescriptor.GetOrCreate(method, options); - // Get marshaling delegates for parameters. - ParameterInfo[] parameters = method.GetParameters(); - _parameterMarshallers = new Func, AIFunctionContext?, object?>[parameters.Length]; - bool sawAIContextParameter = false; - for (int i = 0; i < parameters.Length; i++) + if (target is null && options.AdditionalProperties is null) { - _parameterMarshallers[i] = GetParameterMarshaller(serializerOptions, parameters[i], ref sawAIContextParameter); + // We can use a cached value for static methods not specifying additional properties. + return functionDescriptor.CachedDefaultInstance ??= new(functionDescriptor, target, options); } - _needsAIFunctionContext = sawAIContextParameter; - - // Get the return type and a marshaling func for the return value. - _returnMarshaller = GetReturnMarshaller(method, out Type returnType); - _returnTypeInfo = returnType != typeof(void) ? serializerOptions.GetTypeInfo(returnType) : null; + return new(functionDescriptor, target, options); + } - Name = functionName; - Description = options.Description ?? method.GetCustomAttribute(inherit: true)?.Description ?? string.Empty; - UnderlyingMethod = method; + private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, object? target, AIFunctionFactoryOptions options) + { + FunctionDescriptor = functionDescriptor; + Target = target; AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary.Instance; - JsonSerializerOptions = serializerOptions; - JsonSchema = AIJsonUtilities.CreateFunctionJsonSchema( - method, - title: Name, - description: Description, - options.SerializerOptions, - options.JsonSchemaCreateOptions); } - public override string Name { get; } - public override string Description { get; } - public override MethodInfo? UnderlyingMethod { get; } + public ReflectionAIFunctionDescriptor FunctionDescriptor { get; } + public object? Target { get; } + public override string Name => FunctionDescriptor.Name; + public override string Description => FunctionDescriptor.Description; + public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method; + public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema; + public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions; public override IReadOnlyDictionary AdditionalProperties { get; } - public override JsonSerializerOptions JsonSerializerOptions { get; } - public override JsonElement JsonSchema { get; } - - /// protected override async Task InvokeCoreAsync( IEnumerable>? arguments, CancellationToken cancellationToken) { - var paramMarshallers = _parameterMarshallers; + var paramMarshallers = FunctionDescriptor.ParameterMarshallers; object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : []; IReadOnlyDictionary argDict = @@ -269,7 +201,7 @@ static bool IsAsyncMethod(MethodInfo method) #else ToDictionary(kvp => kvp.Key, kvp => kvp.Value); #endif - AIFunctionContext? context = _needsAIFunctionContext ? + AIFunctionContext? context = FunctionDescriptor.RequiresAIFunctionContext ? new() { CancellationToken = cancellationToken } : null; @@ -278,30 +210,111 @@ static bool IsAsyncMethod(MethodInfo method) args[i] = paramMarshallers[i](argDict, context); } - object? result = await _returnMarshaller(ReflectionInvoke(_method, _target, args)).ConfigureAwait(false); + return await FunctionDescriptor.ReturnParameterMarshaller(ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken).ConfigureAwait(false); + } + } - switch (_returnTypeInfo) + /// + /// A descriptor for a .NET method-backed AIFunction that precomputes its marshalling delegates and JSON schema. + /// + private sealed class ReflectionAIFunctionDescriptor + { + private const int InnerCacheSoftLimit = 512; + private static readonly ConditionalWeakTable> _descriptorCache = new(); + + /// + /// Gets or creates a descriptors using the specified method and options. + /// + public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFunctionFactoryOptions options) + { + JsonSerializerOptions serializerOptions = options.SerializerOptions ?? AIJsonUtilities.DefaultOptions; + AIJsonSchemaCreateOptions schemaOptions = options.JsonSchemaCreateOptions ?? AIJsonSchemaCreateOptions.Default; + serializerOptions.MakeReadOnly(); + ConcurrentDictionary innerCache = _descriptorCache.GetOrCreateValue(serializerOptions); + + DescriptorKey key = new(method, options.Name, options.Description, schemaOptions); + if (innerCache.TryGetValue(key, out ReflectionAIFunctionDescriptor? descriptor)) { - case null: - Debug.Assert( - UnderlyingMethod?.ReturnType == typeof(void) || - UnderlyingMethod?.ReturnType == typeof(Task) || - UnderlyingMethod?.ReturnType == typeof(ValueTask), "The return parameter should be void or non-generic task."); + return descriptor; + } - return null; + descriptor = new(key, serializerOptions); + return innerCache.Count < InnerCacheSoftLimit + ? innerCache.GetOrAdd(key, descriptor) + : descriptor; + } - case { Kind: JsonTypeInfoKind.None }: - // Special-case trivial contracts to avoid the more expensive general-purpose serialization path. - return JsonSerializer.SerializeToElement(result, _returnTypeInfo); + private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions) + { + // Get marshaling delegates for parameters. + ParameterInfo[] parameters = key.Method.GetParameters(); + ParameterMarshallers = new Func, AIFunctionContext?, object?>[parameters.Length]; + bool foundAIFunctionContextParameter = false; + for (int i = 0; i < parameters.Length; i++) + { + ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, parameters[i], ref foundAIFunctionContextParameter); + } + + // Get a marshaling delegate for the return value. + ReturnParameterMarshaller = GetReturnParameterMarshaller(key.Method, serializerOptions); - default: + Method = key.Method; + Name = key.Name ?? GetFunctionName(key.Method); + Description = key.Description ?? key.Method.GetCustomAttribute(inherit: true)?.Description ?? string.Empty; + RequiresAIFunctionContext = foundAIFunctionContextParameter; + JsonSerializerOptions = serializerOptions; + JsonSchema = AIJsonUtilities.CreateFunctionJsonSchema( + key.Method, + Name, + Description, + serializerOptions, + key.SchemaOptions); + } + + public string Name { get; } + public string Description { get; } + public MethodInfo Method { get; } + public JsonSerializerOptions JsonSerializerOptions { get; } + public JsonElement JsonSchema { get; } + public Func, AIFunctionContext?, object?>[] ParameterMarshallers { get; } + public Func> ReturnParameterMarshaller { get; } + public bool RequiresAIFunctionContext { get; } + public ReflectionAIFunction? CachedDefaultInstance { get; set; } + + private static string GetFunctionName(MethodInfo method) + { + // Get the function name to use. + string name = SanitizeMemberName(method.Name); + + const string AsyncSuffix = "Async"; + if (IsAsyncMethod(method) && + name.EndsWith(AsyncSuffix, StringComparison.Ordinal) && + name.Length > AsyncSuffix.Length) + { + name = name.Substring(0, name.Length - AsyncSuffix.Length); + } + + return name; + + static bool IsAsyncMethod(MethodInfo method) + { + Type t = method.ReturnType; + + if (t == typeof(Task) || t == typeof(ValueTask)) { - // Serialize asynchronously to support potential IAsyncEnumerable responses. - using MemoryStream stream = new(); - await JsonSerializer.SerializeAsync(stream, result, _returnTypeInfo, cancellationToken).ConfigureAwait(false); - Utf8JsonReader reader = new(stream.GetBuffer().AsSpan(0, (int)stream.Length)); - return JsonElement.ParseValue(ref reader); + return true; } + + if (t.IsGenericType) + { + t = t.GetGenericTypeDefinition(); + if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) + { + return true; + } + } + + return false; } } @@ -311,7 +324,7 @@ static bool IsAsyncMethod(MethodInfo method) private static Func, AIFunctionContext?, object?> GetParameterMarshaller( JsonSerializerOptions serializerOptions, ParameterInfo parameter, - ref bool sawAIFunctionContext) + ref bool foundAIFunctionContextParameter) { if (string.IsNullOrWhiteSpace(parameter.Name)) { @@ -321,12 +334,12 @@ static bool IsAsyncMethod(MethodInfo method) // Special-case an AIFunctionContext parameter. if (parameter.ParameterType == typeof(AIFunctionContext)) { - if (sawAIFunctionContext) + if (foundAIFunctionContextParameter) { Throw.ArgumentException(nameof(parameter), $"Only one {nameof(AIFunctionContext)} parameter is permitted."); } - sawAIFunctionContext = true; + foundAIFunctionContextParameter = true; return static (_, ctx) => { @@ -386,16 +399,21 @@ static bool IsAsyncMethod(MethodInfo method) /// /// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation. /// - private static Func> GetReturnMarshaller(MethodInfo method, out Type returnType) + private static Func> GetReturnParameterMarshaller(MethodInfo method, JsonSerializerOptions serializerOptions) { - // Handle each known return type for the method - returnType = method.ReturnType; + Type returnType = method.ReturnType; + JsonTypeInfo returnTypeInfo; + + // Void + if (returnType == typeof(void)) + { + return static (_, _) => default; + } // Task if (returnType == typeof(Task)) { - returnType = typeof(void); - return async static result => + return async static (result, _) => { await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); return null; @@ -405,8 +423,7 @@ static bool IsAsyncMethod(MethodInfo method) // ValueTask if (returnType == typeof(ValueTask)) { - returnType = typeof(void); - return async static result => + return async static (result, _) => { await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false); return null; @@ -419,11 +436,12 @@ static bool IsAsyncMethod(MethodInfo method) if (returnType.GetGenericTypeDefinition() == typeof(Task<>)) { MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult); - returnType = taskResultGetter.ReturnType; - return async result => + returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType); + return async (taskObj, cancellationToken) => { - await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); - return ReflectionInvoke(taskResultGetter, result, null); + await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false); + object? result = ReflectionInvoke(taskResultGetter, taskObj, null); + return await SerializeAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); }; } @@ -432,42 +450,38 @@ static bool IsAsyncMethod(MethodInfo method) { MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask); MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult); - returnType = asTaskResultGetter.ReturnType; - return async result => + returnTypeInfo = serializerOptions.GetTypeInfo(asTaskResultGetter.ReturnType); + return async (taskObj, cancellationToken) => { - var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(result), null)!; + var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; await task.ConfigureAwait(false); - return ReflectionInvoke(asTaskResultGetter, task, null); + object? result = ReflectionInvoke(asTaskResultGetter, task, null); + return await SerializeAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); }; } } - // For everything else, just use the result as-is. - return result => new ValueTask(result); - - // Throws an exception if a result is found to be null unexpectedly - static object ThrowIfNullResult(object? result) => result ?? throw new InvalidOperationException("Function returned null unexpectedly."); - } + // For everything else, just serialize the result as-is. + returnTypeInfo = serializerOptions.GetTypeInfo(returnType); + return (result, cancellationToken) => SerializeAsync(result, returnTypeInfo, cancellationToken); - /// Invokes the MethodInfo with the specified target object and arguments. - private static object? ReflectionInvoke(MethodInfo method, object? target, object?[]? arguments) - { -#if NET - return method.Invoke(target, BindingFlags.DoNotWrapExceptions, binder: null, arguments, culture: null); -#else - try + static async ValueTask SerializeAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken) { - return method.Invoke(target, BindingFlags.Default, binder: null, arguments, culture: null); - } - catch (TargetInvocationException e) when (e.InnerException is not null) - { - // If we're targeting .NET Framework, such that BindingFlags.DoNotWrapExceptions - // is ignored, the original exception will be wrapped in a TargetInvocationException. - // Unwrap it and throw that original exception, maintaining its stack information. - System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw(); - return null; + if (returnTypeInfo.Kind is JsonTypeInfoKind.None) + { + // Special-case trivial contracts to avoid the more expensive general-purpose serialization path. + return JsonSerializer.SerializeToElement(result, returnTypeInfo); + } + + // Serialize asynchronously to support potential IAsyncEnumerable responses. + using MemoryStream stream = new(); + await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + Utf8JsonReader reader = new(stream.GetBuffer().AsSpan(0, (int)stream.Length)); + return JsonElement.ParseValue(ref reader); } -#endif + + // Throws an exception if a result is found to be null unexpectedly + static object ThrowIfNullResult(object? result) => result ?? throw new InvalidOperationException("Function returned null unexpectedly."); } private static readonly MethodInfo _taskGetResult = typeof(Task<>).GetProperty(nameof(Task.Result), BindingFlags.Instance | BindingFlags.Public)!.GetMethod!; @@ -485,5 +499,7 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken); #endif } + + private record struct DescriptorKey(MethodInfo Method, string? Name, string? Description, AIJsonSchemaCreateOptions SchemaOptions); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs index a0804a0451f..05084c102ab 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs @@ -64,6 +64,53 @@ public static void AIJsonSchemaCreateOptions_DefaultInstance_ReturnsExpectedValu Assert.Null(options.TransformSchemaNode); } + [Fact] + public static void AIJsonSchemaCreateOptions_UsesStructuralEquality() + { + AssertEqual(new AIJsonSchemaCreateOptions(), new AIJsonSchemaCreateOptions()); + + foreach (PropertyInfo property in typeof(AIJsonSchemaCreateOptions).GetProperties(BindingFlags.Instance | BindingFlags.Public)) + { + AIJsonSchemaCreateOptions options1 = new AIJsonSchemaCreateOptions(); + AIJsonSchemaCreateOptions options2 = new AIJsonSchemaCreateOptions(); + switch (property.GetValue(AIJsonSchemaCreateOptions.Default)) + { + case bool booleanFlag: + property.SetValue(options1, !booleanFlag); + property.SetValue(options2, !booleanFlag); + break; + + case null when property.PropertyType == typeof(Func): + Func transformer = static (context, schema) => (JsonNode)true; + property.SetValue(options1, transformer); + property.SetValue(options2, transformer); + break; + + default: + Assert.Fail($"Unexpected property type: {property.PropertyType}"); + break; + } + + AssertEqual(options1, options2); + AssertNotEqual(AIJsonSchemaCreateOptions.Default, options1); + } + + static void AssertEqual(AIJsonSchemaCreateOptions x, AIJsonSchemaCreateOptions y) + { + Assert.Equal(x.GetHashCode(), y.GetHashCode()); + Assert.Equal(x, x); + Assert.Equal(y, y); + Assert.Equal(x, y); + Assert.Equal(y, x); + } + + static void AssertNotEqual(AIJsonSchemaCreateOptions x, AIJsonSchemaCreateOptions y) + { + Assert.NotEqual(x, y); + Assert.NotEqual(y, x); + } + } + [Fact] public static void CreateJsonSchema_DefaultParameters_GeneratesExpectedJsonSchema() { From fb3d1d7b5baa73d3fab053e7490f9a14360711e0 Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Tue, 18 Feb 2025 19:33:04 +0000 Subject: [PATCH 4/6] Reorder properties, --- .../Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index 30c691d38e1..e9fde9d9940 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -179,12 +179,12 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, public ReflectionAIFunctionDescriptor FunctionDescriptor { get; } public object? Target { get; } + public override IReadOnlyDictionary AdditionalProperties { get; } public override string Name => FunctionDescriptor.Name; public override string Description => FunctionDescriptor.Description; public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method; public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema; public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions; - public override IReadOnlyDictionary AdditionalProperties { get; } protected override async Task InvokeCoreAsync( IEnumerable>? arguments, CancellationToken cancellationToken) From d09aa78b83a6384037c0b43e6f676412d046fe90 Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Tue, 18 Feb 2025 19:37:05 +0000 Subject: [PATCH 5/6] Rename serialization helper. --- .../Functions/AIFunctionFactory.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index e9fde9d9940..a544a13202f 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -441,7 +441,7 @@ static bool IsAsyncMethod(MethodInfo method) { await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false); object? result = ReflectionInvoke(taskResultGetter, taskObj, null); - return await SerializeAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); }; } @@ -456,16 +456,16 @@ static bool IsAsyncMethod(MethodInfo method) var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; await task.ConfigureAwait(false); object? result = ReflectionInvoke(asTaskResultGetter, task, null); - return await SerializeAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); }; } } // For everything else, just serialize the result as-is. returnTypeInfo = serializerOptions.GetTypeInfo(returnType); - return (result, cancellationToken) => SerializeAsync(result, returnTypeInfo, cancellationToken); + return (result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken); - static async ValueTask SerializeAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken) + static async ValueTask SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken) { if (returnTypeInfo.Kind is JsonTypeInfoKind.None) { From 33c9f8c09af68410e26c0a62f2e213ca69180a7b Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Wed, 19 Feb 2025 16:11:49 +0000 Subject: [PATCH 6/6] Address feedback --- .../Functions/AIFunctionFactory.Utilities.cs | 84 ++++++++++++++++++- .../Functions/AIFunctionFactory.cs | 17 ++-- 2 files changed, 91 insertions(+), 10 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs index 2094f2f1886..cbafe78e5d3 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs @@ -1,6 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; +using System.Buffers; +using System.IO; using System.Reflection; using System.Text.RegularExpressions; using Microsoft.Shared.Diagnostics; @@ -48,8 +51,87 @@ internal static string SanitizeMemberName(string memberName) // is ignored, the original exception will be wrapped in a TargetInvocationException. // Unwrap it and throw that original exception, maintaining its stack information. System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw(); - return null; + throw; } #endif } + + /// + /// Implements a simple write-only memory stream that uses pooled buffers. + /// + private sealed class PooledMemoryStream : Stream + { + private const int DefaultBufferSize = 4096; + private byte[] _buffer; + private int _position; + + public PooledMemoryStream(int initialCapacity = DefaultBufferSize) + { + _buffer = ArrayPool.Shared.Rent(initialCapacity); + _position = 0; + } + + public ReadOnlySpan GetBuffer() => _buffer.AsSpan(0, _position); + public override bool CanWrite => true; + public override bool CanRead => false; + public override bool CanSeek => false; + public override long Length => _position; + public override long Position + { + get => _position; + set => throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + EnsureNotDisposed(); + EnsureCapacity(_position + count); + + Buffer.BlockCopy(buffer, offset, _buffer, _position, count); + _position += count; + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + protected override void Dispose(bool disposing) + { + if (_buffer is not null) + { + ArrayPool.Shared.Return(_buffer); + _buffer = null!; + } + + base.Dispose(disposing); + } + + private void EnsureCapacity(int requiredCapacity) + { + if (requiredCapacity <= _buffer.Length) + { + return; + } + + int newCapacity = Math.Max(requiredCapacity, _buffer.Length * 2); + byte[] newBuffer = ArrayPool.Shared.Rent(newCapacity); + Buffer.BlockCopy(_buffer, 0, newBuffer, 0, _position); + + ArrayPool.Shared.Return(_buffer); + _buffer = newBuffer; + } + + private void EnsureNotDisposed() + { + if (_buffer is null) + { + Throw(); + static void Throw() => throw new ObjectDisposedException(nameof(PooledMemoryStream)); + } + } + } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index f2dc0989987..0aff0901c7a 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -6,7 +6,6 @@ using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; -using System.IO; using System.Linq; using System.Reflection; using System.Runtime.CompilerServices; @@ -185,7 +184,7 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method; public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema; public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions; - protected override async Task InvokeCoreAsync( + protected override Task InvokeCoreAsync( IEnumerable>? arguments, CancellationToken cancellationToken) { @@ -210,7 +209,7 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, args[i] = paramMarshallers[i](argDict, context); } - return await FunctionDescriptor.ReturnParameterMarshaller(ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken).ConfigureAwait(false); + return FunctionDescriptor.ReturnParameterMarshaller(ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken); } } @@ -277,7 +276,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions public JsonSerializerOptions JsonSerializerOptions { get; } public JsonElement JsonSchema { get; } public Func, AIFunctionContext?, object?>[] ParameterMarshallers { get; } - public Func> ReturnParameterMarshaller { get; } + public Func> ReturnParameterMarshaller { get; } public bool RequiresAIFunctionContext { get; } public ReflectionAIFunction? CachedDefaultInstance { get; set; } @@ -399,7 +398,7 @@ static bool IsAsyncMethod(MethodInfo method) /// /// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation. /// - private static Func> GetReturnParameterMarshaller(MethodInfo method, JsonSerializerOptions serializerOptions) + private static Func> GetReturnParameterMarshaller(MethodInfo method, JsonSerializerOptions serializerOptions) { Type returnType = method.ReturnType; JsonTypeInfo returnTypeInfo; @@ -407,7 +406,7 @@ static bool IsAsyncMethod(MethodInfo method) // Void if (returnType == typeof(void)) { - return static (_, _) => default; + return static (_, _) => Task.FromResult(null); } // Task @@ -465,7 +464,7 @@ static bool IsAsyncMethod(MethodInfo method) returnTypeInfo = serializerOptions.GetTypeInfo(returnType); return (result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken); - static async ValueTask SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken) + static async Task SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken) { if (returnTypeInfo.Kind is JsonTypeInfoKind.None) { @@ -474,9 +473,9 @@ static bool IsAsyncMethod(MethodInfo method) } // Serialize asynchronously to support potential IAsyncEnumerable responses. - using MemoryStream stream = new(); + using PooledMemoryStream stream = new(); await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(false); - Utf8JsonReader reader = new(stream.GetBuffer().AsSpan(0, (int)stream.Length)); + Utf8JsonReader reader = new(stream.GetBuffer()); return JsonElement.ParseValue(ref reader); }