Skip to content

Commit

Permalink
Merge pull request #80 from cadifyai/72-fix-pyright-errors
Browse files Browse the repository at this point in the history
Fix pyright errors
  • Loading branch information
siliconlad authored May 1, 2024
2 parents 1d551b4 + 587b789 commit 648f993
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 51 deletions.
20 changes: 1 addition & 19 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
with:
python-version: "3.11"
- name: Install pyright
run: python -m pip install pyright
run: python3 -m pip install pyright
- name: Run pyright
run: pyright
coverage:
Expand All @@ -44,39 +44,24 @@ jobs:
python-version: [ "3.9", "3.10", "3.11", "3.12" ]
runs-on: "ubuntu-latest"
steps:
#----------------------------------------------
# check-out repo and set-up python
#----------------------------------------------
- name: Check out repository
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
id: setup-python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
#----------------------------------------------
# ----- install & configure poetry -----
#----------------------------------------------
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
#----------------------------------------------
# --------- install dependencies ---------
#----------------------------------------------
- name: Install dependencies
run: poetry install --no-interaction --no-root
#----------------------------------------------
# add matrix specifics and run test suite
#----------------------------------------------
- name: Run tests
run: |
poetry run coverage run -m pytest tests/
poetry run coverage xml
#----------------------------------------------
# create coverage summary
#----------------------------------------------
- name: Code Coverage Summary Report
uses: irongut/CodeCoverageSummary@v1.3.0
with:
Expand All @@ -86,9 +71,6 @@ jobs:
badge: true
fail_below_min: true
thresholds: '90 95'
#----------------------------------------------
# write job summary
#----------------------------------------------
- name: Write job summary
run: |
cat code-coverage-results.md >> $GITHUB_STEP_SUMMARY
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ reportMissingImports = true
reportMissingModuleSource = false
pythonVersion = "3.11"
pythonPlatform = "Windows"
executionEnvironments = [
{ root = "tool2schema" }
]

[tool.poetry]
name = "tool2schema"
Expand Down
40 changes: 30 additions & 10 deletions tool2schema/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,55 +28,75 @@ def schema_type(self) -> SchemaType:
"""
Type of the schema to create.
"""
return self._get_setting(Config.schema_type.fget.__name__, SchemaType.OPENAI_API)
default_value = SchemaType.OPENAI_API
if (fget := Config.schema_type.fget) is not None:
return self._get_setting(fget.__name__, default_value)
return default_value

@schema_type.setter
def schema_type(self, value: SchemaType):
self._set_setting(Config.schema_type.fget.__name__, value)
if (fget := Config.schema_type.fget) is not None:
self._set_setting(fget.__name__, value)

@property
def ignore_parameters(self) -> list[str]:
"""
List of parameter names to ignore when creating a schema.
"""
return self._get_setting(Config.ignore_parameters.fget.__name__, ["self", "args", "kwargs"])
default_value = ["self", "args", "kwargs"]
if (fget := Config.ignore_parameters.fget) is not None:
return self._get_setting(fget.__name__, default_value)
return default_value

@ignore_parameters.setter
def ignore_parameters(self, value: list[str]):
self._set_setting(Config.ignore_parameters.fget.__name__, value)
if (fget := Config.ignore_parameters.fget) is not None:
self._set_setting(fget.__name__, value)

@property
def ignore_function_description(self) -> bool:
"""
When true, omit the function description from the schema.
"""
return self._get_setting(Config.ignore_function_description.fget.__name__, False)
default_value = False
if (fget := Config.ignore_function_description.fget) is not None:
return self._get_setting(fget.__name__, default_value)
return default_value

@ignore_function_description.setter
def ignore_function_description(self, value: bool):
self._set_setting(Config.ignore_function_description.fget.__name__, value)
if (fget := Config.ignore_function_description.fget) is not None:
self._set_setting(fget.__name__, value)

@property
def ignore_parameter_descriptions(self) -> bool:
"""
When true, omit the parameter descriptions from the schema.
"""
return self._get_setting(Config.ignore_parameter_descriptions.fget.__name__, False)
default_value = False
if (fget := Config.ignore_parameter_descriptions.fget) is not None:
return self._get_setting(fget.__name__, default_value)
return default_value

@ignore_parameter_descriptions.setter
def ignore_parameter_descriptions(self, value: bool):
self._set_setting(Config.ignore_parameter_descriptions.fget.__name__, value)
if (fget := Config.ignore_parameter_descriptions.fget) is not None:
self._set_setting(fget.__name__, value)

@property
def ignore_all_parameters(self) -> bool:
"""
When true, omit all parameters from the schema.
"""
return self._get_setting(Config.ignore_all_parameters.fget.__name__, False)
default_value = False
if (fget := Config.ignore_all_parameters.fget) is not None:
return self._get_setting(fget.__name__, default_value)
return default_value

@ignore_all_parameters.setter
def ignore_all_parameters(self, value: bool):
self._set_setting(Config.ignore_all_parameters.fget.__name__, value)
if (fget := Config.ignore_all_parameters.fget) is not None:
self._set_setting(fget.__name__, value)

def reset_default(self):
"""
Expand Down
5 changes: 4 additions & 1 deletion tool2schema/parameter_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def create(
if type_schema := TypeSchema.create(parameter.annotation):
return ParameterSchema(type_schema, parameter, index, config, docstring)

def _get_description(self) -> Union[str, Parameter.empty]:
def _test(self) -> type[Parameter.empty]:
return Parameter.empty

def _get_description(self) -> Union[str, type[Parameter.empty]]:
"""
Get the description of this parameter, extracted from the function docstring,
to be added to the JSON schema. Return `Parameter.empty` to omit the description
Expand Down
16 changes: 9 additions & 7 deletions tool2schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from inspect import Parameter
from types import ModuleType
from typing import Any, Callable, Generic, Optional, TypeVar, overload
from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union, overload

import tool2schema
from tool2schema.config import Config, SchemaType
Expand Down Expand Up @@ -48,6 +48,7 @@ def FindToolEnabledByName(module: ModuleType, name: str) -> Optional[ToolEnabled
:param module: Module to search for ToolEnabled functions
:param name: Name of the function to find
"""
func: ToolEnabled
for func in FindToolEnabled(module):
if func.__name__ == name:
return func
Expand Down Expand Up @@ -224,24 +225,25 @@ def __init__(self, func: Callable[P, T], **kwargs) -> None:
self.tags = kwargs.pop("tags", [])
self.config = Config(tool2schema.CONFIG, **kwargs)
self.schema = FunctionSchema(func, self.config)
self.__name__ = func.__name__
functools.update_wrapper(self, func)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:

args = list(args) # Tuple is immutable, thus convert to list
args_list = list(args) # Tuple is immutable, thus convert to list

for i, arg in enumerate(args):
for i, arg in enumerate(args_list):
for p in self.schema.parameter_schemas.values():
if p.index == i:
# Convert the JSON value to the type expected by the method
args[i] = p.type_schema.decode(arg)
args_list[i] = p.type_schema.decode(arg)

for key in kwargs:
if key in self.schema.parameter_schemas:
# Convert the JSON value to the type expected by the method
kwargs[key] = self.schema.parameter_schemas[key].type_schema.decode(kwargs[key])

return self.func(*args, **kwargs)
return self.func(*args_list, **kwargs) # type: ignore

def tool_enabled(self) -> bool:
return True
Expand All @@ -264,10 +266,10 @@ def EnableTool(func: Callable[P, T], **kwargs) -> ToolEnabled[P, T]: ...


@overload
def EnableTool(**kwargs) -> Callable[[Callable[P, T]], ToolEnabled[P, T]]: ...
def EnableTool(func: Literal[None] = None, **kwargs) -> Callable[[Callable[P, T]], ToolEnabled[P, T]]: ...


def EnableTool(func=None, **kwargs):
def EnableTool(func: Optional[Callable[P, T]] = None, **kwargs) -> Union[ToolEnabled[P, T], Callable[[Callable[P, T]], ToolEnabled[P, T]]]:
"""Decorator to generate a function schema for OpenAI."""
if func is not None:
return ToolEnabled(func, **kwargs)
Expand Down
26 changes: 15 additions & 11 deletions tool2schema/type_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ def _get_type(self) -> dict:
"""
return {}

def _get_items(self) -> Union[dict, Parameter.empty]:
def _get_items(self) -> Union[dict, type[Parameter.empty]]:
"""
Get the items property to be added to the JSON schema.
Return `Parameter.empty` to omit the items from the schema.
"""
return Parameter.empty

def _get_enum(self) -> Union[list, Parameter.empty]:
def _get_enum(self) -> Union[list, type[Parameter.empty]]:
"""
Get the enum property to be added to the JSON schema.
Return `Parameter.empty` to omit the enum from the schema.
Expand Down Expand Up @@ -145,7 +145,9 @@ def validate(self, value) -> bool:
return self.type == type(value)

def _get_type(self) -> dict:
return {"type": self.TYPE_MAP.get(self.type.__name__, "object")}
if self.type is not None:
return {"type": self.TYPE_MAP.get(self.type.__name__, "object")}
return {"type": "null"}


class GenericTypeSchema(TypeSchema):
Expand All @@ -157,7 +159,7 @@ def _get_sub_types(self) -> list[TypeSchema]:
"""
:return: A list of type schemas corresponding to the generic type arguments.
"""
return [TypeSchema.create(arg) for arg in typing.get_args(self.type)]
return [t for arg in typing.get_args(self.type) if (t := TypeSchema.create(arg)) is not None]

def _get_sub_type(self) -> Optional[TypeSchema]:
"""
Expand All @@ -182,7 +184,7 @@ def matches(p_type: Type) -> bool:
def _get_type(self) -> dict:
return {"type": "array"}

def _get_items(self) -> Union[dict, Parameter.empty]:
def _get_items(self) -> Union[dict, type[Parameter.empty]]:
if sub_type := self._get_sub_type():
return sub_type.to_json()

Expand Down Expand Up @@ -260,9 +262,11 @@ def __init__(self, enum_values, type: Optional[Type] = None):
self.enum_values = enum_values

def _get_type(self) -> dict:
return TypeSchema.create(type(self.enum_values[0]))._get_type()
if (t := TypeSchema.create(type(self.enum_values[0]))) is not None:
return t._get_type()
return {"type": "object"}

def _get_enum(self) -> Union[list, Parameter.empty]:
def _get_enum(self) -> Union[list, type[Parameter.empty]]:
return self.enum_values

def validate(self, value) -> bool:
Expand All @@ -275,12 +279,12 @@ class EnumClassTypeSchema(EnumTypeSchema):
Type schema for enum.Enum types.
"""

def __init__(self, p_type: Enum):
def __init__(self, p_type: Type[Enum]):
super().__init__([e.name for e in p_type], p_type)

@staticmethod
def matches(type_p: Type) -> bool:
return type_p != Parameter.empty and isclass(type_p) and issubclass(type_p, Enum)
def matches(p_type: Type) -> bool:
return p_type != Parameter.empty and isclass(p_type) and issubclass(p_type, Enum)

def encode(self, value):
"""
Expand All @@ -296,7 +300,7 @@ def decode(self, value):
:param value: The enum name to be converted
"""
if value in self.enum_values:
if value in self.enum_values and self.type is not None:
# Convert to an enum instance
return self.type[value]

Expand Down

0 comments on commit 648f993

Please sign in to comment.