Skip to content

Commit 3c15087

Browse files
committed
Add example of customising context
1 parent d48ee14 commit 3c15087

File tree

1 file changed

+164
-0
lines changed

1 file changed

+164
-0
lines changed

examples/custom_context/main.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"flag"
7+
"fmt"
8+
"io"
9+
"log"
10+
"net/http"
11+
"os"
12+
13+
"github.com/mark3labs/mcp-go/mcp"
14+
"github.com/mark3labs/mcp-go/server"
15+
)
16+
17+
// authKey is a custom context key for storing the auth token.
18+
type authKey struct{}
19+
20+
// withAuthKey adds an auth key to the context.
21+
func withAuthKey(ctx context.Context, auth string) context.Context {
22+
return context.WithValue(ctx, authKey{}, auth)
23+
}
24+
25+
// authFromRequest extracts the auth token from the request headers.
26+
func authFromRequest(ctx context.Context, r *http.Request) context.Context {
27+
return withAuthKey(ctx, r.Header.Get("Authorization"))
28+
}
29+
30+
// authFromEnv extracts the auth token from the environment
31+
func authFromEnv(ctx context.Context) context.Context {
32+
return withAuthKey(ctx, os.Getenv("API_KEY"))
33+
}
34+
35+
// tokenFromContext extracts the auth token from the context.
36+
// This can be used by tools to extract the token regardless of the
37+
// transport being used by the server.
38+
func tokenFromContext(ctx context.Context) (string, error) {
39+
auth, ok := ctx.Value(authKey{}).(string)
40+
if !ok {
41+
return "", fmt.Errorf("missing auth")
42+
}
43+
return auth, nil
44+
}
45+
46+
type response struct {
47+
Args map[string]interface{} `json:"args"`
48+
Headers map[string]string `json:"headers"`
49+
}
50+
51+
// makeRequest makes a request to httpbin.org including the auth token in the request
52+
// headers and the message in the query string.
53+
func makeRequest(ctx context.Context, message, token string) (*response, error) {
54+
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpbin.org/anything", nil)
55+
if err != nil {
56+
return nil, err
57+
}
58+
req.Header.Set("Authorization", token)
59+
req.URL.Query().Add("message", message)
60+
resp, err := http.DefaultClient.Do(req)
61+
if err != nil {
62+
return nil, err
63+
}
64+
defer resp.Body.Close()
65+
body, err := io.ReadAll(resp.Body)
66+
if err != nil {
67+
return nil, err
68+
}
69+
var r *response
70+
if err := json.Unmarshal(body, r); err != nil {
71+
return nil, err
72+
}
73+
return r, nil
74+
}
75+
76+
// handleMakeAuthenticatedRequestTool is a tool that makes an authenticated request
77+
// using the token from the context.
78+
func handleMakeAuthenticatedRequestTool(
79+
ctx context.Context,
80+
request mcp.CallToolRequest,
81+
) (*mcp.CallToolResult, error) {
82+
message, ok := request.Params.Arguments["message"].(string)
83+
if !ok {
84+
return nil, fmt.Errorf("missing message")
85+
}
86+
token, err := tokenFromContext(ctx)
87+
if err != nil {
88+
return nil, fmt.Errorf("missing token: %v", err)
89+
}
90+
// Now our tool can make a request with the token, irrespective of where it came from.
91+
resp, err := makeRequest(ctx, message, token)
92+
if err != nil {
93+
return nil, err
94+
}
95+
return mcp.NewToolResultText(fmt.Sprintf("%+v", resp)), nil
96+
}
97+
98+
type MCPServer struct {
99+
server *server.MCPServer
100+
}
101+
102+
func NewMCPServer() *MCPServer {
103+
mcpServer := server.NewMCPServer(
104+
"example-server",
105+
"1.0.0",
106+
server.WithResourceCapabilities(true, true),
107+
server.WithPromptCapabilities(true),
108+
server.WithToolCapabilities(true),
109+
)
110+
mcpServer.AddTool(mcp.NewTool("make_authenticated_request",
111+
mcp.WithDescription("Makes an authenticated request"),
112+
mcp.WithString("message",
113+
mcp.Description("Message to echo"),
114+
mcp.Required(),
115+
),
116+
), handleMakeAuthenticatedRequestTool)
117+
118+
return &MCPServer{
119+
server: mcpServer,
120+
}
121+
}
122+
123+
func (s *MCPServer) ServeSSE(addr string) *server.SSEServer {
124+
return server.NewSSEServer(s.server,
125+
server.WithBaseURL(fmt.Sprintf("http://%s", addr)),
126+
server.WithSSEContextFunc(authFromRequest),
127+
)
128+
}
129+
130+
func (s *MCPServer) ServeStdio() error {
131+
return server.ServeStdio(s.server, server.WithStdioContextFunc(authFromEnv))
132+
}
133+
134+
func main() {
135+
var transport string
136+
flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or sse)")
137+
flag.StringVar(
138+
&transport,
139+
"transport",
140+
"stdio",
141+
"Transport type (stdio or sse)",
142+
)
143+
flag.Parse()
144+
145+
s := NewMCPServer()
146+
147+
switch transport {
148+
case "stdio":
149+
if err := s.ServeStdio(); err != nil {
150+
log.Fatalf("Server error: %v", err)
151+
}
152+
case "sse":
153+
sseServer := s.ServeSSE("localhost:8080")
154+
log.Printf("SSE server listening on :8080")
155+
if err := sseServer.Start(":8080"); err != nil {
156+
log.Fatalf("Server error: %v", err)
157+
}
158+
default:
159+
log.Fatalf(
160+
"Invalid transport type: %s. Must be 'stdio' or 'sse'",
161+
transport,
162+
)
163+
}
164+
}

0 commit comments

Comments
 (0)