Skip to content

Commit

Permalink
Add schema type parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
siliconlad committed Apr 30, 2024
1 parent 79ab938 commit 05bd7ac
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ There are a number of setting available for you to tweak.
| `ignore_all_parameters` | A boolean value indicating whether to exclude all parameters from the schema. When set to true, `ignore_parameters` and `ignore_parameter_descriptions` will be ignored. | `False` |
| `ignore_function_description` | A boolean value indicating whether to exclude the function description from the schema. | `False` |
| `ignore_parameter_descriptions` | A boolean value indicating whether to exclude all parameter descriptions from the schema | `False` |
| `schema_type` | Default schema type to use. | `SchemaType.OPENAI_API` |

### Decorator Configuration

Expand Down
9 changes: 9 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,15 @@ def test_global_configuration_ignore_all_parameters_tune(global_config):
assert function.tags == []


def test_global_configuration_schema_type(global_config):
# Change the global configuration
tool2schema.CONFIG.schema_type = SchemaType.OPENAI_TUNE

rf = ReferenceSchema(function)
assert function.to_json() == rf.tune_schema
assert function.tags == []


########################################
# Test global configuration override #
########################################
Expand Down
3 changes: 1 addition & 2 deletions tool2schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# flake8: noqa
__version__ = "v1.3.1"

from .config import Config
from .config import Config, SchemaType
from .schema import (
EnableTool,
FindToolEnabled,
Expand All @@ -12,7 +12,6 @@
FindToolEnabledSchemas,
LoadToolEnabled,
SaveToolEnabled,
SchemaType,
)

# Default global configuration
Expand Down
20 changes: 20 additions & 0 deletions tool2schema/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from __future__ import annotations

import copy
from enum import Enum
from typing import Optional


class SchemaType(Enum):
"""Enum for schema types."""

OPENAI_API = 0
OPENAI_TUNE = 1
ANTHROPIC_CLAUDE = 2


class Config:
"""
Configuration class for tool2schema.
Expand All @@ -14,6 +23,17 @@ def __init__(self, parent: Optional[Config] = None, **settings):
self._settings = settings
self._initial_settings = copy.deepcopy(settings)

@property
def schema_type(self) -> SchemaType:
"""
Type of the schema to create.
"""
return self._get_setting(Config.schema_type.fget.__name__, SchemaType.OPENAI_API)

@schema_type.setter
def schema_type(self, value: SchemaType):
self._set_setting(Config.schema_type.fget.__name__, value)

@property
def ignore_parameters(self) -> list[str]:
"""
Expand Down
42 changes: 22 additions & 20 deletions tool2schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import json
import re
import sys
from enum import Enum
from inspect import Parameter
from types import ModuleType
from typing import Any, Callable, Generic, Optional, TypeVar, overload

import tool2schema
from tool2schema.config import Config
from tool2schema.config import Config, SchemaType
from tool2schema.parameter_schema import ParameterSchema

if sys.version_info < (3, 10):
Expand All @@ -21,14 +20,6 @@
from typing import ParamSpec


class SchemaType(Enum):
"""Enum for schema types."""

OPENAI_API = 0
OPENAI_TUNE = 1
ANTHROPIC_CLAUDE = 2


def FindToolEnabled(module: ModuleType) -> list[ToolEnabled]:
"""
Find all functions with the EnableTool decorator.
Expand All @@ -39,13 +30,13 @@ def FindToolEnabled(module: ModuleType) -> list[ToolEnabled]:


def FindToolEnabledSchemas(
module: ModuleType, schema_type: SchemaType = SchemaType.OPENAI_API
module: ModuleType, schema_type: Optional[SchemaType] = None
) -> list[dict]:
"""
Find all function schemas with the EnableTool decorator.
:param module: Module to search for ToolEnabled functions
:param schema_type: Type of schema to return
:param schema_type: Type of schema to return (None indicates default)
"""
return [x.to_json(schema_type) for x in FindToolEnabled(module)]

Expand All @@ -63,13 +54,15 @@ def FindToolEnabledByName(module: ModuleType, name: str) -> Optional[ToolEnabled
return None


def FindToolEnabledByNameSchema(module: ModuleType, name: str, schema_type: SchemaType = SchemaType.OPENAI_API) -> Optional[dict]:
def FindToolEnabledByNameSchema(
module: ModuleType, name: str, schema_type: Optional[SchemaType] = None
) -> Optional[dict]:
"""
Find a function schema with the EnableTool decorator by name.
:param module: Module to search for ToolEnabled functions
:param name: Name of the function to find
:param schema_type: Type of schema to return
:param schema_type: Type of schema to return (None indicates default)
"""
if (func := FindToolEnabledByName(module, name)) is None:
return None
Expand All @@ -86,24 +79,26 @@ def FindToolEnabledByTag(module: ModuleType, tag: str) -> list[ToolEnabled]:
return [x for x in FindToolEnabled(module) if x.has(tag)]


def FindToolEnabledByTagSchemas(module: ModuleType, tag: str, schema_type: SchemaType = SchemaType.OPENAI_API) -> list[dict]:
def FindToolEnabledByTagSchemas(
module: ModuleType, tag: str, schema_type: Optional[SchemaType] = None
) -> list[dict]:
"""
Find all function schemas with the EnableTool decorator by tag.
:param module: Module to search for ToolEnabled functions
:param tag: Tag to search for
:param schema_type: Type of schema to return
:param schema_type: Type of schema to return (None indicates default)
"""
return [x.to_json(schema_type) for x in FindToolEnabledByTag(module, tag)]


def SaveToolEnabled(module: ModuleType, path: str, schema_type: SchemaType = SchemaType.OPENAI_API):
def SaveToolEnabled(module: ModuleType, path: str, schema_type: Optional[SchemaType] = None):
"""
Save all function schemas with the EnableTool decorator to a file.
:param module: Module to search for ToolEnabled functions
:param path: Path to save the schemas to
:param schema_type: Type of schema to return
:param schema_type: Type of schema to return (None indicates default)
"""
schemas = FindToolEnabledSchemas(module, schema_type)
json.dump(schemas, open(path, "w"))
Expand Down Expand Up @@ -251,7 +246,13 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
def tool_enabled(self) -> bool:
return True

def to_json(self, schema_type: SchemaType = SchemaType.OPENAI_API) -> dict:
def to_json(self, schema_type: Optional[SchemaType] = None) -> dict:
"""
Return JSON schema for the function.
:param schema_type: None indicates default schema type
:return: JSON schema
"""
return self.schema.to_json(schema_type)

def has(self, tag: str) -> bool:
Expand Down Expand Up @@ -292,11 +293,12 @@ def __init__(self, f: Callable, config: Config):
self.config = config
self._all_parameter_schemas: dict[str, ParameterSchema] = self._get_all_parameter_schemas()

def to_json(self, schema_type: SchemaType = SchemaType.OPENAI_API) -> dict:
def to_json(self, schema_type: Optional[SchemaType] = None) -> dict:
"""
Convert schema to JSON.
:param schema_type: Type of schema to return
"""
schema_type = schema_type or self.config.schema_type
if schema_type == SchemaType.OPENAI_TUNE:
return self._get_function_schema(schema_type)
elif schema_type == SchemaType.ANTHROPIC_CLAUDE:
Expand Down

0 comments on commit 05bd7ac

Please sign in to comment.