Skip to content

Commit 1eddde7

Browse files
authored
feat: client-side streamable-http transport supports continuously listening (#317)
feat: client-side streamable-http transport supports continuously listening (#317)
1 parent 8f5b048 commit 1eddde7

File tree

4 files changed

+623
-145
lines changed

4 files changed

+623
-145
lines changed

client/http_test.go

Lines changed: 142 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@ package client
33
import (
44
"context"
55
"fmt"
6-
"github.com/mark3labs/mcp-go/mcp"
7-
"github.com/mark3labs/mcp-go/server"
6+
"sync"
87
"testing"
98
"time"
9+
10+
"github.com/mark3labs/mcp-go/client/transport"
11+
"github.com/mark3labs/mcp-go/mcp"
12+
"github.com/mark3labs/mcp-go/server"
1013
)
1114

15+
1216
func TestHTTPClient(t *testing.T) {
1317
hooks := &server.Hooks{}
1418
hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
@@ -47,30 +51,46 @@ func TestHTTPClient(t *testing.T) {
4751
return nil, fmt.Errorf("failed to send notification: %w", err)
4852
}
4953

50-
return &mcp.CallToolResult{
51-
Content: []mcp.Content{
52-
mcp.TextContent{
53-
Type: "text",
54-
Text: "notification sent successfully",
55-
},
56-
},
57-
}, nil
54+
return mcp.NewToolResultText("notification sent successfully"), nil
5855
},
5956
)
6057

58+
addServerToolfunc := func(name string) {
59+
mcpServer.AddTool(
60+
mcp.NewTool(name),
61+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
62+
server := server.ServerFromContext(ctx)
63+
server.SendNotificationToAllClients("helloToEveryone", map[string]any{
64+
"message": "hello",
65+
})
66+
return mcp.NewToolResultText("done"), nil
67+
},
68+
)
69+
}
70+
6171
testServer := server.NewTestStreamableHTTPServer(mcpServer)
6272
defer testServer.Close()
6373

74+
initRequest := mcp.InitializeRequest{
75+
Params: mcp.InitializeParams{
76+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
77+
ClientInfo: mcp.Implementation{
78+
Name: "test-client2",
79+
Version: "1.0.0",
80+
},
81+
},
82+
}
83+
6484
t.Run("Can receive notification from server", func(t *testing.T) {
6585
client, err := NewStreamableHttpClient(testServer.URL)
6686
if err != nil {
6787
t.Fatalf("create client failed %v", err)
6888
return
6989
}
7090

71-
notificationNum := 0
91+
notificationNum := NewSafeMap()
7292
client.OnNotification(func(notification mcp.JSONRPCNotification) {
73-
notificationNum += 1
93+
notificationNum.Increment(notification.Method)
7494
})
7595

7696
ctx := context.Background()
@@ -81,31 +101,122 @@ func TestHTTPClient(t *testing.T) {
81101
}
82102

83103
// Initialize
84-
initRequest := mcp.InitializeRequest{}
85-
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
86-
initRequest.Params.ClientInfo = mcp.Implementation{
87-
Name: "test-client",
88-
Version: "1.0.0",
89-
}
90-
91104
_, err = client.Initialize(ctx, initRequest)
92105
if err != nil {
93106
t.Fatalf("Failed to initialize: %v\n", err)
94107
}
95108

96-
request := mcp.CallToolRequest{}
97-
request.Params.Name = "notify"
98-
result, err := client.CallTool(ctx, request)
99-
if err != nil {
100-
t.Fatalf("CallTool failed: %v", err)
101-
}
109+
t.Run("Can receive notifications related to the request", func(t *testing.T) {
110+
request := mcp.CallToolRequest{}
111+
request.Params.Name = "notify"
112+
result, err := client.CallTool(ctx, request)
113+
if err != nil {
114+
t.Fatalf("CallTool failed: %v", err)
115+
}
102116

103-
if len(result.Content) != 1 {
104-
t.Errorf("Expected 1 content item, got %d", len(result.Content))
105-
}
117+
if len(result.Content) != 1 {
118+
t.Errorf("Expected 1 content item, got %d", len(result.Content))
119+
}
120+
121+
if n := notificationNum.Get("notifications/progress"); n != 1 {
122+
t.Errorf("Expected 1 progross notification item, got %d", n)
123+
}
124+
if n := notificationNum.Len(); n != 1 {
125+
t.Errorf("Expected 1 type of notification, got %d", n)
126+
}
127+
})
128+
129+
t.Run("Can not receive global notifications from server by default", func(t *testing.T) {
130+
addServerToolfunc("hello1")
131+
time.Sleep(time.Millisecond * 50)
132+
133+
helloNotifications := notificationNum.Get("hello1")
134+
if helloNotifications != 0 {
135+
t.Errorf("Expected 0 notification item, got %d", helloNotifications)
136+
}
137+
})
138+
139+
t.Run("Can receive global notifications from server when WithContinuousListening enabled", func(t *testing.T) {
140+
141+
client, err := NewStreamableHttpClient(testServer.URL,
142+
transport.WithContinuousListening())
143+
if err != nil {
144+
t.Fatalf("create client failed %v", err)
145+
return
146+
}
147+
defer client.Close()
148+
149+
notificationNum := NewSafeMap()
150+
client.OnNotification(func(notification mcp.JSONRPCNotification) {
151+
notificationNum.Increment(notification.Method)
152+
})
153+
154+
ctx := context.Background()
155+
156+
if err := client.Start(ctx); err != nil {
157+
t.Fatalf("Failed to start client: %v", err)
158+
return
159+
}
160+
161+
// Initialize
162+
_, err = client.Initialize(ctx, initRequest)
163+
if err != nil {
164+
t.Fatalf("Failed to initialize: %v\n", err)
165+
}
166+
167+
// can receive normal notification
168+
request := mcp.CallToolRequest{}
169+
request.Params.Name = "notify"
170+
_, err = client.CallTool(ctx, request)
171+
if err != nil {
172+
t.Fatalf("CallTool failed: %v", err)
173+
}
174+
175+
if n := notificationNum.Get("notifications/progress"); n != 1 {
176+
t.Errorf("Expected 1 progross notification item, got %d", n)
177+
}
178+
if n := notificationNum.Len(); n != 1 {
179+
t.Errorf("Expected 1 type of notification, got %d", n)
180+
}
181+
182+
// can receive global notification
183+
addServerToolfunc("hello2")
184+
time.Sleep(time.Millisecond * 50) // wait for the notification to be sent as upper action is async
185+
186+
n := notificationNum.Get("notifications/tools/list_changed")
187+
if n != 1 {
188+
t.Errorf("Expected 1 notification item, got %d, %v", n, notificationNum)
189+
}
190+
})
106191

107-
if notificationNum != 1 {
108-
t.Errorf("Expected 1 notification item, got %d", notificationNum)
109-
}
110192
})
111193
}
194+
195+
type SafeMap struct {
196+
mu sync.RWMutex
197+
data map[string]int
198+
}
199+
200+
func NewSafeMap() *SafeMap {
201+
return &SafeMap{
202+
data: make(map[string]int),
203+
}
204+
}
205+
206+
func (sm *SafeMap) Increment(key string) {
207+
sm.mu.Lock()
208+
defer sm.mu.Unlock()
209+
sm.data[key]++
210+
}
211+
212+
func (sm *SafeMap) Get(key string) int {
213+
sm.mu.RLock()
214+
defer sm.mu.RUnlock()
215+
return sm.data[key]
216+
}
217+
218+
func (sm *SafeMap) Len() int {
219+
sm.mu.RLock()
220+
defer sm.mu.RUnlock()
221+
return len(sm.data)
222+
}

0 commit comments

Comments
 (0)