diff --git a/sentry_sdk/consts.py b/sentry_sdk/consts.py index 33900acd50..a5fc8e5a99 100644 --- a/sentry_sdk/consts.py +++ b/sentry_sdk/consts.py @@ -706,6 +706,12 @@ class SPANDATA: Example: 6379 """ + NETWORK_TRANSPORT = "network.transport" + """ + The transport protocol used for the network connection. + Example: "tcp", "udp", "unix" + """ + PROFILER_ID = "profiler_id" """ Label identifying the profiler id that the span occurred in. This should be a string. @@ -824,7 +830,7 @@ class SPANDATA: MCP_TRANSPORT = "mcp.transport" """ The transport method used for MCP communication. - Example: "pipe" (stdio), "tcp" (HTTP/WebSocket/SSE) + Example: "http", "sse", "stdio" """ MCP_SESSION_ID = "mcp.session.id" diff --git a/sentry_sdk/integrations/mcp.py b/sentry_sdk/integrations/mcp.py index 2a2d440616..7b72aa4763 100644 --- a/sentry_sdk/integrations/mcp.py +++ b/sentry_sdk/integrations/mcp.py @@ -56,17 +56,17 @@ def setup_once(): def _get_request_context_data(): # type: () -> tuple[Optional[str], Optional[str], str] """ - Extract request ID, session ID, and transport type from the MCP request context. + Extract request ID, session ID, and MCP transport type from the request context. Returns: - Tuple of (request_id, session_id, transport). + Tuple of (request_id, session_id, mcp_transport). - request_id: May be None if not available - session_id: May be None if not available - - transport: "tcp" for HTTP-based, "pipe" for stdio + - mcp_transport: "http", "sse", "stdio" """ request_id = None # type: Optional[str] session_id = None # type: Optional[str] - transport = "pipe" # type: str + mcp_transport = "stdio" # type: str try: ctx = request_ctx.get() @@ -74,16 +74,26 @@ def _get_request_context_data(): if ctx is not None: request_id = ctx.request_id if hasattr(ctx, "request") and ctx.request is not None: - transport = "tcp" request = ctx.request - if hasattr(request, "headers"): + # Detect transport type by checking request characteristics + if hasattr(request, "query_params") and request.query_params.get( + "session_id" + ): + # SSE transport uses query parameter + mcp_transport = "sse" + session_id = request.query_params.get("session_id") + elif hasattr(request, "headers") and request.headers.get( + "mcp-session-id" + ): + # StreamableHTTP transport uses header + mcp_transport = "http" session_id = request.headers.get("mcp-session-id") except LookupError: - # No request context available - default to pipe + # No request context available - default to stdio pass - return request_id, session_id, transport + return request_id, session_id, mcp_transport def _get_span_config(handler_type, item_name): @@ -120,16 +130,20 @@ def _set_span_input_data( arguments, request_id, session_id, - transport, + mcp_transport, ): # type: (Any, str, str, str, dict[str, Any], Optional[str], Optional[str], str) -> None """Set input span data for MCP handlers.""" + # Set handler identifier span.set_data(span_data_key, handler_name) span.set_data(SPANDATA.MCP_METHOD_NAME, mcp_method_name) - # Set transport type - span.set_data(SPANDATA.MCP_TRANSPORT, transport) + # Set transport/MCP transport type + span.set_data( + SPANDATA.NETWORK_TRANSPORT, "pipe" if mcp_transport == "stdio" else "tcp" + ) + span.set_data(SPANDATA.MCP_TRANSPORT, mcp_transport) # Set request_id if provided if request_id: @@ -331,7 +345,7 @@ async def _async_handler_wrapper(handler_type, func, original_args): origin=MCPIntegration.origin, ) as span: # Get request ID, session ID, and transport from context - request_id, session_id, transport = _get_request_context_data() + request_id, session_id, mcp_transport = _get_request_context_data() # Set input span data _set_span_input_data( @@ -342,7 +356,7 @@ async def _async_handler_wrapper(handler_type, func, original_args): arguments, request_id, session_id, - transport, + mcp_transport, ) # For resources, extract and set protocol @@ -396,7 +410,7 @@ def _sync_handler_wrapper(handler_type, func, original_args): origin=MCPIntegration.origin, ) as span: # Get request ID, session ID, and transport from context - request_id, session_id, transport = _get_request_context_data() + request_id, session_id, mcp_transport = _get_request_context_data() # Set input span data _set_span_input_data( @@ -407,7 +421,7 @@ def _sync_handler_wrapper(handler_type, func, original_args): arguments, request_id, session_id, - transport, + mcp_transport, ) # For resources, extract and set protocol diff --git a/tests/integrations/mcp/test_mcp.py b/tests/integrations/mcp/test_mcp.py index 738fdedf48..508aea5a3a 100644 --- a/tests/integrations/mcp/test_mcp.py +++ b/tests/integrations/mcp/test_mcp.py @@ -76,21 +76,29 @@ def __str__(self): class MockRequestContext: """Mock MCP request context""" - def __init__(self, request_id=None, session_id=None, transport="pipe"): + def __init__(self, request_id=None, session_id=None, transport="stdio"): self.request_id = request_id - if transport == "tcp": - self.request = MockHTTPRequest(session_id) + if transport in ("http", "sse"): + self.request = MockHTTPRequest(session_id, transport) else: self.request = None class MockHTTPRequest: - """Mock HTTP request for SSE/WebSocket transport""" + """Mock HTTP request for SSE/StreamableHTTP transport""" - def __init__(self, session_id=None): + def __init__(self, session_id=None, transport="http"): self.headers = {} - if session_id: - self.headers["mcp-session-id"] = session_id + self.query_params = {} + + if transport == "sse": + # SSE transport uses query parameter + if session_id: + self.query_params["session_id"] = session_id + else: + # StreamableHTTP transport uses header + if session_id: + self.headers["mcp-session-id"] = session_id class MockTextContent: @@ -151,7 +159,7 @@ def test_tool_handler_sync( server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-123", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-123", transport="stdio") request_ctx.set(mock_ctx) @server.call_tool() @@ -176,7 +184,7 @@ def test_tool(tool_name, arguments): # Check span data assert span["data"][SPANDATA.MCP_TOOL_NAME] == "calculate" assert span["data"][SPANDATA.MCP_METHOD_NAME] == "tools/call" - assert span["data"][SPANDATA.MCP_TRANSPORT] == "pipe" + assert span["data"][SPANDATA.MCP_TRANSPORT] == "stdio" assert span["data"][SPANDATA.MCP_REQUEST_ID] == "req-123" assert span["data"]["mcp.request.argument.x"] == "10" assert span["data"]["mcp.request.argument.y"] == "5" @@ -215,7 +223,7 @@ async def test_tool_handler_async( # Set up mock request context mock_ctx = MockRequestContext( - request_id="req-456", session_id="session-789", transport="tcp" + request_id="req-456", session_id="session-789", transport="http" ) request_ctx.set(mock_ctx) @@ -240,7 +248,7 @@ async def test_tool_async(tool_name, arguments): # Check span data assert span["data"][SPANDATA.MCP_TOOL_NAME] == "process" assert span["data"][SPANDATA.MCP_METHOD_NAME] == "tools/call" - assert span["data"][SPANDATA.MCP_TRANSPORT] == "tcp" + assert span["data"][SPANDATA.MCP_TRANSPORT] == "http" assert span["data"][SPANDATA.MCP_REQUEST_ID] == "req-456" assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-789" assert span["data"]["mcp.request.argument.data"] == '"test"' @@ -265,7 +273,7 @@ def test_tool_handler_with_error(sentry_init, capture_events): server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-error", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-error", transport="stdio") request_ctx.set(mock_ctx) @server.call_tool() @@ -313,7 +321,7 @@ def test_prompt_handler_sync( server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-prompt", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-prompt", transport="stdio") request_ctx.set(mock_ctx) @server.get_prompt() @@ -338,7 +346,7 @@ def test_prompt(name, arguments): # Check span data assert span["data"][SPANDATA.MCP_PROMPT_NAME] == "code_help" assert span["data"][SPANDATA.MCP_METHOD_NAME] == "prompts/get" - assert span["data"][SPANDATA.MCP_TRANSPORT] == "pipe" + assert span["data"][SPANDATA.MCP_TRANSPORT] == "stdio" assert span["data"][SPANDATA.MCP_REQUEST_ID] == "req-prompt" assert span["data"]["mcp.request.argument.name"] == '"code_help"' assert span["data"]["mcp.request.argument.language"] == '"python"' @@ -378,7 +386,7 @@ async def test_prompt_handler_async( # Set up mock request context mock_ctx = MockRequestContext( - request_id="req-async-prompt", session_id="session-abc", transport="tcp" + request_id="req-async-prompt", session_id="session-abc", transport="http" ) request_ctx.set(mock_ctx) @@ -422,7 +430,7 @@ def test_prompt_handler_with_error(sentry_init, capture_events): server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-error-prompt", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-error-prompt", transport="stdio") request_ctx.set(mock_ctx) @server.get_prompt() @@ -452,7 +460,7 @@ def test_resource_handler_sync(sentry_init, capture_events): server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-resource", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-resource", transport="stdio") request_ctx.set(mock_ctx) @server.read_resource() @@ -477,7 +485,7 @@ def test_resource(uri): # Check span data assert span["data"][SPANDATA.MCP_RESOURCE_URI] == "file:///path/to/file.txt" assert span["data"][SPANDATA.MCP_METHOD_NAME] == "resources/read" - assert span["data"][SPANDATA.MCP_TRANSPORT] == "pipe" + assert span["data"][SPANDATA.MCP_TRANSPORT] == "stdio" assert span["data"][SPANDATA.MCP_REQUEST_ID] == "req-resource" assert span["data"][SPANDATA.MCP_RESOURCE_PROTOCOL] == "file" # Resources don't capture result content @@ -497,7 +505,7 @@ async def test_resource_handler_async(sentry_init, capture_events): # Set up mock request context mock_ctx = MockRequestContext( - request_id="req-async-resource", session_id="session-res", transport="tcp" + request_id="req-async-resource", session_id="session-res", transport="http" ) request_ctx.set(mock_ctx) @@ -535,7 +543,7 @@ def test_resource_handler_with_error(sentry_init, capture_events): server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-error-resource", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-error-resource", transport="stdio") request_ctx.set(mock_ctx) @server.read_resource() @@ -573,7 +581,7 @@ def test_tool_result_extraction_tuple( server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-tuple", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-tuple", transport="stdio") request_ctx.set(mock_ctx) @server.call_tool() @@ -621,7 +629,7 @@ def test_tool_result_extraction_unstructured( server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-unstructured", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-unstructured", transport="stdio") request_ctx.set(mock_ctx) @server.call_tool() @@ -678,7 +686,7 @@ def test_tool_no_ctx(tool_name, arguments): span = tx["spans"][0] # Transport defaults to "pipe" when no context - assert span["data"][SPANDATA.MCP_TRANSPORT] == "pipe" + assert span["data"][SPANDATA.MCP_TRANSPORT] == "stdio" # Request ID and Session ID should not be present assert SPANDATA.MCP_REQUEST_ID not in span["data"] assert SPANDATA.MCP_SESSION_ID not in span["data"] @@ -695,7 +703,7 @@ def test_span_origin(sentry_init, capture_events): server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-origin", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-origin", transport="stdio") request_ctx.set(mock_ctx) @server.call_tool() @@ -722,7 +730,7 @@ def test_multiple_handlers(sentry_init, capture_events): server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-multi", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-multi", transport="stdio") request_ctx.set(mock_ctx) @server.call_tool() @@ -774,7 +782,7 @@ def test_prompt_with_dict_result( server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-dict-prompt", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-dict-prompt", transport="stdio") request_ctx.set(mock_ctx) @server.get_prompt() @@ -818,7 +826,7 @@ def test_resource_without_protocol(sentry_init, capture_events): server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-no-proto", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-no-proto", transport="stdio") request_ctx.set(mock_ctx) @server.read_resource() @@ -848,7 +856,7 @@ def test_tool_with_complex_arguments(sentry_init, capture_events): server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-complex", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-complex", transport="stdio") request_ctx.set(mock_ctx) @server.call_tool() @@ -886,7 +894,7 @@ async def test_async_handlers_mixed(sentry_init, capture_events): server = Server("test-server") # Set up mock request context - mock_ctx = MockRequestContext(request_id="req-mixed", transport="pipe") + mock_ctx = MockRequestContext(request_id="req-mixed", transport="stdio") request_ctx.set(mock_ctx) @server.call_tool() @@ -909,3 +917,104 @@ async def async_tool(tool_name, arguments): # Both should be instrumented correctly assert all(span["op"] == OP.MCP_SERVER for span in tx["spans"]) + + +def test_sse_transport_detection(sentry_init, capture_events): + """Test that SSE transport is correctly detected via query parameter""" + sentry_init( + integrations=[MCPIntegration()], + traces_sample_rate=1.0, + ) + events = capture_events() + + server = Server("test-server") + + # Set up mock request context with SSE transport + mock_ctx = MockRequestContext( + request_id="req-sse", session_id="session-sse-123", transport="sse" + ) + request_ctx.set(mock_ctx) + + @server.call_tool() + def test_tool(tool_name, arguments): + return {"result": "success"} + + with start_transaction(name="mcp tx"): + result = test_tool("sse_tool", {}) + + assert result == {"result": "success"} + + (tx,) = events + span = tx["spans"][0] + + # Check that SSE transport is detected + assert span["data"][SPANDATA.MCP_TRANSPORT] == "sse" + assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "tcp" + assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-sse-123" + + +def test_streamable_http_transport_detection(sentry_init, capture_events): + """Test that StreamableHTTP transport is correctly detected via header""" + sentry_init( + integrations=[MCPIntegration()], + traces_sample_rate=1.0, + ) + events = capture_events() + + server = Server("test-server") + + # Set up mock request context with StreamableHTTP transport + mock_ctx = MockRequestContext( + request_id="req-http", session_id="session-http-456", transport="http" + ) + request_ctx.set(mock_ctx) + + @server.call_tool() + def test_tool(tool_name, arguments): + return {"result": "success"} + + with start_transaction(name="mcp tx"): + result = test_tool("http_tool", {}) + + assert result == {"result": "success"} + + (tx,) = events + span = tx["spans"][0] + + # Check that HTTP transport is detected + assert span["data"][SPANDATA.MCP_TRANSPORT] == "http" + assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "tcp" + assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-http-456" + + +def test_stdio_transport_detection(sentry_init, capture_events): + """Test that stdio transport is correctly detected when no HTTP request""" + sentry_init( + integrations=[MCPIntegration()], + traces_sample_rate=1.0, + ) + events = capture_events() + + server = Server("test-server") + + # Set up mock request context with stdio transport (no HTTP request) + mock_ctx = MockRequestContext(request_id="req-stdio", transport="stdio") + request_ctx.set(mock_ctx) + + @server.call_tool() + def test_tool(tool_name, arguments): + return {"result": "success"} + + with start_transaction(name="mcp tx"): + result = test_tool("stdio_tool", {}) + + assert result == {"result": "success"} + + (tx,) = events + span = tx["spans"][0] + + # Check that stdio transport is detected + assert span["data"][SPANDATA.MCP_TRANSPORT] == "stdio" + assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "pipe" + # No session ID for stdio transport + assert SPANDATA.MCP_SESSION_ID not in span["data"]