Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 34 additions & 54 deletions mcpgateway/cache/session_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from mcpgateway.services.logging_service import LoggingService
from mcpgateway.transports import SSETransport
from mcpgateway.utils.retry_manager import ResilientHttpClient
from mcpgateway.validation.jsonrpc import JSONRPCError

# Initialize logging service first
logging_service: LoggingService = LoggingService()
Expand Down Expand Up @@ -1276,19 +1277,43 @@ async def generate_response(self, message: Dict[str, Any], transport: SSETranspo
>>> # Response: {}
"""
result = {}

if "method" in message and "id" in message:
method = message["method"]
params = message.get("params", {})
req_id = message["id"]
db = next(get_db())
if method == "initialize":
init_result = await self.handle_initialize_logic(params)
response = {
try:
method = message["method"]
params = message.get("params", {})
params["server_id"] = server_id
req_id = message["id"]

rpc_input = {
"jsonrpc": "2.0",
"result": init_result.model_dump(by_alias=True, exclude_none=True),
"method": method,
"params": params,
"id": req_id,
}
await transport.send_message(response)
headers = {"Authorization": f"Bearer {user['token']}", "Content-Type": "application/json"}
rpc_url = base_url + "/rpc"
async with ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) as client:
rpc_response = await client.post(
url=rpc_url,
json=rpc_input,
headers=headers,
)
result = rpc_response.json()
result = result.get("result", {})

response = {"jsonrpc": "2.0", "result": result, "id": req_id}
except JSONRPCError as e:
result = e.to_dict()
response = {"jsonrpc": "2.0", "error": result["error"], "id": req_id}
except Exception as e:
result = {"code": -32000, "message": "Internal error", "data": str(e)}
response = {"jsonrpc": "2.0", "error": result, "id": req_id}

logging.debug(f"Sending sse message:{response}")
await transport.send_message(response)

if message["method"] == "initialize":
await transport.send_message(
{
"jsonrpc": "2.0",
Expand All @@ -1309,48 +1334,3 @@ async def generate_response(self, message: Dict[str, Any], transport: SSETranspo
"params": {},
}
)
elif method == "tools/list":
if server_id:
tools = await tool_service.list_server_tools(db, server_id=server_id)
else:
tools = await tool_service.list_tools(db)
result = {"tools": [t.model_dump(by_alias=True, exclude_none=True) for t in tools]}
elif method == "resources/list":
if server_id:
resources = await resource_service.list_server_resources(db, server_id=server_id)
else:
resources = await resource_service.list_resources(db)
result = {"resources": [r.model_dump(by_alias=True, exclude_none=True) for r in resources]}
elif method == "prompts/list":
if server_id:
prompts = await prompt_service.list_server_prompts(db, server_id=server_id)
else:
prompts = await prompt_service.list_prompts(db)
result = {"prompts": [p.model_dump(by_alias=True, exclude_none=True) for p in prompts]}
elif method == "prompts/get":
prompts = await prompt_service.get_prompt(db, name=params.get("name"), arguments=params.get("arguments", {}))
result = prompts.model_dump(by_alias=True, exclude_none=True)
elif method == "ping":
result = {}
elif method == "tools/call":
rpc_input = {
"jsonrpc": "2.0",
"method": message["params"]["name"],
"params": message["params"]["arguments"],
"id": 1,
}
headers = {"Authorization": f"Bearer {user['token']}", "Content-Type": "application/json"}
rpc_url = base_url + "/rpc"
async with ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) as client:
rpc_response = await client.post(
url=rpc_url,
json=rpc_input,
headers=headers,
)
result = rpc_response.json()
else:
result = {}

response = {"jsonrpc": "2.0", "result": result, "id": req_id}
logging.info(f"Sending sse message:{response}")
await transport.send_message(response)
95 changes: 67 additions & 28 deletions mcpgateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from mcpgateway.db import Prompt as DbPrompt
from mcpgateway.db import PromptMetric, refresh_slugs_on_startup, SessionLocal
from mcpgateway.handlers.sampling import SamplingHandler
from mcpgateway.models import InitializeRequest, InitializeResult, ListResourceTemplatesResult, LogLevel, ResourceContent, Root
from mcpgateway.models import InitializeResult, ListResourceTemplatesResult, LogLevel, ResourceContent, Root
from mcpgateway.observability import init_telemetry
from mcpgateway.plugins import PluginManager, PluginViolationError
from mcpgateway.schemas import (
Expand Down Expand Up @@ -2225,39 +2225,57 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user: str
logger.debug(f"User {user} made an RPC request")
body = await request.json()
method = body["method"]
# rpc_id = body.get("id")
req_id = body.get("id") if "body" in locals() else None
params = body.get("params", {})
server_id = params.get("server_id", None)
cursor = params.get("cursor") # Extract cursor parameter

RPCRequest(jsonrpc="2.0", method=method, params=params) # Validate the request body against the RPCRequest model

if method == "tools/list":
tools = await tool_service.list_tools(db, cursor=cursor)
result = [t.model_dump(by_alias=True, exclude_none=True) for t in tools]
if method == "initialize":
result = await session_registry.handle_initialize_logic(body.get("params", {}))
if hasattr(result, "model_dump"):
result = result.model_dump(by_alias=True, exclude_none=True)
elif method == "tools/list":
if server_id:
tools = await tool_service.list_server_tools(db, server_id, cursor=cursor)
else:
tools = await tool_service.list_tools(db, cursor=cursor)
result = {"tools": [t.model_dump(by_alias=True, exclude_none=True) for t in tools]}
elif method == "list_tools": # Legacy endpoint
tools = await tool_service.list_tools(db, cursor=cursor)
result = [t.model_dump(by_alias=True, exclude_none=True) for t in tools]
elif method == "initialize":
result = initialize(
InitializeRequest(
protocol_version=params.get("protocolVersion") or params.get("protocol_version", ""),
capabilities=params.get("capabilities", {}),
client_info=params.get("clientInfo") or params.get("client_info", {}),
),
user,
).model_dump(by_alias=True, exclude_none=True)
if server_id:
tools = await tool_service.list_server_tools(db, server_id, cursor=cursor)
else:
tools = await tool_service.list_tools(db, cursor=cursor)
result = {"tools": [t.model_dump(by_alias=True, exclude_none=True) for t in tools]}
elif method == "list_gateways":
gateways = await gateway_service.list_gateways(db, include_inactive=False)
result = [g.model_dump(by_alias=True, exclude_none=True) for g in gateways]
result = {"gateways": [g.model_dump(by_alias=True, exclude_none=True) for g in gateways]}
elif method == "list_roots":
roots = await root_service.list_roots()
result = [r.model_dump(by_alias=True, exclude_none=True) for r in roots]
result = {"roots": [r.model_dump(by_alias=True, exclude_none=True) for r in roots]}
elif method == "resources/list":
resources = await resource_service.list_resources(db)
result = [r.model_dump(by_alias=True, exclude_none=True) for r in resources]
if server_id:
resources = await resource_service.list_server_resources(db, server_id)
else:
resources = await resource_service.list_resources(db)
result = {"resources": [r.model_dump(by_alias=True, exclude_none=True) for r in resources]}
elif method == "resources/read":
uri = params.get("uri")
request_id = params.get("requestId", None)
if not uri:
raise JSONRPCError(-32602, "Missing resource URI in parameters", params)
result = await resource_service.read_resource(db, uri, request_id=request_id, user=user)
if hasattr(result, "model_dump"):
result = {"contents": [result.model_dump(by_alias=True, exclude_none=True)]}
else:
result = {"contents": [result]}
elif method == "prompts/list":
prompts = await prompt_service.list_prompts(db, cursor=cursor)
result = [p.model_dump(by_alias=True, exclude_none=True) for p in prompts]
if server_id:
prompts = await prompt_service.list_server_prompts(db, server_id, cursor=cursor)
else:
prompts = await prompt_service.list_prompts(db, cursor=cursor)
result = {"prompts": [p.model_dump(by_alias=True, exclude_none=True) for p in prompts]}
elif method == "prompts/get":
name = params.get("name")
arguments = params.get("arguments", {})
Expand All @@ -2269,31 +2287,52 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user: str
elif method == "ping":
# Per the MCP spec, a ping returns an empty result.
result = {}
else:
elif method == "tools/call":
# Get request headers
headers = {k.lower(): v for k, v in request.headers.items()}
name = params.get("name")
arguments = params.get("arguments", {})
if not name:
raise JSONRPCError(-32602, "Missing tool name in parameters", params)
try:
result = await tool_service.invoke_tool(db=db, name=method, arguments=params, request_headers=headers)
result = await tool_service.invoke_tool(db=db, name=name, arguments=arguments, request_headers=headers)
if hasattr(result, "model_dump"):
result = result.model_dump(by_alias=True, exclude_none=True)
except ValueError:
result = await gateway_service.forward_request(db, method, params)
if hasattr(result, "model_dump"):
result = result.model_dump(by_alias=True, exclude_none=True)
# TODO: Implement methods
elif method == "resources/templates/list":
result = {}
elif method.startswith("roots/"):
result = {}
elif method.startswith("notifications/"):
result = {}
elif method.startswith("sampling/"):
result = {}
elif method.startswith("elicitation/"):
result = {}
elif method.startswith("completion/"):
result = {}
elif method.startswith("logging/"):
result = {}
else:
raise JSONRPCError(-32000, "Invalid method", params)

response = result
return response
return {"jsonrpc": "2.0", "result": result, "id": req_id}

except JSONRPCError as e:
return e.to_dict()
error = e.to_dict()
return {"jsonrpc": "2.0", "error": error["error"], "id": req_id}
except Exception as e:
if isinstance(e, ValueError):
return JSONResponse(content={"message": "Method invalid"}, status_code=422)
logger.error(f"RPC error: {str(e)}")
return {
"jsonrpc": "2.0",
"error": {"code": -32000, "message": "Internal error", "data": str(e)},
"id": body.get("id") if "body" in locals() else None,
"id": req_id,
}


Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/test_main_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,7 @@ async def test_rpc_ping(self, client: AsyncClient, mock_auth):

assert response.status_code == 200
result = response.json()
assert result == {} # ping returns empty result
assert result == {"jsonrpc": "2.0", "result": {}, "id": "test-123"} # ping returns empty result

async def test_rpc_list_tools(self, client: AsyncClient, mock_auth):
"""Test POST /rpc - tools/list method."""
Expand All @@ -1334,7 +1334,7 @@ async def test_rpc_list_tools(self, client: AsyncClient, mock_auth):

assert response.status_code == 200
result = response.json()
assert isinstance(result, list)
assert isinstance(result.get("result", {}).get("tools"), list)

async def test_rpc_invalid_method(self, client: AsyncClient, mock_auth):
"""Test POST /rpc with invalid method."""
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,10 @@ def test_rpc_tool_invocation_flow(
"is_error": False,
}

rpc_body = {"jsonrpc": "2.0", "id": 7, "method": "test_tool", "params": {"foo": "bar"}}
rpc_body = {"jsonrpc": "2.0", "id": 7, "method": "tools/call", "params": {"name": "test_tool", "arguments": {"foo": "bar"}}}
resp = test_client.post("/rpc/", json=rpc_body, headers=auth_headers)
assert resp.status_code == 200
assert resp.json()["content"][0]["text"] == "ok"
assert resp.json()["result"]["content"][0]["text"] == "ok"
mock_invoke.assert_awaited_once_with(db=ANY, name="test_tool", arguments={"foo": "bar"}, request_headers=ANY)

# --------------------------------------------------------------------- #
Expand Down
Loading
Loading