Skip to content

Commit 37d5406

Browse files
committed
Add tests for custom context functions
1 parent d753c31 commit 37d5406

File tree

3 files changed

+314
-1
lines changed

3 files changed

+314
-1
lines changed

server/sse.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,14 @@ func (s *SSEServer) SetContextFunc(fn SSEContextFunc) {
105105
s.contextFunc = fn
106106
}
107107

108+
type sseServerOpt func(sseServer *SSEServer)
109+
108110
// NewTestServer creates a test server for testing purposes
109-
func NewTestServer(server *MCPServer) *httptest.Server {
111+
func NewTestServer(server *MCPServer, opts ...sseServerOpt) *httptest.Server {
110112
sseServer := NewSSEServer(server)
113+
for _, opt := range opts {
114+
opt(sseServer)
115+
}
111116

112117
testServer := httptest.NewServer(sseServer)
113118
sseServer.baseURL = testServer.URL

server/sse_test.go

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import (
1111
"sync"
1212
"testing"
1313
"time"
14+
15+
"github.com/mark3labs/mcp-go/mcp"
1416
)
1517

1618
func TestSSEServer(t *testing.T) {
@@ -468,4 +470,150 @@ func TestSSEServer(t *testing.T) {
468470
cancel()
469471
<-done
470472
})
473+
474+
t.Run("Can use a custom context function", func(t *testing.T) {
475+
// Use a custom context key to store a test value.
476+
type testContextKey struct{}
477+
testValFromContext := func(ctx context.Context) string {
478+
val := ctx.Value(testContextKey{})
479+
if val == nil {
480+
return ""
481+
}
482+
return val.(string)
483+
}
484+
// Create a context function that sets a test value from the request.
485+
// In real life this could be used to send configuration using headers
486+
// or query parameters.
487+
const testHeader = "X-Test-Header"
488+
setTestValFromRequest := func(ctx context.Context, r *http.Request) context.Context {
489+
return context.WithValue(ctx, testContextKey{}, r.Header.Get(testHeader))
490+
}
491+
492+
mcpServer := NewMCPServer("test", "1.0.0",
493+
WithResourceCapabilities(true, true),
494+
)
495+
// Add a tool which uses the context function.
496+
mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
497+
// Note this is agnostic to the transport type i.e. doesn't know about request headers.
498+
testVal := testValFromContext(ctx)
499+
return mcp.NewToolResultText(testVal), nil
500+
})
501+
502+
testServer := NewTestServer(mcpServer, func(sseServer *SSEServer) {
503+
sseServer.contextFunc = setTestValFromRequest
504+
})
505+
defer testServer.Close()
506+
507+
// Connect to SSE endpoint
508+
sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL))
509+
if err != nil {
510+
t.Fatalf("Failed to connect to SSE endpoint: %v", err)
511+
}
512+
defer sseResp.Body.Close()
513+
514+
// Read the endpoint event
515+
buf := make([]byte, 1024)
516+
n, err := sseResp.Body.Read(buf)
517+
if err != nil {
518+
t.Fatalf("Failed to read SSE response: %v", err)
519+
}
520+
521+
endpointEvent := string(buf[:n])
522+
messageURL := strings.TrimSpace(
523+
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
524+
)
525+
526+
// Send initialize request
527+
initRequest := map[string]interface{}{
528+
"jsonrpc": "2.0",
529+
"id": 1,
530+
"method": "initialize",
531+
"params": map[string]interface{}{
532+
"protocolVersion": "2024-11-05",
533+
"clientInfo": map[string]interface{}{
534+
"name": "test-client",
535+
"version": "1.0.0",
536+
},
537+
},
538+
}
539+
540+
requestBody, err := json.Marshal(initRequest)
541+
if err != nil {
542+
t.Fatalf("Failed to marshal request: %v", err)
543+
}
544+
545+
resp, err := http.Post(
546+
messageURL,
547+
"application/json",
548+
bytes.NewBuffer(requestBody),
549+
)
550+
551+
if err != nil {
552+
t.Fatalf("Failed to send message: %v", err)
553+
}
554+
defer resp.Body.Close()
555+
556+
if resp.StatusCode != http.StatusAccepted {
557+
t.Errorf("Expected status 202, got %d", resp.StatusCode)
558+
}
559+
560+
// Verify response
561+
var response map[string]interface{}
562+
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
563+
t.Fatalf("Failed to decode response: %v", err)
564+
}
565+
566+
if response["jsonrpc"] != "2.0" {
567+
t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"])
568+
}
569+
if response["id"].(float64) != 1 {
570+
t.Errorf("Expected id 1, got %v", response["id"])
571+
}
572+
573+
// Call the tool.
574+
toolRequest := map[string]interface{}{
575+
"jsonrpc": "2.0",
576+
"id": 2,
577+
"method": "tools/call",
578+
"params": map[string]interface{}{
579+
"name": "test_tool",
580+
},
581+
}
582+
requestBody, err = json.Marshal(toolRequest)
583+
if err != nil {
584+
t.Fatalf("Failed to marshal tool request: %v", err)
585+
}
586+
587+
req, err := http.NewRequest(http.MethodPost, messageURL, bytes.NewBuffer(requestBody))
588+
if err != nil {
589+
t.Fatalf("Failed to create tool request: %v", err)
590+
}
591+
// Set the test header to a custom value.
592+
req.Header.Set(testHeader, "test_value")
593+
594+
resp, err = http.DefaultClient.Do(req)
595+
if err != nil {
596+
t.Fatalf("Failed to call tool: %v", err)
597+
}
598+
defer resp.Body.Close()
599+
600+
response = make(map[string]interface{})
601+
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
602+
t.Fatalf("Failed to decode response: %v", err)
603+
}
604+
605+
if response["jsonrpc"] != "2.0" {
606+
t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"])
607+
}
608+
if response["id"].(float64) != 2 {
609+
t.Errorf("Expected id 2, got %v", response["id"])
610+
}
611+
if response["result"].(map[string]interface{})["content"].([]interface{})[0].(map[string]interface{})["text"] != "test_value" {
612+
t.Errorf("Expected result 'test_value', got %v", response["result"])
613+
}
614+
if response["error"] != nil {
615+
t.Errorf("Expected no error, got %v", response["error"])
616+
}
617+
})
618+
471619
}

server/stdio_test.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ import (
66
"encoding/json"
77
"io"
88
"log"
9+
"os"
910
"testing"
11+
12+
"github.com/mark3labs/mcp-go/mcp"
1013
)
1114

1215
func TestStdioServer(t *testing.T) {
@@ -110,4 +113,161 @@ func TestStdioServer(t *testing.T) {
110113
t.Errorf("unexpected server error: %v", err)
111114
}
112115
})
116+
117+
t.Run("Can use a custom context function", func(t *testing.T) {
118+
// Use a custom context key to store a test value.
119+
type testContextKey struct{}
120+
testValFromContext := func(ctx context.Context) string {
121+
val := ctx.Value(testContextKey{})
122+
if val == nil {
123+
return ""
124+
}
125+
return val.(string)
126+
}
127+
// Create a context function that sets a test value from the environment.
128+
// In real life this could be used to send configuration in a similar way,
129+
// or from a config file.
130+
const testEnvVar = "TEST_ENV_VAR"
131+
setTestValFromEnv := func(ctx context.Context) context.Context {
132+
return context.WithValue(ctx, testContextKey{}, os.Getenv(testEnvVar))
133+
}
134+
os.Setenv(testEnvVar, "test_value")
135+
t.Cleanup(func() {
136+
os.Unsetenv(testEnvVar)
137+
})
138+
139+
// Create pipes for stdin and stdout
140+
stdinReader, stdinWriter := io.Pipe()
141+
stdoutReader, stdoutWriter := io.Pipe()
142+
143+
// Create server
144+
mcpServer := NewMCPServer("test", "1.0.0")
145+
// Add a tool which uses the context function.
146+
mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
147+
// Note this is agnostic to the transport type i.e. doesn't know about request headers.
148+
testVal := testValFromContext(ctx)
149+
return mcp.NewToolResultText(testVal), nil
150+
})
151+
stdioServer := NewStdioServer(mcpServer)
152+
stdioServer.SetErrorLogger(log.New(io.Discard, "", 0))
153+
stdioServer.SetContextFunc(setTestValFromEnv)
154+
155+
// Create context with cancel
156+
ctx, cancel := context.WithCancel(context.Background())
157+
defer cancel()
158+
159+
// Create error channel to catch server errors
160+
serverErrCh := make(chan error, 1)
161+
162+
// Start server in goroutine
163+
go func() {
164+
err := stdioServer.Listen(ctx, stdinReader, stdoutWriter)
165+
if err != nil && err != io.EOF && err != context.Canceled {
166+
serverErrCh <- err
167+
}
168+
close(serverErrCh)
169+
}()
170+
171+
// Create test message
172+
initRequest := map[string]interface{}{
173+
"jsonrpc": "2.0",
174+
"id": 1,
175+
"method": "initialize",
176+
"params": map[string]interface{}{
177+
"protocolVersion": "2024-11-05",
178+
"clientInfo": map[string]interface{}{
179+
"name": "test-client",
180+
"version": "1.0.0",
181+
},
182+
},
183+
}
184+
185+
// Send request
186+
requestBytes, err := json.Marshal(initRequest)
187+
if err != nil {
188+
t.Fatal(err)
189+
}
190+
_, err = stdinWriter.Write(append(requestBytes, '\n'))
191+
if err != nil {
192+
t.Fatal(err)
193+
}
194+
195+
// Read response
196+
scanner := bufio.NewScanner(stdoutReader)
197+
if !scanner.Scan() {
198+
t.Fatal("failed to read response")
199+
}
200+
responseBytes := scanner.Bytes()
201+
202+
var response map[string]interface{}
203+
if err := json.Unmarshal(responseBytes, &response); err != nil {
204+
t.Fatalf("failed to unmarshal response: %v", err)
205+
}
206+
207+
// Verify response structure
208+
if response["jsonrpc"] != "2.0" {
209+
t.Errorf("expected jsonrpc version 2.0, got %v", response["jsonrpc"])
210+
}
211+
if response["id"].(float64) != 1 {
212+
t.Errorf("expected id 1, got %v", response["id"])
213+
}
214+
if response["error"] != nil {
215+
t.Errorf("unexpected error in response: %v", response["error"])
216+
}
217+
if response["result"] == nil {
218+
t.Error("expected result in response")
219+
}
220+
221+
// Call the tool.
222+
toolRequest := map[string]interface{}{
223+
"jsonrpc": "2.0",
224+
"id": 2,
225+
"method": "tools/call",
226+
"params": map[string]interface{}{
227+
"name": "test_tool",
228+
},
229+
}
230+
requestBytes, err = json.Marshal(toolRequest)
231+
if err != nil {
232+
t.Fatalf("Failed to marshal tool request: %v", err)
233+
}
234+
235+
_, err = stdinWriter.Write(append(requestBytes, '\n'))
236+
if err != nil {
237+
t.Fatal(err)
238+
}
239+
240+
if !scanner.Scan() {
241+
t.Fatal("failed to read response")
242+
}
243+
responseBytes = scanner.Bytes()
244+
245+
response = map[string]interface{}{}
246+
if err := json.Unmarshal(responseBytes, &response); err != nil {
247+
t.Fatalf("failed to unmarshal response: %v", err)
248+
}
249+
250+
if response["jsonrpc"] != "2.0" {
251+
t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"])
252+
}
253+
if response["id"].(float64) != 2 {
254+
t.Errorf("Expected id 2, got %v", response["id"])
255+
}
256+
if response["result"].(map[string]interface{})["content"].([]interface{})[0].(map[string]interface{})["text"] != "test_value" {
257+
t.Errorf("Expected result 'test_value', got %v", response["result"])
258+
}
259+
if response["error"] != nil {
260+
t.Errorf("Expected no error, got %v", response["error"])
261+
}
262+
263+
// Clean up
264+
cancel()
265+
stdinWriter.Close()
266+
stdoutWriter.Close()
267+
268+
// Check for server errors
269+
if err := <-serverErrCh; err != nil {
270+
t.Errorf("unexpected server error: %v", err)
271+
}
272+
})
113273
}

0 commit comments

Comments
 (0)