Skip to content

Commit

Permalink
Fixed Type override for ParameterDescriptionAttribute
Browse files Browse the repository at this point in the history
Added InvalidFunctionCallException
Added / updated relevant tests
  • Loading branch information
Velvárt András committed Aug 6, 2023
1 parent 632ad9f commit a5d9337
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 31 deletions.
84 changes: 81 additions & 3 deletions OpenAI.Utilities.Tests/FunctionCallingHelperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,22 @@ public void VerifyGetFunctionDefinition()
functionDefinition.Parameters.Properties.ShouldContainKey("OverriddenName");
}

[Fact]
public void VerifyTypeOverride()
{
var functionDefinition = FunctionCallingHelper.GetFunctionDefinition(typeof(FunctionCallingTestClass).GetMethod("ThirdFunction")!);

var overriddenNameParameter = functionDefinition.Parameters.Properties["overriddenTypeParameter"];
overriddenNameParameter.Type.ShouldBe("string");
overriddenNameParameter.Description.ShouldBe("Overridden type parameter");
}

[Fact]
public void VerifyGetFunctionDefinitions()
{
var obj = new FunctionCallingTestClass();
var functionDefinitions = FunctionCallingHelper.GetFunctionDefinitions(obj);
var functionDefinitions = FunctionCallingHelper.GetFunctionDefinitions<FunctionCallingTestClass>();

functionDefinitions.Count.ShouldBe(2);
functionDefinitions.Count.ShouldBe(3);

var functionDefinition = functionDefinitions.First(x => x.Name == "TestFunction");
functionDefinition.Description.ShouldBe("Test Function");
Expand All @@ -64,6 +73,11 @@ public void VerifyGetFunctionDefinitions()
functionDefinition2.Description.ShouldBe("Second Function");
functionDefinition2.Parameters.ShouldNotBeNull();
functionDefinition2.Parameters.Properties!.Count.ShouldBe(0);

var functionDefinition3 = functionDefinitions.First(x => x.Name == "ThirdFunction");
functionDefinition3.Description.ShouldBe("Third Function");
functionDefinition3.Parameters.ShouldNotBeNull();
functionDefinition3.Parameters.Properties!.Count.ShouldBe(1);
}

[Fact]
Expand Down Expand Up @@ -121,6 +135,62 @@ public void VerifyCallFunction_ArgumentsDoNotMatch()
Should.Throw<Exception>(() => FunctionCallingHelper.CallFunction<int>(functionCall, obj));
}

[Fact]
public void CallFunctionShouldThrowIfObjIsNull()
{
var functionCall = new FunctionCall
{
Name = "SecondFunction",
};

Should.Throw<ArgumentNullException>(() => FunctionCallingHelper.CallFunction<string>(functionCall, null!));
}

[Fact]
public void CallFunctionShouldThrowIfFunctionCallIsNull()
{
var obj = new FunctionCallingTestClass();

Should.Throw<ArgumentNullException>(() => FunctionCallingHelper.CallFunction<string>(null!, obj));
}

[Fact]
public void CallFunctionShouldThrowIfFunctionCallNameIsNotSet()
{
var obj = new FunctionCallingTestClass();

var functionCall = new FunctionCall
{
Name = null!,
};

Should.Throw<InvalidFunctionCallException>(() => FunctionCallingHelper.CallFunction<string>(functionCall, obj));
}

[Fact]
public void CallFunctionShouldThrowIfFunctionCallNameIsNotValid()
{
var obj = new FunctionCallingTestClass();

var functionCall = new FunctionCall
{
Name = "NonExistentFunction",
};

Should.Throw<InvalidFunctionCallException>(() => FunctionCallingHelper.CallFunction<string>(functionCall, obj));
}

[Fact]
public void CallFunctionShouldThrowIfInvalidReturnType()
{
var obj = new FunctionCallingTestClass();
var functionCall = new FunctionCall()
{
Name = "SecondFunction",
};

Should.Throw<InvalidFunctionCallException>(() => FunctionCallingHelper.CallFunction<int>(functionCall, obj));
}
}

internal class FunctionCallingTestClass
Expand All @@ -134,6 +204,7 @@ internal class FunctionCallingTestClass
public int RequiredIntParameter;
public int? NotRequiredIntParameter;
public int OverriddenNameParameter;
public string OverriddenTypeParameter = null!;

[FunctionDescription("Test Function")]
public int TestFunction(
Expand All @@ -146,6 +217,7 @@ public int TestFunction(
[ParameterDescription(Description = "Required Int Parameter", Required= true)] int requiredIntParameter,
[ParameterDescription(Description = "Not required Int Parameter", Required = false)] int notRequiredIntParameter,
[ParameterDescription(Name = "OverriddenName", Description = "Overridden")] int overriddenNameParameter)

{
IntParameter = intParameter;
FloatParameter = floatParameter;
Expand All @@ -165,6 +237,12 @@ public string SecondFunction()
{
return "Hello";
}

[FunctionDescription("Third Function")]
public void ThirdFunction([ParameterDescription(Type = "string", Description = "Overridden type parameter")] int overriddenTypeParameter)
{
OverriddenTypeParameter = overriddenTypeParameter.ToString();
}
}

public enum TestEnum
Expand Down
79 changes: 51 additions & 28 deletions OpenAI.Utilities/FunctionCallingHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,40 @@ public static FunctionDefinition GetFunctionDefinition(MethodInfo methodInfo)

PropertyDefinition definition;

switch (parameter.ParameterType)
switch (parameter.ParameterType, parameterDescriptionAttribute?.Type == null)
{
case { } t when t.IsAssignableFrom(typeof(int)):
definition = PropertyDefinition.DefineInteger(description);
break;
case { } t when t.IsAssignableFrom(typeof(float)):
definition = PropertyDefinition.DefineNumber(description);
break;
case { } t when t.IsAssignableFrom(typeof(bool)):
definition = PropertyDefinition.DefineBoolean(description);
break;
case { } t when t.IsAssignableFrom(typeof(string)):
definition = PropertyDefinition.DefineString(description);
break;
case { IsEnum: true }:

var enumValues = string.IsNullOrEmpty(parameterDescriptionAttribute?.Enum)
? Enum.GetNames(parameter.ParameterType).ToList()
: parameterDescriptionAttribute.Enum.Split(",").Select(x => x.Trim()).ToList();


definition =
PropertyDefinition.DefineEnum(enumValues, description);
break;
default:
throw new Exception($"Parameter type '{parameter.ParameterType}' not supported");
case (_, false):
definition = new PropertyDefinition()
{
Type = parameterDescriptionAttribute!.Type!,
Description = description,
};

break;
case ({ } t, _) when t.IsAssignableFrom(typeof(int)):
definition = PropertyDefinition.DefineInteger(description);
break;
case ({ } t, _) when t.IsAssignableFrom(typeof(float)):
definition = PropertyDefinition.DefineNumber(description);
break;
case ({ } t, _) when t.IsAssignableFrom(typeof(bool)):
definition = PropertyDefinition.DefineBoolean(description);
break;
case ({ } t, _) when t.IsAssignableFrom(typeof(string)):
definition = PropertyDefinition.DefineString(description);
break;
case ({ IsEnum: true }, _):

var enumValues = string.IsNullOrEmpty(parameterDescriptionAttribute?.Enum)
? Enum.GetNames(parameter.ParameterType).ToList()
: parameterDescriptionAttribute.Enum.Split(",").Select(x => x.Trim()).ToList();

definition =
PropertyDefinition.DefineEnum(enumValues, description);

break;
default:
throw new Exception($"Parameter type '{parameter.ParameterType}' not supported");
}

result.AddParameter(
Expand Down Expand Up @@ -122,15 +130,18 @@ public static List<FunctionDefinition> GetFunctionDefinitions(Type type)
throw new ArgumentNullException(nameof(functionCall));

if (functionCall.Name == null)
throw new Exception("Function name is null");
throw new InvalidFunctionCallException("Function Name is null");

if (obj == null)
throw new ArgumentNullException(nameof(obj));

var methodInfo = obj.GetType().GetMethod(functionCall.Name);

if (methodInfo == null)
throw new Exception($"Method '{functionCall.Name}' on type '{obj.GetType()}' not found");
throw new InvalidFunctionCallException($"Method '{functionCall.Name}' on type '{obj.GetType()}' not found");

if (!methodInfo.ReturnType.IsAssignableTo(typeof(T)))
throw new Exception(
throw new InvalidFunctionCallException(
$"Method '{functionCall.Name}' on type '{obj.GetType()}' has return type '{methodInfo.ReturnType}' but expected '{typeof(T)}'");

var parameters = methodInfo.GetParameters().ToList();
Expand Down Expand Up @@ -160,6 +171,18 @@ public static List<FunctionDefinition> GetFunctionDefinitions(Type type)
}
}

/// <summary>
/// Exception thrown when a function call is invalid
/// </summary>
public class InvalidFunctionCallException : Exception
{
/// <summary>
/// Creates a new instance of the <see cref="InvalidFunctionCallException"/> with the provided message
/// </summary>
public InvalidFunctionCallException(string message) : base(message)
{ }
}

/// <summary>
/// Attribute to mark a method as a function, and provide a description for the function. Can also be used to override the Name of the function
/// </summary>
Expand Down
6 changes: 6 additions & 0 deletions OpenAI.lutconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<LUTConfig Version="1.0">
<Repository />
<ParallelBuilds>true</ParallelBuilds>
<ParallelTestRuns>true</ParallelTestRuns>
<TestCaseTimeout>180000</TestCaseTimeout>
</LUTConfig>

0 comments on commit a5d9337

Please sign in to comment.