Skip to content

Commit

Permalink
Draft handling of complex inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jurasofish committed Dec 1, 2024
1 parent 288e80c commit 8e02f80
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 14 deletions.
77 changes: 77 additions & 0 deletions examples/complex_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
FastMCP Complex inputs Example
This can be used to verify that the JSON schema and data
parsing work with complex inputs like lists, nulls, and sub-models.
Suggested usage:
- Set up claude desktop with it
- Ask it to print out the schema for you
- Ask it to try out the endpoint without supplying any of the optional args
- Confirm it's okay
- Ask it to try out the endpoint with the optional args explicitly set
- Confirm it's okay
"""

from pydantic import BaseModel, Field
from typing import Annotated
from fastmcp.server import FastMCP, Context

mcp = FastMCP("Demo")


class TestInputModelA(BaseModel):
pass


class TestInputModelB(BaseModel):
class InnerModel(BaseModel):
x: int

how_many_shrimp: Annotated[int, Field(description="How many shrimp in the tank???")]
ok: InnerModel


@mcp.tool()
def complex_inputs(
ctx: Context,
an_int: int,
must_be_none: None,
list_of_ints: list[int],
# list[str] | str is an interesting case because if it comes in as JSON like
# "[\"a\", \"b\"]" then it will be naively parsed as a string.
list_str_or_str: list[str] | str,
an_int_annotated_with_field: Annotated[
int, Field(description="An int with a field")
],
# TODO: handle this case too
# field_with_default_via_field_annotation_before_nondefault_arg: Annotated[
# int, Field(1)
# ],
my_model_a: TestInputModelA,
my_model_b: TestInputModelB,
an_int_annotated_with_field_default: Annotated[
int,
Field(1, description="An int with a field"),
],
my_model_a_with_default: TestInputModelA = TestInputModelA(), # noqa: B008
an_int_with_default: int = 1,
an_int_with_equals_field: int = Field(1, ge=0),
int_annotated_with_default: Annotated[int, Field(description="hey")] = 5,
) -> str:
_ = (
ctx,
an_int,
must_be_none,
list_of_ints,
list_str_or_str,
an_int_annotated_with_field,
an_int_annotated_with_field_default,
my_model_a,
my_model_b,
my_model_a_with_default,
an_int_with_default,
an_int_with_equals_field,
int_annotated_with_default,
)
return "ok!"
27 changes: 17 additions & 10 deletions src/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import fastmcp
from fastmcp.exceptions import ToolError


from pydantic import BaseModel, Field, TypeAdapter, validate_call
from fastmcp.utilities.func_to_pyd import func_to_pyd, ArgModelBase
from pydantic import BaseModel, Field


import inspect
Expand All @@ -19,6 +19,9 @@ class Tool(BaseModel):
name: str = Field(description="Name of the tool")
description: str = Field(description="Description of what the tool does")
parameters: dict = Field(description="JSON schema for tool parameters")
arg_model: type[ArgModelBase] = Field(
description="Pydantic model for tool arguments"
)
is_async: bool = Field(description="Whether the tool is async")
context_kwarg: Optional[str] = Field(
None, description="Name of the kwarg that should receive context"
Expand All @@ -41,9 +44,6 @@ def from_function(
func_doc = description or fn.__doc__ or ""
is_async = inspect.iscoroutinefunction(fn)

# Get schema from TypeAdapter - will fail if function isn't properly typed
parameters = TypeAdapter(fn).json_schema()

# Find context parameter if it exists
if context_kwarg is None:
sig = inspect.signature(fn)
Expand All @@ -52,28 +52,35 @@ def from_function(
context_kwarg = param_name
break

# ensure the arguments are properly cast
fn = validate_call(fn)
arg_model = func_to_pyd(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
)
parameters = arg_model.model_json_schema()

return cls(
fn=fn,
name=func_name,
description=func_doc,
parameters=parameters,
arg_model=arg_model,
is_async=is_async,
context_kwarg=context_kwarg,
)

async def run(self, arguments: dict, context: Optional["Context"] = None) -> Any:
"""Run the tool with arguments."""
try:
arguments_pre_parsed = self.arg_model.pre_parse_json(arguments)
arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed)
arguments_parsed_dict = arguments_parsed_model.model_dump_one_level()
# Inject context if needed
if self.context_kwarg:
arguments[self.context_kwarg] = context
arguments_parsed_dict[self.context_kwarg] = context

# Call function with proper async handling
if self.is_async:
return await self.fn(**arguments)
return self.fn(**arguments)
return await self.fn(**arguments_parsed_dict)
return self.fn(**arguments_parsed_dict)
except Exception as e:
raise ToolError(f"Error executing tool {self.name}: {e}") from e
180 changes: 180 additions & 0 deletions src/fastmcp/utilities/func_to_pyd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import inspect
from collections.abc import Callable, Sequence
from copy import deepcopy
from typing import (
Annotated,
Any,
get_args,
get_origin,
)

from pydantic import BaseModel, ConfigDict, TypeAdapter, ValidationError, create_model
from pydantic.fields import FieldInfo
from pydantic import WithJsonSchema
from fastmcp.utilities.logging import get_logger

logger = get_logger(__name__)


class IgnoredType(str):
# See https://docs.pydantic.dev/2.10/errors/usage_errors/#model-field-missing-annotation
pass


class ArgModelBase(BaseModel):
@classmethod
def pre_parse_json(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Pre-parse data from JSON.
Go through and first try parsing as JSON, then as Python.
This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside
a string rather than an actual list. Claude desktop is prone to this - in fact
it seems incapable of NOT doing this.
This is still a WIP - for example, a string like '"a"' will be parsed as "a"
(a single char) rather than an "a" with two quotes (three chars).
"""
new_data: dict[str, Any] = {}
for field_name, field_info in cls.model_fields.items():
if field_name not in data.keys():
continue
ta: TypeAdapter[Any] = TypeAdapter(field_info.annotation)
# Try JSON first as it's generally more specific
try:
parsed_item = ta.validate_json(data[field_name])
new_data[field_name] = parsed_item
logger.debug(
f"Parsed {field_name} as {field_info.annotation} from JSON"
)
continue
except ValidationError:
pass
# Try Python next
try:
parsed_item = ta.validate_python(data[field_name])
new_data[field_name] = parsed_item
continue
except ValidationError:
pass
assert new_data.keys() == data.keys()
return new_data

def model_dump_one_level(self) -> dict[str, Any]:
"""Return a dict of the model's fields, one level deep.
That is, sub-models etc are not dumped - they are kept as pydantic models.
"""
kwargs: dict[str, Any] = {}
for field_name in self.model_fields.keys():
kwargs[field_name] = getattr(self, field_name)
return kwargs

model_config = ConfigDict(
arbitrary_types_allowed=True, ignored_types=(IgnoredType,)
)


def func_to_pyd(func: Callable, skip_names: Sequence[str] = ()) -> BaseModel:
"""Given a function, return a pydantic model representing its signature.
The use case fot this is
```
arg_model = func_to_pyd(func)
validated_args = arg_model.model_validate(some_raw_data_dict)
return func(**validated_args.model_dump_one_level())
```
**critically** it also provides pre-parse helper to attempt to parse things from JSON.
There is a chance this may not do what you actually want.
TODO: discuss this a lot more.
"""
sig = inspect.signature(func)
params = sig.parameters
dynamic_pydantic_model_params: dict[str, Any] = {}
for param in params.values():
if param.name.startswith("_"):
raise ValueError(
f"Parameter {param.name} must not start with an underscore"
)

if param.name in skip_names:
continue
annotation = param.annotation

# TODO: test annotations like `x: Any`
if annotation is inspect.Parameter.empty:
# For untyped parameters like `x` or `lambda x: ...`
default = (
param.default if param.default is not inspect.Parameter.empty else ...
)
dynamic_pydantic_model_params[param.name] = (
Annotated[IgnoredType, WithJsonSchema(None)],
default,
)
elif get_origin(annotation) is Annotated:
# Cases like
# - `x: Annotated[str, Field(description="pure red line")]`
# - `x: Annotated[str, Field("hey", description="bloody mary")]`
# - `x: Annotated[str, Field(description="blue dream")] = "hey"`

# Annotated[int, 'a', 'b', 'c'].__metadata__ == ('a', 'b', 'c')
annotated_args = annotation.__metadata__
if len(annotated_args) != 1:
raise ValueError(
f"Only one annotation is supported for Annotated. "
f"Got {annotated_args} for param {param.name}",
)
assert len(annotated_args) == 1
field_info = annotated_args[0]
if not isinstance(field_info, FieldInfo):
raise ValueError(
f"The only annotation supported is pydantic's FieldInfo "
f"(via pydantic.Field). Got {type(field_info)} for param {param.name}",
)
assert isinstance(field_info, FieldInfo)
if param.default is inspect.Parameter.empty:
# Like `x: Annotated[str, Field...]`
# If there's no default we can just throw the whole
# annotated thing at the dynamic model creator
dynamic_pydantic_model_params[param.name] = annotation
else:
# We've got `x: Annotated[str, Field...] = x`.
# We need to make sure that we respect the default in the
# function signature (`= x`). To do this, we need to confirm
# that only one of field_info.default, field_info.default_factory,
# or a default defined in the function signature is set.
# If a default is defined in the function signature, we have no
# way to pass it into the dynamic pydantic model creation, so we
# modify field_info.default to match it 👍
if not field_info.is_required():
raise ValueError(
f"{param.name} has a default in the function signature "
f"but is not required in the pydantic FieldInfo. Have you "
f"set a default/default_factory at the same time as in "
f"the function signature?",
)
# Copy to avoid mutating the original, which is still in
# use in the function signature
field_info = deepcopy(field_info)
field_info.default = param.default
dynamic_pydantic_model_params[param.name] = Annotated[
get_args(annotation)[0],
field_info,
]
else:
# Cases like
# - `x: str`
# - `x: int = 1`
# - `x: int = pydantic.Field...`
# - `x: MyPydanticModel`
default = (
param.default if param.default is not inspect.Parameter.empty else ...
)
dynamic_pydantic_model_params[param.name] = (annotation, default)

arguments_model = create_model(
f"{func.__name__}Arguments",
**dynamic_pydantic_model_params,
__base__=ArgModelBase,
)
return arguments_model
Loading

0 comments on commit 8e02f80

Please sign in to comment.