Skip to content

Commit 46d4bad

Browse files
Fixed Translate streamablehttp (IBM#729)
* Fixed streamable http support for translating stdio servers Signed-off-by: Keval Mahajan <mahajankeval23@gmail.com> * minor changes Signed-off-by: Keval Mahajan <mahajankeval23@gmail.com> * Fix bandit B112 warning: Use specific exception types - Replace generic Exception with specific json.JSONDecodeError and ValueError - Addresses bandit security scan warning about try/except/continue pattern Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> * Fix pre-commit hook to exclude test helper files - Update name-tests-test hook to exclude files in helpers/ directories - Prevents false positives for helper modules like trace_generator.py - Keeps the test naming convention check for actual test files Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> --------- Signed-off-by: Keval Mahajan <mahajankeval23@gmail.com> Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> Co-authored-by: Mihai Criveti <crivetimihai@gmail.com>
1 parent 165eb0f commit 46d4bad

File tree

2 files changed

+112
-34
lines changed

2 files changed

+112
-34
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ repos:
368368
description: Verifies test files in tests/ directories start with `test_`.
369369
language: python
370370
files: (^|/)tests/.+\.py$
371-
exclude: ^tests/.*/pages/.*\.py$ # Exclude page object files
371+
exclude: ^tests/.*/(pages|helpers)/.*\.py$ # Exclude page object and helper files
372372
args: [--pytest-test-first] # `test_.*\.py`
373373

374374
# - repo: https://github.com/pycqa/flake8

mcpgateway/translate.py

Lines changed: 111 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@
127127
# Third-Party
128128
from fastapi import FastAPI, Request, Response, status
129129
from fastapi.middleware.cors import CORSMiddleware
130-
from fastapi.responses import PlainTextResponse
130+
from fastapi.responses import JSONResponse, PlainTextResponse
131131
from sse_starlette.sse import EventSourceResponse
132132
import uvicorn
133133

@@ -1633,13 +1633,13 @@ async def _run_multi_protocol_server( # pylint: disable=too-many-positional-arg
16331633
LOGGER.info(f"Starting multi-protocol server for command: {cmd}")
16341634
LOGGER.info(f"Protocols: SSE={expose_sse}, StreamableHTTP={expose_streamable_http}")
16351635

1636-
# Create the pubsub for SSE if needed
1637-
pubsub = _PubSub() if expose_sse else None
1636+
# Create a shared pubsub whenever either protocol needs stdout observations
1637+
pubsub = _PubSub() if (expose_sse or expose_streamable_http) else None
16381638

16391639
# Create the stdio endpoint
1640-
stdio = StdIOEndpoint(cmd, pubsub) if expose_sse else None
1640+
stdio = StdIOEndpoint(cmd, pubsub) if (expose_sse or expose_streamable_http) else None
16411641

1642-
# Create the FastAPI app
1642+
# Create fastapi app and middleware
16431643
app = FastAPI()
16441644

16451645
# Add CORS middleware if specified
@@ -1652,10 +1652,13 @@ async def _run_multi_protocol_server( # pylint: disable=too-many-positional-arg
16521652
allow_headers=["*"],
16531653
)
16541654

1655-
# Add SSE endpoints if requested
1656-
if expose_sse and stdio:
1655+
# Start stdio if at least one transport requires it
1656+
if stdio:
16571657
await stdio.start()
16581658

1659+
# SSE endpoints
1660+
if expose_sse and stdio and pubsub:
1661+
16591662
@app.get(sse_path)
16601663
async def get_sse(request: Request) -> EventSourceResponse:
16611664
"""SSE endpoint.
@@ -1746,9 +1749,14 @@ async def health() -> Response:
17461749
"""
17471750
return PlainTextResponse("ok")
17481751

1749-
# Add streamable HTTP endpoint if requested
1752+
# Streamable HTTP support
17501753
streamable_server = None
17511754
streamable_manager = None
1755+
streamable_context = None
1756+
1757+
# Keep a reference to the original FastAPI app so we can wrap it with an ASGI
1758+
# layer that delegates `/mcp` scopes to the StreamableHTTPSessionManager if present.
1759+
original_app = app
17521760

17531761
if expose_streamable_http:
17541762
# Create an MCP server instance
@@ -1761,39 +1769,110 @@ async def health() -> Response:
17611769
json_response=json_response,
17621770
)
17631771

1764-
# Store the original app before modifying
1765-
original_app = app
1772+
# Register POST /mcp on the FastAPI app as the canonical client->server POST
1773+
# path for Streamable HTTP. This forwards JSON-RPC notifications/requests to stdio.
1774+
@original_app.post("/mcp")
1775+
async def mcp_post(request: Request) -> Response:
1776+
"""
1777+
Handles POST requests to the `/mcp` endpoint, forwarding JSON payloads to stdio
1778+
and optionally waiting for a correlated response.
17661779
1767-
# Create a custom middleware for handling MCP requests
1768-
async def mcp_middleware(scope, receive, send):
1769-
"""Middleware to handle MCP requests via streamable HTTP.
1780+
The request body is expected to be a JSON object or newline-delimited JSON.
1781+
If the JSON includes an "id" field, the function attempts to match it with
1782+
a response from stdio using a pubsub queue, within a timeout period.
17701783
17711784
Args:
1772-
scope: ASGI scope dictionary.
1773-
receive: ASGI receive callable.
1774-
send: ASGI send callable.
1785+
request (Request): The incoming FastAPI request containing the JSON payload.
17751786
1776-
Examples:
1777-
>>> async def test_middleware():
1778-
... scope = {"type": "http", "path": "/mcp"}
1779-
... async def receive(): return {}
1780-
... async def send(msg): return None
1781-
... # Would route to streamable_manager for /mcp
1782-
... return scope["path"] == "/mcp"
1783-
>>> import asyncio
1784-
>>> asyncio.run(test_middleware())
1787+
Returns:
1788+
Response: A FastAPI Response object.
1789+
- 200 OK with matched JSON response if correlation succeeds.
1790+
- 202 Accepted if no matching response is received in time or for notifications.
1791+
- 400 Bad Request if the payload is not valid JSON.
1792+
1793+
Example:
1794+
>>> import httpx
1795+
>>> response = httpx.post("http://localhost:8000/mcp", json={"id": 123, "method": "ping"})
1796+
>>> response.status_code in (200, 202)
17851797
True
1798+
>>> response.text # May be the matched JSON or "accepted"
1799+
'{"id": 123, "result": "pong"}' # or "accepted"
17861800
"""
1787-
if scope["type"] == "http" and scope["path"] == "/mcp":
1788-
await streamable_manager.handle_request(scope, receive, send)
1801+
# Read and validate JSON
1802+
body = await request.body()
1803+
try:
1804+
obj = json.loads(body)
1805+
except Exception as exc:
1806+
return PlainTextResponse(f"Invalid JSON payload: {exc}", status_code=status.HTTP_400_BAD_REQUEST)
1807+
1808+
# Forward raw newline-delimited JSON to stdio
1809+
await stdio.send(body.decode().rstrip() + "\n")
1810+
1811+
# If it's a request (has an id) -> attempt to correlate response from stdio
1812+
if isinstance(obj, dict) and "id" in obj:
1813+
if not pubsub:
1814+
return PlainTextResponse("accepted", status_code=status.HTTP_202_ACCEPTED)
1815+
1816+
queue = pubsub.subscribe()
1817+
try:
1818+
timeout = 10.0 # seconds; tuneable
1819+
deadline = asyncio.get_event_loop().time() + timeout
1820+
while True:
1821+
remaining = max(0.0, deadline - asyncio.get_event_loop().time())
1822+
if remaining == 0:
1823+
break
1824+
try:
1825+
msg = await asyncio.wait_for(queue.get(), timeout=remaining)
1826+
except asyncio.TimeoutError:
1827+
break
1828+
1829+
# stdio stdout lines may contain JSON objects or arrays
1830+
try:
1831+
parsed = json.loads(msg)
1832+
except (json.JSONDecodeError, ValueError):
1833+
# not JSON -> skip
1834+
continue
1835+
1836+
candidates = parsed if isinstance(parsed, list) else [parsed]
1837+
for candidate in candidates:
1838+
if isinstance(candidate, dict) and candidate.get("id") == obj.get("id"):
1839+
# return the matched response as JSON
1840+
return JSONResponse(candidate)
1841+
1842+
# timeout -> accept and return 202
1843+
return PlainTextResponse("accepted (no response yet)", status_code=status.HTTP_202_ACCEPTED)
1844+
finally:
1845+
pubsub.unsubscribe(queue)
1846+
1847+
# Notification -> return 202
1848+
return PlainTextResponse("accepted", status_code=status.HTTP_202_ACCEPTED)
1849+
1850+
# ASGI wrapper to route GET/other /mcp scopes to streamable_manager.handle_request
1851+
async def mcp_asgi_wrapper(scope, receive, send):
1852+
"""
1853+
ASGI middleware that intercepts HTTP requests to the `/mcp` endpoint.
1854+
1855+
If the request is an HTTP call to `/mcp` and a `streamable_manager` is available,
1856+
it can handle the request (currently commented out). All other requests are
1857+
passed to the original FastAPI application.
1858+
1859+
Args:
1860+
scope (dict): The ASGI scope dictionary containing request metadata.
1861+
receive (Callable): An awaitable that yields incoming ASGI events.
1862+
send (Callable): An awaitable used to send ASGI events.
1863+
"""
1864+
if scope.get("type") == "http" and scope.get("path") == "/mcp" and streamable_manager:
1865+
# Let StreamableHTTPSessionManager handle session-oriented streaming
1866+
# await streamable_manager.handle_request(scope, receive, send)
1867+
await original_app(scope, receive, send)
17891868
else:
1790-
# Pass through to the original app for other routes
1869+
# Delegate everything else to the original FastAPI app
17911870
await original_app(scope, receive, send)
17921871

1793-
# Replace the app with our middleware wrapper
1794-
app = mcp_middleware
1872+
# Replace the app used by uvicorn with the ASGI wrapper
1873+
app = mcp_asgi_wrapper
17951874

1796-
# Run the server
1875+
# ---------------------- Server lifecycle ----------------------
17971876
config = uvicorn.Config(
17981877
app,
17991878
host=host,
@@ -1821,8 +1900,7 @@ async def _shutdown() -> None:
18211900
with suppress(NotImplementedError):
18221901
loop.add_signal_handler(sig, lambda: asyncio.create_task(_shutdown()))
18231902

1824-
# Start streamable HTTP manager if needed
1825-
streamable_context = None
1903+
# If we have a streamable manager, start its context so it can accept ASGI /mcp
18261904
if streamable_manager:
18271905
streamable_context = streamable_manager.run()
18281906
await streamable_context.__aenter__() # pylint: disable=unnecessary-dunder-call,no-member

0 commit comments

Comments
 (0)