From 8e02f80e8bf1b384db3e2220c76fbbbabfd44eab Mon Sep 17 00:00:00 2001 From: Michael Jurasovic Date: Sun, 1 Dec 2024 22:37:27 +1100 Subject: [PATCH] Draft handling of complex inputs --- examples/complex_inputs.py | 77 ++++++++++++ src/fastmcp/tools/base.py | 27 ++-- src/fastmcp/utilities/func_to_pyd.py | 180 +++++++++++++++++++++++++++ tests/test_tool_manager.py | 67 +++++++++- 4 files changed, 337 insertions(+), 14 deletions(-) create mode 100644 examples/complex_inputs.py create mode 100644 src/fastmcp/utilities/func_to_pyd.py diff --git a/examples/complex_inputs.py b/examples/complex_inputs.py new file mode 100644 index 0000000..e6c1d4b --- /dev/null +++ b/examples/complex_inputs.py @@ -0,0 +1,77 @@ +""" +FastMCP Complex inputs Example + +This can be used to verify that the JSON schema and data +parsing work with complex inputs like lists, nulls, and sub-models. + +Suggested usage: +- Set up claude desktop with it +- Ask it to print out the schema for you +- Ask it to try out the endpoint without supplying any of the optional args +- Confirm it's okay +- Ask it to try out the endpoint with the optional args explicitly set +- Confirm it's okay +""" + +from pydantic import BaseModel, Field +from typing import Annotated +from fastmcp.server import FastMCP, Context + +mcp = FastMCP("Demo") + + +class TestInputModelA(BaseModel): + pass + + +class TestInputModelB(BaseModel): + class InnerModel(BaseModel): + x: int + + how_many_shrimp: Annotated[int, Field(description="How many shrimp in the tank???")] + ok: InnerModel + + +@mcp.tool() +def complex_inputs( + ctx: Context, + an_int: int, + must_be_none: None, + list_of_ints: list[int], + # list[str] | str is an interesting case because if it comes in as JSON like + # "[\"a\", \"b\"]" then it will be naively parsed as a string. + list_str_or_str: list[str] | str, + an_int_annotated_with_field: Annotated[ + int, Field(description="An int with a field") + ], + # TODO: handle this case too + # field_with_default_via_field_annotation_before_nondefault_arg: Annotated[ + # int, Field(1) + # ], + my_model_a: TestInputModelA, + my_model_b: TestInputModelB, + an_int_annotated_with_field_default: Annotated[ + int, + Field(1, description="An int with a field"), + ], + my_model_a_with_default: TestInputModelA = TestInputModelA(), # noqa: B008 + an_int_with_default: int = 1, + an_int_with_equals_field: int = Field(1, ge=0), + int_annotated_with_default: Annotated[int, Field(description="hey")] = 5, +) -> str: + _ = ( + ctx, + an_int, + must_be_none, + list_of_ints, + list_str_or_str, + an_int_annotated_with_field, + an_int_annotated_with_field_default, + my_model_a, + my_model_b, + my_model_a_with_default, + an_int_with_default, + an_int_with_equals_field, + int_annotated_with_default, + ) + return "ok!" diff --git a/src/fastmcp/tools/base.py b/src/fastmcp/tools/base.py index 21c188c..c4d18ab 100644 --- a/src/fastmcp/tools/base.py +++ b/src/fastmcp/tools/base.py @@ -1,8 +1,8 @@ import fastmcp from fastmcp.exceptions import ToolError - -from pydantic import BaseModel, Field, TypeAdapter, validate_call +from fastmcp.utilities.func_to_pyd import func_to_pyd, ArgModelBase +from pydantic import BaseModel, Field import inspect @@ -19,6 +19,9 @@ class Tool(BaseModel): name: str = Field(description="Name of the tool") description: str = Field(description="Description of what the tool does") parameters: dict = Field(description="JSON schema for tool parameters") + arg_model: type[ArgModelBase] = Field( + description="Pydantic model for tool arguments" + ) is_async: bool = Field(description="Whether the tool is async") context_kwarg: Optional[str] = Field( None, description="Name of the kwarg that should receive context" @@ -41,9 +44,6 @@ def from_function( func_doc = description or fn.__doc__ or "" is_async = inspect.iscoroutinefunction(fn) - # Get schema from TypeAdapter - will fail if function isn't properly typed - parameters = TypeAdapter(fn).json_schema() - # Find context parameter if it exists if context_kwarg is None: sig = inspect.signature(fn) @@ -52,14 +52,18 @@ def from_function( context_kwarg = param_name break - # ensure the arguments are properly cast - fn = validate_call(fn) + arg_model = func_to_pyd( + fn, + skip_names=[context_kwarg] if context_kwarg is not None else [], + ) + parameters = arg_model.model_json_schema() return cls( fn=fn, name=func_name, description=func_doc, parameters=parameters, + arg_model=arg_model, is_async=is_async, context_kwarg=context_kwarg, ) @@ -67,13 +71,16 @@ def from_function( async def run(self, arguments: dict, context: Optional["Context"] = None) -> Any: """Run the tool with arguments.""" try: + arguments_pre_parsed = self.arg_model.pre_parse_json(arguments) + arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed) + arguments_parsed_dict = arguments_parsed_model.model_dump_one_level() # Inject context if needed if self.context_kwarg: - arguments[self.context_kwarg] = context + arguments_parsed_dict[self.context_kwarg] = context # Call function with proper async handling if self.is_async: - return await self.fn(**arguments) - return self.fn(**arguments) + return await self.fn(**arguments_parsed_dict) + return self.fn(**arguments_parsed_dict) except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e diff --git a/src/fastmcp/utilities/func_to_pyd.py b/src/fastmcp/utilities/func_to_pyd.py new file mode 100644 index 0000000..a965e75 --- /dev/null +++ b/src/fastmcp/utilities/func_to_pyd.py @@ -0,0 +1,180 @@ +import inspect +from collections.abc import Callable, Sequence +from copy import deepcopy +from typing import ( + Annotated, + Any, + get_args, + get_origin, +) + +from pydantic import BaseModel, ConfigDict, TypeAdapter, ValidationError, create_model +from pydantic.fields import FieldInfo +from pydantic import WithJsonSchema +from fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class IgnoredType(str): + # See https://docs.pydantic.dev/2.10/errors/usage_errors/#model-field-missing-annotation + pass + + +class ArgModelBase(BaseModel): + @classmethod + def pre_parse_json(cls, data: dict[str, Any]) -> dict[str, Any]: + """Pre-parse data from JSON. + + Go through and first try parsing as JSON, then as Python. + This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside + a string rather than an actual list. Claude desktop is prone to this - in fact + it seems incapable of NOT doing this. + + This is still a WIP - for example, a string like '"a"' will be parsed as "a" + (a single char) rather than an "a" with two quotes (three chars). + """ + new_data: dict[str, Any] = {} + for field_name, field_info in cls.model_fields.items(): + if field_name not in data.keys(): + continue + ta: TypeAdapter[Any] = TypeAdapter(field_info.annotation) + # Try JSON first as it's generally more specific + try: + parsed_item = ta.validate_json(data[field_name]) + new_data[field_name] = parsed_item + logger.debug( + f"Parsed {field_name} as {field_info.annotation} from JSON" + ) + continue + except ValidationError: + pass + # Try Python next + try: + parsed_item = ta.validate_python(data[field_name]) + new_data[field_name] = parsed_item + continue + except ValidationError: + pass + assert new_data.keys() == data.keys() + return new_data + + def model_dump_one_level(self) -> dict[str, Any]: + """Return a dict of the model's fields, one level deep. + + That is, sub-models etc are not dumped - they are kept as pydantic models. + """ + kwargs: dict[str, Any] = {} + for field_name in self.model_fields.keys(): + kwargs[field_name] = getattr(self, field_name) + return kwargs + + model_config = ConfigDict( + arbitrary_types_allowed=True, ignored_types=(IgnoredType,) + ) + + +def func_to_pyd(func: Callable, skip_names: Sequence[str] = ()) -> BaseModel: + """Given a function, return a pydantic model representing its signature. + + The use case fot this is + ``` + arg_model = func_to_pyd(func) + validated_args = arg_model.model_validate(some_raw_data_dict) + return func(**validated_args.model_dump_one_level()) + ``` + + **critically** it also provides pre-parse helper to attempt to parse things from JSON. + There is a chance this may not do what you actually want. + TODO: discuss this a lot more. + """ + sig = inspect.signature(func) + params = sig.parameters + dynamic_pydantic_model_params: dict[str, Any] = {} + for param in params.values(): + if param.name.startswith("_"): + raise ValueError( + f"Parameter {param.name} must not start with an underscore" + ) + + if param.name in skip_names: + continue + annotation = param.annotation + + # TODO: test annotations like `x: Any` + if annotation is inspect.Parameter.empty: + # For untyped parameters like `x` or `lambda x: ...` + default = ( + param.default if param.default is not inspect.Parameter.empty else ... + ) + dynamic_pydantic_model_params[param.name] = ( + Annotated[IgnoredType, WithJsonSchema(None)], + default, + ) + elif get_origin(annotation) is Annotated: + # Cases like + # - `x: Annotated[str, Field(description="pure red line")]` + # - `x: Annotated[str, Field("hey", description="bloody mary")]` + # - `x: Annotated[str, Field(description="blue dream")] = "hey"` + + # Annotated[int, 'a', 'b', 'c'].__metadata__ == ('a', 'b', 'c') + annotated_args = annotation.__metadata__ + if len(annotated_args) != 1: + raise ValueError( + f"Only one annotation is supported for Annotated. " + f"Got {annotated_args} for param {param.name}", + ) + assert len(annotated_args) == 1 + field_info = annotated_args[0] + if not isinstance(field_info, FieldInfo): + raise ValueError( + f"The only annotation supported is pydantic's FieldInfo " + f"(via pydantic.Field). Got {type(field_info)} for param {param.name}", + ) + assert isinstance(field_info, FieldInfo) + if param.default is inspect.Parameter.empty: + # Like `x: Annotated[str, Field...]` + # If there's no default we can just throw the whole + # annotated thing at the dynamic model creator + dynamic_pydantic_model_params[param.name] = annotation + else: + # We've got `x: Annotated[str, Field...] = x`. + # We need to make sure that we respect the default in the + # function signature (`= x`). To do this, we need to confirm + # that only one of field_info.default, field_info.default_factory, + # or a default defined in the function signature is set. + # If a default is defined in the function signature, we have no + # way to pass it into the dynamic pydantic model creation, so we + # modify field_info.default to match it 👍 + if not field_info.is_required(): + raise ValueError( + f"{param.name} has a default in the function signature " + f"but is not required in the pydantic FieldInfo. Have you " + f"set a default/default_factory at the same time as in " + f"the function signature?", + ) + # Copy to avoid mutating the original, which is still in + # use in the function signature + field_info = deepcopy(field_info) + field_info.default = param.default + dynamic_pydantic_model_params[param.name] = Annotated[ + get_args(annotation)[0], + field_info, + ] + else: + # Cases like + # - `x: str` + # - `x: int = 1` + # - `x: int = pydantic.Field...` + # - `x: MyPydanticModel` + default = ( + param.default if param.default is not inspect.Parameter.empty else ... + ) + dynamic_pydantic_model_params[param.name] = (annotation, default) + + arguments_model = create_model( + f"{func.__name__}Arguments", + **dynamic_pydantic_model_params, + __base__=ArgModelBase, + ) + return arguments_model diff --git a/tests/test_tool_manager.py b/tests/test_tool_manager.py index 3192454..90ae362 100644 --- a/tests/test_tool_manager.py +++ b/tests/test_tool_manager.py @@ -3,7 +3,8 @@ import pytest from pydantic import BaseModel - +from fastmcp import FastMCP, Context +import json from fastmcp.exceptions import ToolError from fastmcp.tools import ToolManager @@ -156,6 +157,64 @@ async def test_call_unknown_tool(self): with pytest.raises(ToolError): await manager.call_tool("unknown", {"a": 1}) + async def test_call_tool_with_list_int_input(self): + def sum_vals(vals: list[int]) -> int: + return sum(vals) + + manager = ToolManager() + manager.add_tool(sum_vals) + result = await manager.call_tool("sum_vals", {"vals": "[1, 2, 3]"}) + assert result == 6 + + async def test_call_tool_with_list_str_or_str_input(self): + def concat_strs(vals: list[str] | str) -> str: + return vals if isinstance(vals, str) else "".join(vals) + + manager = ToolManager() + manager.add_tool(concat_strs) + result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}) + assert result == "abc" + result = await manager.call_tool("concat_strs", {"vals": '["a", "b", "c"]'}) + assert result == "abc" + result = await manager.call_tool("concat_strs", {"vals": "a"}) + assert result == "a" + + # THIS SHOULD PASS # TODO: fix + # result = await manager.call_tool("concat_strs", {"vals": '"a"'}) + # assert result == '"a"' + + async def test_call_tool_with_complex_model(self): + class MyShrimpTank(BaseModel): + class Shrimp(BaseModel): + name: str + + shrimp: list[Shrimp] + + def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]: + return [x.name for x in tank.shrimp] + + manager = ToolManager() + manager.add_tool(name_shrimp) + result = await manager.call_tool( + "name_shrimp", {"tank": {"shrimp": [{"name": "rex"}, {"name": "gertrude"}]}} + ) + assert result == ["rex", "gertrude"] + result = await manager.call_tool( + "name_shrimp", + {"tank": '{"shrimp": [{"name": "rex"}, {"name": "gertrude"}]}'}, + ) + assert result == ["rex", "gertrude"] + + async def test_context_arg_excluded(self): + def something(a: int, ctx: Context) -> int: + return a + + manager = ToolManager() + tool = manager.add_tool(something) + assert "ctx" not in json.dumps(tool.parameters) + assert "Context" not in json.dumps(tool.parameters) + assert "ctx" not in tool.arg_model.model_fields + class TestContextHandling: """Test context handling in the tool manager.""" @@ -179,7 +238,7 @@ def tool_without_context(x: int) -> str: async def test_context_injection(self): """Test that context is properly injected during tool execution.""" - from fastmcp import Context, FastMCP + from fastmcp import Context def tool_with_context(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) @@ -195,7 +254,7 @@ def tool_with_context(x: int, ctx: Context) -> str: async def test_context_injection_async(self): """Test that context is properly injected in async tools.""" - from fastmcp import Context, FastMCP + from fastmcp import Context async def async_tool(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) @@ -224,7 +283,7 @@ def tool_with_context(x: int, ctx: Optional[Context] = None) -> str: async def test_context_error_handling(self): """Test error handling when context injection fails.""" - from fastmcp import Context, FastMCP + from fastmcp import Context def tool_with_context(x: int, ctx: Context) -> str: raise ValueError("Test error")