Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/support enumeration type hints #37

Merged
merged 20 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
f3ee51f
Add support for typing.Literal #10
lorenzocelli Feb 8, 2024
aa8da6b
Add support for enum.Enum #10
lorenzocelli Feb 8, 2024
0347fa9
Merge branch 'main' into feature/support-enumeration-type-hints
lorenzocelli Feb 10, 2024
15ecfdf
Add LiteralParameterSchema and EnumParameterSchema
lorenzocelli Feb 10, 2024
b8f27c9
Merge branch 'main' into feature/support-enumeration-type-hints
lorenzocelli Feb 11, 2024
f61c784
Use ReferenceSchema for enum tests
lorenzocelli Feb 11, 2024
0ceab9d
Store ParameterSchema instances in FunctionSchema and convert enum va…
lorenzocelli Feb 11, 2024
2471860
Convert enum to names rather than values
lorenzocelli Feb 11, 2024
cb80329
Merge branch 'main' into feature/support-enumeration-type-hints
lorenzocelli Feb 11, 2024
cda7b51
Merge branch 'main' into feature/support-enumeration-type-hints
lorenzocelli Feb 12, 2024
3a9fd4a
Avoid populating the schema before calling to_json
lorenzocelli Feb 12, 2024
4f42b4c
Refactor ParameterSchema add-to logic to get logic
lorenzocelli Feb 15, 2024
6057e6f
Support positional arguments value parsing
lorenzocelli Feb 15, 2024
a420a7d
Support enum.Enum parameter with default value
lorenzocelli Feb 15, 2024
1354c64
Fix remove_param
lorenzocelli Feb 15, 2024
e72566f
Fix get_required_parameters type hint
lorenzocelli Feb 17, 2024
ad025df
Use dictionary comprehension and check value type
lorenzocelli Feb 17, 2024
4f5b663
Update README
lorenzocelli Feb 17, 2024
4c21510
Simplify dictionary comprehension
lorenzocelli Feb 17, 2024
6d659b4
Add comment in ParameterSchema._get_default
lorenzocelli Feb 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 110 additions & 11 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from typing import Callable, List, Optional
from enum import Enum
from typing import List, Optional, Literal, Callable

from tool2schema import (
FindGPTEnabled,
Expand Down Expand Up @@ -149,8 +150,8 @@ class ReferenceSchema:
def __init__(self, f: Callable, reference_schema: Optional[dict] = None):
"""
Initialize the schema.
:param f: The function to create the schema for.
:param reference_schema: The schema to start with, defaults to DEFAULT_SCHEMA.
:param f: The function to create the schema for
:param reference_schema: The schema to start with, defaults to DEFAULT_SCHEMA
"""
self.schema = copy.deepcopy(reference_schema or DEFAULT_SCHEMA)
self.schema["function"]["name"] = f.__name__
Expand All @@ -166,7 +167,7 @@ def remove_param(self, param: str) -> None:
"""
Remove a parameter from the schema.

:param param: Name of the parameter to remove.
:param param: Name of the parameter to remove
"""
self.schema["function"]["parameters"]["properties"].pop(param)
self.schema["function"]["parameters"]["required"].pop(param, None)
Expand All @@ -175,17 +176,17 @@ def get_param(self, param: str) -> dict:
"""
Get a parameter dictionary from the schema.

:param param: Name of the parameter.
:return: The parameter dictionary.
:param param: Name of the parameter
:return: The parameter dictionary
"""
return self.schema["function"]["parameters"]["properties"][param]

def set_param(self, param, value: dict) -> None:
"""
Set a parameter dictionary.

:param param: Name of the parameter.
:param value: The new parameter dictionary.
:param param: Name of the parameter
:param value: The new parameter dictionary
"""
self.schema["function"]["parameters"]["properties"][param] = value

Expand Down Expand Up @@ -250,9 +251,9 @@ def test_function_tags_tune():
assert function_tags.tags == ["test"]


########################################
# Example function to test with enum #
########################################
#########################################################
# Example function to test with enum (using add_enum) #
#########################################################


@GPTEnabled
Expand Down Expand Up @@ -577,3 +578,101 @@ def test_function_docstring():

assert function_docstring.schema.to_json() == rf.schema
assert function_docstring.tags == []


######################################################
# Example functions with typing.Literal annotation #
######################################################


@GPTEnabled
def function_typing_literal_int(
a: Literal[1, 2, 3], b: str, c: bool = False, d: list[int] = [1, 2, 3]
):
"""
This is a test function.

:param a: This is a parameter
:param b: This is another parameter
:param c: This is a boolean parameter
:param d: This is a list parameter
"""
return a, b, c, d


def test_function_typing_literal_int():
# Check schema
rf = ReferenceSchema(function_typing_literal_int)
rf.get_param("a")["enum"] = [1, 2, 3]
assert function_typing_literal_int.schema.to_json() == rf.schema
assert function_typing_literal_int.tags == []


@GPTEnabled
def function_typing_literal_string(
a: Literal["a", "b", "c"], b: str, c: bool = False, d: list[int] = [1, 2, 3]
):
"""
This is a test function.

:param a: This is a parameter
:param b: This is another parameter
:param c: This is a boolean parameter
:param d: This is a list parameter
"""
return a, b, c, d


def test_function_typing_literal_string():
# Check schema
rf = ReferenceSchema(function_typing_literal_string)
rf.get_param("a")["enum"] = ["a", "b", "c"]
rf.get_param("a")["type"] = "string"
assert function_typing_literal_string.schema.to_json() == rf.schema
assert function_typing_literal_string.tags == []


#################################################
# Example functions with enum.Enum annotation #
#################################################


class CustomEnum(Enum):
A = 1
B = 2
C = 3


@GPTEnabled
def function_custom_enum(
a: CustomEnum, b: str, c: bool = False, d: list[int] = [1, 2, 3]
):
"""
This is a test function.

:param a: This is a parameter
:param b: This is another parameter
:param c: This is a boolean parameter
:param d: This is a list parameter
"""
return a, b, c, d


def test_function_custom_enum():
rf = ReferenceSchema(function_custom_enum)
rf.get_param("a")["type"] = "string"
rf.get_param("a")["enum"] = [x.name for x in CustomEnum]
assert function_custom_enum.schema.to_json() == rf.schema
assert function_custom_enum.tags == []

# Try invoking the function to verify that "A" is converted to CustomEnum.A
a, _, _, _ = function_custom_enum(a=CustomEnum.A.name, b="", c=False, d=[])
assert a == CustomEnum.A

# Verify it is possible to invoke the function with the Enum instance
a, _, _, _ = function_custom_enum(a=CustomEnum.A, b="", c=False, d=[])
assert a == CustomEnum.A

# Verify it is possible to invoke the function with positional args
a, _, _, _ = function_custom_enum(CustomEnum.A, "", False, [])
assert a == CustomEnum.A
77 changes: 76 additions & 1 deletion tool2schema/parameter_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import typing
from inspect import Parameter
from enum import Enum
from inspect import Parameter, isclass

TYPE_MAP = {
"int": "integer",
Expand Down Expand Up @@ -93,6 +94,18 @@ def to_json(self) -> dict:
self._add_enum(json)
return json

def parse_value(self, value):
"""
Convert the given value from the JSON representation to an instance
that can be passed to the original method as a parameter. Overriding
methods should check whether the value needs to be converted, and return
it as is if no conversion is necessary.

:param value: The value to be converted
:return: An instance of the type required by the original method
"""
return value


class ValueTypeSchema(ParameterSchema):
"""
Expand Down Expand Up @@ -162,11 +175,73 @@ def _add_type(self, schema: dict):
schema["type"] = sub_type


class EnumParameterSchema(ParameterSchema):
"""
Parameter schema for Enum types.
"""

def __init__(self, parameter: Parameter, docstring: str = None):
super().__init__(parameter, docstring)
self.enum_names = [e.name for e in parameter.annotation]

@staticmethod
def matches(parameter: Parameter) -> bool:
return (
parameter.annotation != parameter.empty
and isclass(parameter.annotation)
and issubclass(parameter.annotation, Enum)
)

def _add_type(self, schema: dict):
schema["type"] = TYPE_MAP["str"]

def _add_enum(self, schema: dict):
schema["enum"] = self.enum_names

def parse_value(self, value):
"""
Convert an enum name to an instance of the enum type.

:param value: The enum name to be converted
"""
if value in self.enum_names:
# Convert to an enum instance
return self.parameter.annotation[value]

# The user is invoking the method directly
return value


class LiteralParameterSchema(ParameterSchema):
"""
Parameter schema for typing.Literal types.
"""

def __init__(self, parameter: Parameter, docstring: str = None):
super().__init__(parameter, docstring)
self.enum_values = list(typing.get_args(parameter.annotation))

@staticmethod
def matches(parameter: Parameter) -> bool:
return (
parameter.annotation != parameter.empty
and typing.get_origin(parameter.annotation) is typing.Literal
)

def _add_type(self, schema: dict):
schema["type"] = TYPE_MAP.get(type(self.enum_values[0]).__name__, "object")

def _add_enum(self, schema: dict):
schema["enum"] = self.enum_values


# Order matters: specific classes should appear before more generic ones;
# for example, ListParameterSchema must precede ValueTypeSchema,
# as they both match list types
PARAMETER_SCHEMAS = [
OptionalParameterSchema,
LiteralParameterSchema,
EnumParameterSchema,
ListParameterSchema,
ValueTypeSchema,
]
Loading
Loading