Skip to content

Commit e7f084c

Browse files
JBUinfosemistrictmiguelb-gk
authored
feat: implement MCP elicitation support (#413) (#548)
* feat: implement MCP elicitation support (#413) * Add ElicitationRequest, ElicitationResult, and related types to mcp/types.go * Implement server-side RequestElicitation method with session support * Add client-side ElicitationHandler interface and request handling * Implement elicitation in stdio and in-process transports * Add comprehensive tests following sampling patterns * Create elicitation example demonstrating usage patterns * Use 'Elicitation' prefix for type names to maintain clarity * Address review comments and auto-format * Address further minor review comments * Add sentinel errors * Revert sampling formatting changes * Update elicitation response to match spec Updating elicitation response to match MCP spec document https://modelcontextprotocol.io/specification/draft/client/elicitation * feat(streamable_http): elicitation request Author: Ghosthell --------- Co-authored-by: Ramon Nogueira <ramon@echophase.com> Co-authored-by: Miguel <miguel.bautista@gitkraken.com>
1 parent 3288753 commit e7f084c

File tree

21 files changed

+1490
-107
lines changed

21 files changed

+1490
-107
lines changed

client/client.go

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type Client struct {
2525
serverCapabilities mcp.ServerCapabilities
2626
protocolVersion string
2727
samplingHandler SamplingHandler
28+
elicitationHandler ElicitationHandler
2829
}
2930

3031
type ClientOption func(*Client)
@@ -44,6 +45,14 @@ func WithSamplingHandler(handler SamplingHandler) ClientOption {
4445
}
4546
}
4647

48+
// WithElicitationHandler sets the elicitation handler for the client.
49+
// When set, the client will declare elicitation capability during initialization.
50+
func WithElicitationHandler(handler ElicitationHandler) ClientOption {
51+
return func(c *Client) {
52+
c.elicitationHandler = handler
53+
}
54+
}
55+
4756
// WithSession assumes a MCP Session has already been initialized
4857
func WithSession() ClientOption {
4958
return func(c *Client) {
@@ -174,6 +183,10 @@ func (c *Client) Initialize(
174183
if c.samplingHandler != nil {
175184
capabilities.Sampling = &struct{}{}
176185
}
186+
// Add elicitation capability if handler is configured
187+
if c.elicitationHandler != nil {
188+
capabilities.Elicitation = &struct{}{}
189+
}
177190

178191
// Ensure we send a params object with all required fields
179192
params := struct {
@@ -458,11 +471,15 @@ func (c *Client) Complete(
458471
}
459472

460473
// handleIncomingRequest processes incoming requests from the server.
461-
// This is the main entry point for server-to-client requests like sampling.
474+
// This is the main entry point for server-to-client requests like sampling and elicitation.
462475
func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
463476
switch request.Method {
464477
case string(mcp.MethodSamplingCreateMessage):
465478
return c.handleSamplingRequestTransport(ctx, request)
479+
case string(mcp.MethodElicitationCreate):
480+
return c.handleElicitationRequestTransport(ctx, request)
481+
case string(mcp.MethodPing):
482+
return c.handlePingRequestTransport(ctx, request)
466483
default:
467484
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
468485
}
@@ -515,6 +532,64 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra
515532

516533
return response, nil
517534
}
535+
536+
// handleElicitationRequestTransport handles elicitation requests at the transport level.
537+
func (c *Client) handleElicitationRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
538+
if c.elicitationHandler == nil {
539+
return nil, fmt.Errorf("no elicitation handler configured")
540+
}
541+
542+
// Parse the request parameters
543+
var params mcp.ElicitationParams
544+
if request.Params != nil {
545+
paramsBytes, err := json.Marshal(request.Params)
546+
if err != nil {
547+
return nil, fmt.Errorf("failed to marshal params: %w", err)
548+
}
549+
if err := json.Unmarshal(paramsBytes, &params); err != nil {
550+
return nil, fmt.Errorf("failed to unmarshal params: %w", err)
551+
}
552+
}
553+
554+
// Create the MCP request
555+
mcpRequest := mcp.ElicitationRequest{
556+
Request: mcp.Request{
557+
Method: string(mcp.MethodElicitationCreate),
558+
},
559+
Params: params,
560+
}
561+
562+
// Call the elicitation handler
563+
result, err := c.elicitationHandler.Elicit(ctx, mcpRequest)
564+
if err != nil {
565+
return nil, err
566+
}
567+
568+
// Marshal the result
569+
resultBytes, err := json.Marshal(result)
570+
if err != nil {
571+
return nil, fmt.Errorf("failed to marshal result: %w", err)
572+
}
573+
574+
// Create the transport response
575+
response := &transport.JSONRPCResponse{
576+
JSONRPC: mcp.JSONRPC_VERSION,
577+
ID: request.ID,
578+
Result: json.RawMessage(resultBytes),
579+
}
580+
581+
return response, nil
582+
}
583+
584+
func (c *Client) handlePingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
585+
b, _ := json.Marshal(&mcp.EmptyResult{})
586+
return &transport.JSONRPCResponse{
587+
JSONRPC: mcp.JSONRPC_VERSION,
588+
ID: request.ID,
589+
Result: b,
590+
}, nil
591+
}
592+
518593
func listByPage[T any](
519594
ctx context.Context,
520595
client *Client,

client/elicitation.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package client
2+
3+
import (
4+
"context"
5+
6+
"github.com/mark3labs/mcp-go/mcp"
7+
)
8+
9+
// ElicitationHandler defines the interface for handling elicitation requests from servers.
10+
// Clients can implement this interface to request additional information from users.
11+
type ElicitationHandler interface {
12+
// Elicit handles an elicitation request from the server and returns the user's response.
13+
// The implementation should:
14+
// 1. Present the request message to the user
15+
// 2. Validate input against the requested schema
16+
// 3. Allow the user to accept, decline, or cancel
17+
// 4. Return the appropriate response
18+
Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error)
19+
}

client/elicitation_test.go

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"testing"
8+
9+
"github.com/mark3labs/mcp-go/client/transport"
10+
"github.com/mark3labs/mcp-go/mcp"
11+
)
12+
13+
// mockElicitationHandler implements ElicitationHandler for testing
14+
type mockElicitationHandler struct {
15+
result *mcp.ElicitationResult
16+
err error
17+
}
18+
19+
func (m *mockElicitationHandler) Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) {
20+
if m.err != nil {
21+
return nil, m.err
22+
}
23+
return m.result, nil
24+
}
25+
26+
func TestClient_HandleElicitationRequest(t *testing.T) {
27+
tests := []struct {
28+
name string
29+
handler ElicitationHandler
30+
expectedError string
31+
}{
32+
{
33+
name: "no handler configured",
34+
handler: nil,
35+
expectedError: "no elicitation handler configured",
36+
},
37+
{
38+
name: "successful elicitation - accept",
39+
handler: &mockElicitationHandler{
40+
result: &mcp.ElicitationResult{
41+
ElicitationResponse: mcp.ElicitationResponse{
42+
Action: mcp.ElicitationResponseActionAccept,
43+
Content: map[string]any{
44+
"name": "test-project",
45+
"framework": "react",
46+
},
47+
},
48+
},
49+
},
50+
},
51+
{
52+
name: "successful elicitation - decline",
53+
handler: &mockElicitationHandler{
54+
result: &mcp.ElicitationResult{
55+
ElicitationResponse: mcp.ElicitationResponse{
56+
Action: mcp.ElicitationResponseActionDecline,
57+
},
58+
},
59+
},
60+
},
61+
{
62+
name: "successful elicitation - cancel",
63+
handler: &mockElicitationHandler{
64+
result: &mcp.ElicitationResult{
65+
ElicitationResponse: mcp.ElicitationResponse{
66+
Action: mcp.ElicitationResponseActionCancel,
67+
},
68+
},
69+
},
70+
},
71+
{
72+
name: "handler returns error",
73+
handler: &mockElicitationHandler{
74+
err: fmt.Errorf("user interaction failed"),
75+
},
76+
expectedError: "user interaction failed",
77+
},
78+
}
79+
80+
for _, tt := range tests {
81+
t.Run(tt.name, func(t *testing.T) {
82+
client := &Client{elicitationHandler: tt.handler}
83+
84+
request := transport.JSONRPCRequest{
85+
ID: mcp.NewRequestId(1),
86+
Method: string(mcp.MethodElicitationCreate),
87+
Params: map[string]any{
88+
"message": "Please provide project details",
89+
"requestedSchema": map[string]any{
90+
"type": "object",
91+
"properties": map[string]any{
92+
"name": map[string]any{"type": "string"},
93+
"framework": map[string]any{"type": "string"},
94+
},
95+
},
96+
},
97+
}
98+
99+
result, err := client.handleElicitationRequestTransport(context.Background(), request)
100+
101+
if tt.expectedError != "" {
102+
if err == nil {
103+
t.Errorf("expected error %q, got nil", tt.expectedError)
104+
} else if err.Error() != tt.expectedError {
105+
t.Errorf("expected error %q, got %q", tt.expectedError, err.Error())
106+
}
107+
} else {
108+
if err != nil {
109+
t.Errorf("unexpected error: %v", err)
110+
}
111+
if result == nil {
112+
t.Error("expected result, got nil")
113+
} else {
114+
// Verify the response is properly formatted
115+
var elicitationResult mcp.ElicitationResult
116+
if err := json.Unmarshal(result.Result, &elicitationResult); err != nil {
117+
t.Errorf("failed to unmarshal result: %v", err)
118+
}
119+
}
120+
}
121+
})
122+
}
123+
}
124+
125+
func TestWithElicitationHandler(t *testing.T) {
126+
handler := &mockElicitationHandler{}
127+
client := &Client{}
128+
129+
option := WithElicitationHandler(handler)
130+
option(client)
131+
132+
if client.elicitationHandler != handler {
133+
t.Error("elicitation handler not set correctly")
134+
}
135+
}
136+
137+
func TestClient_Initialize_WithElicitationHandler(t *testing.T) {
138+
mockTransport := &mockElicitationTransport{
139+
sendRequestFunc: func(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
140+
// Verify that elicitation capability is included
141+
// The client internally converts the typed params to a map for transport
142+
// So we check if we're getting the initialize request
143+
if request.Method != "initialize" {
144+
t.Fatalf("expected initialize method, got %s", request.Method)
145+
}
146+
147+
// Return successful initialization response
148+
result := mcp.InitializeResult{
149+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
150+
ServerInfo: mcp.Implementation{
151+
Name: "test-server",
152+
Version: "1.0.0",
153+
},
154+
Capabilities: mcp.ServerCapabilities{},
155+
}
156+
157+
resultBytes, _ := json.Marshal(result)
158+
return &transport.JSONRPCResponse{
159+
ID: request.ID,
160+
Result: json.RawMessage(resultBytes),
161+
}, nil
162+
},
163+
sendNotificationFunc: func(ctx context.Context, notification mcp.JSONRPCNotification) error {
164+
return nil
165+
},
166+
}
167+
168+
handler := &mockElicitationHandler{}
169+
client := NewClient(mockTransport, WithElicitationHandler(handler))
170+
171+
err := client.Start(context.Background())
172+
if err != nil {
173+
t.Fatalf("failed to start client: %v", err)
174+
}
175+
176+
_, err = client.Initialize(context.Background(), mcp.InitializeRequest{
177+
Params: mcp.InitializeParams{
178+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
179+
ClientInfo: mcp.Implementation{
180+
Name: "test-client",
181+
Version: "1.0.0",
182+
},
183+
Capabilities: mcp.ClientCapabilities{},
184+
},
185+
})
186+
187+
if err != nil {
188+
t.Fatalf("failed to initialize: %v", err)
189+
}
190+
}
191+
192+
// mockElicitationTransport implements transport.Interface for testing
193+
type mockElicitationTransport struct {
194+
sendRequestFunc func(context.Context, transport.JSONRPCRequest) (*transport.JSONRPCResponse, error)
195+
sendNotificationFunc func(context.Context, mcp.JSONRPCNotification) error
196+
}
197+
198+
func (m *mockElicitationTransport) Start(ctx context.Context) error {
199+
return nil
200+
}
201+
202+
func (m *mockElicitationTransport) Close() error {
203+
return nil
204+
}
205+
206+
func (m *mockElicitationTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
207+
if m.sendRequestFunc != nil {
208+
return m.sendRequestFunc(ctx, request)
209+
}
210+
return nil, nil
211+
}
212+
213+
func (m *mockElicitationTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
214+
if m.sendNotificationFunc != nil {
215+
return m.sendNotificationFunc(ctx, notification)
216+
}
217+
return nil
218+
}
219+
220+
func (m *mockElicitationTransport) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) {
221+
}
222+
223+
func (m *mockElicitationTransport) GetSessionId() string {
224+
return "mock-session"
225+
}

0 commit comments

Comments
 (0)