From a5d93376d4f7ed7f02174a7a53efb76d0e1dded9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Velv=C3=A1rt=20Andr=C3=A1s?= Date: Sun, 6 Aug 2023 14:57:55 +0200 Subject: [PATCH] Fixed Type override for ParameterDescriptionAttribute Added InvalidFunctionCallException Added / updated relevant tests --- .../FunctionCallingHelperTests.cs | 84 ++++++++++++++++++- OpenAI.Utilities/FunctionCallingHelper.cs | 79 ++++++++++------- OpenAI.lutconfig | 6 ++ 3 files changed, 138 insertions(+), 31 deletions(-) create mode 100644 OpenAI.lutconfig diff --git a/OpenAI.Utilities.Tests/FunctionCallingHelperTests.cs b/OpenAI.Utilities.Tests/FunctionCallingHelperTests.cs index 50099f73..cab2056e 100644 --- a/OpenAI.Utilities.Tests/FunctionCallingHelperTests.cs +++ b/OpenAI.Utilities.Tests/FunctionCallingHelperTests.cs @@ -47,13 +47,22 @@ public void VerifyGetFunctionDefinition() functionDefinition.Parameters.Properties.ShouldContainKey("OverriddenName"); } + [Fact] + public void VerifyTypeOverride() + { + var functionDefinition = FunctionCallingHelper.GetFunctionDefinition(typeof(FunctionCallingTestClass).GetMethod("ThirdFunction")!); + + var overriddenNameParameter = functionDefinition.Parameters.Properties["overriddenTypeParameter"]; + overriddenNameParameter.Type.ShouldBe("string"); + overriddenNameParameter.Description.ShouldBe("Overridden type parameter"); + } + [Fact] public void VerifyGetFunctionDefinitions() { - var obj = new FunctionCallingTestClass(); - var functionDefinitions = FunctionCallingHelper.GetFunctionDefinitions(obj); + var functionDefinitions = FunctionCallingHelper.GetFunctionDefinitions(); - functionDefinitions.Count.ShouldBe(2); + functionDefinitions.Count.ShouldBe(3); var functionDefinition = functionDefinitions.First(x => x.Name == "TestFunction"); functionDefinition.Description.ShouldBe("Test Function"); @@ -64,6 +73,11 @@ public void VerifyGetFunctionDefinitions() functionDefinition2.Description.ShouldBe("Second Function"); functionDefinition2.Parameters.ShouldNotBeNull(); functionDefinition2.Parameters.Properties!.Count.ShouldBe(0); + + var functionDefinition3 = functionDefinitions.First(x => x.Name == "ThirdFunction"); + functionDefinition3.Description.ShouldBe("Third Function"); + functionDefinition3.Parameters.ShouldNotBeNull(); + functionDefinition3.Parameters.Properties!.Count.ShouldBe(1); } [Fact] @@ -121,6 +135,62 @@ public void VerifyCallFunction_ArgumentsDoNotMatch() Should.Throw(() => FunctionCallingHelper.CallFunction(functionCall, obj)); } + [Fact] + public void CallFunctionShouldThrowIfObjIsNull() + { + var functionCall = new FunctionCall + { + Name = "SecondFunction", + }; + + Should.Throw(() => FunctionCallingHelper.CallFunction(functionCall, null!)); + } + + [Fact] + public void CallFunctionShouldThrowIfFunctionCallIsNull() + { + var obj = new FunctionCallingTestClass(); + + Should.Throw(() => FunctionCallingHelper.CallFunction(null!, obj)); + } + + [Fact] + public void CallFunctionShouldThrowIfFunctionCallNameIsNotSet() + { + var obj = new FunctionCallingTestClass(); + + var functionCall = new FunctionCall + { + Name = null!, + }; + + Should.Throw(() => FunctionCallingHelper.CallFunction(functionCall, obj)); + } + + [Fact] + public void CallFunctionShouldThrowIfFunctionCallNameIsNotValid() + { + var obj = new FunctionCallingTestClass(); + + var functionCall = new FunctionCall + { + Name = "NonExistentFunction", + }; + + Should.Throw(() => FunctionCallingHelper.CallFunction(functionCall, obj)); + } + + [Fact] + public void CallFunctionShouldThrowIfInvalidReturnType() + { + var obj = new FunctionCallingTestClass(); + var functionCall = new FunctionCall() + { + Name = "SecondFunction", + }; + + Should.Throw(() => FunctionCallingHelper.CallFunction(functionCall, obj)); + } } internal class FunctionCallingTestClass @@ -134,6 +204,7 @@ internal class FunctionCallingTestClass public int RequiredIntParameter; public int? NotRequiredIntParameter; public int OverriddenNameParameter; + public string OverriddenTypeParameter = null!; [FunctionDescription("Test Function")] public int TestFunction( @@ -146,6 +217,7 @@ public int TestFunction( [ParameterDescription(Description = "Required Int Parameter", Required= true)] int requiredIntParameter, [ParameterDescription(Description = "Not required Int Parameter", Required = false)] int notRequiredIntParameter, [ParameterDescription(Name = "OverriddenName", Description = "Overridden")] int overriddenNameParameter) + { IntParameter = intParameter; FloatParameter = floatParameter; @@ -165,6 +237,12 @@ public string SecondFunction() { return "Hello"; } + + [FunctionDescription("Third Function")] + public void ThirdFunction([ParameterDescription(Type = "string", Description = "Overridden type parameter")] int overriddenTypeParameter) + { + OverriddenTypeParameter = overriddenTypeParameter.ToString(); + } } public enum TestEnum diff --git a/OpenAI.Utilities/FunctionCallingHelper.cs b/OpenAI.Utilities/FunctionCallingHelper.cs index 1e73e081..90e088a6 100644 --- a/OpenAI.Utilities/FunctionCallingHelper.cs +++ b/OpenAI.Utilities/FunctionCallingHelper.cs @@ -32,32 +32,40 @@ public static FunctionDefinition GetFunctionDefinition(MethodInfo methodInfo) PropertyDefinition definition; - switch (parameter.ParameterType) + switch (parameter.ParameterType, parameterDescriptionAttribute?.Type == null) { - case { } t when t.IsAssignableFrom(typeof(int)): - definition = PropertyDefinition.DefineInteger(description); - break; - case { } t when t.IsAssignableFrom(typeof(float)): - definition = PropertyDefinition.DefineNumber(description); - break; - case { } t when t.IsAssignableFrom(typeof(bool)): - definition = PropertyDefinition.DefineBoolean(description); - break; - case { } t when t.IsAssignableFrom(typeof(string)): - definition = PropertyDefinition.DefineString(description); - break; - case { IsEnum: true }: - - var enumValues = string.IsNullOrEmpty(parameterDescriptionAttribute?.Enum) - ? Enum.GetNames(parameter.ParameterType).ToList() - : parameterDescriptionAttribute.Enum.Split(",").Select(x => x.Trim()).ToList(); - - - definition = - PropertyDefinition.DefineEnum(enumValues, description); - break; - default: - throw new Exception($"Parameter type '{parameter.ParameterType}' not supported"); + case (_, false): + definition = new PropertyDefinition() + { + Type = parameterDescriptionAttribute!.Type!, + Description = description, + }; + + break; + case ({ } t, _) when t.IsAssignableFrom(typeof(int)): + definition = PropertyDefinition.DefineInteger(description); + break; + case ({ } t, _) when t.IsAssignableFrom(typeof(float)): + definition = PropertyDefinition.DefineNumber(description); + break; + case ({ } t, _) when t.IsAssignableFrom(typeof(bool)): + definition = PropertyDefinition.DefineBoolean(description); + break; + case ({ } t, _) when t.IsAssignableFrom(typeof(string)): + definition = PropertyDefinition.DefineString(description); + break; + case ({ IsEnum: true }, _): + + var enumValues = string.IsNullOrEmpty(parameterDescriptionAttribute?.Enum) + ? Enum.GetNames(parameter.ParameterType).ToList() + : parameterDescriptionAttribute.Enum.Split(",").Select(x => x.Trim()).ToList(); + + definition = + PropertyDefinition.DefineEnum(enumValues, description); + + break; + default: + throw new Exception($"Parameter type '{parameter.ParameterType}' not supported"); } result.AddParameter( @@ -122,15 +130,18 @@ public static List GetFunctionDefinitions(Type type) throw new ArgumentNullException(nameof(functionCall)); if (functionCall.Name == null) - throw new Exception("Function name is null"); + throw new InvalidFunctionCallException("Function Name is null"); + + if (obj == null) + throw new ArgumentNullException(nameof(obj)); var methodInfo = obj.GetType().GetMethod(functionCall.Name); if (methodInfo == null) - throw new Exception($"Method '{functionCall.Name}' on type '{obj.GetType()}' not found"); + throw new InvalidFunctionCallException($"Method '{functionCall.Name}' on type '{obj.GetType()}' not found"); if (!methodInfo.ReturnType.IsAssignableTo(typeof(T))) - throw new Exception( + throw new InvalidFunctionCallException( $"Method '{functionCall.Name}' on type '{obj.GetType()}' has return type '{methodInfo.ReturnType}' but expected '{typeof(T)}'"); var parameters = methodInfo.GetParameters().ToList(); @@ -160,6 +171,18 @@ public static List GetFunctionDefinitions(Type type) } } +/// +/// Exception thrown when a function call is invalid +/// +public class InvalidFunctionCallException : Exception +{ + /// + /// Creates a new instance of the with the provided message + /// + public InvalidFunctionCallException(string message) : base(message) + { } +} + /// /// Attribute to mark a method as a function, and provide a description for the function. Can also be used to override the Name of the function /// diff --git a/OpenAI.lutconfig b/OpenAI.lutconfig new file mode 100644 index 00000000..596a8603 --- /dev/null +++ b/OpenAI.lutconfig @@ -0,0 +1,6 @@ + + + true + true + 180000 + \ No newline at end of file