From 54722a624a78874068df257f6523d10a4b2a5e49 Mon Sep 17 00:00:00 2001 From: rakdutta Date: Fri, 17 Oct 2025 10:51:13 +0530 Subject: [PATCH 1/6] rest pass changed tool service Signed-off-by: rakdutta --- mcpgateway/db.py | 12 + mcpgateway/schemas.py | 344 ++++++++++++++++++++++++++++ mcpgateway/services/tool_service.py | 10 + mcpgateway/static/admin.js | 90 +++++++- mcpgateway/templates/admin.html | 11 + 5 files changed, 466 insertions(+), 1 deletion(-) diff --git a/mcpgateway/db.py b/mcpgateway/db.py index 91dd02e1d..3e6fca1e3 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -1603,6 +1603,18 @@ class Tool(Base): custom_name_slug: Mapped[Optional[str]] = mapped_column(String(255), nullable=False) display_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + # Passthrough REST fields + base_url: Mapped[Optional[str]] = mapped_column(String, nullable=True) + path_template: Mapped[Optional[str]] = mapped_column(String, nullable=True) + query_mapping: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + header_mapping: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + timeout_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, default=None) + expose_passthrough: Mapped[bool] = mapped_column(Boolean, default=True) + allowlist: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) + plugin_chain_pre: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) + plugin_chain_post: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) + + # Federation relationship with a local gateway gateway_id: Mapped[Optional[str]] = mapped_column(ForeignKey("gateways.id")) # gateway_slug: Mapped[Optional[str]] = mapped_column(ForeignKey("gateways.slug")) diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index 037a04db1..9d3d98570 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -27,6 +27,7 @@ import logging import re from typing import Any, Dict, List, Literal, Optional, Self, Union +from urllib.parse import urlparse # Third-Party from pydantic import AnyHttpUrl, BaseModel, ConfigDict, EmailStr, Field, field_serializer, field_validator, model_validator, ValidationInfo @@ -421,6 +422,21 @@ class ToolCreate(BaseModel): team_id: Optional[str] = Field(None, description="Team ID for resource organization") owner_email: Optional[str] = Field(None, description="Email of the tool owner") visibility: Optional[str] = Field(default="public", description="Visibility level (private, team, public)") + + # Passthrough REST fields + base_url: Optional[str] = Field(None, description="Base URL for REST passthrough") + path_template: Optional[str] = Field(None, description="Path template for REST passthrough") + query_mapping: Optional[Dict[str, Any]] = Field(None, description="Query mapping for REST passthrough") + header_mapping: Optional[Dict[str, Any]] = Field(None, description="Header mapping for REST passthrough") + timeout_ms: Optional[int] = Field( + default=None, + description="Timeout in milliseconds for REST passthrough (20000 if integration_type='REST', else None)" + ) + expose_passthrough: Optional[bool] = Field(True, description="Expose passthrough endpoint for this tool") + allowlist: Optional[List[str]] = Field(None, description="Allowed upstream hosts/schemes for passthrough") + plugin_chain_pre: Optional[List[str]] = Field(None, description="Pre-plugin chain for passthrough") + plugin_chain_post: Optional[List[str]] = Field(None, description="Post-plugin chain for passthrough") + @field_validator("tags") @classmethod @@ -752,7 +768,172 @@ def prevent_manual_mcp_creation(cls, values: Dict[str, Any]) -> Dict[str, Any]: if integration_type == "A2A" and not allow_auto: raise ValueError("Cannot manually create A2A tools. Add A2A agents via the A2A interface - tools will be auto-created when agents are associated with servers.") return values + + @model_validator(mode="before") + @classmethod + def enforce_passthrough_fields_for_rest(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """ + Enforce that passthrough REST fields are only set for integration_type 'REST'. + If any passthrough field is set for non-REST, raise ValueError. + + Args: + values (Dict[str, Any]): The input values to validate. + + Returns: + Dict[str, Any]: The validated values. + + Raises: + ValueError: If passthrough fields are set for non-REST integration_type. + """ + passthrough_fields = ["base_url", "path_template", "query_mapping", "header_mapping", "timeout_ms", "expose_passthrough", "allowlist", "plugin_chain_pre", "plugin_chain_post"] + integration_type = values.get("integration_type") + if integration_type != "REST": + for field in passthrough_fields: + if field in values and values[field] not in (None, [], {}): + raise ValueError(f"Field '{field}' is only allowed for integration_type 'REST'.") + return values + + @model_validator(mode="before") + @classmethod + def extract_base_url_and_path_template(cls, values: dict) -> dict: + """ + Only for integration_type 'REST': + If 'url' is provided, extract 'base_url' and 'path_template'. + Ensures path_template starts with a single '/'. + + Args: + values (dict): The input values to process. + + Returns: + dict: The updated values with base_url and path_template if applicable. + """ + integration_type = values.get("integration_type") + if integration_type != "REST": + # Only process for REST, skip for others + return values + url = values.get("url") + if url: + parsed = urlparse(str(url)) + base_url = f"{parsed.scheme}://{parsed.netloc}" + path_template = parsed.path + # Ensure path_template starts with a single '/' + if path_template: + path_template = "/" + path_template.lstrip("/") + if not values.get("base_url"): + values["base_url"] = base_url + if not values.get("path_template"): + values["path_template"] = path_template + return values + + @field_validator("base_url") + @classmethod + def validate_base_url(cls, v): + """ + Validate that base_url is a valid URL with scheme and netloc. + + Args: + v (str): The base_url value to validate. + + Returns: + str: The validated base_url value. + + Raises: + ValueError: If base_url is not a valid URL. + """ + if v is None: + return v + parsed = urlparse(str(v)) + if not parsed.scheme or not parsed.netloc: + raise ValueError("base_url must be a valid URL with scheme and netloc") + return v + + @field_validator("path_template") + @classmethod + def validate_path_template(cls, v): + """ + Validate that path_template starts with '/'. + + Args: + v (str): The path_template value to validate. + + Returns: + str: The validated path_template value. + + Raises: + ValueError: If path_template does not start with '/'. + """ + if v and not str(v).startswith("/"): + raise ValueError("path_template must start with '/'") + return v + + @field_validator("timeout_ms") + @classmethod + def validate_timeout_ms(cls, v): + """ + Validate that timeout_ms is a positive integer. + + Args: + v (int): The timeout_ms value to validate. + + Returns: + int: The validated timeout_ms value. + + Raises: + ValueError: If timeout_ms is not a positive integer. + """ + if v is not None and v <= 0: + raise ValueError("timeout_ms must be a positive integer") + return v + + @field_validator("allowlist") + @classmethod + def validate_allowlist(cls, v): + """ + Validate that allowlist is a list and each entry is a valid host or scheme string. + + Args: + v (List[str]): The allowlist to validate. + + Returns: + List[str]: The validated allowlist. + + Raises: + ValueError: If allowlist is not a list or any entry is not a valid host/scheme string. + """ + if v is None: + return None + if not isinstance(v, list): + raise ValueError("allowlist must be a list of host/scheme strings") + hostname_regex = re.compile(r"^(https?://)?([a-zA-Z0-9.-]+)(:[0-9]+)?$") + for host in v: + if not isinstance(host, str): + raise ValueError(f"Invalid type in allowlist: {host} (must be str)") + if not hostname_regex.match(host): + raise ValueError(f"Invalid host/scheme in allowlist: {host}") + return v + + @field_validator("plugin_chain_pre", "plugin_chain_post") + @classmethod + def validate_plugin_chain(cls, v): + """ + Validate that each plugin in the chain is allowed. + + Args: + v (List[str]): The plugin chain to validate. + + Returns: + List[str]: The validated plugin chain. + Raises: + ValueError: If any plugin is not in the allowed set. + """ + allowed_plugins = {"deny_filter", "rate_limit", "pii_filter", "response_shape", "regex_filter", "resource_filter"} + if v is None: + return v + for plugin in v: + if plugin not in allowed_plugins: + raise ValueError(f"Unknown plugin: {plugin}") + return v class ToolUpdate(BaseModelWithConfigDict): """Schema for updating an existing tool. @@ -777,6 +958,20 @@ class ToolUpdate(BaseModelWithConfigDict): tags: Optional[List[str]] = Field(None, description="Tags for categorizing the tool") visibility: Optional[str] = Field(default="public", description="Visibility level: private, team, or public") + # Passthrough REST fields + base_url: Optional[str] = Field(None, description="Base URL for REST passthrough") + path_template: Optional[str] = Field(None, description="Path template for REST passthrough") + query_mapping: Optional[Dict[str, Any]] = Field(None, description="Query mapping for REST passthrough") + header_mapping: Optional[Dict[str, Any]] = Field(None, description="Header mapping for REST passthrough") + timeout_ms: Optional[int] = Field( + default=None, + description="Timeout in milliseconds for REST passthrough (20000 if integration_type='REST', else None)" + ) + expose_passthrough: Optional[bool] = Field(True, description="Expose passthrough endpoint for this tool") + allowlist: Optional[List[str]] = Field(None, description="Allowed upstream hosts/schemes for passthrough") + plugin_chain_pre: Optional[List[str]] = Field(None, description="Pre-plugin chain for passthrough") + plugin_chain_post: Optional[List[str]] = Field(None, description="Post-plugin chain for passthrough") + @field_validator("tags") @classmethod def validate_tags(cls, v: Optional[List[str]]) -> List[str]: @@ -1011,6 +1206,144 @@ def prevent_manual_mcp_update(cls, values: Dict[str, Any]) -> Dict[str, Any]: if integration_type == "A2A": raise ValueError("Cannot update tools to A2A integration type. A2A tools are managed by the A2A service.") return values + @model_validator(mode="before") + @classmethod + def extract_base_url_and_path_template(cls, values: dict) -> dict: + """ + If 'integration_type' is 'REST' and 'url' is provided, extract 'base_url' and 'path_template'. + Ensures path_template starts with a single '/'. + + Args: + values (dict): The input values to process. + + Returns: + dict: The updated values with base_url and path_template if applicable. + """ + integration_type = values.get("integration_type") + url = values.get("url") + if integration_type == "REST" and url: + parsed = urlparse(str(url)) + base_url = f"{parsed.scheme}://{parsed.netloc}" + path_template = parsed.path + # Ensure path_template starts with a single '/' + if path_template and not path_template.startswith("/"): + path_template = "/" + path_template.lstrip("/") + elif path_template: + path_template = "/" + path_template.lstrip("/") + if not values.get("base_url"): + values["base_url"] = base_url + if not values.get("path_template"): + values["path_template"] = path_template + return values + + @field_validator("base_url") + @classmethod + def validate_base_url(cls, v): + """ + Validate that base_url is a valid URL with scheme and netloc. + + Args: + v (str): The base_url value to validate. + + Returns: + str: The validated base_url value. + + Raises: + ValueError: If base_url is not a valid URL. + """ + if v is None: + return v + parsed = urlparse(str(v)) + if not parsed.scheme or not parsed.netloc: + raise ValueError("base_url must be a valid URL with scheme and netloc") + return v + + @field_validator("path_template") + @classmethod + def validate_path_template(cls, v): + """ + Validate that path_template starts with '/'. + + Args: + v (str): The path_template value to validate. + + Returns: + str: The validated path_template value. + + Raises: + ValueError: If path_template does not start with '/'. + """ + if v and not str(v).startswith("/"): + raise ValueError("path_template must start with '/'") + return v + + @field_validator("timeout_ms") + @classmethod + def validate_timeout_ms(cls, v): + """ + Validate that timeout_ms is a positive integer. + + Args: + v (int): The timeout_ms value to validate. + + Returns: + int: The validated timeout_ms value. + + Raises: + ValueError: If timeout_ms is not a positive integer. + """ + if v is not None and v <= 0: + raise ValueError("timeout_ms must be a positive integer") + return v + + @field_validator("allowlist") + @classmethod + def validate_allowlist(cls, v): + """ + Validate that allowlist is a list and each entry is a valid host or scheme string. + + Args: + v (List[str]): The allowlist to validate. + + Returns: + List[str]: The validated allowlist. + + Raises: + ValueError: If allowlist is not a list or any entry is not a valid host/scheme string. + """ + if v is None: + return None + if not isinstance(v, list): + raise ValueError("allowlist must be a list of host/scheme strings") + hostname_regex = re.compile(r"^(https?://)?([a-zA-Z0-9.-]+)(:[0-9]+)?$") + for host in v: + if not isinstance(host, str): + raise ValueError(f"Invalid type in allowlist: {host} (must be str)") + if not hostname_regex.match(host): + raise ValueError(f"Invalid host/scheme in allowlist: {host}") + return v + @field_validator("plugin_chain_pre", "plugin_chain_post") + @classmethod + def validate_plugin_chain(cls, v): + """ + Validate that each plugin in the chain is allowed. + + Args: + v (List[str]): The plugin chain to validate. + + Returns: + List[str]: The validated plugin chain. + + Raises: + ValueError: If any plugin is not in the allowed set. + """ + allowed_plugins = {"deny_filter", "rate_limit", "pii_filter", "response_shape", "regex_filter", "resource_filter"} + if v is None: + return v + for plugin in v: + if plugin not in allowed_plugins: + raise ValueError(f"Unknown plugin: {plugin}") + return v class ToolRead(BaseModelWithConfigDict): @@ -1074,6 +1407,17 @@ class ToolRead(BaseModelWithConfigDict): owner_email: Optional[str] = Field(None, description="Email of the user who owns this resource") visibility: Optional[str] = Field(default="public", description="Visibility level: private, team, or public") + # Passthrough REST fields + base_url: Optional[str] = Field(None, description="Base URL for REST passthrough") + path_template: Optional[str] = Field(None, description="Path template for REST passthrough") + query_mapping: Optional[Dict[str, Any]] = Field(None, description="Query mapping for REST passthrough") + header_mapping: Optional[Dict[str, Any]] = Field(None, description="Header mapping for REST passthrough") + timeout_ms: Optional[int] = Field(20000, description="Timeout in milliseconds for REST passthrough") + expose_passthrough: Optional[bool] = Field(True, description="Expose passthrough endpoint for this tool") + allowlist: Optional[List[str]] = Field(None, description="Allowed upstream hosts/schemes for passthrough") + plugin_chain_pre: Optional[List[str]] = Field(None, description="Pre-plugin chain for passthrough") + plugin_chain_post: Optional[List[str]] = Field(None, description="Post-plugin chain for passthrough") + class ToolInvocation(BaseModelWithConfigDict): """Schema for tool invocation requests. diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 05414b2b6..90c806a3d 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -486,6 +486,16 @@ async def register_tool( team_id=team_id, owner_email=owner_email or created_by, visibility=visibility, + # passthrough REST tools fields + base_url=tool.base_url if tool.integration_type == "REST" else None, + path_template=tool.path_template if tool.integration_type == "REST" else None, + query_mapping=tool.query_mapping if tool.integration_type == "REST" else None, + header_mapping=tool.header_mapping if tool.integration_type == "REST" else None, + timeout_ms=tool.timeout_ms if tool.integration_type == "REST" else None, + expose_passthrough=(tool.expose_passthrough if tool.integration_type == "REST" and tool.expose_passthrough is not None else True) if tool.integration_type == "REST" else None, + allowlist=tool.allowlist if tool.integration_type == "REST" else None, + plugin_chain_pre=tool.plugin_chain_pre if tool.integration_type == "REST" else None, + plugin_chain_post=tool.plugin_chain_post if tool.integration_type == "REST" else None, ) db.add(db_tool) diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index 1c0fd1746..e38b028c4 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -1,3 +1,86 @@ +// Add three fields to passthrough section on Advanced button click +function handleAddPassthrough() { + const passthroughContainer = safeGetElement("passthrough-container"); + if (!passthroughContainer) { + console.error("Passthrough container not found"); + return; + } + + // Toggle visibility + if (passthroughContainer.style.display === "none" || passthroughContainer.style.display === "") { + passthroughContainer.style.display = "block"; + // Add fields only if not already present + if (!document.getElementById("query-mapping-field")) { + const queryDiv = document.createElement("div"); + queryDiv.className = "mb-4"; + queryDiv.innerHTML = ` + + + `; + passthroughContainer.appendChild(queryDiv); + } + if (!document.getElementById("header-mapping-field")) { + const headerDiv = document.createElement("div"); + headerDiv.className = "mb-4"; + headerDiv.innerHTML = ` + + + `; + passthroughContainer.appendChild(headerDiv); + } + if (!document.getElementById("timeout-ms-field")) { + const timeoutDiv = document.createElement("div"); + timeoutDiv.className = "mb-4"; + timeoutDiv.innerHTML = ` + + + `; + passthroughContainer.appendChild(timeoutDiv); + } + if (!document.getElementById("expose-passthrough-field")) { + const exposeDiv = document.createElement("div"); + exposeDiv.className = "mb-4"; + exposeDiv.innerHTML = ` + + + `; + passthroughContainer.appendChild(exposeDiv); + } + if (!document.getElementById("allowlist-field")) { + const allowlistDiv = document.createElement("div"); + allowlistDiv.className = "mb-4"; + allowlistDiv.innerHTML = ` + + + `; + passthroughContainer.appendChild(allowlistDiv); + } + if (!document.getElementById("plugin-chain-pre-field")) { + const pluginPreDiv = document.createElement("div"); + pluginPreDiv.className = "mb-4"; + pluginPreDiv.innerHTML = ` + + + `; + passthroughContainer.appendChild(pluginPreDiv); + } + if (!document.getElementById("plugin-chain-post-field")) { + const pluginPostDiv = document.createElement("div"); + pluginPostDiv.className = "mb-4"; + pluginPostDiv.innerHTML = ` + + + `; + passthroughContainer.appendChild(pluginPostDiv); + } + } else { + passthroughContainer.style.display = "none"; + } +} + // Make URL field read-only for integration type MCP function updateEditToolUrl() { const editTypeField = document.getElementById("edit-tool-type"); @@ -9098,7 +9181,12 @@ function setupFormHandlers() { if (paramButton) { paramButton.addEventListener("click", handleAddParameter); } - + + const passthroughButton = safeGetElement("add-passthrough-btn"); + if (passthroughButton) { + passthroughButton.addEventListener("click", handleAddPassthrough); + } + const serverForm = safeGetElement("add-server-form"); if (serverForm) { serverForm.addEventListener("submit", handleServerFormSubmit); diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index 904b2e606..44e52a5fe 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -3079,6 +3079,17 @@

> Add New Parameter + + +