Skip to content

Commit 98db010

Browse files
authored
feat: allow customising context via callback on transport servers (#32)
* feat: allow customising context via callback on transport servers This commit adds an callback function to the transport servers that allows developers to inject context values into the server context. This can be used to inject context values extracted from environment variables (in stdio mode) or from headers (in sse mode), and access them in tools using the provided context. * Add tests for custom context functions * Use t.Setenv instead of manual os.Setenv * Use existing Option type from main branch for configuring SSE server with context func * Add StdioOption to configure context func on Stdio server * Add example of customising context
1 parent 468b381 commit 98db010

File tree

5 files changed

+542
-12
lines changed

5 files changed

+542
-12
lines changed

examples/custom_context/main.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"flag"
7+
"fmt"
8+
"io"
9+
"log"
10+
"net/http"
11+
"os"
12+
13+
"github.com/mark3labs/mcp-go/mcp"
14+
"github.com/mark3labs/mcp-go/server"
15+
)
16+
17+
// authKey is a custom context key for storing the auth token.
18+
type authKey struct{}
19+
20+
// withAuthKey adds an auth key to the context.
21+
func withAuthKey(ctx context.Context, auth string) context.Context {
22+
return context.WithValue(ctx, authKey{}, auth)
23+
}
24+
25+
// authFromRequest extracts the auth token from the request headers.
26+
func authFromRequest(ctx context.Context, r *http.Request) context.Context {
27+
return withAuthKey(ctx, r.Header.Get("Authorization"))
28+
}
29+
30+
// authFromEnv extracts the auth token from the environment
31+
func authFromEnv(ctx context.Context) context.Context {
32+
return withAuthKey(ctx, os.Getenv("API_KEY"))
33+
}
34+
35+
// tokenFromContext extracts the auth token from the context.
36+
// This can be used by tools to extract the token regardless of the
37+
// transport being used by the server.
38+
func tokenFromContext(ctx context.Context) (string, error) {
39+
auth, ok := ctx.Value(authKey{}).(string)
40+
if !ok {
41+
return "", fmt.Errorf("missing auth")
42+
}
43+
return auth, nil
44+
}
45+
46+
type response struct {
47+
Args map[string]interface{} `json:"args"`
48+
Headers map[string]string `json:"headers"`
49+
}
50+
51+
// makeRequest makes a request to httpbin.org including the auth token in the request
52+
// headers and the message in the query string.
53+
func makeRequest(ctx context.Context, message, token string) (*response, error) {
54+
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpbin.org/anything", nil)
55+
if err != nil {
56+
return nil, err
57+
}
58+
req.Header.Set("Authorization", token)
59+
req.URL.Query().Add("message", message)
60+
resp, err := http.DefaultClient.Do(req)
61+
if err != nil {
62+
return nil, err
63+
}
64+
defer resp.Body.Close()
65+
body, err := io.ReadAll(resp.Body)
66+
if err != nil {
67+
return nil, err
68+
}
69+
var r *response
70+
if err := json.Unmarshal(body, r); err != nil {
71+
return nil, err
72+
}
73+
return r, nil
74+
}
75+
76+
// handleMakeAuthenticatedRequestTool is a tool that makes an authenticated request
77+
// using the token from the context.
78+
func handleMakeAuthenticatedRequestTool(
79+
ctx context.Context,
80+
request mcp.CallToolRequest,
81+
) (*mcp.CallToolResult, error) {
82+
message, ok := request.Params.Arguments["message"].(string)
83+
if !ok {
84+
return nil, fmt.Errorf("missing message")
85+
}
86+
token, err := tokenFromContext(ctx)
87+
if err != nil {
88+
return nil, fmt.Errorf("missing token: %v", err)
89+
}
90+
// Now our tool can make a request with the token, irrespective of where it came from.
91+
resp, err := makeRequest(ctx, message, token)
92+
if err != nil {
93+
return nil, err
94+
}
95+
return mcp.NewToolResultText(fmt.Sprintf("%+v", resp)), nil
96+
}
97+
98+
type MCPServer struct {
99+
server *server.MCPServer
100+
}
101+
102+
func NewMCPServer() *MCPServer {
103+
mcpServer := server.NewMCPServer(
104+
"example-server",
105+
"1.0.0",
106+
server.WithResourceCapabilities(true, true),
107+
server.WithPromptCapabilities(true),
108+
server.WithToolCapabilities(true),
109+
)
110+
mcpServer.AddTool(mcp.NewTool("make_authenticated_request",
111+
mcp.WithDescription("Makes an authenticated request"),
112+
mcp.WithString("message",
113+
mcp.Description("Message to echo"),
114+
mcp.Required(),
115+
),
116+
), handleMakeAuthenticatedRequestTool)
117+
118+
return &MCPServer{
119+
server: mcpServer,
120+
}
121+
}
122+
123+
func (s *MCPServer) ServeSSE(addr string) *server.SSEServer {
124+
return server.NewSSEServer(s.server,
125+
server.WithBaseURL(fmt.Sprintf("http://%s", addr)),
126+
server.WithSSEContextFunc(authFromRequest),
127+
)
128+
}
129+
130+
func (s *MCPServer) ServeStdio() error {
131+
return server.ServeStdio(s.server, server.WithStdioContextFunc(authFromEnv))
132+
}
133+
134+
func main() {
135+
var transport string
136+
flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or sse)")
137+
flag.StringVar(
138+
&transport,
139+
"transport",
140+
"stdio",
141+
"Transport type (stdio or sse)",
142+
)
143+
flag.Parse()
144+
145+
s := NewMCPServer()
146+
147+
switch transport {
148+
case "stdio":
149+
if err := s.ServeStdio(); err != nil {
150+
log.Fatalf("Server error: %v", err)
151+
}
152+
case "sse":
153+
sseServer := s.ServeSSE("localhost:8080")
154+
log.Printf("SSE server listening on :8080")
155+
if err := sseServer.Start(":8080"); err != nil {
156+
log.Fatalf("Server error: %v", err)
157+
}
158+
default:
159+
log.Fatalf(
160+
"Invalid transport type: %s. Must be 'stdio' or 'sse'",
161+
transport,
162+
)
163+
}
164+
}

server/sse.go

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ type sseSession struct {
2121
eventQueue chan string // Channel for queuing events
2222
}
2323

24+
// SSEContextFunc is a function that takes an existing context and the current
25+
// request and returns a potentially modified context based on the request
26+
// content. This can be used to inject context values from headers, for example.
27+
type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context
28+
2429
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
2530
// It provides real-time communication capabilities over HTTP using the SSE protocol.
2631
type SSEServer struct {
@@ -31,20 +36,21 @@ type SSEServer struct {
3136
sseEndpoint string
3237
sessions sync.Map
3338
srv *http.Server
39+
contextFunc SSEContextFunc
3440
}
3541

36-
// Option defines a function type for configuring SSEServer
37-
type Option func(*SSEServer)
42+
// SSEOption defines a function type for configuring SSEServer
43+
type SSEOption func(*SSEServer)
3844

3945
// WithBaseURL sets the base URL for the SSE server
40-
func WithBaseURL(baseURL string) Option {
46+
func WithBaseURL(baseURL string) SSEOption {
4147
return func(s *SSEServer) {
4248
s.baseURL = baseURL
4349
}
4450
}
4551

4652
// Add a new option for setting base path
47-
func WithBasePath(basePath string) Option {
53+
func WithBasePath(basePath string) SSEOption {
4854
return func(s *SSEServer) {
4955
// Ensure the path starts with / and doesn't end with /
5056
if !strings.HasPrefix(basePath, "/") {
@@ -56,28 +62,36 @@ func WithBasePath(basePath string) Option {
5662
}
5763

5864
// WithMessageEndpoint sets the message endpoint path
59-
func WithMessageEndpoint(endpoint string) Option {
65+
func WithMessageEndpoint(endpoint string) SSEOption {
6066
return func(s *SSEServer) {
6167
s.messageEndpoint = endpoint
6268
}
6369
}
6470

6571
// WithSSEEndpoint sets the SSE endpoint path
66-
func WithSSEEndpoint(endpoint string) Option {
72+
func WithSSEEndpoint(endpoint string) SSEOption {
6773
return func(s *SSEServer) {
6874
s.sseEndpoint = endpoint
6975
}
7076
}
7177

7278
// WithHTTPServer sets the HTTP server instance
73-
func WithHTTPServer(srv *http.Server) Option {
79+
func WithHTTPServer(srv *http.Server) SSEOption {
7480
return func(s *SSEServer) {
7581
s.srv = srv
7682
}
7783
}
7884

85+
// WithContextFunc sets a function that will be called to customise the context
86+
// to the server using the incoming request.
87+
func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
88+
return func(s *SSEServer) {
89+
s.contextFunc = fn
90+
}
91+
}
92+
7993
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
80-
func NewSSEServer(server *MCPServer, opts ...Option) *SSEServer {
94+
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
8195
s := &SSEServer{
8296
server: server,
8397
sseEndpoint: "/sse",
@@ -94,8 +108,11 @@ func NewSSEServer(server *MCPServer, opts ...Option) *SSEServer {
94108
}
95109

96110
// NewTestServer creates a test server for testing purposes
97-
func NewTestServer(server *MCPServer) *httptest.Server {
111+
func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server {
98112
sseServer := NewSSEServer(server)
113+
for _, opt := range opts {
114+
opt(sseServer)
115+
}
99116

100117
testServer := httptest.NewServer(sseServer)
101118
sseServer.baseURL = testServer.URL
@@ -230,6 +247,10 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
230247
SessionID: sessionID,
231248
})
232249

250+
if s.contextFunc != nil {
251+
ctx = s.contextFunc(ctx, r)
252+
}
253+
233254
sessionI, ok := s.sessions.Load(sessionID)
234255
if !ok {
235256
s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID")

0 commit comments

Comments
 (0)