From 02cf5cf24b853e91cbbf3dfbae2ebb46bb4dea92 Mon Sep 17 00:00:00 2001 From: leavez Date: Sat, 5 Apr 2025 22:26:56 +0800 Subject: [PATCH 01/22] add transport layer interface --- client/transport/interface.go | 45 +++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 client/transport/interface.go diff --git a/client/transport/interface.go b/client/transport/interface.go new file mode 100644 index 000000000..8ac75d746 --- /dev/null +++ b/client/transport/interface.go @@ -0,0 +1,45 @@ +package transport + +import ( + "context" + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Interface for the transport layer. +type Interface interface { + // Start the connection. Start should only be called once. + Start(ctx context.Context) error + + // SendRequest sends a json RPC request and returns the response synchronously. + SendRequest(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) + + // SendNotification sends a json RPC Notification to the server. + SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error + + // SetNotificationHandler sets the handler for notifications. + // Any notification before the handler is set will be discarded. + SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) + + // Close the connection. + Close() error +} + +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID *int64 `json:"id"` + Result json.RawMessage `json:"result"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + } `json:"error"` +} From c915c3e1e7fc09843728a048628b571285304008 Mon Sep 17 00:00:00 2001 From: leavez Date: Sat, 5 Apr 2025 22:27:08 +0800 Subject: [PATCH 02/22] universal client --- client/impl.go | 314 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 314 insertions(+) create mode 100644 client/impl.go diff --git a/client/impl.go b/client/impl.go new file mode 100644 index 000000000..8b9607d8c --- /dev/null +++ b/client/impl.go @@ -0,0 +1,314 @@ +package client + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "sync/atomic" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +type Client struct { + transport transport.Interface + + initialized bool + notifications []func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + requestID atomic.Int64 + capabilities mcp.ServerCapabilities +} + +// NewClient creates a new MCP client with the given transport layer. +// Usage: +// +// client, err := NewClient(transport.NewStdio("mcp", nil, "--stdio")) +// if err != nil { +// log.Fatalf("Failed to create client: %v", err) +// } +func NewClient(transport transport.Interface) *Client { + return &Client{ + transport: transport, + } +} + +// Start initiates the transport connection to the server. +// Returns an error if the transport is nil or if the connection fails. +func (c *Client) Start(ctx context.Context) error { + if c.transport == nil { + return fmt.Errorf("transport is nil") + } + err := c.transport.Start(ctx) + if err != nil { + return err + } + + c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + c.notifyMu.RLock() + defer c.notifyMu.RUnlock() + for _, handler := range c.notifications { + handler(notification) + } + }) + return nil +} + +// Close shuts down the client and closes the transport. +func (c *Client) Close() error { + return c.transport.Close() +} + +// OnNotification registers a handler function to be called when notifications are received. +// Multiple handlers can be registered and will be called in the order they were added. +func (c *Client) OnNotification( + handler func(notification mcp.JSONRPCNotification), +) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.notifications = append(c.notifications, handler) +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// Returns the raw JSON response message or an error if the request fails. +func (c *Client) sendRequest( + ctx context.Context, + method string, + params interface{}, +) (*json.RawMessage, error) { + if !c.initialized && method != "initialize" { + return nil, fmt.Errorf("client not initialized") + } + + id := c.requestID.Add(1) + + request := transport.JSONRPCRequest{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: method, + Params: params, + } + + response, err := c.transport.SendRequest(ctx, request) + if err != nil { + return nil, fmt.Errorf("transport error: %w", err) + } + + if response.Error != nil { + return nil, errors.New(response.Error.Message) + } + + return &response.Result, nil +} + +func (c *Client) Initialize( + ctx context.Context, + request mcp.InitializeRequest, +) (*mcp.InitializeResult, error) { + // Ensure we send a params object with all required fields + params := struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo mcp.Implementation `json:"clientInfo"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + }{ + ProtocolVersion: request.Params.ProtocolVersion, + ClientInfo: request.Params.ClientInfo, + Capabilities: request.Params.Capabilities, // Will be empty struct if not set + } + + response, err := c.sendRequest(ctx, "initialize", params) + if err != nil { + return nil, err + } + + var result mcp.InitializeResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Store capabilities + c.capabilities = result.Capabilities + + // Send initialized notification + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: "notifications/initialized", + }, + } + + err = c.transport.SendNotification(ctx, notification) + if err != nil { + return nil, fmt.Errorf( + "failed to send initialized notification: %w", + err, + ) + } + + c.initialized = true + return &result, nil +} + +func (c *Client) Ping(ctx context.Context) error { + _, err := c.sendRequest(ctx, "ping", nil) + return err +} + +func (c *Client) ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, error) { + response, err := c.sendRequest(ctx, "resources/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListResourcesResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *Client) ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, error) { + response, err := c.sendRequest( + ctx, + "resources/templates/list", + request.Params, + ) + if err != nil { + return nil, err + } + + var result mcp.ListResourceTemplatesResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *Client) ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, +) (*mcp.ReadResourceResult, error) { + response, err := c.sendRequest(ctx, "resources/read", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseReadResourceResult(response) +} + +func (c *Client) Subscribe( + ctx context.Context, + request mcp.SubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) + return err +} + +func (c *Client) Unsubscribe( + ctx context.Context, + request mcp.UnsubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) + return err +} + +func (c *Client) ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, error) { + response, err := c.sendRequest(ctx, "prompts/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListPromptsResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *Client) GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, +) (*mcp.GetPromptResult, error) { + response, err := c.sendRequest(ctx, "prompts/get", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseGetPromptResult(response) +} + +func (c *Client) ListTools( + ctx context.Context, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, error) { + response, err := c.sendRequest(ctx, "tools/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListToolsResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *Client) CallTool( + ctx context.Context, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, error) { + response, err := c.sendRequest(ctx, "tools/call", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseCallToolResult(response) +} + +func (c *Client) SetLevel( + ctx context.Context, + request mcp.SetLevelRequest, +) error { + _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) + return err +} + +func (c *Client) Complete( + ctx context.Context, + request mcp.CompleteRequest, +) (*mcp.CompleteResult, error) { + response, err := c.sendRequest(ctx, "completion/complete", request.Params) + if err != nil { + return nil, err + } + + var result mcp.CompleteResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +// Helper methods + +// GetTransport gives access to the underlying transport layer. +// Cast it to the specific transport type and obtain the other helper methods. +func (c *Client) GetTransport() transport.Interface { + return c.transport +} From 67860f707a3ee4e6203094219f405ac31e17da6d Mon Sep 17 00:00:00 2001 From: leavez Date: Sat, 5 Apr 2025 22:27:18 +0800 Subject: [PATCH 03/22] impl sse & stdio transport based on the original client --- client/transport/sse.go | 363 ++++++++++++++++++++++++++++++++++++++ client/transport/stdio.go | 231 ++++++++++++++++++++++++ 2 files changed, 594 insertions(+) create mode 100644 client/transport/sse.go create mode 100644 client/transport/stdio.go diff --git a/client/transport/sse.go b/client/transport/sse.go new file mode 100644 index 000000000..3eafb33eb --- /dev/null +++ b/client/transport/sse.go @@ -0,0 +1,363 @@ +package transport + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE). +// It maintains a persistent HTTP connection to receive server-pushed events +// while sending requests over regular HTTP POST calls. The client handles +// automatic reconnection and message routing between requests and responses. +type SSE struct { + baseURL *url.URL + endpoint *url.URL + httpClient *http.Client + responses map[int64]chan *JSONRPCResponse + mu sync.RWMutex + done chan struct{} + onNotification func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + endpointChan chan struct{} + headers map[string]string + sseReadTimeout time.Duration +} + +type ClientOption func(*SSE) + +func WithHeaders(headers map[string]string) ClientOption { + return func(sc *SSE) { + sc.headers = headers + } +} + +func WithSSEReadTimeout(timeout time.Duration) ClientOption { + return func(sc *SSE) { + sc.sseReadTimeout = timeout + } +} + +// NewSSE creates a new SSE-based MCP client with the given base URL. +// Returns an error if the URL is invalid. +func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { + parsedURL, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + + smc := &SSE{ + baseURL: parsedURL, + httpClient: &http.Client{}, + responses: make(map[int64]chan *JSONRPCResponse), + done: make(chan struct{}), + endpointChan: make(chan struct{}), + sseReadTimeout: 30 * time.Second, + headers: make(map[string]string), + } + + for _, opt := range options { + opt(smc) + } + + return smc, nil +} + +// Start initiates the SSE connection to the server and waits for the endpoint information. +// Returns an error if the connection fails or times out waiting for the endpoint. +func (c *SSE) Start(ctx context.Context) error { + + req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) + + if err != nil { + + return fmt.Errorf("failed to create request: %w", err) + + } + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to connect to SSE stream: %w", err) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + go c.readSSE(resp.Body) + + // Wait for the endpoint to be received + + select { + case <-c.endpointChan: + // Endpoint received, proceed + case <-ctx.Done(): + return fmt.Errorf("context cancelled while waiting for endpoint") + case <-time.After(30 * time.Second): // Add a timeout + return fmt.Errorf("timeout waiting for endpoint") + } + + return nil +} + +// readSSE continuously reads the SSE stream and processes events. +// It runs until the connection is closed or an error occurs. +func (c *SSE) readSSE(reader io.ReadCloser) { + defer reader.Close() + + br := bufio.NewReader(reader) + var event, data string + + ctx, cancel := context.WithTimeout(context.Background(), c.sseReadTimeout) + defer cancel() + + for { + select { + case <-ctx.Done(): + return + default: + line, err := br.ReadString('\n') + if err != nil { + if err == io.EOF { + // Process any pending event before exit + if event != "" && data != "" { + c.handleSSEEvent(event, data) + } + break + } + select { + case <-c.done: + return + default: + fmt.Printf("SSE stream error: %v\n", err) + return + } + } + + // Remove only newline markers + line = strings.TrimRight(line, "\r\n") + if line == "" { + // Empty line means end of event + if event != "" && data != "" { + c.handleSSEEvent(event, data) + event = "" + data = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + } +} + +// handleSSEEvent processes SSE events based on their type. +// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication. +func (c *SSE) handleSSEEvent(event, data string) { + switch event { + case "endpoint": + endpoint, err := c.baseURL.Parse(data) + if err != nil { + fmt.Printf("Error parsing endpoint URL: %v\n", err) + return + } + if endpoint.Host != c.baseURL.Host { + fmt.Printf("Endpoint origin does not match connection origin\n") + return + } + c.endpoint = endpoint + close(c.endpointChan) + + case "message": + var baseMessage JSONRPCResponse + if err := json.Unmarshal([]byte(data), &baseMessage); err != nil { + fmt.Printf("Error unmarshaling message: %v\n", err) + return + } + + // Handle notification + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(data), ¬ification); err != nil { + return + } + c.notifyMu.RLock() + if c.onNotification != nil { + c.onNotification(notification) + } + c.notifyMu.RUnlock() + return + } + + c.mu.RLock() + ch, ok := c.responses[*baseMessage.ID] + c.mu.RUnlock() + + if ok { + ch <- &baseMessage + c.mu.Lock() + delete(c.responses, *baseMessage.ID) + c.mu.Unlock() + } + } +} + +func (c *SSE) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.onNotification = handler +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// Returns the raw JSON response message or an error if the request fails. +func (c *SSE) SendRequest( + ctx context.Context, + request JSONRPCRequest, +) (*JSONRPCResponse, error) { + + if c.endpoint == nil { + return nil, fmt.Errorf("endpoint not received") + } + + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + responseChan := make(chan *JSONRPCResponse, 1) + c.mu.Lock() + c.responses[request.ID] = responseChan + c.mu.Unlock() + + req, err := http.NewRequestWithContext( + ctx, + "POST", + c.endpoint.String(), + bytes.NewReader(requestBytes), + ) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + // set custom http headers + for k, v := range c.headers { + req.Header.Set(k, v) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf( + "request failed with status %d: %s", + resp.StatusCode, + body, + ) + } + + select { + case <-ctx.Done(): + c.mu.Lock() + delete(c.responses, request.ID) + c.mu.Unlock() + return nil, ctx.Err() + case response := <-responseChan: + return response, nil + } +} + +// Close shuts down the SSE client connection and cleans up any pending responses. +// Returns an error if the shutdown process fails. +func (c *SSE) Close() error { + select { + case <-c.done: + return nil // Already closed + default: + close(c.done) + } + + // Clean up any pending responses + c.mu.Lock() + for _, ch := range c.responses { + close(ch) + } + c.responses = make(map[int64]chan *JSONRPCResponse) + c.mu.Unlock() + + return nil +} + +// SendNotification sends a JSON-RPC notification to the server without expecting a response. +func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + if c.endpoint == nil { + return fmt.Errorf("endpoint not received") + } + + notificationBytes, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + + req, err := http.NewRequestWithContext( + ctx, + "POST", + c.endpoint.String(), + bytes.NewReader(notificationBytes), + ) + if err != nil { + return fmt.Errorf("failed to create notification request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + // Set custom HTTP headers + for k, v := range c.headers { + req.Header.Set(k, v) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send notification: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf( + "notification failed with status %d: %s", + resp.StatusCode, + body, + ) + } + + return nil +} + +// GetEndpoint returns the current endpoint URL for the SSE connection. +func (c *SSE) GetEndpoint() *url.URL { + return c.endpoint +} diff --git a/client/transport/stdio.go b/client/transport/stdio.go new file mode 100644 index 000000000..0e29603b3 --- /dev/null +++ b/client/transport/stdio.go @@ -0,0 +1,231 @@ +package transport + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "sync" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Stdio implements the transport layer of the MCP protocol using stdio communication. +// It launches a subprocess and communicates with it via standard input/output streams +// using JSON-RPC messages. The client handles message routing between requests and +// responses, and supports asynchronous notifications. +type Stdio struct { + command string + args []string + env []string + + cmd *exec.Cmd + stdin io.WriteCloser + stdout *bufio.Reader + stderr io.ReadCloser + responses map[int64]chan *JSONRPCResponse + mu sync.RWMutex + done chan struct{} + onNotification func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex +} + +// NewStdio creates a new stdio transport to communicate with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Returns an error if the subprocess cannot be started or the pipes cannot be created. +func NewStdio( + command string, + env []string, + args ...string, +) *Stdio { + + client := &Stdio{ + command: command, + args: args, + env: env, + + responses: make(map[int64]chan *JSONRPCResponse), + done: make(chan struct{}), + } + + return client +} + +func (c *Stdio) Start(ctx context.Context) error { + cmd := exec.CommandContext(ctx, c.command, c.args...) + + mergedEnv := os.Environ() + mergedEnv = append(mergedEnv, c.env...) + + cmd.Env = mergedEnv + + stdin, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to create stdin pipe: %w", err) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderr, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + + c.cmd = cmd + c.stdin = stdin + c.stderr = stderr + c.stdout = bufio.NewReader(stdout) + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start command: %w", err) + } + + // Start reading responses in a goroutine and wait for it to be ready + ready := make(chan struct{}) + go func() { + close(ready) + c.readResponses() + }() + <-ready + + return nil +} + +// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. +// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. +func (c *Stdio) Close() error { + close(c.done) + if err := c.stdin.Close(); err != nil { + return fmt.Errorf("failed to close stdin: %w", err) + } + if err := c.stderr.Close(); err != nil { + return fmt.Errorf("failed to close stderr: %w", err) + } + return c.cmd.Wait() +} + +// OnNotification registers a handler function to be called when notifications are received. +// Multiple handlers can be registered and will be called in the order they were added. +func (c *Stdio) SetNotificationHandler( + handler func(notification mcp.JSONRPCNotification), +) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.onNotification = handler +} + +// readResponses continuously reads and processes responses from the server's stdout. +// It handles both responses to requests and notifications, routing them appropriately. +// Runs until the done channel is closed or an error occurs reading from stdout. +func (c *Stdio) readResponses() { + for { + select { + case <-c.done: + return + default: + line, err := c.stdout.ReadString('\n') + if err != nil { + if err != io.EOF { + fmt.Printf("Error reading response: %v\n", err) + } + return + } + + var baseMessage JSONRPCResponse + if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { + continue + } + + // Handle notification + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(line), ¬ification); err != nil { + continue + } + c.notifyMu.RLock() + if c.onNotification != nil { + c.onNotification(notification) + } + c.notifyMu.RUnlock() + continue + } + + c.mu.RLock() + ch, ok := c.responses[*baseMessage.ID] + c.mu.RUnlock() + + if ok { + ch <- &baseMessage + c.mu.Lock() + delete(c.responses, *baseMessage.ID) + c.mu.Unlock() + } + } + } +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// It creates a unique request ID, sends the request over stdin, and waits for +// the corresponding response or context cancellation. +// Returns the raw JSON response message or an error if the request fails. +func (c *Stdio) SendRequest( + ctx context.Context, + request JSONRPCRequest, +) (*JSONRPCResponse, error) { + + // Create the complete request structure + responseChan := make(chan *JSONRPCResponse, 1) + c.mu.Lock() + c.responses[request.ID] = responseChan + c.mu.Unlock() + + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := c.stdin.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write request: %w", err) + } + + select { + case <-ctx.Done(): + c.mu.Lock() + delete(c.responses, request.ID) + c.mu.Unlock() + return nil, ctx.Err() + case response := <-responseChan: + return response, nil + } +} + +// SendNotification sends a json RPC Notification to the server. +func (c *Stdio) SendNotification( + ctx context.Context, + notification mcp.JSONRPCNotification, +) error { + notificationBytes, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + notificationBytes = append(notificationBytes, '\n') + + if _, err := c.stdin.Write(notificationBytes); err != nil { + return fmt.Errorf("failed to write notification: %w", err) + } + + return nil +} + +// Stderr returns a reader for the stderr output of the subprocess. +// This can be used to capture error messages or logs from the subprocess. +func (c *Stdio) Stderr() io.Reader { + return c.stderr +} From e4ec65761b9b7958dfa488fe6fbb55a067d5fcb7 Mon Sep 17 00:00:00 2001 From: leavez Date: Sat, 5 Apr 2025 22:46:48 +0800 Subject: [PATCH 04/22] refactor old client to provide compibility --- client/sse.go | 568 ++-------------------------------------- client/sse_test.go | 2 +- client/stdio.go | 442 ++----------------------------- client/transport/sse.go | 5 + client/types.go | 8 - 5 files changed, 42 insertions(+), 983 deletions(-) delete mode 100644 client/types.go diff --git a/client/sse.go b/client/sse.go index cf4a1028e..c3d978723 100644 --- a/client/sse.go +++ b/client/sse.go @@ -1,588 +1,60 @@ package client import ( - "bufio" - "bytes" - "context" - "encoding/json" - "errors" "fmt" - "io" - "net/http" "net/url" - "strings" - "sync" - "sync/atomic" "time" - "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/client/transport" ) // SSEMCPClient implements the MCPClient interface using Server-Sent Events (SSE). // It maintains a persistent HTTP connection to receive server-pushed events // while sending requests over regular HTTP POST calls. The client handles // automatic reconnection and message routing between requests and responses. +// +// Deprecated: Use Client instead. type SSEMCPClient struct { - baseURL *url.URL - endpoint *url.URL - httpClient *http.Client - requestID atomic.Int64 - responses map[int64]chan RPCResponse - mu sync.RWMutex - done chan struct{} - initialized bool - notifications []func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - endpointChan chan struct{} - capabilities mcp.ServerCapabilities - headers map[string]string - sseReadTimeout time.Duration + Client } -type ClientOption func(*SSEMCPClient) +type ClientOption = transport.ClientOption func WithHeaders(headers map[string]string) ClientOption { - return func(sc *SSEMCPClient) { - sc.headers = headers - } + return transport.WithHeaders(headers) } func WithSSEReadTimeout(timeout time.Duration) ClientOption { - return func(sc *SSEMCPClient) { - sc.sseReadTimeout = timeout - } + return transport.WithSSEReadTimeout(timeout) } // NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. // Returns an error if the URL is invalid. +// +// Deprecated: Use NewClient instead. func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) { - parsedURL, err := url.Parse(baseURL) + + sseTransport, err := transport.NewSSE(baseURL, options...) if err != nil { - return nil, fmt.Errorf("invalid URL: %w", err) + return nil, fmt.Errorf("failed to create SSE transport: %w", err) } smc := &SSEMCPClient{ - baseURL: parsedURL, - httpClient: &http.Client{}, - responses: make(map[int64]chan RPCResponse), - done: make(chan struct{}), - endpointChan: make(chan struct{}), - sseReadTimeout: 30 * time.Second, - headers: make(map[string]string), - } - - for _, opt := range options { - opt(smc) + Client: *NewClient(sseTransport), } return smc, nil } -// Start initiates the SSE connection to the server and waits for the endpoint information. -// Returns an error if the connection fails or times out waiting for the endpoint. -func (c *SSEMCPClient) Start(ctx context.Context) error { - - req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) - - if err != nil { - - return fmt.Errorf("failed to create request: %w", err) - - } - - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - - resp, err := c.httpClient.Do(req) - if err != nil { - return fmt.Errorf("failed to connect to SSE stream: %w", err) - } - - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - go c.readSSE(resp.Body) - - // Wait for the endpoint to be received - - select { - case <-c.endpointChan: - // Endpoint received, proceed - case <-ctx.Done(): - return fmt.Errorf("context cancelled while waiting for endpoint") - case <-time.After(30 * time.Second): // Add a timeout - return fmt.Errorf("timeout waiting for endpoint") - } - - return nil -} - -// readSSE continuously reads the SSE stream and processes events. -// It runs until the connection is closed or an error occurs. -func (c *SSEMCPClient) readSSE(reader io.ReadCloser) { - defer reader.Close() - - br := bufio.NewReader(reader) - var event, data string - - ctx, cancel := context.WithTimeout(context.Background(), c.sseReadTimeout) - defer cancel() - - for { - select { - case <-ctx.Done(): - return - default: - line, err := br.ReadString('\n') - if err != nil { - if err == io.EOF { - // Process any pending event before exit - if event != "" && data != "" { - c.handleSSEEvent(event, data) - } - break - } - select { - case <-c.done: - return - default: - fmt.Printf("SSE stream error: %v\n", err) - return - } - } - - // Remove only newline markers - line = strings.TrimRight(line, "\r\n") - if line == "" { - // Empty line means end of event - if event != "" && data != "" { - c.handleSSEEvent(event, data) - event = "" - data = "" - } - continue - } - - if strings.HasPrefix(line, "event:") { - event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) - } else if strings.HasPrefix(line, "data:") { - data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) - } - } - } -} - -// handleSSEEvent processes SSE events based on their type. -// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication. -func (c *SSEMCPClient) handleSSEEvent(event, data string) { - switch event { - case "endpoint": - endpoint, err := c.baseURL.Parse(data) - if err != nil { - fmt.Printf("Error parsing endpoint URL: %v\n", err) - return - } - if endpoint.Host != c.baseURL.Host { - fmt.Printf("Endpoint origin does not match connection origin\n") - return - } - c.endpoint = endpoint - close(c.endpointChan) - - case "message": - var baseMessage struct { - JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` - Method string `json:"method,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` - } - - if err := json.Unmarshal([]byte(data), &baseMessage); err != nil { - fmt.Printf("Error unmarshaling message: %v\n", err) - return - } - - // Handle notification - if baseMessage.ID == nil { - var notification mcp.JSONRPCNotification - if err := json.Unmarshal([]byte(data), ¬ification); err != nil { - return - } - c.notifyMu.RLock() - for _, handler := range c.notifications { - handler(notification) - } - c.notifyMu.RUnlock() - return - } - - c.mu.RLock() - ch, ok := c.responses[*baseMessage.ID] - c.mu.RUnlock() - - if ok { - if baseMessage.Error != nil { - ch <- RPCResponse{ - Error: &baseMessage.Error.Message, - } - } else { - ch <- RPCResponse{ - Response: &baseMessage.Result, - } - } - c.mu.Lock() - delete(c.responses, *baseMessage.ID) - c.mu.Unlock() - } - } -} - -// OnNotification registers a handler function to be called when notifications are received. -// Multiple handlers can be registered and will be called in the order they were added. -func (c *SSEMCPClient) OnNotification( - handler func(notification mcp.JSONRPCNotification), -) { - c.notifyMu.Lock() - defer c.notifyMu.Unlock() - c.notifications = append(c.notifications, handler) -} - -// sendRequest sends a JSON-RPC request to the server and waits for a response. -// Returns the raw JSON response message or an error if the request fails. -func (c *SSEMCPClient) sendRequest( - ctx context.Context, - method string, - params interface{}, -) (*json.RawMessage, error) { - if !c.initialized && method != "initialize" { - return nil, fmt.Errorf("client not initialized") - } - - if c.endpoint == nil { - return nil, fmt.Errorf("endpoint not received") - } - - id := c.requestID.Add(1) - - request := mcp.JSONRPCRequest{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: id, - Request: mcp.Request{ - Method: method, - }, - Params: params, - } - - requestBytes, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - responseChan := make(chan RPCResponse, 1) - c.mu.Lock() - c.responses[id] = responseChan - c.mu.Unlock() - - req, err := http.NewRequestWithContext( - ctx, - "POST", - c.endpoint.String(), - bytes.NewReader(requestBytes), - ) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - // set custom http headers - for k, v := range c.headers { - req.Header.Set(k, v) - } - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK && - resp.StatusCode != http.StatusAccepted { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf( - "request failed with status %d: %s", - resp.StatusCode, - body, - ) - } - - select { - case <-ctx.Done(): - c.mu.Lock() - delete(c.responses, id) - c.mu.Unlock() - return nil, ctx.Err() - case response := <-responseChan: - if response.Error != nil { - return nil, errors.New(*response.Error) - } - return response.Response, nil - } +func (c *SSEMCPClient) GetBaseUrl() *url.URL { + t := c.GetTransport() + sse := t.(*transport.SSE) + return sse.GetBaseURL() } -func (c *SSEMCPClient) Initialize( - ctx context.Context, - request mcp.InitializeRequest, -) (*mcp.InitializeResult, error) { - // Ensure we send a params object with all required fields - params := struct { - ProtocolVersion string `json:"protocolVersion"` - ClientInfo mcp.Implementation `json:"clientInfo"` - Capabilities mcp.ClientCapabilities `json:"capabilities"` - }{ - ProtocolVersion: request.Params.ProtocolVersion, - ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set - } - - response, err := c.sendRequest(ctx, "initialize", params) - if err != nil { - return nil, err - } - - var result mcp.InitializeResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - // Store capabilities - c.capabilities = result.Capabilities - - // Send initialized notification - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: "notifications/initialized", - }, - } - - notificationBytes, err := json.Marshal(notification) - if err != nil { - return nil, fmt.Errorf( - "failed to marshal initialized notification: %w", - err, - ) - } - - req, err := http.NewRequestWithContext( - ctx, - "POST", - c.endpoint.String(), - bytes.NewReader(notificationBytes), - ) - if err != nil { - return nil, fmt.Errorf("failed to create notification request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf( - "failed to send initialized notification: %w", - err, - ) - } - resp.Body.Close() - - c.initialized = true - return &result, nil -} - -func (c *SSEMCPClient) Ping(ctx context.Context) error { - _, err := c.sendRequest(ctx, "ping", nil) - return err -} - -func (c *SSEMCPClient) ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp.ListResourcesResult, error) { - response, err := c.sendRequest(ctx, "resources/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListResourcesResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *SSEMCPClient) ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp.ListResourceTemplatesResult, error) { - response, err := c.sendRequest( - ctx, - "resources/templates/list", - request.Params, - ) - if err != nil { - return nil, err - } - - var result mcp.ListResourceTemplatesResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *SSEMCPClient) ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, -) (*mcp.ReadResourceResult, error) { - response, err := c.sendRequest(ctx, "resources/read", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseReadResourceResult(response) -} - -func (c *SSEMCPClient) Subscribe( - ctx context.Context, - request mcp.SubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) - return err -} - -func (c *SSEMCPClient) Unsubscribe( - ctx context.Context, - request mcp.UnsubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) - return err -} - -func (c *SSEMCPClient) ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - response, err := c.sendRequest(ctx, "prompts/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListPromptsResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *SSEMCPClient) GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, -) (*mcp.GetPromptResult, error) { - response, err := c.sendRequest(ctx, "prompts/get", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseGetPromptResult(response) -} - -func (c *SSEMCPClient) ListTools( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - response, err := c.sendRequest(ctx, "tools/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListToolsResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *SSEMCPClient) CallTool( - ctx context.Context, - request mcp.CallToolRequest, -) (*mcp.CallToolResult, error) { - response, err := c.sendRequest(ctx, "tools/call", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseCallToolResult(response) -} - -func (c *SSEMCPClient) SetLevel( - ctx context.Context, - request mcp.SetLevelRequest, -) error { - _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) - return err -} - -func (c *SSEMCPClient) Complete( - ctx context.Context, - request mcp.CompleteRequest, -) (*mcp.CompleteResult, error) { - response, err := c.sendRequest(ctx, "completion/complete", request.Params) - if err != nil { - return nil, err - } - - var result mcp.CompleteResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -// Helper methods - // GetEndpoint returns the current endpoint URL for the SSE connection. func (c *SSEMCPClient) GetEndpoint() *url.URL { - return c.endpoint -} - -// Close shuts down the SSE client connection and cleans up any pending responses. -// Returns an error if the shutdown process fails. -func (c *SSEMCPClient) Close() error { - select { - case <-c.done: - return nil // Already closed - default: - close(c.done) - } - - // Clean up any pending responses - c.mu.Lock() - for _, ch := range c.responses { - close(ch) - } - c.responses = make(map[int64]chan RPCResponse) - c.mu.Unlock() - - return nil + t := c.GetTransport() + sse := t.(*transport.SSE) + return sse.GetEndpoint() } diff --git a/client/sse_test.go b/client/sse_test.go index 366fbc517..53e3f4ea1 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -46,7 +46,7 @@ func TestSSEMCPClient(t *testing.T) { } defer client.Close() - if client.baseURL == nil { + if client.GetBaseUrl() == nil { t.Error("Base URL should not be nil") } }) diff --git a/client/stdio.go b/client/stdio.go index 8e0845dca..b17c4b75d 100644 --- a/client/stdio.go +++ b/client/stdio.go @@ -1,457 +1,47 @@ package client import ( - "bufio" "context" - "encoding/json" - "errors" - "fmt" "io" - "os" - "os/exec" - "sync" - "sync/atomic" - "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/client/transport" ) // StdioMCPClient implements the MCPClient interface using stdio communication. // It launches a subprocess and communicates with it via standard input/output streams // using JSON-RPC messages. The client handles message routing between requests and // responses, and supports asynchronous notifications. +// +// Deprecated: Use Client instead. type StdioMCPClient struct { - cmd *exec.Cmd - stdin io.WriteCloser - stdout *bufio.Reader - stderr io.ReadCloser - requestID atomic.Int64 - responses map[int64]chan RPCResponse - mu sync.RWMutex - done chan struct{} - initialized bool - notifications []func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - capabilities mcp.ServerCapabilities + Client } // NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess. // It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. // Returns an error if the subprocess cannot be started or the pipes cannot be created. +// +// NewStdioMCPClient will start the connection automatically. Don't call the Start method manually. +// +// Deprecated: Use NewClient instead. func NewStdioMCPClient( command string, env []string, args ...string, ) (*StdioMCPClient, error) { - cmd := exec.Command(command, args...) - mergedEnv := os.Environ() - mergedEnv = append(mergedEnv, env...) + stdioTransport := transport.NewStdio(command, env, args...) + stdioTransport.Start(context.Background()) - cmd.Env = mergedEnv - - stdin, err := cmd.StdinPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stdin pipe: %w", err) - } - - stdout, err := cmd.StdoutPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stdout pipe: %w", err) - } - - stderr, err := cmd.StderrPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stderr pipe: %w", err) - } - - client := &StdioMCPClient{ - cmd: cmd, - stdin: stdin, - stderr: stderr, - stdout: bufio.NewReader(stdout), - responses: make(map[int64]chan RPCResponse), - done: make(chan struct{}), - } - - if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("failed to start command: %w", err) - } - - // Start reading responses in a goroutine and wait for it to be ready - ready := make(chan struct{}) - go func() { - close(ready) - client.readResponses() - }() - <-ready - - return client, nil -} - -// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. -// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. -func (c *StdioMCPClient) Close() error { - close(c.done) - if err := c.stdin.Close(); err != nil { - return fmt.Errorf("failed to close stdin: %w", err) - } - if err := c.stderr.Close(); err != nil { - return fmt.Errorf("failed to close stderr: %w", err) - } - return c.cmd.Wait() + return &StdioMCPClient{ + Client: *NewClient(stdioTransport), + }, nil } // Stderr returns a reader for the stderr output of the subprocess. // This can be used to capture error messages or logs from the subprocess. func (c *StdioMCPClient) Stderr() io.Reader { - return c.stderr -} - -// OnNotification registers a handler function to be called when notifications are received. -// Multiple handlers can be registered and will be called in the order they were added. -func (c *StdioMCPClient) OnNotification( - handler func(notification mcp.JSONRPCNotification), -) { - c.notifyMu.Lock() - defer c.notifyMu.Unlock() - c.notifications = append(c.notifications, handler) -} - -// readResponses continuously reads and processes responses from the server's stdout. -// It handles both responses to requests and notifications, routing them appropriately. -// Runs until the done channel is closed or an error occurs reading from stdout. -func (c *StdioMCPClient) readResponses() { - for { - select { - case <-c.done: - return - default: - line, err := c.stdout.ReadString('\n') - if err != nil { - if err != io.EOF { - fmt.Printf("Error reading response: %v\n", err) - } - return - } - - var baseMessage struct { - JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` - Method string `json:"method,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` - } - - if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { - continue - } - - // Handle notification - if baseMessage.ID == nil { - var notification mcp.JSONRPCNotification - if err := json.Unmarshal([]byte(line), ¬ification); err != nil { - continue - } - c.notifyMu.RLock() - for _, handler := range c.notifications { - handler(notification) - } - c.notifyMu.RUnlock() - continue - } - - c.mu.RLock() - ch, ok := c.responses[*baseMessage.ID] - c.mu.RUnlock() - - if ok { - if baseMessage.Error != nil { - ch <- RPCResponse{ - Error: &baseMessage.Error.Message, - } - } else { - ch <- RPCResponse{ - Response: &baseMessage.Result, - } - } - c.mu.Lock() - delete(c.responses, *baseMessage.ID) - c.mu.Unlock() - } - } - } -} - -// sendRequest sends a JSON-RPC request to the server and waits for a response. -// It creates a unique request ID, sends the request over stdin, and waits for -// the corresponding response or context cancellation. -// Returns the raw JSON response message or an error if the request fails. -func (c *StdioMCPClient) sendRequest( - ctx context.Context, - method string, - params interface{}, -) (*json.RawMessage, error) { - if !c.initialized && method != "initialize" { - return nil, fmt.Errorf("client not initialized") - } - - id := c.requestID.Add(1) - - // Create the complete request structure - request := mcp.JSONRPCRequest{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: id, - Request: mcp.Request{ - Method: method, - }, - Params: params, - } - - responseChan := make(chan RPCResponse, 1) - c.mu.Lock() - c.responses[id] = responseChan - c.mu.Unlock() - - requestBytes, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - requestBytes = append(requestBytes, '\n') - - if _, err := c.stdin.Write(requestBytes); err != nil { - return nil, fmt.Errorf("failed to write request: %w", err) - } - - select { - case <-ctx.Done(): - c.mu.Lock() - delete(c.responses, id) - c.mu.Unlock() - return nil, ctx.Err() - case response := <-responseChan: - if response.Error != nil { - return nil, errors.New(*response.Error) - } - return response.Response, nil - } -} - -func (c *StdioMCPClient) Ping(ctx context.Context) error { - _, err := c.sendRequest(ctx, "ping", nil) - return err -} - -func (c *StdioMCPClient) Initialize( - ctx context.Context, - request mcp.InitializeRequest, -) (*mcp.InitializeResult, error) { - // This structure ensures Capabilities is always included in JSON - params := struct { - ProtocolVersion string `json:"protocolVersion"` - ClientInfo mcp.Implementation `json:"clientInfo"` - Capabilities mcp.ClientCapabilities `json:"capabilities"` - }{ - ProtocolVersion: request.Params.ProtocolVersion, - ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set - } - - response, err := c.sendRequest(ctx, "initialize", params) - if err != nil { - return nil, err - } - - var result mcp.InitializeResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - // Store capabilities - c.capabilities = result.Capabilities - - // Send initialized notification - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: "notifications/initialized", - }, - } - - notificationBytes, err := json.Marshal(notification) - if err != nil { - return nil, fmt.Errorf( - "failed to marshal initialized notification: %w", - err, - ) - } - notificationBytes = append(notificationBytes, '\n') - - if _, err := c.stdin.Write(notificationBytes); err != nil { - return nil, fmt.Errorf( - "failed to send initialized notification: %w", - err, - ) - } - - c.initialized = true - return &result, nil -} - -func (c *StdioMCPClient) ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp. - ListResourcesResult, error) { - response, err := c.sendRequest( - ctx, - "resources/list", - request.Params, - ) - if err != nil { - return nil, err - } - - var result mcp.ListResourcesResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *StdioMCPClient) ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp. - ListResourceTemplatesResult, error) { - response, err := c.sendRequest( - ctx, - "resources/templates/list", - request.Params, - ) - if err != nil { - return nil, err - } - - var result mcp.ListResourceTemplatesResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *StdioMCPClient) ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, -) (*mcp.ReadResourceResult, - error) { - response, err := c.sendRequest(ctx, "resources/read", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseReadResourceResult(response) -} - -func (c *StdioMCPClient) Subscribe( - ctx context.Context, - request mcp.SubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) - return err -} - -func (c *StdioMCPClient) Unsubscribe( - ctx context.Context, - request mcp.UnsubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) - return err -} - -func (c *StdioMCPClient) ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - response, err := c.sendRequest(ctx, "prompts/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListPromptsResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *StdioMCPClient) GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, -) (*mcp.GetPromptResult, error) { - response, err := c.sendRequest(ctx, "prompts/get", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseGetPromptResult(response) -} - -func (c *StdioMCPClient) ListTools( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - response, err := c.sendRequest(ctx, "tools/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListToolsResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *StdioMCPClient) CallTool( - ctx context.Context, - request mcp.CallToolRequest, -) (*mcp.CallToolResult, error) { - response, err := c.sendRequest(ctx, "tools/call", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseCallToolResult(response) -} - -func (c *StdioMCPClient) SetLevel( - ctx context.Context, - request mcp.SetLevelRequest, -) error { - _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) - return err -} - -func (c *StdioMCPClient) Complete( - ctx context.Context, - request mcp.CompleteRequest, -) (*mcp.CompleteResult, error) { - response, err := c.sendRequest(ctx, "completion/complete", request.Params) - if err != nil { - return nil, err - } - - var result mcp.CompleteResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil + t := c.GetTransport() + stdio := t.(*transport.Stdio) + return stdio.Stderr() } diff --git a/client/transport/sse.go b/client/transport/sse.go index 3eafb33eb..43c9e9fb0 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -361,3 +361,8 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti func (c *SSE) GetEndpoint() *url.URL { return c.endpoint } + +// GetBaseURL returns the base URL set in the SSE constructor. +func (c *SSE) GetBaseURL() *url.URL { + return c.baseURL +} diff --git a/client/types.go b/client/types.go deleted file mode 100644 index 4402bd024..000000000 --- a/client/types.go +++ /dev/null @@ -1,8 +0,0 @@ -package client - -import "encoding/json" - -type RPCResponse struct { - Error *string - Response *json.RawMessage -} From 3ace207698b8a994869da688d3c599e7c0c3d74a Mon Sep 17 00:00:00 2001 From: leavez Date: Sat, 5 Apr 2025 22:47:44 +0800 Subject: [PATCH 05/22] rename --- client/client.go | 380 +++++++++++++++++++++++++++++++++++--------- client/impl.go | 314 ------------------------------------ client/interface.go | 84 ++++++++++ 3 files changed, 389 insertions(+), 389 deletions(-) delete mode 100644 client/impl.go create mode 100644 client/interface.go diff --git a/client/client.go b/client/client.go index 1d3cb1051..8b9607d8c 100644 --- a/client/client.go +++ b/client/client.go @@ -1,84 +1,314 @@ -// Package client provides MCP (Model Control Protocol) client implementations. package client import ( "context" + "encoding/json" + "errors" + "fmt" + "sync" + "sync/atomic" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" ) -// MCPClient represents an MCP client interface -type MCPClient interface { - // Initialize sends the initial connection request to the server - Initialize( - ctx context.Context, - request mcp.InitializeRequest, - ) (*mcp.InitializeResult, error) - - // Ping checks if the server is alive - Ping(ctx context.Context) error - - // ListResources requests a list of available resources from the server - ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, - ) (*mcp.ListResourcesResult, error) - - // ListResourceTemplates requests a list of available resource templates from the server - ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, - ) (*mcp.ListResourceTemplatesResult, - error) - - // ReadResource reads a specific resource from the server - ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, - ) (*mcp.ReadResourceResult, error) - - // Subscribe requests notifications for changes to a specific resource - Subscribe(ctx context.Context, request mcp.SubscribeRequest) error - - // Unsubscribe cancels notifications for a specific resource - Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error - - // ListPrompts requests a list of available prompts from the server - ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, - ) (*mcp.ListPromptsResult, error) - - // GetPrompt retrieves a specific prompt from the server - GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, - ) (*mcp.GetPromptResult, error) - - // ListTools requests a list of available tools from the server - ListTools( - ctx context.Context, - request mcp.ListToolsRequest, - ) (*mcp.ListToolsResult, error) - - // CallTool invokes a specific tool on the server - CallTool( - ctx context.Context, - request mcp.CallToolRequest, - ) (*mcp.CallToolResult, error) - - // SetLevel sets the logging level for the server - SetLevel(ctx context.Context, request mcp.SetLevelRequest) error - - // Complete requests completion options for a given argument - Complete( - ctx context.Context, - request mcp.CompleteRequest, - ) (*mcp.CompleteResult, error) - - // Close client connection and cleanup resources - Close() error - - // OnNotification registers a handler for notifications - OnNotification(handler func(notification mcp.JSONRPCNotification)) +type Client struct { + transport transport.Interface + + initialized bool + notifications []func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + requestID atomic.Int64 + capabilities mcp.ServerCapabilities +} + +// NewClient creates a new MCP client with the given transport layer. +// Usage: +// +// client, err := NewClient(transport.NewStdio("mcp", nil, "--stdio")) +// if err != nil { +// log.Fatalf("Failed to create client: %v", err) +// } +func NewClient(transport transport.Interface) *Client { + return &Client{ + transport: transport, + } +} + +// Start initiates the transport connection to the server. +// Returns an error if the transport is nil or if the connection fails. +func (c *Client) Start(ctx context.Context) error { + if c.transport == nil { + return fmt.Errorf("transport is nil") + } + err := c.transport.Start(ctx) + if err != nil { + return err + } + + c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + c.notifyMu.RLock() + defer c.notifyMu.RUnlock() + for _, handler := range c.notifications { + handler(notification) + } + }) + return nil +} + +// Close shuts down the client and closes the transport. +func (c *Client) Close() error { + return c.transport.Close() +} + +// OnNotification registers a handler function to be called when notifications are received. +// Multiple handlers can be registered and will be called in the order they were added. +func (c *Client) OnNotification( + handler func(notification mcp.JSONRPCNotification), +) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.notifications = append(c.notifications, handler) +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// Returns the raw JSON response message or an error if the request fails. +func (c *Client) sendRequest( + ctx context.Context, + method string, + params interface{}, +) (*json.RawMessage, error) { + if !c.initialized && method != "initialize" { + return nil, fmt.Errorf("client not initialized") + } + + id := c.requestID.Add(1) + + request := transport.JSONRPCRequest{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: method, + Params: params, + } + + response, err := c.transport.SendRequest(ctx, request) + if err != nil { + return nil, fmt.Errorf("transport error: %w", err) + } + + if response.Error != nil { + return nil, errors.New(response.Error.Message) + } + + return &response.Result, nil +} + +func (c *Client) Initialize( + ctx context.Context, + request mcp.InitializeRequest, +) (*mcp.InitializeResult, error) { + // Ensure we send a params object with all required fields + params := struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo mcp.Implementation `json:"clientInfo"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + }{ + ProtocolVersion: request.Params.ProtocolVersion, + ClientInfo: request.Params.ClientInfo, + Capabilities: request.Params.Capabilities, // Will be empty struct if not set + } + + response, err := c.sendRequest(ctx, "initialize", params) + if err != nil { + return nil, err + } + + var result mcp.InitializeResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Store capabilities + c.capabilities = result.Capabilities + + // Send initialized notification + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: "notifications/initialized", + }, + } + + err = c.transport.SendNotification(ctx, notification) + if err != nil { + return nil, fmt.Errorf( + "failed to send initialized notification: %w", + err, + ) + } + + c.initialized = true + return &result, nil +} + +func (c *Client) Ping(ctx context.Context) error { + _, err := c.sendRequest(ctx, "ping", nil) + return err +} + +func (c *Client) ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, error) { + response, err := c.sendRequest(ctx, "resources/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListResourcesResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *Client) ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, error) { + response, err := c.sendRequest( + ctx, + "resources/templates/list", + request.Params, + ) + if err != nil { + return nil, err + } + + var result mcp.ListResourceTemplatesResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *Client) ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, +) (*mcp.ReadResourceResult, error) { + response, err := c.sendRequest(ctx, "resources/read", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseReadResourceResult(response) +} + +func (c *Client) Subscribe( + ctx context.Context, + request mcp.SubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) + return err +} + +func (c *Client) Unsubscribe( + ctx context.Context, + request mcp.UnsubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) + return err +} + +func (c *Client) ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, error) { + response, err := c.sendRequest(ctx, "prompts/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListPromptsResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *Client) GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, +) (*mcp.GetPromptResult, error) { + response, err := c.sendRequest(ctx, "prompts/get", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseGetPromptResult(response) +} + +func (c *Client) ListTools( + ctx context.Context, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, error) { + response, err := c.sendRequest(ctx, "tools/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListToolsResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *Client) CallTool( + ctx context.Context, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, error) { + response, err := c.sendRequest(ctx, "tools/call", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseCallToolResult(response) +} + +func (c *Client) SetLevel( + ctx context.Context, + request mcp.SetLevelRequest, +) error { + _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) + return err +} + +func (c *Client) Complete( + ctx context.Context, + request mcp.CompleteRequest, +) (*mcp.CompleteResult, error) { + response, err := c.sendRequest(ctx, "completion/complete", request.Params) + if err != nil { + return nil, err + } + + var result mcp.CompleteResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +// Helper methods + +// GetTransport gives access to the underlying transport layer. +// Cast it to the specific transport type and obtain the other helper methods. +func (c *Client) GetTransport() transport.Interface { + return c.transport } diff --git a/client/impl.go b/client/impl.go deleted file mode 100644 index 8b9607d8c..000000000 --- a/client/impl.go +++ /dev/null @@ -1,314 +0,0 @@ -package client - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "sync" - "sync/atomic" - - "github.com/mark3labs/mcp-go/client/transport" - "github.com/mark3labs/mcp-go/mcp" -) - -type Client struct { - transport transport.Interface - - initialized bool - notifications []func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - requestID atomic.Int64 - capabilities mcp.ServerCapabilities -} - -// NewClient creates a new MCP client with the given transport layer. -// Usage: -// -// client, err := NewClient(transport.NewStdio("mcp", nil, "--stdio")) -// if err != nil { -// log.Fatalf("Failed to create client: %v", err) -// } -func NewClient(transport transport.Interface) *Client { - return &Client{ - transport: transport, - } -} - -// Start initiates the transport connection to the server. -// Returns an error if the transport is nil or if the connection fails. -func (c *Client) Start(ctx context.Context) error { - if c.transport == nil { - return fmt.Errorf("transport is nil") - } - err := c.transport.Start(ctx) - if err != nil { - return err - } - - c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { - c.notifyMu.RLock() - defer c.notifyMu.RUnlock() - for _, handler := range c.notifications { - handler(notification) - } - }) - return nil -} - -// Close shuts down the client and closes the transport. -func (c *Client) Close() error { - return c.transport.Close() -} - -// OnNotification registers a handler function to be called when notifications are received. -// Multiple handlers can be registered and will be called in the order they were added. -func (c *Client) OnNotification( - handler func(notification mcp.JSONRPCNotification), -) { - c.notifyMu.Lock() - defer c.notifyMu.Unlock() - c.notifications = append(c.notifications, handler) -} - -// sendRequest sends a JSON-RPC request to the server and waits for a response. -// Returns the raw JSON response message or an error if the request fails. -func (c *Client) sendRequest( - ctx context.Context, - method string, - params interface{}, -) (*json.RawMessage, error) { - if !c.initialized && method != "initialize" { - return nil, fmt.Errorf("client not initialized") - } - - id := c.requestID.Add(1) - - request := transport.JSONRPCRequest{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: id, - Method: method, - Params: params, - } - - response, err := c.transport.SendRequest(ctx, request) - if err != nil { - return nil, fmt.Errorf("transport error: %w", err) - } - - if response.Error != nil { - return nil, errors.New(response.Error.Message) - } - - return &response.Result, nil -} - -func (c *Client) Initialize( - ctx context.Context, - request mcp.InitializeRequest, -) (*mcp.InitializeResult, error) { - // Ensure we send a params object with all required fields - params := struct { - ProtocolVersion string `json:"protocolVersion"` - ClientInfo mcp.Implementation `json:"clientInfo"` - Capabilities mcp.ClientCapabilities `json:"capabilities"` - }{ - ProtocolVersion: request.Params.ProtocolVersion, - ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set - } - - response, err := c.sendRequest(ctx, "initialize", params) - if err != nil { - return nil, err - } - - var result mcp.InitializeResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - // Store capabilities - c.capabilities = result.Capabilities - - // Send initialized notification - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: "notifications/initialized", - }, - } - - err = c.transport.SendNotification(ctx, notification) - if err != nil { - return nil, fmt.Errorf( - "failed to send initialized notification: %w", - err, - ) - } - - c.initialized = true - return &result, nil -} - -func (c *Client) Ping(ctx context.Context) error { - _, err := c.sendRequest(ctx, "ping", nil) - return err -} - -func (c *Client) ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp.ListResourcesResult, error) { - response, err := c.sendRequest(ctx, "resources/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListResourcesResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *Client) ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp.ListResourceTemplatesResult, error) { - response, err := c.sendRequest( - ctx, - "resources/templates/list", - request.Params, - ) - if err != nil { - return nil, err - } - - var result mcp.ListResourceTemplatesResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *Client) ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, -) (*mcp.ReadResourceResult, error) { - response, err := c.sendRequest(ctx, "resources/read", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseReadResourceResult(response) -} - -func (c *Client) Subscribe( - ctx context.Context, - request mcp.SubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) - return err -} - -func (c *Client) Unsubscribe( - ctx context.Context, - request mcp.UnsubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) - return err -} - -func (c *Client) ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - response, err := c.sendRequest(ctx, "prompts/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListPromptsResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *Client) GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, -) (*mcp.GetPromptResult, error) { - response, err := c.sendRequest(ctx, "prompts/get", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseGetPromptResult(response) -} - -func (c *Client) ListTools( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - response, err := c.sendRequest(ctx, "tools/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListToolsResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *Client) CallTool( - ctx context.Context, - request mcp.CallToolRequest, -) (*mcp.CallToolResult, error) { - response, err := c.sendRequest(ctx, "tools/call", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseCallToolResult(response) -} - -func (c *Client) SetLevel( - ctx context.Context, - request mcp.SetLevelRequest, -) error { - _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) - return err -} - -func (c *Client) Complete( - ctx context.Context, - request mcp.CompleteRequest, -) (*mcp.CompleteResult, error) { - response, err := c.sendRequest(ctx, "completion/complete", request.Params) - if err != nil { - return nil, err - } - - var result mcp.CompleteResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -// Helper methods - -// GetTransport gives access to the underlying transport layer. -// Cast it to the specific transport type and obtain the other helper methods. -func (c *Client) GetTransport() transport.Interface { - return c.transport -} diff --git a/client/interface.go b/client/interface.go new file mode 100644 index 000000000..1d3cb1051 --- /dev/null +++ b/client/interface.go @@ -0,0 +1,84 @@ +// Package client provides MCP (Model Control Protocol) client implementations. +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// MCPClient represents an MCP client interface +type MCPClient interface { + // Initialize sends the initial connection request to the server + Initialize( + ctx context.Context, + request mcp.InitializeRequest, + ) (*mcp.InitializeResult, error) + + // Ping checks if the server is alive + Ping(ctx context.Context) error + + // ListResources requests a list of available resources from the server + ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, + ) (*mcp.ListResourcesResult, error) + + // ListResourceTemplates requests a list of available resource templates from the server + ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, + ) (*mcp.ListResourceTemplatesResult, + error) + + // ReadResource reads a specific resource from the server + ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, + ) (*mcp.ReadResourceResult, error) + + // Subscribe requests notifications for changes to a specific resource + Subscribe(ctx context.Context, request mcp.SubscribeRequest) error + + // Unsubscribe cancels notifications for a specific resource + Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error + + // ListPrompts requests a list of available prompts from the server + ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, + ) (*mcp.ListPromptsResult, error) + + // GetPrompt retrieves a specific prompt from the server + GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, + ) (*mcp.GetPromptResult, error) + + // ListTools requests a list of available tools from the server + ListTools( + ctx context.Context, + request mcp.ListToolsRequest, + ) (*mcp.ListToolsResult, error) + + // CallTool invokes a specific tool on the server + CallTool( + ctx context.Context, + request mcp.CallToolRequest, + ) (*mcp.CallToolResult, error) + + // SetLevel sets the logging level for the server + SetLevel(ctx context.Context, request mcp.SetLevelRequest) error + + // Complete requests completion options for a given argument + Complete( + ctx context.Context, + request mcp.CompleteRequest, + ) (*mcp.CompleteResult, error) + + // Close client connection and cleanup resources + Close() error + + // OnNotification registers a handler for notifications + OnNotification(handler func(notification mcp.JSONRPCNotification)) +} From 090a0ec2ae5afd72bbf602200510ab25b2072918 Mon Sep 17 00:00:00 2001 From: leavez Date: Sat, 5 Apr 2025 23:05:26 +0800 Subject: [PATCH 06/22] remove old client types --- client/sse.go | 36 +++++++----------------------------- client/sse_test.go | 4 +++- client/stdio.go | 30 +++++++++++------------------- client/stdio_test.go | 2 +- 4 files changed, 22 insertions(+), 50 deletions(-) diff --git a/client/sse.go b/client/sse.go index c3d978723..0776d8ac2 100644 --- a/client/sse.go +++ b/client/sse.go @@ -8,52 +8,30 @@ import ( "github.com/mark3labs/mcp-go/client/transport" ) -// SSEMCPClient implements the MCPClient interface using Server-Sent Events (SSE). -// It maintains a persistent HTTP connection to receive server-pushed events -// while sending requests over regular HTTP POST calls. The client handles -// automatic reconnection and message routing between requests and responses. -// -// Deprecated: Use Client instead. -type SSEMCPClient struct { - Client -} - -type ClientOption = transport.ClientOption - -func WithHeaders(headers map[string]string) ClientOption { +func WithHeaders(headers map[string]string) transport.ClientOption { return transport.WithHeaders(headers) } -func WithSSEReadTimeout(timeout time.Duration) ClientOption { +func WithSSEReadTimeout(timeout time.Duration) transport.ClientOption { return transport.WithSSEReadTimeout(timeout) } // NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. // Returns an error if the URL is invalid. -// -// Deprecated: Use NewClient instead. -func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) { +func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) { sseTransport, err := transport.NewSSE(baseURL, options...) if err != nil { return nil, fmt.Errorf("failed to create SSE transport: %w", err) } - smc := &SSEMCPClient{ - Client: *NewClient(sseTransport), - } - - return smc, nil -} - -func (c *SSEMCPClient) GetBaseUrl() *url.URL { - t := c.GetTransport() - sse := t.(*transport.SSE) - return sse.GetBaseURL() + return NewClient(sseTransport), nil } // GetEndpoint returns the current endpoint URL for the SSE connection. -func (c *SSEMCPClient) GetEndpoint() *url.URL { +// +// Note: This method only works with SSE transport. +func GetEndpoint(c *Client) *url.URL { t := c.GetTransport() sse := t.(*transport.SSE) return sse.GetEndpoint() diff --git a/client/sse_test.go b/client/sse_test.go index 53e3f4ea1..7308d043f 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "github.com/mark3labs/mcp-go/client/transport" "testing" "time" @@ -46,7 +47,8 @@ func TestSSEMCPClient(t *testing.T) { } defer client.Close() - if client.GetBaseUrl() == nil { + sseTransport := client.GetTransport().(*transport.SSE) + if sseTransport.GetBaseURL() == nil { t.Error("Base URL should not be nil") } }) diff --git a/client/stdio.go b/client/stdio.go index b17c4b75d..0b0fd9fe9 100644 --- a/client/stdio.go +++ b/client/stdio.go @@ -2,45 +2,37 @@ package client import ( "context" + "fmt" "io" "github.com/mark3labs/mcp-go/client/transport" ) -// StdioMCPClient implements the MCPClient interface using stdio communication. -// It launches a subprocess and communicates with it via standard input/output streams -// using JSON-RPC messages. The client handles message routing between requests and -// responses, and supports asynchronous notifications. -// -// Deprecated: Use Client instead. -type StdioMCPClient struct { - Client -} - // NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess. // It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. // Returns an error if the subprocess cannot be started or the pipes cannot be created. // // NewStdioMCPClient will start the connection automatically. Don't call the Start method manually. -// -// Deprecated: Use NewClient instead. func NewStdioMCPClient( command string, env []string, args ...string, -) (*StdioMCPClient, error) { +) (*Client, error) { stdioTransport := transport.NewStdio(command, env, args...) - stdioTransport.Start(context.Background()) + err := stdioTransport.Start(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to start stdio transport: %w", err) + } - return &StdioMCPClient{ - Client: *NewClient(stdioTransport), - }, nil + return NewClient(stdioTransport), nil } -// Stderr returns a reader for the stderr output of the subprocess. +// GetStderr returns a reader for the stderr output of the subprocess. // This can be used to capture error messages or logs from the subprocess. -func (c *StdioMCPClient) Stderr() io.Reader { +// +// Note: This method only works with stdio transport. +func GetStderr(c *Client) io.Reader { t := c.GetTransport() stdio := t.(*transport.Stdio) return stdio.Stderr() diff --git a/client/stdio_test.go b/client/stdio_test.go index df69b46a3..94da0b541 100644 --- a/client/stdio_test.go +++ b/client/stdio_test.go @@ -47,7 +47,7 @@ func TestStdioMCPClient(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - dec := json.NewDecoder(client.Stderr()) + dec := json.NewDecoder(GetStderr(client)) for { var record map[string]any if err := dec.Decode(&record); err != nil { From 07f052dc1fd25fbe65a1014c04520ac19e79f48f Mon Sep 17 00:00:00 2001 From: leavez Date: Sun, 6 Apr 2025 01:25:38 +0800 Subject: [PATCH 07/22] add test for stdio transport --- client/transport/stdio.go | 3 + client/transport/stdio_test.go | 366 +++++++++++++++++++++++++++++++++ testdata/mockstdio_server.go | 28 ++- 3 files changed, 395 insertions(+), 2 deletions(-) create mode 100644 client/transport/stdio_test.go diff --git a/client/transport/stdio.go b/client/transport/stdio.go index 0e29603b3..85a300a15 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -178,6 +178,9 @@ func (c *Stdio) SendRequest( ctx context.Context, request JSONRPCRequest, ) (*JSONRPCResponse, error) { + if c.stdin == nil { + return nil, fmt.Errorf("stdio client not started") + } // Create the complete request structure responseChan := make(chan *JSONRPCResponse, 1) diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go new file mode 100644 index 000000000..e9922a003 --- /dev/null +++ b/client/transport/stdio_test.go @@ -0,0 +1,366 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +func compileTestServer(outputPath string) error { + cmd := exec.Command( + "go", + "build", + "-o", + outputPath, + "../../testdata/mockstdio_server.go", + ) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) + } + return nil +} + +func TestStdio(t *testing.T) { + // Compile mock server + mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + if err := compileTestServer(mockServerPath); err != nil { + t.Fatalf("Failed to compile mock server: %v", err) + } + defer os.Remove(mockServerPath) + + // Create a new Stdio transport + stdio := NewStdio(mockServerPath, nil) + + // Start the transport + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := stdio.Start(ctx) + if err != nil { + t.Fatalf("Failed to start Stdio transport: %v", err) + } + defer stdio.Close() + + t.Run("SendRequest", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 使用简单的几种基本类型作为参数 + params := map[string]interface{}{ + "string": "hello world", + "array": []interface{}{1, 2, 3}, + } + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "debug/echo", + Params: params, + } + + // Send the request + response, err := stdio.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + // Verify response data matches what was sent + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + if result.ID != 1 { + t.Errorf("Expected ID 1, got %d", result.ID) + } + if result.Method != "debug/echo" { + t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) + } + + if str, ok := result.Params["string"].(string); !ok || str != "hello world" { + t.Errorf("Expected string 'hello world', got %v", result.Params["string"]) + } + + if arr, ok := result.Params["array"].([]interface{}); !ok || len(arr) != 3 { + t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) + } + }) + + t.Run("SendRequestWithTimeout", func(t *testing.T) { + // Create a context that's already canceled + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel the context immediately + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 3, + Method: "debug/echo", + } + + // The request should fail because the context is canceled + _, err := stdio.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected context canceled error, got nil") + } else if err != context.Canceled { + t.Errorf("Expected context.Canceled error, got: %v", err) + } + }) + + t.Run("SendNotification & NotificationHandler", func(t *testing.T) { + + var wg sync.WaitGroup + notificationChan := make(chan mcp.JSONRPCNotification, 1) + + // Set notification handler + stdio.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + notificationChan <- notification + }) + + // Send a notification + // This would trigger a notification from the server + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + notification := mcp.JSONRPCNotification{ + JSONRPC: "2.0", + Notification: mcp.Notification{ + Method: "debug/echo_notification", + Params: mcp.NotificationParams{ + AdditionalFields: map[string]interface{}{"test": "value"}, + }, + }, + } + err := stdio.SendNotification(ctx, notification) + if err != nil { + t.Fatalf("SendNotification failed: %v", err) + } + + wg.Add(1) + go func() { + defer wg.Done() + select { + case nt := <-notificationChan: + // We received a notification + responseJson, _ := json.Marshal(nt.Params.AdditionalFields) + requestJson, _ := json.Marshal(notification) + if string(responseJson) != string(requestJson) { + t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) + } + + case <-time.After(1 * time.Second): + t.Errorf("Expected notification, got none") + } + }() + + wg.Wait() + }) + + t.Run("MultipleRequests", func(t *testing.T) { + var wg sync.WaitGroup + const numRequests = 5 + + // Send multiple requests concurrently + responses := make([]*JSONRPCResponse, numRequests) + errors := make([]error, numRequests) + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Each request has a unique ID and payload + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: int64(100 + idx), + Method: "debug/echo", + Params: map[string]interface{}{ + "requestIndex": idx, + "timestamp": time.Now().UnixNano(), + }, + } + + resp, err := stdio.SendRequest(ctx, request) + responses[idx] = resp + errors[idx] = err + }(i) + } + + wg.Wait() + + // Check results + for i := 0; i < numRequests; i++ { + if errors[i] != nil { + t.Errorf("Request %d failed: %v", i, errors[i]) + continue + } + + if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { + t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + continue + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(responses[i].Result, &result); err != nil { + t.Errorf("Request %d: Failed to unmarshal result: %v", i, err) + continue + } + + // Verify data matches what was sent + if result.ID != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + } + + if result.Method != "debug/echo" { + t.Errorf("Request %d: Expected method 'debug/echo', got '%s'", i, result.Method) + } + + // Verify the requestIndex parameter + if idx, ok := result.Params["requestIndex"].(float64); !ok || int(idx) != i { + t.Errorf("Request %d: Expected requestIndex %d, got %v", i, i, result.Params["requestIndex"]) + } + } + }) + + t.Run("ResponseError", func(t *testing.T) { + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 100, + Method: "debug/echo_error_string", + } + + // The request should fail because the context is canceled + reps, err := stdio.SendRequest(ctx, request) + if err != nil { + t.Errorf("SendRequest failed: %v", err) + } + + if reps.Error == nil { + t.Errorf("Expected error, got nil") + } + + var responseError JSONRPCRequest + if err := json.Unmarshal([]byte(reps.Error.Message), &responseError); err != nil { + t.Errorf("Failed to unmarshal result: %v", err) + } + + if responseError.Method != "debug/echo_error_string" { + t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) + } + if responseError.ID != 100 { + t.Errorf("Expected ID 100, got %d", responseError.ID) + } + if responseError.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) + } + }) + +} + +func TestStdioErrors(t *testing.T) { + t.Run("InvalidCommand", func(t *testing.T) { + // Create a new Stdio transport with a non-existent command + stdio := NewStdio("non_existent_command", nil) + + // Start should fail + ctx := context.Background() + err := stdio.Start(ctx) + if err == nil { + t.Errorf("Expected error when starting with invalid command, got nil") + stdio.Close() + } + }) + + t.Run("RequestBeforeStart", func(t *testing.T) { + // 创建一个新的 Stdio 实例但不调用 Start 方法 + mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + if err := compileTestServer(mockServerPath); err != nil { + t.Fatalf("Failed to compile mock server: %v", err) + } + defer os.Remove(mockServerPath) + + uninitiatedStdio := NewStdio(mockServerPath, nil) + + // 准备一个请求 + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 99, + Method: "ping", + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + _, err := uninitiatedStdio.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected SendRequest to panic before Start(), but it didn't") + } else if err.Error() != "stdio client not started" { + t.Errorf("Expected error 'stdio client not started', got: %v", err) + } + }) + + t.Run("RequestAfterClose", func(t *testing.T) { + // Compile mock server + mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + if err := compileTestServer(mockServerPath); err != nil { + t.Fatalf("Failed to compile mock server: %v", err) + } + defer os.Remove(mockServerPath) + + // Create a new Stdio transport + stdio := NewStdio(mockServerPath, nil) + + // Start the transport + ctx := context.Background() + if err := stdio.Start(ctx); err != nil { + t.Fatalf("Failed to start Stdio transport: %v", err) + } + + // Close the transport - ignore errors like "broken pipe" since the process might exit already + stdio.Close() + + // Wait a bit to ensure process has exited + time.Sleep(100 * time.Millisecond) + + // Try to send a request after close + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "ping", + } + + _, err := stdio.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected error when sending request after close, got nil") + } + }) + +} diff --git a/testdata/mockstdio_server.go b/testdata/mockstdio_server.go index 3100c5a2a..9f13d5547 100644 --- a/testdata/mockstdio_server.go +++ b/testdata/mockstdio_server.go @@ -10,14 +10,14 @@ import ( type JSONRPCRequest struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID *int64 `json:"id,omitempty"` Method string `json:"method"` Params json.RawMessage `json:"params"` } type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID *int64 `json:"id,omitempty"` Result interface{} `json:"result,omitempty"` Error *struct { Code int `json:"code"` @@ -138,6 +138,30 @@ func handleRequest(request JSONRPCRequest) JSONRPCResponse { "values": []string{"test completion"}, }, } + + // Debug methods for testing transport. + case "debug/echo": + response.Result = request + case "debug/echo_notification": + response.Result = request + + // send notification to client + responseBytes, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "method": "debug/test", + "params": request, + }) + fmt.Fprintf(os.Stdout, "%s\n", responseBytes) + + case "debug/echo_error_string": + all, _ := json.Marshal(request) + response.Error = &struct { + Code int `json:"code"` + Message string `json:"message"` + }{ + Code: -32601, + Message: string(all), + } default: response.Error = &struct { Code int `json:"code"` From 80178e3dfaf78492af10be1411fecae47d8d93f1 Mon Sep 17 00:00:00 2001 From: leavez Date: Sun, 6 Apr 2025 10:33:08 +0800 Subject: [PATCH 08/22] rename 'done' to 'closed', to distinguish with ctx.Done --- client/transport/sse.go | 11 ++++++----- client/transport/stdio_test.go | 3 +-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/client/transport/sse.go b/client/transport/sse.go index 43c9e9fb0..ae0fa2640 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -26,12 +26,13 @@ type SSE struct { httpClient *http.Client responses map[int64]chan *JSONRPCResponse mu sync.RWMutex - done chan struct{} onNotification func(mcp.JSONRPCNotification) notifyMu sync.RWMutex endpointChan chan struct{} headers map[string]string sseReadTimeout time.Duration + + closed chan struct{} } type ClientOption func(*SSE) @@ -60,7 +61,7 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { baseURL: parsedURL, httpClient: &http.Client{}, responses: make(map[int64]chan *JSONRPCResponse), - done: make(chan struct{}), + closed: make(chan struct{}), endpointChan: make(chan struct{}), sseReadTimeout: 30 * time.Second, headers: make(map[string]string), @@ -141,7 +142,7 @@ func (c *SSE) readSSE(reader io.ReadCloser) { break } select { - case <-c.done: + case <-c.closed: return default: fmt.Printf("SSE stream error: %v\n", err) @@ -295,10 +296,10 @@ func (c *SSE) SendRequest( // Returns an error if the shutdown process fails. func (c *SSE) Close() error { select { - case <-c.done: + case <-c.closed: return nil // Already closed default: - close(c.done) + close(c.closed) } // Clean up any pending responses diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index e9922a003..1e3d310a4 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -50,10 +50,9 @@ func TestStdio(t *testing.T) { defer stdio.Close() t.Run("SendRequest", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5000000000*time.Second) defer cancel() - // 使用简单的几种基本类型作为参数 params := map[string]interface{}{ "string": "hello world", "array": []interface{}{1, 2, 3}, From 4793c5efbb53413a4a376b46340ad05057c6fdb1 Mon Sep 17 00:00:00 2001 From: leavez Date: Sun, 6 Apr 2025 10:38:20 +0800 Subject: [PATCH 09/22] add cancelSSEStream for better handling of close --- client/transport/sse.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/client/transport/sse.go b/client/transport/sse.go index ae0fa2640..6155c2a93 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -33,6 +33,7 @@ type SSE struct { sseReadTimeout time.Duration closed chan struct{} + cancelSSEStream context.CancelFunc } type ClientOption func(*SSE) @@ -78,6 +79,9 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { // Returns an error if the connection fails or times out waiting for the endpoint. func (c *SSE) Start(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + c.cancelSSEStream = cancel + req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) if err != nil { @@ -302,6 +306,12 @@ func (c *SSE) Close() error { close(c.closed) } + if c.cancelSSEStream != nil { + // It could stop the sse stream body, to quit the readSSE loop immediately + // Also, it could quit start() immediately if not receiving the endpoint + c.cancelSSEStream() + } + // Clean up any pending responses c.mu.Lock() for _, ch := range c.responses { From 71d1d0d6b2123056620e072950a1412df36be239 Mon Sep 17 00:00:00 2001 From: leavez Date: Sun, 6 Apr 2025 10:41:02 +0800 Subject: [PATCH 10/22] fix connection leak when start timeout --- client/transport/sse.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client/transport/sse.go b/client/transport/sse.go index 6155c2a93..72a36869f 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -114,6 +114,7 @@ func (c *SSE) Start(ctx context.Context) error { case <-ctx.Done(): return fmt.Errorf("context cancelled while waiting for endpoint") case <-time.After(30 * time.Second): // Add a timeout + cancel() return fmt.Errorf("timeout waiting for endpoint") } From 2672c1365f3e4aa98c7b987b97b0788859277373 Mon Sep 17 00:00:00 2001 From: leavez Date: Sun, 6 Apr 2025 10:49:28 +0800 Subject: [PATCH 11/22] avoid multiple starting --- client/transport/sse.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/client/transport/sse.go b/client/transport/sse.go index 72a36869f..ddd4c786f 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -11,6 +11,7 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "time" "github.com/mark3labs/mcp-go/mcp" @@ -32,6 +33,7 @@ type SSE struct { headers map[string]string sseReadTimeout time.Duration + started atomic.Bool closed chan struct{} cancelSSEStream context.CancelFunc } @@ -79,6 +81,10 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { // Returns an error if the connection fails or times out waiting for the endpoint. func (c *SSE) Start(ctx context.Context) error { + if c.started.Load() { + return fmt.Errorf("has already started") + } + ctx, cancel := context.WithCancel(ctx) c.cancelSSEStream = cancel @@ -118,6 +124,7 @@ func (c *SSE) Start(ctx context.Context) error { return fmt.Errorf("timeout waiting for endpoint") } + c.started.Store(true) return nil } @@ -240,6 +247,10 @@ func (c *SSE) SendRequest( request JSONRPCRequest, ) (*JSONRPCResponse, error) { + if !c.started.Load() { + return nil, fmt.Errorf("transport not started") + } + if c.endpoint == nil { return nil, fmt.Errorf("endpoint not received") } From a627a2f25e25fbd5459d558afdb24c71fc2a7b35 Mon Sep 17 00:00:00 2001 From: leavez Date: Sun, 6 Apr 2025 12:38:22 +0800 Subject: [PATCH 12/22] use atomic for closed to be more natural compared to started --- client/transport/sse.go | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/client/transport/sse.go b/client/transport/sse.go index ddd4c786f..2aa920358 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -34,7 +34,7 @@ type SSE struct { sseReadTimeout time.Duration started atomic.Bool - closed chan struct{} + closed atomic.Bool cancelSSEStream context.CancelFunc } @@ -64,7 +64,6 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { baseURL: parsedURL, httpClient: &http.Client{}, responses: make(map[int64]chan *JSONRPCResponse), - closed: make(chan struct{}), endpointChan: make(chan struct{}), sseReadTimeout: 30 * time.Second, headers: make(map[string]string), @@ -91,9 +90,7 @@ func (c *SSE) Start(ctx context.Context) error { req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } req.Header.Set("Accept", "text/event-stream") @@ -153,13 +150,10 @@ func (c *SSE) readSSE(reader io.ReadCloser) { } break } - select { - case <-c.closed: - return - default: + if !c.closed.Load() { fmt.Printf("SSE stream error: %v\n", err) - return } + return } // Remove only newline markers @@ -248,9 +242,11 @@ func (c *SSE) SendRequest( ) (*JSONRPCResponse, error) { if !c.started.Load() { - return nil, fmt.Errorf("transport not started") + return nil, fmt.Errorf("transport not started yet") + } + if c.closed.Load() { + return nil, fmt.Errorf("transport has been closed") } - if c.endpoint == nil { return nil, fmt.Errorf("endpoint not received") } @@ -311,11 +307,8 @@ func (c *SSE) SendRequest( // Close shuts down the SSE client connection and cleans up any pending responses. // Returns an error if the shutdown process fails. func (c *SSE) Close() error { - select { - case <-c.closed: + if !c.closed.CompareAndSwap(false, true) { return nil // Already closed - default: - close(c.closed) } if c.cancelSSEStream != nil { From efbf5ffdc6e6c43ac2a279c374199d2912b9e84c Mon Sep 17 00:00:00 2001 From: leavez Date: Sun, 6 Apr 2025 12:39:32 +0800 Subject: [PATCH 13/22] fix leak of timer --- client/transport/sse.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/client/transport/sse.go b/client/transport/sse.go index 2aa920358..ad911fb8b 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -110,13 +110,14 @@ func (c *SSE) Start(ctx context.Context) error { go c.readSSE(resp.Body) // Wait for the endpoint to be received - + timeout := time.NewTimer(30 * time.Second) + defer timeout.Stop() select { case <-c.endpointChan: // Endpoint received, proceed case <-ctx.Done(): return fmt.Errorf("context cancelled while waiting for endpoint") - case <-time.After(30 * time.Second): // Add a timeout + case <-timeout.C: // Add a timeout cancel() return fmt.Errorf("timeout waiting for endpoint") } From 84c7c7848e3d826808438048f1d7aa5d22af11c1 Mon Sep 17 00:00:00 2001 From: leavez Date: Sun, 6 Apr 2025 12:39:36 +0800 Subject: [PATCH 14/22] Create sse_test.go --- client/transport/sse_test.go | 469 +++++++++++++++++++++++++++++++++++ 1 file changed, 469 insertions(+) create mode 100644 client/transport/sse_test.go diff --git a/client/transport/sse_test.go b/client/transport/sse_test.go new file mode 100644 index 000000000..02879fdcb --- /dev/null +++ b/client/transport/sse_test.go @@ -0,0 +1,469 @@ +package transport + +import ( + "context" + "encoding/json" + "errors" + "sync" + "testing" + "time" + + "fmt" + "net/http" + "net/http/httptest" + + "github.com/mark3labs/mcp-go/mcp" +) + +// startMockSSEEchoServer starts a test HTTP server that implements +// a minimal SSE-based echo server for testing purposes. +// It returns the server URL and a function to close the server. +func startMockSSEEchoServer() (string, func()) { + // Create handler for SSE endpoint + var sseWriter http.ResponseWriter + var flush http.Flusher + var mu sync.Mutex + sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Setup SSE headers + defer fmt.Printf("SSEHandler ends: %v\n", r.Context().Err()) + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + mu.Lock() + sseWriter = w + flush = flusher + mu.Unlock() + // Send initial endpoint event with message endpoint URL + fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", "/message") + flusher.Flush() + + // Keep connection open + <-r.Context().Done() + }) + + // Create handler for message endpoint + messageHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle only POST requests + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse incoming JSON-RPC request + var request map[string]interface{} + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&request); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + // Echo back the request as the response result + response := map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": request, + } + + method := request["method"] + switch method { + case "debug/echo": + response["result"] = request + case "debug/echo_notification": + response["result"] = request + // send notification to client + responseBytes, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "method": "debug/test", + "params": request, + }) + mu.Lock() + fmt.Fprintf(sseWriter, "event: message\ndata: %s\n\n", responseBytes) + flush.Flush() + mu.Unlock() + case "debug/echo_error_string": + data, _ := json.Marshal(request) + response["error"] = map[string]interface{}{ + "code": -1, + "message": string(data), + } + } + + // Set response headers + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + + go func() { + data, _ := json.Marshal(response) + mu.Lock() + defer mu.Unlock() + if sseWriter != nil && flush != nil { + fmt.Fprintf(sseWriter, "event: message\ndata: %s\n\n", data) + flush.Flush() + } + }() + + }) + + // Create a router to handle different endpoints + mux := http.NewServeMux() + mux.Handle("/", sseHandler) + mux.Handle("/message", messageHandler) + + // Start test server + testServer := httptest.NewServer(mux) + + return testServer.URL, testServer.Close +} + +func TestSSE(t *testing.T) { + // Compile mock server + url, closeF := startMockSSEEchoServer() + defer closeF() + + trans, err := NewSSE(url) + if err != nil { + t.Fatal(err) + } + + // Start the transport + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + err = trans.Start(ctx) + if err != nil { + t.Fatalf("Failed to start transport: %v", err) + } + defer trans.Close() + + t.Run("SendRequest", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + params := map[string]interface{}{ + "string": "hello world", + "array": []interface{}{1, 2, 3}, + } + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "debug/echo", + Params: params, + } + + // Send the request + response, err := trans.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + // Verify response data matches what was sent + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + if result.ID != 1 { + t.Errorf("Expected ID 1, got %d", result.ID) + } + if result.Method != "debug/echo" { + t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) + } + + if str, ok := result.Params["string"].(string); !ok || str != "hello world" { + t.Errorf("Expected string 'hello world', got %v", result.Params["string"]) + } + + if arr, ok := result.Params["array"].([]interface{}); !ok || len(arr) != 3 { + t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) + } + }) + + t.Run("SendRequestWithTimeout", func(t *testing.T) { + // Create a context that's already canceled + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel the context immediately + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 3, + Method: "debug/echo", + } + + // The request should fail because the context is canceled + _, err := trans.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected context canceled error, got nil") + } else if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled error, got: %v", err) + } + }) + + t.Run("SendNotification & NotificationHandler", func(t *testing.T) { + + var wg sync.WaitGroup + notificationChan := make(chan mcp.JSONRPCNotification, 1) + + // Set notification handler + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + notificationChan <- notification + }) + + // Send a notification + // This would trigger a notification from the server + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + notification := mcp.JSONRPCNotification{ + JSONRPC: "2.0", + Notification: mcp.Notification{ + Method: "debug/echo_notification", + Params: mcp.NotificationParams{ + AdditionalFields: map[string]interface{}{"test": "value"}, + }, + }, + } + err := trans.SendNotification(ctx, notification) + if err != nil { + t.Fatalf("SendNotification failed: %v", err) + } + + wg.Add(1) + go func() { + defer wg.Done() + select { + case nt := <-notificationChan: + // We received a notification + responseJson, _ := json.Marshal(nt.Params.AdditionalFields) + requestJson, _ := json.Marshal(notification) + if string(responseJson) != string(requestJson) { + t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) + } + + case <-time.After(1 * time.Second): + t.Errorf("Expected notification, got none") + } + }() + + wg.Wait() + }) + + t.Run("MultipleRequests", func(t *testing.T) { + var wg sync.WaitGroup + const numRequests = 5 + + // Send multiple requests concurrently + mu := sync.Mutex{} + responses := make([]*JSONRPCResponse, numRequests) + errors := make([]error, numRequests) + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Each request has a unique ID and payload + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: int64(100 + idx), + Method: "debug/echo", + Params: map[string]interface{}{ + "requestIndex": idx, + "timestamp": time.Now().UnixNano(), + }, + } + + resp, err := trans.SendRequest(ctx, request) + mu.Lock() + responses[idx] = resp + errors[idx] = err + mu.Unlock() + }(i) + } + + wg.Wait() + + // Check results + for i := 0; i < numRequests; i++ { + if errors[i] != nil { + t.Errorf("Request %d failed: %v", i, errors[i]) + continue + } + + if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { + t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + continue + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(responses[i].Result, &result); err != nil { + t.Errorf("Request %d: Failed to unmarshal result: %v", i, err) + continue + } + + // Verify data matches what was sent + if result.ID != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + } + + if result.Method != "debug/echo" { + t.Errorf("Request %d: Expected method 'debug/echo', got '%s'", i, result.Method) + } + + // Verify the requestIndex parameter + if idx, ok := result.Params["requestIndex"].(float64); !ok || int(idx) != i { + t.Errorf("Request %d: Expected requestIndex %d, got %v", i, i, result.Params["requestIndex"]) + } + } + }) + + t.Run("ResponseError", func(t *testing.T) { + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 100, + Method: "debug/echo_error_string", + } + + // The request should fail because the context is canceled + reps, err := trans.SendRequest(ctx, request) + if err != nil { + t.Errorf("SendRequest failed: %v", err) + } + + if reps.Error == nil { + t.Errorf("Expected error, got nil") + } + + var responseError JSONRPCRequest + if err := json.Unmarshal([]byte(reps.Error.Message), &responseError); err != nil { + t.Errorf("Failed to unmarshal result: %v", err) + } + + if responseError.Method != "debug/echo_error_string" { + t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) + } + if responseError.ID != 100 { + t.Errorf("Expected ID 100, got %d", responseError.ID) + } + if responseError.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) + } + }) + +} + +func TestSSEErrors(t *testing.T) { + t.Run("InvalidURL", func(t *testing.T) { + // Create a new SSE transport with an invalid URL + _, err := NewSSE("://invalid-url") + if err == nil { + t.Errorf("Expected error when creating with invalid URL, got nil") + } + }) + + t.Run("NonExistentURL", func(t *testing.T) { + // Create a new SSE transport with a non-existent URL + sse, err := NewSSE("http://localhost:1") + if err != nil { + t.Fatalf("Failed to create SSE transport: %v", err) + } + + // Start should fail + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = sse.Start(ctx) + if err == nil { + t.Errorf("Expected error when starting with non-existent URL, got nil") + sse.Close() + } + }) + + t.Run("RequestBeforeStart", func(t *testing.T) { + url, closeF := startMockSSEEchoServer() + defer closeF() + + // Create a new SSE instance without calling Start method + sse, err := NewSSE(url) + if err != nil { + t.Fatalf("Failed to create SSE transport: %v", err) + } + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 99, + Method: "ping", + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + _, err = sse.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected SendRequest to fail before Start(), but it didn't") + } + }) + + t.Run("RequestAfterClose", func(t *testing.T) { + // Start a mock server + url, closeF := startMockSSEEchoServer() + defer closeF() + + // Create a new SSE transport + sse, err := NewSSE(url) + if err != nil { + t.Fatalf("Failed to create SSE transport: %v", err) + } + + // Start the transport + ctx := context.Background() + if err := sse.Start(ctx); err != nil { + t.Fatalf("Failed to start SSE transport: %v", err) + } + + // Close the transport + sse.Close() + + // Wait a bit to ensure connection has closed + time.Sleep(100 * time.Millisecond) + + // Try to send a request after close + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "ping", + } + + _, err = sse.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected error when sending request after close, got nil") + } + }) + +} From 70602b5fceb11299a83b53ed134ce42d7feea9e8 Mon Sep 17 00:00:00 2001 From: leavez Date: Sun, 6 Apr 2025 12:40:46 +0800 Subject: [PATCH 15/22] enforce test --- client/transport/stdio_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 1e3d310a4..445ba07ea 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -180,7 +180,7 @@ func TestStdio(t *testing.T) { // Send multiple requests concurrently responses := make([]*JSONRPCResponse, numRequests) errors := make([]error, numRequests) - + mu := sync.Mutex{} for i := 0; i < numRequests; i++ { wg.Add(1) go func(idx int) { @@ -200,8 +200,10 @@ func TestStdio(t *testing.T) { } resp, err := stdio.SendRequest(ctx, request) + mu.Lock() responses[idx] = resp errors[idx] = err + mu.Unlock() }(i) } From 6512ee75e2447b6aed30e621f0acd5e1a160d9ce Mon Sep 17 00:00:00 2001 From: leavez Date: Sun, 6 Apr 2025 12:50:54 +0800 Subject: [PATCH 16/22] add comment --- client/client.go | 9 ++++++--- client/stdio.go | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/client/client.go b/client/client.go index 8b9607d8c..e23264d0e 100644 --- a/client/client.go +++ b/client/client.go @@ -12,6 +12,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" ) +// Client implements the MCP client. type Client struct { transport transport.Interface @@ -22,7 +23,7 @@ type Client struct { capabilities mcp.ServerCapabilities } -// NewClient creates a new MCP client with the given transport layer. +// NewClient creates a new MCP client with the given transport. // Usage: // // client, err := NewClient(transport.NewStdio("mcp", nil, "--stdio")) @@ -35,8 +36,8 @@ func NewClient(transport transport.Interface) *Client { } } -// Start initiates the transport connection to the server. -// Returns an error if the transport is nil or if the connection fails. +// Start initiates the connection to the server. +// Must be called before using the client. func (c *Client) Start(ctx context.Context) error { if c.transport == nil { return fmt.Errorf("transport is nil") @@ -103,6 +104,8 @@ func (c *Client) sendRequest( return &response.Result, nil } +// Initialize negotiates with the server. +// Must be called after Start, and before any request methods. func (c *Client) Initialize( ctx context.Context, request mcp.InitializeRequest, diff --git a/client/stdio.go b/client/stdio.go index 0b0fd9fe9..d6867083d 100644 --- a/client/stdio.go +++ b/client/stdio.go @@ -12,7 +12,8 @@ import ( // It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. // Returns an error if the subprocess cannot be started or the pipes cannot be created. // -// NewStdioMCPClient will start the connection automatically. Don't call the Start method manually. +// NOTICE: NewStdioMCPClient will start the connection automatically. Don't call the Start method manually. +// This is for backward compatibility. func NewStdioMCPClient( command string, env []string, From 0372773ab3e3a662aa95886d3015f1711b8998ac Mon Sep 17 00:00:00 2001 From: leavez Date: Wed, 9 Apr 2025 00:21:48 +0800 Subject: [PATCH 17/22] sse: add custom header in start request --- client/transport/sse.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client/transport/sse.go b/client/transport/sse.go index ad911fb8b..e29aaa834 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -97,6 +97,11 @@ func (c *SSE) Start(ctx context.Context) error { req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") + // set custom http headers + for k, v := range c.headers { + req.Header.Set(k, v) + } + resp, err := c.httpClient.Do(req) if err != nil { return fmt.Errorf("failed to connect to SSE stream: %w", err) From 991c6de64e638439a951a3596aa65aa4ee564262 Mon Sep 17 00:00:00 2001 From: leavez Date: Wed, 9 Apr 2025 00:25:31 +0800 Subject: [PATCH 18/22] update comment --- client/client.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/client/client.go b/client/client.go index e23264d0e..cc94abab0 100644 --- a/client/client.go +++ b/client/client.go @@ -26,9 +26,10 @@ type Client struct { // NewClient creates a new MCP client with the given transport. // Usage: // -// client, err := NewClient(transport.NewStdio("mcp", nil, "--stdio")) +// stdio := transport.NewStdio("./mcp_server", nil, "xxx") +// client, err := NewClient(stdio) // if err != nil { -// log.Fatalf("Failed to create client: %v", err) +// log.Fatalf("Failed to create client: %v", err) // } func NewClient(transport transport.Interface) *Client { return &Client{ From 4f583a405db00debca31710b21bd9001d7fcb1c2 Mon Sep 17 00:00:00 2001 From: leavez Date: Wed, 9 Apr 2025 00:30:07 +0800 Subject: [PATCH 19/22] comment --- client/sse.go | 2 +- client/stdio.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client/sse.go b/client/sse.go index 0776d8ac2..d998f2cc5 100644 --- a/client/sse.go +++ b/client/sse.go @@ -30,7 +30,7 @@ func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client // GetEndpoint returns the current endpoint URL for the SSE connection. // -// Note: This method only works with SSE transport. +// Note: This method only works with SSE transport, or it will panic. func GetEndpoint(c *Client) *url.URL { t := c.GetTransport() sse := t.(*transport.SSE) diff --git a/client/stdio.go b/client/stdio.go index d6867083d..a25f6d19d 100644 --- a/client/stdio.go +++ b/client/stdio.go @@ -32,7 +32,7 @@ func NewStdioMCPClient( // GetStderr returns a reader for the stderr output of the subprocess. // This can be used to capture error messages or logs from the subprocess. // -// Note: This method only works with stdio transport. +// Note: This method only works with stdio transport, or it will panic. func GetStderr(c *Client) io.Reader { t := c.GetTransport() stdio := t.(*transport.Stdio) From 2b28205973378001142f39f8c84dc8d012a78443 Mon Sep 17 00:00:00 2001 From: leavez Date: Sat, 12 Apr 2025 01:23:14 +0800 Subject: [PATCH 20/22] cover #88 --- client/sse.go | 8 +---- client/transport/sse.go | 78 +++++++++++++++++------------------------ 2 files changed, 33 insertions(+), 53 deletions(-) diff --git a/client/sse.go b/client/sse.go index d998f2cc5..c26744a39 100644 --- a/client/sse.go +++ b/client/sse.go @@ -2,20 +2,14 @@ package client import ( "fmt" - "net/url" - "time" - "github.com/mark3labs/mcp-go/client/transport" + "net/url" ) func WithHeaders(headers map[string]string) transport.ClientOption { return transport.WithHeaders(headers) } -func WithSSEReadTimeout(timeout time.Duration) transport.ClientOption { - return transport.WithSSEReadTimeout(timeout) -} - // NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. // Returns an error if the URL is invalid. func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) { diff --git a/client/transport/sse.go b/client/transport/sse.go index e29aaa834..a515ae760 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -31,7 +31,6 @@ type SSE struct { notifyMu sync.RWMutex endpointChan chan struct{} headers map[string]string - sseReadTimeout time.Duration started atomic.Bool closed atomic.Bool @@ -46,12 +45,6 @@ func WithHeaders(headers map[string]string) ClientOption { } } -func WithSSEReadTimeout(timeout time.Duration) ClientOption { - return func(sc *SSE) { - sc.sseReadTimeout = timeout - } -} - // NewSSE creates a new SSE-based MCP client with the given base URL. // Returns an error if the URL is invalid. func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { @@ -61,12 +54,11 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { } smc := &SSE{ - baseURL: parsedURL, - httpClient: &http.Client{}, - responses: make(map[int64]chan *JSONRPCResponse), - endpointChan: make(chan struct{}), - sseReadTimeout: 30 * time.Second, - headers: make(map[string]string), + baseURL: parsedURL, + httpClient: &http.Client{}, + responses: make(map[int64]chan *JSONRPCResponse), + endpointChan: make(chan struct{}), + headers: make(map[string]string), } for _, opt := range options { @@ -139,46 +131,40 @@ func (c *SSE) readSSE(reader io.ReadCloser) { br := bufio.NewReader(reader) var event, data string - ctx, cancel := context.WithTimeout(context.Background(), c.sseReadTimeout) - defer cancel() - for { - select { - case <-ctx.Done(): - return - default: - line, err := br.ReadString('\n') - if err != nil { - if err == io.EOF { - // Process any pending event before exit - if event != "" && data != "" { - c.handleSSEEvent(event, data) - } - break - } - if !c.closed.Load() { - fmt.Printf("SSE stream error: %v\n", err) - } - return - } - - // Remove only newline markers - line = strings.TrimRight(line, "\r\n") - if line == "" { - // Empty line means end of event + // when close or start's ctx cancel, the reader will be closed + // and the for loop will break. + line, err := br.ReadString('\n') + if err != nil { + if err == io.EOF { + // Process any pending event before exit if event != "" && data != "" { c.handleSSEEvent(event, data) - event = "" - data = "" } - continue + break + } + if !c.closed.Load() { + fmt.Printf("SSE stream error: %v\n", err) } + return + } - if strings.HasPrefix(line, "event:") { - event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) - } else if strings.HasPrefix(line, "data:") { - data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + // Remove only newline markers + line = strings.TrimRight(line, "\r\n") + if line == "" { + // Empty line means end of event + if event != "" && data != "" { + c.handleSSEEvent(event, data) + event = "" + data = "" } + continue + } + + if strings.HasPrefix(line, "event:") { + event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) } } } From 12a31a88345c30d7a56774cecc3f53b491bc56b5 Mon Sep 17 00:00:00 2001 From: leavez Date: Sat, 12 Apr 2025 01:43:27 +0800 Subject: [PATCH 21/22] cover #107 --- client/client.go | 152 +++++++++++++++++++++++++++++++++++--------- client/interface.go | 25 ++++++++ 2 files changed, 146 insertions(+), 31 deletions(-) diff --git a/client/client.go b/client/client.go index cc94abab0..60fe0cbfe 100644 --- a/client/client.go +++ b/client/client.go @@ -160,42 +160,77 @@ func (c *Client) Ping(ctx context.Context) error { return err } -func (c *Client) ListResources( +// ListResourcesByPage manually list resources by page. +func (c *Client) ListResourcesByPage( ctx context.Context, request mcp.ListResourcesRequest, ) (*mcp.ListResourcesResult, error) { - response, err := c.sendRequest(ctx, "resources/list", request.Params) + result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list") if err != nil { return nil, err } + return result, nil +} - var result mcp.ListResourcesResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) +func (c *Client) ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, error) { + result, err := c.ListResourcesByPage(ctx, request) + if err != nil { + return nil, err + } + for result.NextCursor != "" { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + request.Params.Cursor = result.NextCursor + newPageRes, err := c.ListResourcesByPage(ctx, request) + if err != nil { + return nil, err + } + result.Resources = append(result.Resources, newPageRes.Resources...) + result.NextCursor = newPageRes.NextCursor + } } + return result, nil +} - return &result, nil +func (c *Client) ListResourceTemplatesByPage( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, error) { + result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list") + if err != nil { + return nil, err + } + return result, nil } func (c *Client) ListResourceTemplates( ctx context.Context, request mcp.ListResourceTemplatesRequest, ) (*mcp.ListResourceTemplatesResult, error) { - response, err := c.sendRequest( - ctx, - "resources/templates/list", - request.Params, - ) + result, err := c.ListResourceTemplatesByPage(ctx, request) if err != nil { return nil, err } - - var result mcp.ListResourceTemplatesResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) + for result.NextCursor != "" { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + request.Params.Cursor = result.NextCursor + newPageRes, err := c.ListResourceTemplatesByPage(ctx, request) + if err != nil { + return nil, err + } + result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...) + result.NextCursor = newPageRes.NextCursor + } } - - return &result, nil + return result, nil } func (c *Client) ReadResource( @@ -226,21 +261,40 @@ func (c *Client) Unsubscribe( return err } -func (c *Client) ListPrompts( +func (c *Client) ListPromptsByPage( ctx context.Context, request mcp.ListPromptsRequest, ) (*mcp.ListPromptsResult, error) { - response, err := c.sendRequest(ctx, "prompts/list", request.Params) + result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list") if err != nil { return nil, err } + return result, nil +} - var result mcp.ListPromptsResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) +func (c *Client) ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, error) { + result, err := c.ListPromptsByPage(ctx, request) + if err != nil { + return nil, err } - - return &result, nil + for result.NextCursor != "" { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + request.Params.Cursor = result.NextCursor + newPageRes, err := c.ListPromptsByPage(ctx, request) + if err != nil { + return nil, err + } + result.Prompts = append(result.Prompts, newPageRes.Prompts...) + result.NextCursor = newPageRes.NextCursor + } + } + return result, nil } func (c *Client) GetPrompt( @@ -255,21 +309,40 @@ func (c *Client) GetPrompt( return mcp.ParseGetPromptResult(response) } -func (c *Client) ListTools( +func (c *Client) ListToolsByPage( ctx context.Context, request mcp.ListToolsRequest, ) (*mcp.ListToolsResult, error) { - response, err := c.sendRequest(ctx, "tools/list", request.Params) + result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list") if err != nil { return nil, err } + return result, nil +} - var result mcp.ListToolsResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) +func (c *Client) ListTools( + ctx context.Context, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, error) { + result, err := c.ListToolsByPage(ctx, request) + if err != nil { + return nil, err } - - return &result, nil + for result.NextCursor != "" { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + request.Params.Cursor = result.NextCursor + newPageRes, err := c.ListToolsByPage(ctx, request) + if err != nil { + return nil, err + } + result.Tools = append(result.Tools, newPageRes.Tools...) + result.NextCursor = newPageRes.NextCursor + } + } + return result, nil } func (c *Client) CallTool( @@ -309,6 +382,23 @@ func (c *Client) Complete( return &result, nil } +func listByPage[T any]( + ctx context.Context, + client *Client, + request mcp.PaginatedRequest, + method string, +) (*T, error) { + response, err := client.sendRequest(ctx, method, request.Params) + if err != nil { + return nil, err + } + var result T + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + return &result, nil +} + // Helper methods // GetTransport gives access to the underlying transport layer. diff --git a/client/interface.go b/client/interface.go index 1d3cb1051..ea7f4d1fb 100644 --- a/client/interface.go +++ b/client/interface.go @@ -18,12 +18,25 @@ type MCPClient interface { // Ping checks if the server is alive Ping(ctx context.Context) error + // ListResourcesByPage manually list resources by page. + ListResourcesByPage( + ctx context.Context, + request mcp.ListResourcesRequest, + ) (*mcp.ListResourcesResult, error) + // ListResources requests a list of available resources from the server ListResources( ctx context.Context, request mcp.ListResourcesRequest, ) (*mcp.ListResourcesResult, error) + // ListResourceTemplatesByPage manually list resource templates by page. + ListResourceTemplatesByPage( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, + ) (*mcp.ListResourceTemplatesResult, + error) + // ListResourceTemplates requests a list of available resource templates from the server ListResourceTemplates( ctx context.Context, @@ -43,6 +56,12 @@ type MCPClient interface { // Unsubscribe cancels notifications for a specific resource Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error + // ListPromptsByPage manually list prompts by page. + ListPromptsByPage( + ctx context.Context, + request mcp.ListPromptsRequest, + ) (*mcp.ListPromptsResult, error) + // ListPrompts requests a list of available prompts from the server ListPrompts( ctx context.Context, @@ -55,6 +74,12 @@ type MCPClient interface { request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, error) + // ListToolsByPage manually list tools by page. + ListToolsByPage( + ctx context.Context, + request mcp.ListToolsRequest, + ) (*mcp.ListToolsResult, error) + // ListTools requests a list of available tools from the server ListTools( ctx context.Context, From 009895a8b08db6aaa7fa694faac7c7e4119a97e4 Mon Sep 17 00:00:00 2001 From: leavez Date: Sat, 12 Apr 2025 11:39:26 +0800 Subject: [PATCH 22/22] fix demo sse server in race test --- client/transport/sse_test.go | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/client/transport/sse_test.go b/client/transport/sse_test.go index 02879fdcb..0c4dff6a2 100644 --- a/client/transport/sse_test.go +++ b/client/transport/sse_test.go @@ -21,24 +21,35 @@ import ( func startMockSSEEchoServer() (string, func()) { // Create handler for SSE endpoint var sseWriter http.ResponseWriter - var flush http.Flusher + var flush func() var mu sync.Mutex sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Setup SSE headers - defer fmt.Printf("SSEHandler ends: %v\n", r.Context().Err()) + defer func() { + mu.Lock() // for passing race test + sseWriter = nil + flush = nil + mu.Unlock() + fmt.Printf("SSEHandler ends: %v\n", r.Context().Err()) + }() + w.Header().Set("Content-Type", "text/event-stream") flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "Streaming unsupported", http.StatusInternalServerError) return } + mu.Lock() sseWriter = w - flush = flusher + flush = flusher.Flush mu.Unlock() + // Send initial endpoint event with message endpoint URL + mu.Lock() fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", "/message") flusher.Flush() + mu.Unlock() // Keep connection open <-r.Context().Done() @@ -81,7 +92,7 @@ func startMockSSEEchoServer() (string, func()) { }) mu.Lock() fmt.Fprintf(sseWriter, "event: message\ndata: %s\n\n", responseBytes) - flush.Flush() + flush() mu.Unlock() case "debug/echo_error_string": data, _ := json.Marshal(request) @@ -101,7 +112,7 @@ func startMockSSEEchoServer() (string, func()) { defer mu.Unlock() if sseWriter != nil && flush != nil { fmt.Fprintf(sseWriter, "event: message\ndata: %s\n\n", data) - flush.Flush() + flush() } }()