Skip to content
4 changes: 2 additions & 2 deletions client/stdio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ func TestStdioMCPClient(t *testing.T) {
}
tempFile.Close()
mockServerPath := tempFile.Name()

// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
os.Remove(mockServerPath) // Remove the empty file first
mockServerPath += ".exe"
}

if compileErr := compileTestServer(mockServerPath); compileErr != nil {
t.Fatalf("Failed to compile mock server: %v", compileErr)
}
Expand Down
1 change: 0 additions & 1 deletion client/transport/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,5 +501,4 @@ func TestSSEErrors(t *testing.T) {
t.Errorf("Expected error when sending request after close, got nil")
}
})

}
76 changes: 51 additions & 25 deletions client/transport/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@ import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"sync"
"sync/atomic"
"time"

"github.com/mark3labs/mcp-go/mcp"
)

const (
readyTimeout = 5 * time.Second
readyCheckTimeout = 1 * time.Second
)

// Stdio implements the transport layer of the MCP protocol using stdio communication.
// It launches a subprocess and communicates with it via standard input/output streams
// using JSON-RPC messages. The client handles message routing between requests and
Expand All @@ -31,6 +39,8 @@ type Stdio struct {
done chan struct{}
onNotification func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
processExited chan struct{}
exitErr atomic.Value
}

// NewIO returns a new stdio-based transport using existing input, output, and
Expand Down Expand Up @@ -61,8 +71,9 @@ func NewStdio(
args: args,
env: env,

responses: make(map[int64]chan *JSONRPCResponse),
done: make(chan struct{}),
responses: make(map[int64]chan *JSONRPCResponse),
done: make(chan struct{}),
processExited: make(chan struct{}),
}

return client
Expand All @@ -72,14 +83,6 @@ func (c *Stdio) Start(ctx context.Context) error {
if err := c.spawnCommand(ctx); err != nil {
return err
}

ready := make(chan struct{})
go func() {
close(ready)
c.readResponses()
}()
<-ready

return nil
}

Expand All @@ -95,7 +98,6 @@ func (c *Stdio) spawnCommand(ctx context.Context) error {
mergedEnv = append(mergedEnv, c.env...)

cmd.Env = mergedEnv

stdin, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("failed to create stdin pipe: %w", err)
Expand All @@ -119,37 +121,62 @@ func (c *Stdio) spawnCommand(ctx context.Context) error {
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start command: %w", err)
}
go func() {
err := cmd.Wait()
if err != nil {
c.exitErr.Store(err)
}
close(c.processExited)
}()

// Start reading responses in a goroutine and wait for it to be ready
ready := make(chan struct{})
go func() {
close(ready)
c.readResponses()
}()

if err := waitUntilReadyOrExit(ready, c.processExited, readyTimeout); err != nil {
return err
}
return nil
}

func waitUntilReadyOrExit(ready <-chan struct{}, exited <-chan struct{}, timeout time.Duration) error {
select {
case <-exited:
return errors.New("process exited before signalling readiness")
case <-ready:
select {
case <-exited:
return errors.New("process exited after readiness")
case <-time.After(readyCheckTimeout):
return nil
}
case <-time.After(timeout):
return errors.New("timeout waiting for process ready")
}
}

// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
func (c *Stdio) Close() error {
select {
case <-c.done:
return nil
default:
}
// cancel all in-flight request
close(c.done)

if err := c.stdin.Close(); err != nil {
return fmt.Errorf("failed to close stdin: %w", err)
}
if err := c.stderr.Close(); err != nil {
return fmt.Errorf("failed to close stderr: %w", err)
}

if c.cmd != nil {
return c.cmd.Wait()
<-c.processExited
if err, ok := c.exitErr.Load().(error); ok && err != nil {
return err
}

return nil
}

// SetNotificationHandler sets the handler function to be called when a notification is received.
// Only one handler can be set at a time; setting a new one replaces the previous handler.
// OnNotification registers a handler function to be called when notifications are received.
// Multiple handlers can be registered and will be called in the order they were added.
func (c *Stdio) SetNotificationHandler(
handler func(notification mcp.JSONRPCNotification),
) {
Expand Down Expand Up @@ -243,7 +270,6 @@ func (c *Stdio) SendRequest(
deleteResponseChan()
return nil, fmt.Errorf("failed to write request: %w", err)
}

select {
case <-ctx.Done():
deleteResponseChan()
Expand Down
41 changes: 35 additions & 6 deletions client/transport/stdio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ func TestStdio(t *testing.T) {
}
tempFile.Close()
mockServerPath := tempFile.Name()

// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
os.Remove(mockServerPath) // Remove the empty file first
mockServerPath += ".exe"
}

if compileErr := compileTestServer(mockServerPath); compileErr != nil {
t.Fatalf("Failed to compile mock server: %v", compileErr)
}
Expand Down Expand Up @@ -329,13 +329,13 @@ func TestStdioErrors(t *testing.T) {
}
tempFile.Close()
mockServerPath := tempFile.Name()

// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
os.Remove(mockServerPath) // Remove the empty file first
mockServerPath += ".exe"
}

if compileErr := compileTestServer(mockServerPath); compileErr != nil {
t.Fatalf("Failed to compile mock server: %v", compileErr)
}
Expand Down Expand Up @@ -368,13 +368,13 @@ func TestStdioErrors(t *testing.T) {
}
tempFile.Close()
mockServerPath := tempFile.Name()

// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
os.Remove(mockServerPath) // Remove the empty file first
mockServerPath += ".exe"
}

if compileErr := compileTestServer(mockServerPath); compileErr != nil {
t.Fatalf("Failed to compile mock server: %v", compileErr)
}
Expand Down Expand Up @@ -407,5 +407,34 @@ func TestStdioErrors(t *testing.T) {
t.Errorf("Expected error when sending request after close, got nil")
}
})
t.Run("SubprocessStartsAndExitsImmediately", func(t *testing.T) {
// Create a temporary file for the mock server
tempFile, err := os.CreateTemp("", "mockstdio_server")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
tempFile.Close()
mockServerPath := tempFile.Name()

// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
os.Remove(mockServerPath) // Remove the empty file first
mockServerPath += ".exe"
}

if compileErr := compileTestServer(mockServerPath); compileErr != nil {
t.Fatalf("Failed to compile mock server: %v", compileErr)
}
//defer os.Remove(mockServerPath)

// Create a new Stdio transport
stdio := NewStdio(mockServerPath, nil)
stdio.env = append(stdio.env, "MOCK_FAIL_IMMEDIATELY=1")
defer stdio.Close()
// Start the transport
ctx := context.Background()
if startErr := stdio.Start(ctx); startErr == nil {
t.Fatalf("Expected error when starting Stdio transport, got nil")
}
})
}
4 changes: 0 additions & 4 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,6 @@ func (s *SSEServer) MessageHandler() http.Handler {

// ServeHTTP implements the http.Handler interface.
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.dynamicBasePathFunc != nil {
http.Error(w, (&ErrDynamicPathConfig{Method: "ServeHTTP"}).Error(), http.StatusInternalServerError)
return
}
path := r.URL.Path
// Use exact path matching rather than Contains
ssePath := s.CompleteSsePath()
Expand Down
4 changes: 4 additions & 0 deletions testdata/mockstdio_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ type JSONRPCResponse struct {
}

func main() {
if os.Getenv("MOCK_FAIL_IMMEDIATELY") == "1" {
fmt.Fprintln(os.Stderr, "mock server: simulated startup failure")
os.Exit(1)
}
logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{}))
logger.Info("launch successful")
scanner := bufio.NewScanner(os.Stdin)
Expand Down
Loading