Skip to content
70 changes: 50 additions & 20 deletions client/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,34 @@ import (
"os/exec"
"sync"
"sync/atomic"
"time"

"github.com/mark3labs/mcp-go/mcp"
)

const (
readyTimeout = 5 * time.Second
readyCheckTimeout = 1 * time.Second
)

// StdioMCPClient implements the MCPClient interface using stdio communication.
// It launches a subprocess and communicates with it via standard input/output streams
// using JSON-RPC messages. The client handles message routing between requests and
// responses, and supports asynchronous notifications.
type StdioMCPClient struct {
cmd *exec.Cmd
stdin io.WriteCloser
stdout *bufio.Reader
stderr io.ReadCloser
requestID atomic.Int64
responses map[int64]chan RPCResponse
mu sync.RWMutex
done chan struct{}
initialized bool
notifications []func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
capabilities mcp.ServerCapabilities
cmd *exec.Cmd
stdin io.WriteCloser
stdout *bufio.Reader
stderr io.ReadCloser
requestID atomic.Int64
responses map[int64]chan RPCResponse
mu sync.RWMutex
done chan struct{}
initialized bool
notifications []func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
capabilities mcp.ServerCapabilities
processExitErr chan error
}

// NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess.
Expand Down Expand Up @@ -65,29 +72,52 @@ func NewStdioMCPClient(
}

client := &StdioMCPClient{
cmd: cmd,
stdin: stdin,
stderr: stderr,
stdout: bufio.NewReader(stdout),
responses: make(map[int64]chan RPCResponse),
done: make(chan struct{}),
cmd: cmd,
stdin: stdin,
stderr: stderr,
stdout: bufio.NewReader(stdout),
responses: make(map[int64]chan RPCResponse),
done: make(chan struct{}),
processExitErr: make(chan error, 1),
}

if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start command: %w", err)
}

go func() {
client.processExitErr <- cmd.Wait()
}()

// Start reading responses in a goroutine and wait for it to be ready
ready := make(chan struct{})
go func() {
close(ready)
client.readResponses()
}()
<-ready

if err := waitUntilReadyOrExit(ready, client.processExitErr, readyTimeout); err != nil {
return nil, err
}
return client, nil
}

func waitUntilReadyOrExit(ready <-chan struct{}, waitErr <-chan error, timeout time.Duration) error {
select {
case err := <-waitErr:
return fmt.Errorf("process exited early: %w", err)
case <-ready:
select {
case err := <-waitErr:
return fmt.Errorf("process exited after ready: %w", err)
case <-time.After(readyCheckTimeout):
return nil
}
case <-time.After(timeout):
return errors.New("timeout waiting for process ready")
}
}

// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
func (c *StdioMCPClient) Close() error {
Expand All @@ -98,7 +128,7 @@ func (c *StdioMCPClient) Close() error {
if err := c.stderr.Close(); err != nil {
return fmt.Errorf("failed to close stderr: %w", err)
}
return c.cmd.Wait()
return <-c.processExitErr
}

// Stderr returns a reader for the stderr output of the subprocess.
Expand Down