Skip to content

Commit 80a7a1e

Browse files
committed
Update tests
1 parent fbecaf4 commit 80a7a1e

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

server/session_test.go

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package server
33
import (
44
"context"
55
"errors"
6+
"sync"
67
"testing"
78
"time"
89

@@ -361,7 +362,7 @@ func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) {
361362
session3 := &sessionTestClient{
362363
sessionID: "session-3",
363364
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
364-
initialized: false, // Not initialized
365+
initialized: false, // Not initialized - deliberately not calling Initialize()
365366
}
366367

367368
// Register sessions
@@ -408,12 +409,16 @@ func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) {
408409

409410
func TestMCPServer_NotificationChannelBlocked(t *testing.T) {
410411
// Set up a hooks object to capture error notifications
412+
var mu sync.Mutex
411413
errorCaptured := false
412414
errorSessionID := ""
413415
errorMethod := ""
414416

415417
hooks := &Hooks{}
416418
hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
419+
mu.Lock()
420+
defer mu.Unlock()
421+
417422
errorCaptured = true
418423
// Extract session ID and method from the error message metadata
419424
if msgMap, ok := message.(map[string]interface{}); ok {
@@ -455,15 +460,23 @@ func TestMCPServer_NotificationChannelBlocked(t *testing.T) {
455460
time.Sleep(10 * time.Millisecond)
456461

457462
// Verify the error was logged via hooks
458-
assert.True(t, errorCaptured, "Error hook should have been called")
459-
assert.Equal(t, "blocked-session", errorSessionID, "Session ID should be captured in the error hook")
460-
assert.Equal(t, "blocked-message", errorMethod, "Method should be captured in the error hook")
463+
mu.Lock()
464+
localErrorCaptured := errorCaptured
465+
localErrorSessionID := errorSessionID
466+
localErrorMethod := errorMethod
467+
mu.Unlock()
468+
469+
assert.True(t, localErrorCaptured, "Error hook should have been called")
470+
assert.Equal(t, "blocked-session", localErrorSessionID, "Session ID should be captured in the error hook")
471+
assert.Equal(t, "blocked-message", localErrorMethod, "Method should be captured in the error hook")
461472

462473
// Also test SendNotificationToAllClients with a blocked channel
463474
// Reset the captured data
475+
mu.Lock()
464476
errorCaptured = false
465477
errorSessionID = ""
466478
errorMethod = ""
479+
mu.Unlock()
467480

468481
// Send to all clients (which includes our blocked one)
469482
server.SendNotificationToAllClients("broadcast-message", nil)
@@ -472,7 +485,13 @@ func TestMCPServer_NotificationChannelBlocked(t *testing.T) {
472485
time.Sleep(10 * time.Millisecond)
473486

474487
// Verify the error was logged via hooks
475-
assert.True(t, errorCaptured, "Error hook should have been called for broadcast")
476-
assert.Equal(t, "blocked-session", errorSessionID, "Session ID should be captured in the error hook")
477-
assert.Equal(t, "broadcast-message", errorMethod, "Method should be captured in the error hook")
488+
mu.Lock()
489+
localErrorCaptured = errorCaptured
490+
localErrorSessionID = errorSessionID
491+
localErrorMethod = errorMethod
492+
mu.Unlock()
493+
494+
assert.True(t, localErrorCaptured, "Error hook should have been called for broadcast")
495+
assert.Equal(t, "blocked-session", localErrorSessionID, "Session ID should be captured in the error hook")
496+
assert.Equal(t, "broadcast-message", localErrorMethod, "Method should be captured in the error hook")
478497
}

0 commit comments

Comments
 (0)