Skip to content

Commit

Permalink
Merge pull request #79 from cadifyai/67-parameter-values-of-type-int-…
Browse files Browse the repository at this point in the history
…should-also-work-with-type-float

Allow implicit conversion from int to float
  • Loading branch information
siliconlad authored Apr 30, 2024
2 parents a3898ce + 86d4627 commit c68d06d
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 3 deletions.
10 changes: 10 additions & 0 deletions tests/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ def function_no_params():
return None


@EnableTool
def function_float(a: float):
"""
This is a test function.
:param a: This is a parameter
"""
return a


@EnableTool
def function(a: int, b: str, c: bool = False, d: list[int] = [1, 2, 3]):
"""
Expand Down
9 changes: 9 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,12 @@ def test_load_invalid_argument_values(function, arguments):
f, args = LoadToolEnabled(functions, f_obj, validate=False)
assert f == function
assert args == arguments

################################################
# Test implicit conversion from int to float #
################################################

def test_function_float():
f, args = LoadToolEnabled(functions, get_function_dict(functions.function_float, {"a": 1}))
assert f == functions.function_float
assert args == {"a": 1}
9 changes: 6 additions & 3 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
def test_FindToolEnabled():
tools = FindToolEnabled(functions)
# Check that the function is found
assert len(tools) == 7
assert len(tools) == 8
assert functions.function in tools
assert functions.function_float in tools
assert functions.function_tags in tools
assert functions.function_no_params in tools
assert functions.function_no_params in tools
Expand All @@ -43,8 +44,9 @@ def test_FindToolEnabled():
def test_FindToolEnabledSchemas():
tool_schemas = FindToolEnabledSchemas(functions)
# Check that the function is found
assert len(tool_schemas) == 7
assert len(tool_schemas) == 8
assert functions.function.to_json() in tool_schemas
assert functions.function_float.to_json() in tool_schemas
assert functions.function_tags.to_json() in tool_schemas
assert functions.function_no_params.to_json() in tool_schemas
assert functions.function_literal.to_json() in tool_schemas
Expand All @@ -57,8 +59,9 @@ def test_FindToolEnabledSchemas():
def test_FindToolEnabledSchemas_with_type(schema_type):
# Check that the function is found
tool_schemas = FindToolEnabledSchemas(functions, schema_type=schema_type)
assert len(tool_schemas) == 7
assert len(tool_schemas) == 8
assert functions.function.to_json(schema_type) in tool_schemas
assert functions.function_float.to_json(schema_type) in tool_schemas
assert functions.function_tags.to_json(schema_type) in tool_schemas
assert functions.function_no_params.to_json(schema_type) in tool_schemas
assert functions.function_literal.to_json(schema_type) in tool_schemas
Expand Down
3 changes: 3 additions & 0 deletions tool2schema/type_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def matches(p_type) -> bool:
return True

def validate(self, value) -> bool:
# Allow implicit conversion from int to float
if self.type is float and type(value) is int:
return True
return self.type == type(value)

def _get_type(self) -> dict:
Expand Down

0 comments on commit c68d06d

Please sign in to comment.