@@ -3,12 +3,16 @@ package client
33import (
44 "context"
55 "fmt"
6- "github.com/mark3labs/mcp-go/mcp"
7- "github.com/mark3labs/mcp-go/server"
6+ "sync"
87 "testing"
98 "time"
9+
10+ "github.com/mark3labs/mcp-go/client/transport"
11+ "github.com/mark3labs/mcp-go/mcp"
12+ "github.com/mark3labs/mcp-go/server"
1013)
1114
15+
1216func TestHTTPClient (t * testing.T ) {
1317 hooks := & server.Hooks {}
1418 hooks .AddAfterCallTool (func (ctx context.Context , id any , message * mcp.CallToolRequest , result * mcp.CallToolResult ) {
@@ -47,30 +51,46 @@ func TestHTTPClient(t *testing.T) {
4751 return nil , fmt .Errorf ("failed to send notification: %w" , err )
4852 }
4953
50- return & mcp.CallToolResult {
51- Content : []mcp.Content {
52- mcp.TextContent {
53- Type : "text" ,
54- Text : "notification sent successfully" ,
55- },
56- },
57- }, nil
54+ return mcp .NewToolResultText ("notification sent successfully" ), nil
5855 },
5956 )
6057
58+ addServerToolfunc := func (name string ) {
59+ mcpServer .AddTool (
60+ mcp .NewTool (name ),
61+ func (ctx context.Context , request mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
62+ server := server .ServerFromContext (ctx )
63+ server .SendNotificationToAllClients ("helloToEveryone" , map [string ]any {
64+ "message" : "hello" ,
65+ })
66+ return mcp .NewToolResultText ("done" ), nil
67+ },
68+ )
69+ }
70+
6171 testServer := server .NewTestStreamableHTTPServer (mcpServer )
6272 defer testServer .Close ()
6373
74+ initRequest := mcp.InitializeRequest {
75+ Params : mcp.InitializeParams {
76+ ProtocolVersion : mcp .LATEST_PROTOCOL_VERSION ,
77+ ClientInfo : mcp.Implementation {
78+ Name : "test-client2" ,
79+ Version : "1.0.0" ,
80+ },
81+ },
82+ }
83+
6484 t .Run ("Can receive notification from server" , func (t * testing.T ) {
6585 client , err := NewStreamableHttpClient (testServer .URL )
6686 if err != nil {
6787 t .Fatalf ("create client failed %v" , err )
6888 return
6989 }
7090
71- notificationNum := 0
91+ notificationNum := NewSafeMap ()
7292 client .OnNotification (func (notification mcp.JSONRPCNotification ) {
73- notificationNum += 1
93+ notificationNum . Increment ( notification . Method )
7494 })
7595
7696 ctx := context .Background ()
@@ -81,31 +101,122 @@ func TestHTTPClient(t *testing.T) {
81101 }
82102
83103 // Initialize
84- initRequest := mcp.InitializeRequest {}
85- initRequest .Params .ProtocolVersion = mcp .LATEST_PROTOCOL_VERSION
86- initRequest .Params .ClientInfo = mcp.Implementation {
87- Name : "test-client" ,
88- Version : "1.0.0" ,
89- }
90-
91104 _ , err = client .Initialize (ctx , initRequest )
92105 if err != nil {
93106 t .Fatalf ("Failed to initialize: %v\n " , err )
94107 }
95108
96- request := mcp.CallToolRequest {}
97- request .Params .Name = "notify"
98- result , err := client .CallTool (ctx , request )
99- if err != nil {
100- t .Fatalf ("CallTool failed: %v" , err )
101- }
109+ t .Run ("Can receive notifications related to the request" , func (t * testing.T ) {
110+ request := mcp.CallToolRequest {}
111+ request .Params .Name = "notify"
112+ result , err := client .CallTool (ctx , request )
113+ if err != nil {
114+ t .Fatalf ("CallTool failed: %v" , err )
115+ }
102116
103- if len (result .Content ) != 1 {
104- t .Errorf ("Expected 1 content item, got %d" , len (result .Content ))
105- }
117+ if len (result .Content ) != 1 {
118+ t .Errorf ("Expected 1 content item, got %d" , len (result .Content ))
119+ }
120+
121+ if n := notificationNum .Get ("notifications/progress" ); n != 1 {
122+ t .Errorf ("Expected 1 progross notification item, got %d" , n )
123+ }
124+ if n := notificationNum .Len (); n != 1 {
125+ t .Errorf ("Expected 1 type of notification, got %d" , n )
126+ }
127+ })
128+
129+ t .Run ("Can not receive global notifications from server by default" , func (t * testing.T ) {
130+ addServerToolfunc ("hello1" )
131+ time .Sleep (time .Millisecond * 50 )
132+
133+ helloNotifications := notificationNum .Get ("hello1" )
134+ if helloNotifications != 0 {
135+ t .Errorf ("Expected 0 notification item, got %d" , helloNotifications )
136+ }
137+ })
138+
139+ t .Run ("Can receive global notifications from server when WithContinuousListening enabled" , func (t * testing.T ) {
140+
141+ client , err := NewStreamableHttpClient (testServer .URL ,
142+ transport .WithContinuousListening ())
143+ if err != nil {
144+ t .Fatalf ("create client failed %v" , err )
145+ return
146+ }
147+ defer client .Close ()
148+
149+ notificationNum := NewSafeMap ()
150+ client .OnNotification (func (notification mcp.JSONRPCNotification ) {
151+ notificationNum .Increment (notification .Method )
152+ })
153+
154+ ctx := context .Background ()
155+
156+ if err := client .Start (ctx ); err != nil {
157+ t .Fatalf ("Failed to start client: %v" , err )
158+ return
159+ }
160+
161+ // Initialize
162+ _ , err = client .Initialize (ctx , initRequest )
163+ if err != nil {
164+ t .Fatalf ("Failed to initialize: %v\n " , err )
165+ }
166+
167+ // can receive normal notification
168+ request := mcp.CallToolRequest {}
169+ request .Params .Name = "notify"
170+ _ , err = client .CallTool (ctx , request )
171+ if err != nil {
172+ t .Fatalf ("CallTool failed: %v" , err )
173+ }
174+
175+ if n := notificationNum .Get ("notifications/progress" ); n != 1 {
176+ t .Errorf ("Expected 1 progross notification item, got %d" , n )
177+ }
178+ if n := notificationNum .Len (); n != 1 {
179+ t .Errorf ("Expected 1 type of notification, got %d" , n )
180+ }
181+
182+ // can receive global notification
183+ addServerToolfunc ("hello2" )
184+ time .Sleep (time .Millisecond * 50 ) // wait for the notification to be sent as upper action is async
185+
186+ n := notificationNum .Get ("notifications/tools/list_changed" )
187+ if n != 1 {
188+ t .Errorf ("Expected 1 notification item, got %d, %v" , n , notificationNum )
189+ }
190+ })
106191
107- if notificationNum != 1 {
108- t .Errorf ("Expected 1 notification item, got %d" , notificationNum )
109- }
110192 })
111193}
194+
195+ type SafeMap struct {
196+ mu sync.RWMutex
197+ data map [string ]int
198+ }
199+
200+ func NewSafeMap () * SafeMap {
201+ return & SafeMap {
202+ data : make (map [string ]int ),
203+ }
204+ }
205+
206+ func (sm * SafeMap ) Increment (key string ) {
207+ sm .mu .Lock ()
208+ defer sm .mu .Unlock ()
209+ sm .data [key ]++
210+ }
211+
212+ func (sm * SafeMap ) Get (key string ) int {
213+ sm .mu .RLock ()
214+ defer sm .mu .RUnlock ()
215+ return sm .data [key ]
216+ }
217+
218+ func (sm * SafeMap ) Len () int {
219+ sm .mu .RLock ()
220+ defer sm .mu .RUnlock ()
221+ return len (sm .data )
222+ }
0 commit comments