Skip to content

Commit 5e1fa6f

Browse files
committed
feat: allow customising context via callback on transport servers
This commit adds an callback function to the transport servers that allows developers to inject context values into the server context. This can be used to inject context values extracted from environment variables (in stdio mode) or from headers (in sse mode), and access them in tools using the provided context.
1 parent e8d90a0 commit 5e1fa6f

File tree

2 files changed

+41
-6
lines changed

2 files changed

+41
-6
lines changed

server/sse.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,19 @@ import (
1212
"github.com/mark3labs/mcp-go/mcp"
1313
)
1414

15+
// SSEContextFunc is a function that takes an existing context and the current
16+
// request and returns a potentially modified context based on the request
17+
// content. This can be used to inject context values from headers, for example.
18+
type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context
19+
1520
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
1621
// It provides real-time communication capabilities over HTTP using the SSE protocol.
1722
type SSEServer struct {
18-
server *MCPServer
19-
baseURL string
20-
sessions sync.Map
21-
srv *http.Server
23+
server *MCPServer
24+
baseURL string
25+
sessions sync.Map
26+
srv *http.Server
27+
contextFunc SSEContextFunc
2228
}
2329

2430
// sseSession represents an active SSE connection.
@@ -36,6 +42,12 @@ func NewSSEServer(server *MCPServer, baseURL string) *SSEServer {
3642
}
3743
}
3844

45+
// SetContextFunc sets a function that will be called to customise the context
46+
// to the server using the incoming request.
47+
func (s *SSEServer) SetContextFunc(fn SSEContextFunc) {
48+
s.contextFunc = fn
49+
}
50+
3951
// NewTestServer creates a test server for testing purposes
4052
func NewTestServer(server *MCPServer) *httptest.Server {
4153
sseServer := &SSEServer{
@@ -172,6 +184,10 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
172184
SessionID: sessionID,
173185
})
174186

187+
if s.contextFunc != nil {
188+
ctx = s.contextFunc(ctx, r)
189+
}
190+
175191
sessionI, ok := s.sessions.Load(sessionID)
176192
if !ok {
177193
s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID")

server/stdio.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,19 @@ import (
1414
"github.com/mark3labs/mcp-go/mcp"
1515
)
1616

17+
// StdioContextFunc is a function that takes an existing context and returns
18+
// a potentially modified context.
19+
// This can be used to inject context values from environment variables,
20+
// for example.
21+
type StdioContextFunc func(ctx context.Context) context.Context
22+
1723
// StdioServer wraps a MCPServer and handles stdio communication.
1824
// It provides a simple way to create command-line MCP servers that
1925
// communicate via standard input/output streams using JSON-RPC messages.
2026
type StdioServer struct {
21-
server *MCPServer
22-
errLogger *log.Logger
27+
server *MCPServer
28+
errLogger *log.Logger
29+
contextFunc StdioContextFunc
2330
}
2431

2532
// NewStdioServer creates a new stdio server wrapper around an MCPServer.
@@ -41,6 +48,13 @@ func (s *StdioServer) SetErrorLogger(logger *log.Logger) {
4148
s.errLogger = logger
4249
}
4350

51+
// SetContextFunc sets a function that will be called to customise the context
52+
// to the server. Note that the stdio server uses the same context for all requests,
53+
// so this function will only be called once per server instance.
54+
func (s *StdioServer) SetContextFunc(fn StdioContextFunc) {
55+
s.contextFunc = fn
56+
}
57+
4458
// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output.
4559
// It runs until the context is cancelled or an error occurs.
4660
// Returns an error if there are issues with reading input or writing output.
@@ -55,6 +69,11 @@ func (s *StdioServer) Listen(
5569
SessionID: "stdio",
5670
})
5771

72+
// Add in any custom context.
73+
if s.contextFunc != nil {
74+
ctx = s.contextFunc(ctx)
75+
}
76+
5877
reader := bufio.NewReader(stdin)
5978

6079
// Start notification handler

0 commit comments

Comments
 (0)