@@ -12,6 +12,8 @@ import (
1212 "github.com/mark3labs/mcp-go/server"
1313)
1414
15+ // SafeMap is a thread-safe map wrapper
16+
1517func TestHTTPClient (t * testing.T ) {
1618 hooks := & server.Hooks {}
1719 hooks .AddAfterCallTool (func (ctx context.Context , id any , message * mcp.CallToolRequest , result * mcp.CallToolResult ) {
@@ -87,12 +89,9 @@ func TestHTTPClient(t *testing.T) {
8789 return
8890 }
8991
90- notificationNum := make (map [string ]int )
91- notificationNumMutex := sync.Mutex {}
92+ notificationNum := NewSafeMap ()
9293 client .OnNotification (func (notification mcp.JSONRPCNotification ) {
93- notificationNumMutex .Lock ()
94- notificationNum [notification .Method ] += 1
95- notificationNumMutex .Unlock ()
94+ notificationNum .Increment (notification .Method )
9695 })
9796
9897 ctx := context .Background ()
@@ -120,19 +119,21 @@ func TestHTTPClient(t *testing.T) {
120119 t .Errorf ("Expected 1 content item, got %d" , len (result .Content ))
121120 }
122121
123- if n := notificationNum [ "notifications/progress" ] ; n != 1 {
122+ if n := notificationNum . Get ( "notifications/progress" ) ; n != 1 {
124123 t .Errorf ("Expected 1 progross notification item, got %d" , n )
125124 }
126- if n := len ( notificationNum ); n != 1 {
125+ if n := notificationNum . Len ( ); n != 1 {
127126 t .Errorf ("Expected 1 type of notification, got %d" , n )
128127 }
129128 })
130129
131- t .Run ("Cannot receive global notifications from server by default" , func (t * testing.T ) {
130+ t .Run ("Can not receive global notifications from server by default" , func (t * testing.T ) {
132131 addServerToolfunc ("hello1" )
133132 time .Sleep (time .Millisecond * 50 )
134- if n := notificationNum ["hello1" ]; n != 0 {
135- t .Errorf ("Expected 0 notification item, got %d" , n )
133+
134+ helloNotifications := notificationNum .Get ("hello1" )
135+ if helloNotifications != 0 {
136+ t .Errorf ("Expected 0 notification item, got %d" , helloNotifications )
136137 }
137138 })
138139
@@ -146,13 +147,9 @@ func TestHTTPClient(t *testing.T) {
146147 }
147148 defer client .Close ()
148149
149- notificationNum := make (map [string ]int )
150- notificationNumMutex := sync.Mutex {}
150+ notificationNum := NewSafeMap ()
151151 client .OnNotification (func (notification mcp.JSONRPCNotification ) {
152- notificationNumMutex .Lock ()
153- println (notification .Method )
154- notificationNum [notification .Method ] += 1
155- notificationNumMutex .Unlock ()
152+ notificationNum .Increment (notification .Method )
156153 })
157154
158155 ctx := context .Background ()
@@ -176,20 +173,51 @@ func TestHTTPClient(t *testing.T) {
176173 t .Fatalf ("CallTool failed: %v" , err )
177174 }
178175
179- if n := notificationNum [ "notifications/progress" ] ; n != 1 {
176+ if n := notificationNum . Get ( "notifications/progress" ) ; n != 1 {
180177 t .Errorf ("Expected 1 progross notification item, got %d" , n )
181178 }
182- if n := len ( notificationNum ); n != 1 {
179+ if n := notificationNum . Len ( ); n != 1 {
183180 t .Errorf ("Expected 1 type of notification, got %d" , n )
184181 }
185182
186183 // can receive global notification
187184 addServerToolfunc ("hello2" )
188185 time .Sleep (time .Millisecond * 50 ) // wait for the notification to be sent as upper action is async
189- if n := notificationNum ["notifications/tools/list_changed" ]; n != 1 {
186+
187+ n := notificationNum .Get ("notifications/tools/list_changed" )
188+ if n != 1 {
190189 t .Errorf ("Expected 1 notification item, got %d, %v" , n , notificationNum )
191190 }
192191 })
193192
194193 })
195194}
195+
196+ type SafeMap struct {
197+ mu sync.RWMutex
198+ data map [string ]int
199+ }
200+
201+ func NewSafeMap () * SafeMap {
202+ return & SafeMap {
203+ data : make (map [string ]int ),
204+ }
205+ }
206+
207+ func (sm * SafeMap ) Increment (key string ) {
208+ sm .mu .Lock ()
209+ defer sm .mu .Unlock ()
210+ sm .data [key ]++
211+ }
212+
213+ func (sm * SafeMap ) Get (key string ) int {
214+ sm .mu .RLock ()
215+ defer sm .mu .RUnlock ()
216+ return sm .data [key ]
217+ }
218+
219+ func (sm * SafeMap ) Len () int {
220+ sm .mu .RLock ()
221+ defer sm .mu .RUnlock ()
222+ return len (sm .data )
223+ }
0 commit comments