diff --git a/server/sse.go b/server/sse.go index 7994c606e..94dee1926 100644 --- a/server/sse.go +++ b/server/sse.go @@ -457,7 +457,6 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { go func() { // Process message through MCPServer response := s.server.HandleMessage(ctx, rawMessage) - // Only send response if there is one (not for notifications) if response != nil { var message string @@ -465,7 +464,6 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { // If there is an error marshalling the response, send a generic error response log.Printf("failed to marshal response: %v", err) message = fmt.Sprintf("event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n") - return } else { message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData) } diff --git a/server/sse_test.go b/server/sse_test.go index 75da1eac4..937dc2744 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -62,7 +62,7 @@ func TestSSEServer(t *testing.T) { defer sseResp.Body.Close() // Read the endpoint event - endpointEvent, err := readSeeEvent(sseResp) + endpointEvent, err := readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } @@ -195,7 +195,7 @@ func TestSSEServer(t *testing.T) { } defer resp.Body.Close() - endpointEvent, err = readSeeEvent(sseResp) + endpointEvent, err = readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } @@ -590,7 +590,7 @@ func TestSSEServer(t *testing.T) { defer sseResp.Body.Close() // Read the endpoint event - endpointEvent, err := readSeeEvent(sseResp) + endpointEvent, err := readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } @@ -632,16 +632,16 @@ func TestSSEServer(t *testing.T) { } // Verify response - endpointEvent, err = readSeeEvent(sseResp) + endpointEvent, err = readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } - respFromSee := strings.TrimSpace( + respFromSSE := strings.TrimSpace( strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], ) var response map[string]interface{} - if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil { + if err := json.NewDecoder(strings.NewReader(respFromSSE)).Decode(&response); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -680,17 +680,17 @@ func TestSSEServer(t *testing.T) { } defer resp.Body.Close() - endpointEvent, err = readSeeEvent(sseResp) + endpointEvent, err = readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } - respFromSee = strings.TrimSpace( + respFromSSE = strings.TrimSpace( strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], ) response = make(map[string]interface{}) - if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil { + if err := json.NewDecoder(strings.NewReader(respFromSSE)).Decode(&response); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -1140,7 +1140,7 @@ func TestSSEServer(t *testing.T) { registeredSession = s } }) - + mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) testServer := NewTestServer(mcpServer) defer testServer.Close() @@ -1153,7 +1153,7 @@ func TestSSEServer(t *testing.T) { defer sseResp.Body.Close() // Read the endpoint event to ensure session is established - _, err = readSeeEvent(sseResp) + _, err = readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } @@ -1240,9 +1240,87 @@ func TestSSEServer(t *testing.T) { t.Error("Expected final_tool to exist") } }) + + t.Run("TestServerResponseMarshalError", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0", + WithResourceCapabilities(true, true), + WithHooks(&Hooks{ + OnAfterInitialize: []OnAfterInitializeFunc{ + func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) { + result.Result.Meta = map[string]interface{}{"invalid": func() {}} // marshal will fail + }, + }, + }), + ) + testServer := NewTestServer(mcpServer) + defer testServer.Close() + + // Connect to SSE endpoint + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + // Read the endpoint event + endpointEvent, err := readSSEEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + if !strings.Contains(endpointEvent, "event: endpoint") { + t.Fatalf("Expected endpoint event, got: %s", endpointEvent) + } + + // Extract message endpoint URL + messageURL := strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + + // Send initialize request + initRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "clientInfo": map[string]interface{}{ + "name": "test-client", + "version": "1.0.0", + }, + }, + } + + requestBody, err := json.Marshal(initRequest) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + resp, err := http.Post( + messageURL, + "application/json", + bytes.NewBuffer(requestBody), + ) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp.StatusCode) + } + + endpointEvent, err = readSSEEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + + if !strings.Contains(endpointEvent, "\"id\": null") { + t.Errorf("Expected id to be null") + } + }) } -func readSeeEvent(sseResp *http.Response) (string, error) { +func readSSEEvent(sseResp *http.Response) (string, error) { buf := make([]byte, 1024) n, err := sseResp.Body.Read(buf) if err != nil {