Skip to content

Commit

Permalink
Merge pull request #59 from cadifyai/feature/schema-to-function
Browse files Browse the repository at this point in the history
Convert from schema to function
  • Loading branch information
siliconlad authored Mar 25, 2024
2 parents d670f7f + c512298 commit f810159
Show file tree
Hide file tree
Showing 6 changed files with 376 additions and 37 deletions.
58 changes: 58 additions & 0 deletions tests/functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import typing
from enum import Enum

from tool2schema import GPTEnabled


Expand Down Expand Up @@ -43,3 +46,58 @@ def function_not_enabled(a: int, b: str) -> None:
:param b: This is another parameter
"""
pass


@GPTEnabled
def function_literal(a: typing.Literal[0, 1, 2], b: typing.Literal["a", "b", "c"] = "a"):
"""
This is a test function.
:param a: This is a parameter
:param b: This is another parameter
"""
return a, b


@GPTEnabled
def function_add_enum(a: str):
"""
This is a test function.
:param a: This is a parameter
"""
return a


function_add_enum.schema.add_enum("a", ["YES", "NO", "MAYBE"])


class CustomEnum(Enum):
"""
Custom enum for testing purposes.
"""

X = 0
Y = 1
Z = 2


@GPTEnabled
def function_enum(a: CustomEnum):
"""
This is a test function.
:param a: This is a parameter
"""
return a


@GPTEnabled
def function_union(a: typing.Optional[bool], b: typing.Union[str, int] = 4):
"""
This is a test function.
:param a: This is a parameter
:param b: This is another parameter
"""
return a, b
141 changes: 141 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import json
from typing import Callable, Union

import pytest

from tests import functions
from tool2schema import LoadGPTEnabled
from tool2schema.schema import ParseException

###############################################
# Helper method to get function dictionaries #
###############################################


def get_function_dict(func: Callable, arguments: Union[str, dict]):
"""
Get a dictionary representation of a function, with its name and the
specified argument values converted to a JSON string.
"""
if isinstance(arguments, dict):
arguments = json.dumps(arguments)

return {
"name": func.__name__,
"arguments": arguments,
}


######################################################
# Test function parsing with the 'functions' module #
######################################################


@pytest.mark.parametrize(
"function, arguments",
[
(functions.function, {"a": 1, "b": "test", "c": True, "d": [4, 5]}),
(functions.function, {"a": 1, "b": "test", "c": True, "d": []}),
(functions.function, {"a": 1, "b": "test"}), # Omit args with default values
(functions.function, {"a": 1, "b": "test", "d": [4, 5]}), # Omit c only
(functions.function_no_params, {}),
(functions.function_literal, {"a": 1}), # Omit b
(functions.function_literal, {"a": 1, "b": "b"}),
(functions.function_enum, {"a": functions.CustomEnum.Y.name}),
(functions.function_add_enum, {"a": "MAYBE"}),
(functions.function_union, {"a": True, "b": "x"}),
(functions.function_union, {"a": None, "b": 1}),
],
)
def test_load_function(function, arguments):

# Add hallucinated argument
hall_args = {**arguments, "hallucinated": 4}

f_dict = get_function_dict(function, hall_args)

f, args = LoadGPTEnabled(functions, f_dict)

assert f == function
assert args == arguments # Verify the hallucinated argument has been removed

f(**args) # Verify invoking the function does not throw an exception

# Verify we can pass the arguments as a dictionary
f, args = LoadGPTEnabled(functions, {"name": function.__name__, "arguments": arguments})

assert f == function
assert args == arguments

# Verify an exception is raised if hallucinations are not ignored
with pytest.raises(ParseException):
LoadGPTEnabled(
functions,
f_dict,
validate=True,
ignore_hallucinations=False,
)


def test_load_missing_function():
with pytest.raises(ParseException):
LoadGPTEnabled(functions, get_function_dict(functions.function_not_enabled, "{}"))


def test_load_invalid_arguments_type():
with pytest.raises(ParseException):
LoadGPTEnabled(
functions,
{
"name": functions.function.__name__,
"arguments": 23,
},
)


@pytest.mark.parametrize("arguments", ["", "{", "[]", "23", "null"])
def test_load_invalid_json_arguments(arguments):
with pytest.raises(ParseException):
LoadGPTEnabled(functions, get_function_dict(functions.function, arguments))


def test_load_missing_name():
with pytest.raises(ParseException):
LoadGPTEnabled(functions, {"arguments": "{}"})


def test_load_missing_arguments():
with pytest.raises(ParseException):
LoadGPTEnabled(functions, {"name": "function"})


@pytest.mark.parametrize(
"function, arguments",
[
(functions.function, {"a": 1}), # Missing required argument b
(functions.function, {"a": "x", "b": "test"}), # Invalid value for a
(functions.function, {"a": 1, "b": 0}), # Invalid value for b
(functions.function, {"a": 1, "b": "", "c": "x"}), # Invalid value for c
(functions.function, {"a": 1, "b": "", "c": True, "d": False}), # Invalid value for d
(functions.function, {"a": 1, "b": "", "c": True, "d": [1, "a"]}), # Invalid array value
(functions.function, {"a": 1, "e": 4}), # Missing argument and hallucinated arg
(functions.function_literal, {"a": 3}), # Invalid value for a
(functions.function_literal, {"a": 0, "b": "d"}), # Invalid value for b
(functions.function_enum, {"a": 1}),
(functions.function_enum, {"a": "A"}),
(functions.function_add_enum, {"a": "PERHAPS"}),
(functions.function_add_enum, {"a": "PERHAPS", "b": "MAYBE"}),
(functions.function_union, {"a": None, "b": False}), # Invalid value for b
(functions.function_union, {"a": "x", "b": 1}), # Invalid value for a
],
)
def test_load_invalid_argument_values(function, arguments):
f_obj = get_function_dict(function, arguments)

with pytest.raises(ParseException):
LoadGPTEnabled(functions, f_obj)

# Verify the function and the arguments are returned if validation is disabled
f, args = LoadGPTEnabled(functions, f_obj, validate=False)
assert f == function
assert args == arguments
70 changes: 34 additions & 36 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@


def test_FindGPTEnabled():
gpt_functions = FindGPTEnabled(functions)
# Check that the function is found
assert len(FindGPTEnabled(functions)) == 3
assert functions.function in FindGPTEnabled(functions)
assert functions.function_tags in FindGPTEnabled(functions)
assert functions.function_no_params in FindGPTEnabled(functions)
assert len(gpt_functions) == 7
assert functions.function in gpt_functions
assert functions.function_tags in gpt_functions
assert functions.function_no_params in gpt_functions
assert functions.function_no_params in gpt_functions
assert functions.function_no_params in gpt_functions
assert functions.function_literal in gpt_functions
assert functions.function_add_enum in gpt_functions
assert functions.function_enum in gpt_functions
assert functions.function_union in gpt_functions
# Check that the function is not found
assert functions.function_not_enabled not in FindGPTEnabled(functions)
assert functions.function_not_enabled not in gpt_functions


################################
Expand All @@ -37,39 +44,30 @@ def test_FindGPTEnabled():


def test_FindGPTEnabledSchemas():
gpt_schemas = FindGPTEnabledSchemas(functions)
# Check that the function is found
assert len(FindGPTEnabledSchemas(functions)) == 3
assert functions.function.schema.to_json() in FindGPTEnabledSchemas(functions)
assert functions.function_tags.schema.to_json() in FindGPTEnabledSchemas(functions)
assert functions.function_no_params.schema.to_json() in FindGPTEnabledSchemas(functions)


def test_FindGPTEnabledSchemas_API():
assert len(gpt_schemas) == 7
assert functions.function.schema.to_json() in gpt_schemas
assert functions.function_tags.schema.to_json() in gpt_schemas
assert functions.function_no_params.schema.to_json() in gpt_schemas
assert functions.function_literal.schema.to_json() in gpt_schemas
assert functions.function_add_enum.schema.to_json() in gpt_schemas
assert functions.function_enum.schema.to_json() in gpt_schemas
assert functions.function_union.schema.to_json() in gpt_schemas


@pytest.mark.parametrize("schema_type", [SchemaType.API, SchemaType.TUNE])
def test_FindGPTEnabledSchemas_with_type(schema_type):
# Check that the function is found
assert len(FindGPTEnabledSchemas(functions, schema_type=SchemaType.API)) == 3
assert functions.function.schema.to_json(SchemaType.API) in FindGPTEnabledSchemas(
functions, schema_type=SchemaType.API
)
assert functions.function_tags.schema.to_json(SchemaType.API) in FindGPTEnabledSchemas(
functions, schema_type=SchemaType.API
)
assert functions.function_no_params.schema.to_json(SchemaType.API) in FindGPTEnabledSchemas(
functions, schema_type=SchemaType.API
)


def test_FindGPTEnabledSchemas_TUNE():
# Check that the function is found
assert len(FindGPTEnabledSchemas(functions, schema_type=SchemaType.TUNE)) == 3
assert functions.function.schema.to_json(SchemaType.TUNE) in FindGPTEnabledSchemas(
functions, schema_type=SchemaType.TUNE
)
assert functions.function_tags.schema.to_json(SchemaType.TUNE) in FindGPTEnabledSchemas(
functions, schema_type=SchemaType.TUNE
)
assert functions.function_no_params.schema.to_json(SchemaType.TUNE) in FindGPTEnabledSchemas(
functions, schema_type=SchemaType.TUNE
)
gpt_schemas = FindGPTEnabledSchemas(functions, schema_type=schema_type)
assert len(gpt_schemas) == 7
assert functions.function.schema.to_json(schema_type) in gpt_schemas
assert functions.function_tags.schema.to_json(schema_type) in gpt_schemas
assert functions.function_no_params.schema.to_json(schema_type) in gpt_schemas
assert functions.function_literal.schema.to_json(schema_type) in gpt_schemas
assert functions.function_add_enum.schema.to_json(schema_type) in gpt_schemas
assert functions.function_enum.schema.to_json(schema_type) in gpt_schemas
assert functions.function_union.schema.to_json(schema_type) in gpt_schemas


###############################
Expand Down
1 change: 1 addition & 0 deletions tool2schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
FindGPTEnabledByTag,
FindGPTEnabledSchemas,
GPTEnabled,
LoadGPTEnabled,
SaveGPTEnabled,
SchemaType,
)
Expand Down
Loading

0 comments on commit f810159

Please sign in to comment.