diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs index 6534e041a7c..d5274186645 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs @@ -7,7 +7,6 @@ using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.IO; #if !NET using System.Linq; @@ -374,16 +373,15 @@ public static AIFunction Create(MethodInfo method, object? target, string? name } /// - /// Creates an instance for a method, specified via an for - /// and instance method, along with a representing the type of the target object to - /// instantiate each time the method is invoked. + /// Creates an instance for a method, specified via a for + /// an instance method and a for constructing an instance of + /// the receiver object each time the is invoked. /// /// The instance method to be represented via the created . - /// - /// The to construct an instance of on which to invoke when - /// the resulting is invoked. is used, - /// utilizing the type's public parameterless constructor. If an instance can't be constructed, an exception is - /// thrown during the function's invocation. + /// + /// Callback used on each function invocation to create an instance of the type on which the instance method + /// will be invoked. If the returned instance is or , it will be disposed of + /// after completes its invocation. /// /// Metadata to use to override defaults inferred from . /// The created for invoking . @@ -457,22 +455,16 @@ public static AIFunction Create(MethodInfo method, object? target, string? name /// /// /// is . - /// is . + /// is . /// represents a static method. /// represents an open generic method. /// contains a parameter without a parameter name. - /// is not assignable to 's declaring type. /// A parameter to or its return type is not serializable. public static AIFunction Create( MethodInfo method, - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, - AIFunctionFactoryOptions? options = null) - { - _ = Throw.IfNull(method); - _ = Throw.IfNull(targetType); - - return ReflectionAIFunction.Build(method, targetType, options ?? _defaultOptions); - } + Func createInstanceFunc, + AIFunctionFactoryOptions? options = null) => + ReflectionAIFunction.Build(method, createInstanceFunc, options ?? _defaultOptions); private sealed class ReflectionAIFunction : AIFunction { @@ -503,10 +495,11 @@ public static ReflectionAIFunction Build(MethodInfo method, object? target, AIFu public static ReflectionAIFunction Build( MethodInfo method, - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + Func createInstanceFunc, AIFunctionFactoryOptions options) { _ = Throw.IfNull(method); + _ = Throw.IfNull(createInstanceFunc); if (method.ContainsGenericParameters) { @@ -518,13 +511,7 @@ public static ReflectionAIFunction Build( Throw.ArgumentException(nameof(method), "The method must be an instance method."); } - if (method.DeclaringType is { } declaringType && - !declaringType.IsAssignableFrom(targetType)) - { - Throw.ArgumentException(nameof(targetType), "The target type must be assignable to the method's declaring type."); - } - - return new(ReflectionAIFunctionDescriptor.GetOrCreate(method, options), targetType, options); + return new(ReflectionAIFunctionDescriptor.GetOrCreate(method, options), createInstanceFunc, options); } private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, object? target, AIFunctionFactoryOptions options) @@ -536,20 +523,17 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, private ReflectionAIFunction( ReflectionAIFunctionDescriptor functionDescriptor, - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + Func createInstanceFunc, AIFunctionFactoryOptions options) { FunctionDescriptor = functionDescriptor; - TargetType = targetType; - CreateInstance = options.CreateInstance; + CreateInstanceFunc = createInstanceFunc; AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary.Instance; } public ReflectionAIFunctionDescriptor FunctionDescriptor { get; } public object? Target { get; } - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] - public Type? TargetType { get; } - public Func? CreateInstance { get; } + public Func? CreateInstanceFunc { get; } public override IReadOnlyDictionary AdditionalProperties { get; } public override string Name => FunctionDescriptor.Name; @@ -566,14 +550,12 @@ private ReflectionAIFunction( object? target = Target; try { - if (TargetType is { } targetType) + if (CreateInstanceFunc is { } func) { Debug.Assert(target is null, "Expected target to be null when we have a non-null target type"); Debug.Assert(!FunctionDescriptor.Method.IsStatic, "Expected an instance method"); - target = CreateInstance is not null ? - CreateInstance(targetType, arguments) : - Activator.CreateInstance(targetType); + target = func(arguments); if (target is null) { Throw.InvalidOperationException("Unable to create an instance of the target type."); diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactoryOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactoryOptions.cs index 80ff394359d..e71a4687422 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactoryOptions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactoryOptions.cs @@ -106,24 +106,6 @@ public AIFunctionFactoryOptions() /// public Func>? MarshalResult { get; set; } - /// - /// Gets or sets a delegate used with to create the receiver instance. - /// - /// - /// - /// creates instances that invoke an - /// instance method on the specified . This delegate is used to create the instance of the type that will be used to invoke the method. - /// By default if is , is used. If - /// is non-, the delegate is invoked with the to be instantiated and the - /// provided to the method. - /// - /// - /// Each created instance will be used for a single invocation. If the object is or , it will - /// be disposed of after the invocation completes. - /// - /// - public Func? CreateInstance { get; set; } - /// Provides configuration options produced by the delegate. public readonly record struct ParameterBindingOptions { diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs index 4f5037fc92d..6d448efb710 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs @@ -29,7 +29,8 @@ public void InvalidArguments_Throw() Assert.Throws("method", () => AIFunctionFactory.Create(method: null!, target: new object())); Assert.Throws("method", () => AIFunctionFactory.Create(method: null!, target: new object(), name: "myAiFunk")); Assert.Throws("target", () => AIFunctionFactory.Create(typeof(AIFunctionFactoryTest).GetMethod(nameof(InvalidArguments_Throw))!, (object?)null)); - Assert.Throws("targetType", () => AIFunctionFactory.Create(typeof(AIFunctionFactoryTest).GetMethod(nameof(InvalidArguments_Throw))!, (Type)null!)); + Assert.Throws("createInstanceFunc", () => + AIFunctionFactory.Create(typeof(AIFunctionFactoryTest).GetMethod(nameof(InvalidArguments_Throw))!, (Func)null!)); Assert.Throws("method", () => AIFunctionFactory.Create(typeof(List<>).GetMethod("Add")!, new List())); } @@ -312,16 +313,12 @@ public async Task Create_NoInstance_UsesActivatorUtilitiesWhenServicesAvailable( AIFunction func = AIFunctionFactory.Create( typeof(MyFunctionTypeWithOneArg).GetMethod(nameof(MyFunctionTypeWithOneArg.InstanceMethod))!, - typeof(MyFunctionTypeWithOneArg), - new() + static arguments => { - CreateInstance = (type, arguments) => - { - Assert.NotNull(arguments.Services); - return ActivatorUtilities.CreateInstance(arguments.Services, type); - }, - MarshalResult = (result, type, cancellationToken) => new ValueTask(result), - }); + Assert.NotNull(arguments.Services); + return ActivatorUtilities.CreateInstance(arguments.Services, typeof(MyFunctionTypeWithOneArg)); + }, + new() { MarshalResult = (result, type, cancellationToken) => new ValueTask(result) }); Assert.NotNull(func); var result = (Tuple?)await func.InvokeAsync(new() { Services = sp }); @@ -330,31 +327,25 @@ public async Task Create_NoInstance_UsesActivatorUtilitiesWhenServicesAvailable( } [Fact] - public async Task Create_NoInstance_UsesActivatorWhenServicesUnavailable() + public async Task Create_CreateInstanceReturnsNull_ThrowsDuringInvocation() { AIFunction func = AIFunctionFactory.Create( - typeof(MyFunctionTypeWithNoArgs).GetMethod(nameof(MyFunctionTypeWithNoArgs.InstanceMethod))!, - typeof(MyFunctionTypeWithNoArgs), - new() - { - MarshalResult = (result, type, cancellationToken) => new ValueTask(result), - }); + typeof(MyFunctionTypeWithOneArg).GetMethod(nameof(MyFunctionTypeWithOneArg.InstanceMethod))!, + static _ => null!); Assert.NotNull(func); - Assert.Equal("42", await func.InvokeAsync()); + await Assert.ThrowsAsync(async () => await func.InvokeAsync()); } [Fact] - public async Task Create_NoInstance_ThrowsWhenCantConstructInstance() + public async Task Create_WrongConstructedType_ThrowsDuringInvocation() { - var sp = new ServiceCollection().BuildServiceProvider(); - AIFunction func = AIFunctionFactory.Create( typeof(MyFunctionTypeWithOneArg).GetMethod(nameof(MyFunctionTypeWithOneArg.InstanceMethod))!, - typeof(MyFunctionTypeWithOneArg)); + static _ => new MyFunctionTypeWithNoArgs()); Assert.NotNull(func); - await Assert.ThrowsAsync(async () => await func.InvokeAsync(new() { Services = sp })); + await Assert.ThrowsAsync(async () => await func.InvokeAsync()); } [Fact] @@ -362,15 +353,7 @@ public void Create_NoInstance_ThrowsForStaticMethod() { Assert.Throws("method", () => AIFunctionFactory.Create( typeof(MyFunctionTypeWithNoArgs).GetMethod(nameof(MyFunctionTypeWithNoArgs.StaticMethod))!, - typeof(MyFunctionTypeWithNoArgs))); - } - - [Fact] - public void Create_NoInstance_ThrowsForMismatchedMethod() - { - Assert.Throws("targetType", () => AIFunctionFactory.Create( - typeof(MyFunctionTypeWithNoArgs).GetMethod(nameof(MyFunctionTypeWithNoArgs.InstanceMethod))!, - typeof(MyFunctionTypeWithOneArg))); + static _ => new MyFunctionTypeWithNoArgs())); } [Fact] @@ -378,7 +361,7 @@ public async Task Create_NoInstance_DisposableInstanceCreatedDisposedEachInvocat { AIFunction func = AIFunctionFactory.Create( typeof(DisposableService).GetMethod(nameof(DisposableService.GetThis))!, - typeof(DisposableService), + static _ => new DisposableService(), new() { MarshalResult = (result, type, cancellationToken) => new ValueTask(result), @@ -397,7 +380,7 @@ public async Task Create_NoInstance_AsyncDisposableInstanceCreatedDisposedEachIn { AIFunction func = AIFunctionFactory.Create( typeof(AsyncDisposableService).GetMethod(nameof(AsyncDisposableService.GetThis))!, - typeof(AsyncDisposableService), + static _ => new AsyncDisposableService(), new() { MarshalResult = (result, type, cancellationToken) => new ValueTask(result), @@ -416,7 +399,7 @@ public async Task Create_NoInstance_DisposableAndAsyncDisposableInstanceCreatedD { AIFunction func = AIFunctionFactory.Create( typeof(DisposableAndAsyncDisposableService).GetMethod(nameof(DisposableAndAsyncDisposableService.GetThis))!, - typeof(DisposableAndAsyncDisposableService), + static _ => new DisposableAndAsyncDisposableService(), new() { MarshalResult = (result, type, cancellationToken) => new ValueTask(result), @@ -821,11 +804,7 @@ public ValueTask DisposeAsync() private sealed class MyFunctionTypeWithNoArgs { - private string _value = "42"; - public static void StaticMethod() => throw new NotSupportedException(); - - public string InstanceMethod() => _value; } private sealed class MyFunctionTypeWithOneArg(MyArgumentType arg)