diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index 477143f0..d6b457b2 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -104,6 +104,28 @@ func createJSONRPCRequest(requestID uint64, method string, params interface{}) m } } +// ensureToolCallArguments ensures that the arguments field exists in tools/call params +// The MCP protocol requires the arguments field to always be present, even if empty +func ensureToolCallArguments(params interface{}) interface{} { + // Convert params to map if it isn't already + paramsMap, ok := params.(map[string]interface{}) + if !ok { + // If params isn't a map, return as-is (this shouldn't happen for tools/call) + return params + } + + // Check if arguments field exists + if _, hasArgs := paramsMap["arguments"]; !hasArgs { + // Add empty arguments map if missing + paramsMap["arguments"] = make(map[string]interface{}) + } else if paramsMap["arguments"] == nil { + // Replace nil with empty map + paramsMap["arguments"] = make(map[string]interface{}) + } + + return paramsMap +} + // setupHTTPRequest creates and configures an HTTP request with standard headers func setupHTTPRequest(ctx context.Context, url string, requestBody []byte, headers map[string]string) (*http.Request, error) { httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(requestBody)) @@ -576,6 +598,11 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params // Generate unique request ID using atomic counter requestID := atomic.AddUint64(&requestIDCounter, 1) + // For tools/call, ensure arguments field always exists (MCP protocol requirement) + if method == "tools/call" { + params = ensureToolCallArguments(params) + } + // Create JSON-RPC request request := createJSONRPCRequest(requestID, method, params) @@ -726,11 +753,24 @@ func (c *Connection) callTool(params interface{}) (*Response, error) { return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") } var callParams CallToolParams - paramsJSON, _ := json.Marshal(params) + paramsJSON, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } + logConn.Printf("callTool: marshaled params=%s", string(paramsJSON)) + if err := json.Unmarshal(paramsJSON, &callParams); err != nil { return nil, fmt.Errorf("invalid params: %w", err) } + // Ensure arguments is never nil - default to empty map + // This is required by the MCP protocol which expects arguments to always be present + if callParams.Arguments == nil { + callParams.Arguments = make(map[string]interface{}) + } + + logConn.Printf("callTool: parsed name=%s, arguments=%+v", callParams.Name, callParams.Arguments) + result, err := c.session.CallTool(c.ctx, &sdk.CallToolParams{ Name: callParams.Name, Arguments: callParams.Arguments, diff --git a/internal/mcp/connection_arguments_test.go b/internal/mcp/connection_arguments_test.go new file mode 100644 index 00000000..c5bc01f8 --- /dev/null +++ b/internal/mcp/connection_arguments_test.go @@ -0,0 +1,282 @@ +package mcp + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestCallTool_ArgumentsPassed tests that tool arguments are correctly passed to the backend +func TestCallTool_ArgumentsPassed(t *testing.T) { + tests := []struct { + name string + inputParams map[string]interface{} + expectedArguments map[string]interface{} + }{ + { + name: "simple string argument", + inputParams: map[string]interface{}{ + "name": "test_tool", + "arguments": map[string]interface{}{ + "query": "test query", + }, + }, + expectedArguments: map[string]interface{}{ + "query": "test query", + }, + }, + { + name: "multiple arguments", + inputParams: map[string]interface{}{ + "name": "list_issues", + "arguments": map[string]interface{}{ + "owner": "githubnext", + "repo": "gh-aw-mcpg", + "state": "open", + }, + }, + expectedArguments: map[string]interface{}{ + "owner": "githubnext", + "repo": "gh-aw-mcpg", + "state": "open", + }, + }, + { + name: "nested object arguments", + inputParams: map[string]interface{}{ + "name": "complex_tool", + "arguments": map[string]interface{}{ + "config": map[string]interface{}{ + "timeout": 30, + "retry": true, + }, + "filters": []string{"tag1", "tag2"}, + }, + }, + expectedArguments: map[string]interface{}{ + "config": map[string]interface{}{ + "timeout": float64(30), // JSON numbers are float64 + "retry": true, + }, + "filters": []interface{}{"tag1", "tag2"}, + }, + }, + { + name: "empty arguments", + inputParams: map[string]interface{}{ + "name": "no_args_tool", + "arguments": map[string]interface{}{}, + }, + expectedArguments: map[string]interface{}{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Track what arguments the backend received + var receivedArguments map[string]interface{} + + // Create a mock backend server + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Read the request body + bodyBytes, err := io.ReadAll(r.Body) + require.NoError(t, err, "Failed to read request body") + + var request map[string]interface{} + err = json.Unmarshal(bodyBytes, &request) + require.NoError(t, err, "Failed to parse request JSON") + + method, _ := request["method"].(string) + + if method == "initialize" { + // Return success for initialize + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Mcp-Session-Id", "test-session-123") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "serverInfo": map[string]interface{}{ + "name": "test-server", + "version": "1.0.0", + }, + }, + }) + return + } + + if method == "tools/call" { + // Extract the arguments from the params + params, ok := request["params"].(map[string]interface{}) + require.True(t, ok, "params should be a map") + + // Store the arguments we received + if args, ok := params["arguments"].(map[string]interface{}); ok { + receivedArguments = args + } else { + receivedArguments = map[string]interface{}{} + } + + // Return success + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": map[string]interface{}{ + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "Success", + }, + }, + }, + }) + return + } + + // Unknown method + http.Error(w, "Unknown method", http.StatusBadRequest) + })) + defer testServer.Close() + + // Create connection + conn, err := NewHTTPConnection(context.Background(), testServer.URL, map[string]string{ + "Authorization": "test-token", + }) + require.NoError(t, err, "Failed to create HTTP connection") + defer conn.Close() + + // Send the tool call request + _, err = conn.SendRequestWithServerID(context.Background(), "tools/call", tt.inputParams, "test-server") + require.NoError(t, err, "Tool call should succeed") + + // Verify the arguments were passed correctly + assert.Equal(t, tt.expectedArguments, receivedArguments, "Arguments should match expected values") + }) + } +} + +// TestCallTool_MissingArguments tests behavior when arguments field is missing +func TestCallTool_MissingArguments(t *testing.T) { + var receivedParams map[string]interface{} + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bodyBytes, _ := io.ReadAll(r.Body) + var request map[string]interface{} + json.Unmarshal(bodyBytes, &request) + + method, _ := request["method"].(string) + + if method == "initialize" { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Mcp-Session-Id", "test-session-123") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "serverInfo": map[string]interface{}{ + "name": "test-server", + "version": "1.0.0", + }, + }, + }) + return + } + + if method == "tools/call" { + params, _ := request["params"].(map[string]interface{}) + receivedParams = params + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": map[string]interface{}{ + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "Success", + }, + }, + }, + }) + } + })) + defer testServer.Close() + + conn, err := NewHTTPConnection(context.Background(), testServer.URL, map[string]string{ + "Authorization": "test-token", + }) + require.NoError(t, err) + defer conn.Close() + + // Test 1: Send request with arguments field omitted entirely + t.Run("omitted arguments field", func(t *testing.T) { + receivedParams = nil + params := map[string]interface{}{ + "name": "test_tool", + // No "arguments" field + } + + _, err = conn.SendRequestWithServerID(context.Background(), "tools/call", params, "test-server") + require.NoError(t, err) + + // Verify arguments field exists in the request sent to backend + assert.NotNil(t, receivedParams, "Params should be sent to backend") + assert.Contains(t, receivedParams, "name", "Name should be present") + + // The arguments field should be present (even if empty) + // This is the key: the MCP spec requires arguments to be present + _, hasArguments := receivedParams["arguments"] + assert.True(t, hasArguments, "Arguments field should be present in backend request even if not provided by client") + }) + + // Test 2: Send request with explicit null arguments + t.Run("null arguments", func(t *testing.T) { + receivedParams = nil + params := map[string]interface{}{ + "name": "test_tool", + "arguments": nil, + } + + _, err = conn.SendRequestWithServerID(context.Background(), "tools/call", params, "test-server") + require.NoError(t, err) + + assert.NotNil(t, receivedParams, "Params should be sent to backend") + + // Arguments should be present, even if nil/empty + _, hasArguments := receivedParams["arguments"] + assert.True(t, hasArguments, "Arguments field should be present even if nil") + }) + + // Test 3: Send request with empty arguments object + t.Run("empty arguments object", func(t *testing.T) { + receivedParams = nil + params := map[string]interface{}{ + "name": "test_tool", + "arguments": map[string]interface{}{}, + } + + _, err = conn.SendRequestWithServerID(context.Background(), "tools/call", params, "test-server") + require.NoError(t, err) + + assert.NotNil(t, receivedParams, "Params should be sent to backend") + + // Arguments should be present as an object + arguments, hasArguments := receivedParams["arguments"] + assert.True(t, hasArguments, "Arguments field should be present") + + // It should be an empty map + if argsMap, ok := arguments.(map[string]interface{}); ok { + assert.Empty(t, argsMap, "Arguments should be an empty map") + } + }) +} diff --git a/internal/server/tool_call_arguments_test.go b/internal/server/tool_call_arguments_test.go new file mode 100644 index 00000000..e42b287d --- /dev/null +++ b/internal/server/tool_call_arguments_test.go @@ -0,0 +1,181 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/githubnext/gh-aw-mcpg/internal/config" +) + +// TestUnifiedServer_ToolCallArguments tests that tool arguments are correctly passed through the gateway +func TestUnifiedServer_ToolCallArguments(t *testing.T) { + // Track what the mock backend received + var receivedToolCalls []map[string]interface{} + var mu sync.Mutex + + // Create a mock MCP backend server + mockBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + t.Logf("Failed to read request body: %v", err) + http.Error(w, "Internal error", http.StatusInternalServerError) + return + } + + var request map[string]interface{} + if err := json.Unmarshal(bodyBytes, &request); err != nil { + t.Logf("Failed to unmarshal request: %v", err) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + method, _ := request["method"].(string) + requestID := request["id"] + + t.Logf("Backend received: method=%s, id=%v", method, requestID) + + if method == "initialize" { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Mcp-Session-Id", "backend-session-123") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": requestID, + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "test-backend", + "version": "1.0.0", + }, + }, + }) + return + } + + if method == "tools/list" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": requestID, + "result": map[string]interface{}{ + "tools": []map[string]interface{}{ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "param1": map[string]interface{}{"type": "string"}, + "param2": map[string]interface{}{"type": "number"}, + }, + }, + }, + }, + }, + }) + return + } + + if method == "tools/call" { + params, _ := request["params"].(map[string]interface{}) + + // Log the entire params structure + mu.Lock() + receivedToolCalls = append(receivedToolCalls, params) + mu.Unlock() + + paramsJSON, _ := json.Marshal(params) + t.Logf("Backend received tools/call params: %s", string(paramsJSON)) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": requestID, + "result": map[string]interface{}{ + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "Success", + }, + }, + }, + }) + return + } + + http.Error(w, fmt.Sprintf("Unknown method: %s", method), http.StatusBadRequest) + })) + defer mockBackend.Close() + + // Create gateway configuration with the mock backend + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{ + "testserver": { + Type: "http", + URL: mockBackend.URL, + Headers: map[string]string{ + "Authorization": "test-auth", + }, + }, + }, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + require.NoError(t, err, "Failed to create unified server") + defer us.Close() + + // Simulate a tool call with arguments + testArgs := map[string]interface{}{ + "param1": "test_value", + "param2": float64(42), + "param3": map[string]interface{}{ + "nested": "value", + }, + } + + // Call the tool through the unified server + var callErr error + _, _, callErr = us.callBackendTool(ctx, "testserver", "test_tool", testArgs) + + // Verify the backend received the tool call first (this is the critical test) + mu.Lock() + defer mu.Unlock() + + require.GreaterOrEqual(t, len(receivedToolCalls), 1, "Backend should have received at least one tool call") + + // Check the most recent tool call + lastCall := receivedToolCalls[len(receivedToolCalls)-1] + t.Logf("Last tool call received by backend: %+v", lastCall) + + // Verify the tool name was passed + assert.Equal(t, "test_tool", lastCall["name"], "Tool name should match") + + // Verify the arguments were passed and not empty + arguments, ok := lastCall["arguments"].(map[string]interface{}) + require.True(t, ok, "Arguments should be a map") + require.NotEmpty(t, arguments, "Arguments should not be empty") + + // Verify specific argument values + assert.Equal(t, "test_value", arguments["param1"], "param1 should match") + assert.Equal(t, float64(42), arguments["param2"], "param2 should match") + + nestedMap, ok := arguments["param3"].(map[string]interface{}) + require.True(t, ok, "param3 should be a nested map") + assert.Equal(t, "value", nestedMap["nested"], "Nested value should match") + + // Now check the result + if callErr != nil { + t.Logf("Error calling tool: %v", callErr) + } + // Note: We don't require no error since we've already verified arguments were passed correctly +} diff --git a/internal/server/unified.go b/internal/server/unified.go index 1cc8eea8..6e65833c 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -233,9 +233,21 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error { // Create the handler function handler := func(ctx context.Context, req *sdk.CallToolRequest, args interface{}) (*sdk.CallToolResult, interface{}, error) { + // Extract arguments from the request params (not the args parameter which is SDK internal state) + var toolArgs map[string]interface{} + if req.Params.Arguments != nil { + if err := json.Unmarshal(req.Params.Arguments, &toolArgs); err != nil { + logger.LogError("client", "Failed to unmarshal tool arguments, tool=%s, error=%v", toolNameCopy, err) + return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to parse arguments: %w", err) + } + } else { + // No arguments provided, use empty map + toolArgs = make(map[string]interface{}) + } + // Log the MCP tool call request sessionID := us.getSessionID(ctx) - argsJSON, _ := json.Marshal(args) + argsJSON, _ := json.Marshal(toolArgs) logger.LogInfo("client", "MCP tool call request, session=%s, tool=%s, args=%s", sessionID, toolNameCopy, string(argsJSON)) // Check session is initialized @@ -244,7 +256,7 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error { return &sdk.CallToolResult{IsError: true}, nil, err } - result, data, err := us.callBackendTool(ctx, serverIDCopy, toolNameCopy, args) + result, data, err := us.callBackendTool(ctx, serverIDCopy, toolNameCopy, toolArgs) // Log the MCP tool call response if err != nil { @@ -282,12 +294,21 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error { func (us *UnifiedServer) registerSysTools() error { // Create sys_init handler sysInitHandler := func(ctx context.Context, req *sdk.CallToolRequest, args interface{}) (*sdk.CallToolResult, interface{}, error) { + // Extract arguments from the request params + var toolArgs map[string]interface{} + if req.Params.Arguments != nil { + if err := json.Unmarshal(req.Params.Arguments, &toolArgs); err != nil { + logger.LogError("client", "Failed to unmarshal sys_init arguments, error=%v", err) + return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to parse arguments: %w", err) + } + } else { + toolArgs = make(map[string]interface{}) + } + // Extract token from args token := "" - if argsMap, ok := args.(map[string]interface{}); ok { - if t, ok := argsMap["token"].(string); ok { - token = t - } + if t, ok := toolArgs["token"].(string); ok { + token = t } // Get session ID from context