@@ -11,6 +11,8 @@ import (
1111 "sync"
1212 "testing"
1313 "time"
14+
15+ "github.com/mark3labs/mcp-go/mcp"
1416)
1517
1618func TestSSEServer (t * testing.T ) {
@@ -468,4 +470,150 @@ func TestSSEServer(t *testing.T) {
468470 cancel ()
469471 <- done
470472 })
473+
474+ t .Run ("Can use a custom context function" , func (t * testing.T ) {
475+ // Use a custom context key to store a test value.
476+ type testContextKey struct {}
477+ testValFromContext := func (ctx context.Context ) string {
478+ val := ctx .Value (testContextKey {})
479+ if val == nil {
480+ return ""
481+ }
482+ return val .(string )
483+ }
484+ // Create a context function that sets a test value from the request.
485+ // In real life this could be used to send configuration using headers
486+ // or query parameters.
487+ const testHeader = "X-Test-Header"
488+ setTestValFromRequest := func (ctx context.Context , r * http.Request ) context.Context {
489+ return context .WithValue (ctx , testContextKey {}, r .Header .Get (testHeader ))
490+ }
491+
492+ mcpServer := NewMCPServer ("test" , "1.0.0" ,
493+ WithResourceCapabilities (true , true ),
494+ )
495+ // Add a tool which uses the context function.
496+ mcpServer .AddTool (mcp .NewTool ("test_tool" ), func (ctx context.Context , request mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
497+ // Note this is agnostic to the transport type i.e. doesn't know about request headers.
498+ testVal := testValFromContext (ctx )
499+ return mcp .NewToolResultText (testVal ), nil
500+ })
501+
502+ testServer := NewTestServer (mcpServer , func (sseServer * SSEServer ) {
503+ sseServer .contextFunc = setTestValFromRequest
504+ })
505+ defer testServer .Close ()
506+
507+ // Connect to SSE endpoint
508+ sseResp , err := http .Get (fmt .Sprintf ("%s/sse" , testServer .URL ))
509+ if err != nil {
510+ t .Fatalf ("Failed to connect to SSE endpoint: %v" , err )
511+ }
512+ defer sseResp .Body .Close ()
513+
514+ // Read the endpoint event
515+ buf := make ([]byte , 1024 )
516+ n , err := sseResp .Body .Read (buf )
517+ if err != nil {
518+ t .Fatalf ("Failed to read SSE response: %v" , err )
519+ }
520+
521+ endpointEvent := string (buf [:n ])
522+ messageURL := strings .TrimSpace (
523+ strings .Split (strings .Split (endpointEvent , "data: " )[1 ], "\n " )[0 ],
524+ )
525+
526+ // Send initialize request
527+ initRequest := map [string ]interface {}{
528+ "jsonrpc" : "2.0" ,
529+ "id" : 1 ,
530+ "method" : "initialize" ,
531+ "params" : map [string ]interface {}{
532+ "protocolVersion" : "2024-11-05" ,
533+ "clientInfo" : map [string ]interface {}{
534+ "name" : "test-client" ,
535+ "version" : "1.0.0" ,
536+ },
537+ },
538+ }
539+
540+ requestBody , err := json .Marshal (initRequest )
541+ if err != nil {
542+ t .Fatalf ("Failed to marshal request: %v" , err )
543+ }
544+
545+ resp , err := http .Post (
546+ messageURL ,
547+ "application/json" ,
548+ bytes .NewBuffer (requestBody ),
549+ )
550+
551+ if err != nil {
552+ t .Fatalf ("Failed to send message: %v" , err )
553+ }
554+ defer resp .Body .Close ()
555+
556+ if resp .StatusCode != http .StatusAccepted {
557+ t .Errorf ("Expected status 202, got %d" , resp .StatusCode )
558+ }
559+
560+ // Verify response
561+ var response map [string ]interface {}
562+ if err := json .NewDecoder (resp .Body ).Decode (& response ); err != nil {
563+ t .Fatalf ("Failed to decode response: %v" , err )
564+ }
565+
566+ if response ["jsonrpc" ] != "2.0" {
567+ t .Errorf ("Expected jsonrpc 2.0, got %v" , response ["jsonrpc" ])
568+ }
569+ if response ["id" ].(float64 ) != 1 {
570+ t .Errorf ("Expected id 1, got %v" , response ["id" ])
571+ }
572+
573+ // Call the tool.
574+ toolRequest := map [string ]interface {}{
575+ "jsonrpc" : "2.0" ,
576+ "id" : 2 ,
577+ "method" : "tools/call" ,
578+ "params" : map [string ]interface {}{
579+ "name" : "test_tool" ,
580+ },
581+ }
582+ requestBody , err = json .Marshal (toolRequest )
583+ if err != nil {
584+ t .Fatalf ("Failed to marshal tool request: %v" , err )
585+ }
586+
587+ req , err := http .NewRequest (http .MethodPost , messageURL , bytes .NewBuffer (requestBody ))
588+ if err != nil {
589+ t .Fatalf ("Failed to create tool request: %v" , err )
590+ }
591+ // Set the test header to a custom value.
592+ req .Header .Set (testHeader , "test_value" )
593+
594+ resp , err = http .DefaultClient .Do (req )
595+ if err != nil {
596+ t .Fatalf ("Failed to call tool: %v" , err )
597+ }
598+ defer resp .Body .Close ()
599+
600+ response = make (map [string ]interface {})
601+ if err := json .NewDecoder (resp .Body ).Decode (& response ); err != nil {
602+ t .Fatalf ("Failed to decode response: %v" , err )
603+ }
604+
605+ if response ["jsonrpc" ] != "2.0" {
606+ t .Errorf ("Expected jsonrpc 2.0, got %v" , response ["jsonrpc" ])
607+ }
608+ if response ["id" ].(float64 ) != 2 {
609+ t .Errorf ("Expected id 2, got %v" , response ["id" ])
610+ }
611+ if response ["result" ].(map [string ]interface {})["content" ].([]interface {})[0 ].(map [string ]interface {})["text" ] != "test_value" {
612+ t .Errorf ("Expected result 'test_value', got %v" , response ["result" ])
613+ }
614+ if response ["error" ] != nil {
615+ t .Errorf ("Expected no error, got %v" , response ["error" ])
616+ }
617+ })
618+
471619}
0 commit comments