Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 119 additions & 25 deletions fastapi_mcp/openapi/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Set, Optional


def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str:
Expand All @@ -16,54 +16,144 @@ def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str:
return param_schema.get("type", "string")


def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dict[str, Any]) -> Dict[str, Any]:
def resolve_schema_references(
schema_part: Dict[str, Any],
reference_schema: Dict[str, Any],
visited_refs: Optional[Set[str]] = None,
skip_components: bool = True,
) -> Dict[str, Any]:
"""
Resolve schema references in OpenAPI schemas.
Resolve schema references with cycle detection and performance optimization.

Args:
schema_part: The part of the schema being processed that may contain references
reference_schema: The complete schema used to resolve references from
visited_refs: Set of currently being resolved references (for cycle detection)
skip_components: Whether to skip processing the components section

Returns:
The schema with references resolved
"""
# Make a copy to avoid modifying the input schema
schema_part = schema_part.copy()

# Handle $ref directly in the schema
if "$ref" in schema_part:
ref_path = schema_part["$ref"]
# Standard OpenAPI references are in the format "#/components/schemas/ModelName"
if ref_path.startswith("#/components/schemas/"):
model_name = ref_path.split("/")[-1]
if "components" in reference_schema and "schemas" in reference_schema["components"]:
if model_name in reference_schema["components"]["schemas"]:
# Replace with the resolved schema
ref_schema = reference_schema["components"]["schemas"][model_name].copy()
# Remove the $ref key and merge with the original schema
schema_part.pop("$ref")
schema_part.update(ref_schema)
if visited_refs is None:
visited_refs = set()

if not isinstance(schema_part, dict):
return schema_part

part = schema_part.copy()

if "$ref" in part:
ref_path = part["$ref"]
if ref_path in visited_refs:
return {"$ref": ref_path}
visited_refs.add(ref_path)
try:
if ref_path.startswith("#/components/schemas/"):
model_name = ref_path.split("/")[-1]
comps = reference_schema.get("components", {}).get("schemas", {})
if model_name in comps:
ref_schema = comps[model_name]
resolved_ref = resolve_schema_references(
ref_schema, reference_schema, visited_refs, skip_components
)
part.pop("$ref", None)
if isinstance(resolved_ref, dict) and "$ref" not in resolved_ref:
part.update(resolved_ref)
finally:
# Cleanup
visited_refs.discard(ref_path)

return part

# Recursively resolve references in all dictionary values
for key, value in schema_part.items():
for key, value in list(part.items()):
if skip_components and key == "components":
continue

if isinstance(value, dict):
schema_part[key] = resolve_schema_references(value, reference_schema)
part[key] = resolve_schema_references(value, reference_schema, visited_refs, skip_components)
elif isinstance(value, list):
# Only process list items that are dictionaries since only they can contain refs
schema_part[key] = [
resolve_schema_references(item, reference_schema) if isinstance(item, dict) else item for item in value
part[key] = [
resolve_schema_references(item, reference_schema, visited_refs, skip_components)
if isinstance(item, dict)
else item
for item in value
]
return part


def resolve_schema_for_display(
schema: Dict[str, Any],
components: Dict[str, Any],
cache: Optional[Dict[str, Dict[str, Any]]] = None,
stack: Optional[Set[str]] = None,
) -> Dict[str, Any]:
"""
Resolve a specific schema for display with caching and cycle detection.

This function is optimized for just-in-time resolution of specific schemas
without processing the entire components tree.

Args:
schema: The schema to resolve
components: The components section containing schema definitions
cache: Cache for memoizing resolved schemas
stack: Stack for cycle detection

Returns:
The resolved schema
"""
if not isinstance(schema, dict):
return schema

return schema_part
cache = cache or {}
stack = stack or set()

# Handle direct $ref first
if "$ref" in schema:
ref = schema["$ref"]

# Break the cycle
if ref in stack:
return {"$ref": ref}

if ref in cache:
return cache[ref]

if not ref.startswith("#/components/schemas/"):
return schema

name = ref.split("/")[-1]
target = components.get("schemas", {}).get(name, {})
stack.add(ref)
resolved = resolve_schema_for_display(target, components, cache, stack)
stack.remove(ref)
cache[ref] = resolved if isinstance(resolved, dict) else target
return cache[ref]

# Recursively resolve the schema but don't descend into the components subtree
out = schema.copy()
for k, v in list(out.items()):
if k == "components":
continue
if isinstance(v, dict):
out[k] = resolve_schema_for_display(v, components, cache, stack)
elif isinstance(v, list):
out[k] = [resolve_schema_for_display(i, components, cache, stack) if isinstance(i, dict) else i for i in v]
return out


def clean_schema_for_display(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Clean up a schema for display by removing internal fields.



Args:
schema: The schema to clean



Returns:
The cleaned schema
"""
Expand Down Expand Up @@ -104,9 +194,13 @@ def generate_example_from_schema(schema: Dict[str, Any]) -> Any:
"""
Generate a simple example response from a JSON schema.



Args:
schema: The JSON schema to generate an example from



Returns:
An example object based on the schema
"""
Expand Down
Loading