diff --git a/Makefile b/Makefile index 7f50cc35..4d774714 100644 --- a/Makefile +++ b/Makefile @@ -53,15 +53,15 @@ test-python-e2e: ## Run Python E2E tests (requires docker-compose services and S .PHONY: run run: ## Run the MCP server in stdio mode. - go run ./cmd/mcp-grafana + GRAFANA_USERNAME=admin GRAFANA_PASSWORD=admin go run ./cmd/mcp-grafana .PHONY: run-sse run-sse: ## Run the MCP server in SSE mode. - go run ./cmd/mcp-grafana --transport sse --log-level debug --debug + GRAFANA_USERNAME=admin GRAFANA_PASSWORD=admin go run ./cmd/mcp-grafana --transport sse --log-level debug --debug PHONY: run-streamable-http run-streamable-http: ## Run the MCP server in StreamableHTTP mode. - go run ./cmd/mcp-grafana --transport streamable-http --log-level debug --debug + GRAFANA_USERNAME=admin GRAFANA_PASSWORD=admin go run ./cmd/mcp-grafana --transport streamable-http --log-level debug --debug .PHONY: run-test-services run-test-services: ## Run the docker-compose services required for the unit and integration tests. diff --git a/cmd/mcp-grafana/main.go b/cmd/mcp-grafana/main.go index 51a09c61..0a3deff5 100644 --- a/cmd/mcp-grafana/main.go +++ b/cmd/mcp-grafana/main.go @@ -14,6 +14,7 @@ import ( "syscall" "time" + "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" mcpgrafana "github.com/grafana/mcp-grafana" @@ -40,7 +41,7 @@ type disabledTools struct { search, datasource, incident, prometheus, loki, alerting, dashboard, folder, oncall, asserts, sift, admin, - pyroscope, navigation bool + pyroscope, navigation, proxied bool } // Configuration for the Grafana client. @@ -56,8 +57,7 @@ type grafanaConfig struct { } func (dt *disabledTools) addFlags() { - flag.StringVar(&dt.enabledTools, "enabled-tools", "search,datasource,incident,prometheus,loki,alerting,dashboard,folder,oncall,asserts,sift,admin,pyroscope,navigation", "A comma separated list of tools enabled for this server. Can be overwritten entirely or by disabling specific components, e.g. --disable-search.") - + flag.StringVar(&dt.enabledTools, "enabled-tools", "search,datasource,incident,prometheus,loki,alerting,dashboard,folder,oncall,asserts,sift,admin,pyroscope,navigation,proxied", "A comma separated list of tools enabled for this server. Can be overwritten entirely or by disabling specific components, e.g. --disable-search.") flag.BoolVar(&dt.search, "disable-search", false, "Disable search tools") flag.BoolVar(&dt.datasource, "disable-datasource", false, "Disable datasource tools") flag.BoolVar(&dt.incident, "disable-incident", false, "Disable incident tools") @@ -72,6 +72,7 @@ func (dt *disabledTools) addFlags() { flag.BoolVar(&dt.admin, "disable-admin", false, "Disable admin tools") flag.BoolVar(&dt.pyroscope, "disable-pyroscope", false, "Disable pyroscope tools") flag.BoolVar(&dt.navigation, "disable-navigation", false, "Disable navigation tools") + flag.BoolVar(&dt.proxied, "disable-proxied", false, "Disable proxied tools (tools from external MCP servers)") } func (gc *grafanaConfig) addFlags() { @@ -102,24 +103,71 @@ func (dt *disabledTools) addTools(s *server.MCPServer) { maybeAddTools(s, tools.AddNavigationTools, enabledTools, dt.navigation, "navigation") } -func newServer(dt disabledTools) *server.MCPServer { - s := server.NewMCPServer("mcp-grafana", mcpgrafana.Version(), server.WithInstructions(` - This server provides access to your Grafana instance and the surrounding ecosystem. - - Available Capabilities: - - Dashboards: Search, retrieve, update, and create dashboards. Extract panel queries and datasource information. - - Datasources: List and fetch details for datasources. - - Prometheus & Loki: Run PromQL and LogQL queries, retrieve metric/log metadata, and explore label names/values. - - Incidents: Search, create, update, and resolve incidents in Grafana Incident. - - Sift Investigations: Start and manage Sift investigations, analyze logs/traces, find error patterns, and detect slow requests. - - Alerting: List and fetch alert rules and notification contact points. - - OnCall: View and manage on-call schedules, shifts, teams, and users. - - Admin: List teams and perform administrative tasks. - - Pyroscope: Profile applications and fetch profiling data. - - Navigation: Generate deeplink URLs for Grafana resources like dashboards, panels, and Explore queries. - `)) +func newServer(transport string, dt disabledTools) (*server.MCPServer, *mcpgrafana.ToolManager) { + sm := mcpgrafana.NewSessionManager() + + // Declare variable for ToolManager that will be initialized after server creation + var stm *mcpgrafana.ToolManager + + // Create hooks + hooks := &server.Hooks{ + OnRegisterSession: []server.OnRegisterSessionHookFunc{sm.CreateSession}, + OnUnregisterSession: []server.OnUnregisterSessionHookFunc{sm.RemoveSession}, + } + + // Add proxied tools hooks if enabled and we're not running in stdio mode. + // (stdio mode is handled by InitializeAndRegisterServerTools; per-session tools + // are not supported). + if transport != "stdio" && !dt.proxied { + // OnBeforeListTools: Discover, connect, and register tools + hooks.OnBeforeListTools = []server.OnBeforeListToolsFunc{ + func(ctx context.Context, id any, request *mcp.ListToolsRequest) { + if stm != nil { + if session := server.ClientSessionFromContext(ctx); session != nil { + stm.InitializeAndRegisterProxiedTools(ctx, session) + } + } + }, + } + + // OnBeforeCallTool: Fallback in case client calls tool without listing first + hooks.OnBeforeCallTool = []server.OnBeforeCallToolFunc{ + func(ctx context.Context, id any, request *mcp.CallToolRequest) { + if stm != nil { + if session := server.ClientSessionFromContext(ctx); session != nil { + stm.InitializeAndRegisterProxiedTools(ctx, session) + } + } + }, + } + } + s := server.NewMCPServer("mcp-grafana", mcpgrafana.Version(), + server.WithInstructions(` +This server provides access to your Grafana instance and the surrounding ecosystem. + +Available Capabilities: +- Dashboards: Search, retrieve, update, and create dashboards. Extract panel queries and datasource information. +- Datasources: List and fetch details for datasources. +- Prometheus & Loki: Run PromQL and LogQL queries, retrieve metric/log metadata, and explore label names/values. +- Incidents: Search, create, update, and resolve incidents in Grafana Incident. +- Sift Investigations: Start and manage Sift investigations, analyze logs/traces, find error patterns, and detect slow requests. +- Alerting: List and fetch alert rules and notification contact points. +- OnCall: View and manage on-call schedules, shifts, teams, and users. +- Admin: List teams and perform administrative tasks. +- Pyroscope: Profile applications and fetch profiling data. +- Navigation: Generate deeplink URLs for Grafana resources like dashboards, panels, and Explore queries. +- Proxied Tools: Access tools from external MCP servers (like Tempo) through dynamic discovery. + +Note that some of these capabilities may be disabled. Do not try to use features that are not available via tools. +`), + server.WithHooks(hooks), + ) + + // Initialize ToolManager now that server is created + stm = mcpgrafana.NewToolManager(sm, s, mcpgrafana.WithProxiedTools(!dt.proxied)) + dt.addTools(s) - return s + return s, stm } type tlsConfig struct { @@ -162,6 +210,7 @@ func runHTTPServer(ctx context.Context, srv httpServer, addr, transportName stri if err := srv.Shutdown(shutdownCtx); err != nil { return fmt.Errorf("shutdown error: %v", err) } + slog.Debug("Shutdown called, waiting for connections to close...") // Wait for server to finish select { @@ -180,7 +229,7 @@ func runHTTPServer(ctx context.Context, srv httpServer, addr, transportName stri func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt disabledTools, gc mcpgrafana.GrafanaConfig, tls tlsConfig) error { slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel}))) - s := newServer(dt) + s, tm := newServer(transport, dt) // Create a context that will be cancelled on shutdown ctx, cancel := context.WithCancel(context.Background()) @@ -207,7 +256,17 @@ func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt switch transport { case "stdio": srv := server.NewStdioServer(s) - srv.SetContextFunc(mcpgrafana.ComposedStdioContextFunc(gc)) + cf := mcpgrafana.ComposedStdioContextFunc(gc) + srv.SetContextFunc(cf) + + // For stdio (single-tenant), initialize proxied tools on the server directly + if !dt.proxied { + stdioCtx := cf(ctx) + if err := tm.InitializeAndRegisterServerTools(stdioCtx); err != nil { + slog.Error("failed to initialize proxied tools for stdio", "error", err) + } + } + slog.Info("Starting Grafana MCP server using stdio transport", "version", mcpgrafana.Version()) err := srv.Listen(ctx, os.Stdin, os.Stdout) @@ -227,7 +286,7 @@ func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt case "streamable-http": opts := []server.StreamableHTTPOption{ server.WithHTTPContextFunc(mcpgrafana.ComposedHTTPContextFunc(gc)), - server.WithStateLess(true), + server.WithStateLess(dt.proxied), // Stateful when proxied tools enabled (requires sessions) server.WithEndpointPath(endpointPath), } if tls.certFile != "" || tls.keyFile != "" { @@ -238,10 +297,7 @@ func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt "version", mcpgrafana.Version(), "address", addr, "endpointPath", endpointPath) return runHTTPServer(ctx, srv, addr, "StreamableHTTP") default: - return fmt.Errorf( - "invalid transport type: %s. Must be 'stdio', 'sse' or 'streamable-http'", - transport, - ) + return fmt.Errorf("invalid transport type: %s. Must be 'stdio', 'sse' or 'streamable-http'", transport) } } diff --git a/docker-compose.yaml b/docker-compose.yaml index 9fa8b79d..4965e7c1 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -43,3 +43,19 @@ services: image: grafana/pyroscope:1.13.4 ports: - 4040:4040 + + tempo: + image: grafana/tempo:2.9.0-rc.0 + command: ["-config.file=/etc/tempo/tempo-config.yaml"] + volumes: + - ./testdata/tempo-config.yaml:/etc/tempo/tempo-config.yaml + ports: + - "3200:3200" # tempo + + tempo2: + image: grafana/tempo:2.9.0-rc.0 + command: ["-config.file=/etc/tempo/tempo-config.yaml"] + volumes: + - ./testdata/tempo-config-2.yaml:/etc/tempo/tempo-config.yaml + ports: + - "3201:3201" # tempo instance 2 diff --git a/go.mod b/go.mod index 8c8c872d..e65fbb9c 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( connectrpc.com/connect v1.19.0 github.com/PaesslerAG/gval v1.2.4 github.com/PaesslerAG/jsonpath v0.1.1 + github.com/cenkalti/backoff/v5 v5.0.3 github.com/go-openapi/runtime v0.29.0 github.com/go-openapi/strfmt v0.24.0 github.com/google/uuid v1.6.0 @@ -31,7 +32,6 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cheekybits/genny v1.0.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect diff --git a/mcpgrafana.go b/mcpgrafana.go index 58f31cfa..82eee3e9 100644 --- a/mcpgrafana.go +++ b/mcpgrafana.go @@ -435,7 +435,7 @@ var ExtractGrafanaClientFromEnv server.StdioContextFunc = func(ctx context.Conte } auth := userAndPassFromEnv() grafanaClient := NewGrafanaClient(ctx, grafanaURL, apiKey, auth) - return context.WithValue(ctx, grafanaClientKey{}, grafanaClient) + return WithGrafanaClient(ctx, grafanaClient) } // ExtractGrafanaClientFromHeaders is a HTTPContextFunc that creates and injects a Grafana client into the context. @@ -443,6 +443,7 @@ var ExtractGrafanaClientFromEnv server.StdioContextFunc = func(ctx context.Conte var ExtractGrafanaClientFromHeaders httpContextFunc = func(ctx context.Context, req *http.Request) context.Context { // Extract transport config from request headers, and set it on the context. u, apiKey, basicAuth := extractKeyGrafanaInfoFromReq(req) + slog.Debug("Creating Grafana client", "url", u, "api_key_set", apiKey != "", "basic_auth_set", basicAuth != nil) grafanaClient := NewGrafanaClient(ctx, u, apiKey, basicAuth) return WithGrafanaClient(ctx, grafanaClient) diff --git a/proxied_client.go b/proxied_client.go new file mode 100644 index 00000000..45571b16 --- /dev/null +++ b/proxied_client.go @@ -0,0 +1,143 @@ +package mcpgrafana + +import ( + "context" + "encoding/base64" + "fmt" + "log/slog" + "sync" + + mcp_client "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// ProxiedClient represents a connection to a remote MCP server (e.g., Tempo datasource) +type ProxiedClient struct { + DatasourceUID string + DatasourceName string + DatasourceType string + Client *mcp_client.Client + Tools []mcp.Tool + mutex sync.RWMutex +} + +// NewProxiedClient creates a new connection to a remote MCP server +func NewProxiedClient(ctx context.Context, datasourceUID, datasourceName, datasourceType, mcpEndpoint string) (*ProxiedClient, error) { + // Get Grafana config for authentication + config := GrafanaConfigFromContext(ctx) + + // Build headers for authentication + headers := make(map[string]string) + if config.APIKey != "" { + headers["Authorization"] = "Bearer " + config.APIKey + } else if config.BasicAuth != nil { + auth := config.BasicAuth.String() + headers["Authorization"] = "Basic " + base64.StdEncoding.EncodeToString([]byte(auth)) + } + + // Create HTTP transport with authentication headers + slog.DebugContext(ctx, "connecting to MCP server", "datasource", datasourceUID, "url", mcpEndpoint) + httpTransport, err := transport.NewStreamableHTTP( + mcpEndpoint, + transport.WithHTTPHeaders(headers), + ) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP transport: %w", err) + } + + // Create MCP client + mcpClient := mcp_client.NewClient(httpTransport) + + // Initialize the connection + initReq := mcp.InitializeRequest{} + initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initReq.Params.ClientInfo = mcp.Implementation{ + Name: "mcp-grafana-proxy", + Version: Version(), + } + + _, err = mcpClient.Initialize(ctx, initReq) + if err != nil { + _ = mcpClient.Close() + return nil, fmt.Errorf("failed to initialize MCP client: %w", err) + } + + // List available tools from the remote server + listReq := mcp.ListToolsRequest{} + toolsResult, err := mcpClient.ListTools(ctx, listReq) + if err != nil { + _ = mcpClient.Close() + return nil, fmt.Errorf("failed to list tools from remote MCP server: %w", err) + } + + slog.DebugContext(ctx, "connected to proxied MCP server", + "datasource", datasourceUID, + "type", datasourceType, + "tools", len(toolsResult.Tools)) + + return &ProxiedClient{ + DatasourceUID: datasourceUID, + DatasourceName: datasourceName, + DatasourceType: datasourceType, + Client: mcpClient, + Tools: toolsResult.Tools, + }, nil +} + +// CallTool forwards a tool call to the remote MCP server +func (pc *ProxiedClient) CallTool(ctx context.Context, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + pc.mutex.RLock() + defer pc.mutex.RUnlock() + + // Validate the tool exists + var toolExists bool + for _, tool := range pc.Tools { + if tool.Name == toolName { + toolExists = true + break + } + } + if !toolExists { + return nil, fmt.Errorf("tool %s not found in remote MCP server", toolName) + } + + // Create the call tool request + req := mcp.CallToolRequest{} + req.Params.Name = toolName + req.Params.Arguments = arguments + + // Forward the call to the remote server + result, err := pc.Client.CallTool(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to call tool on remote MCP server: %w", err) + } + + return result, nil +} + +// ListTools returns the tools available from this remote server +// Note: This method doesn't take a context parameter as the tools are cached locally +func (pc *ProxiedClient) ListTools() []mcp.Tool { + pc.mutex.RLock() + defer pc.mutex.RUnlock() + + // Return a copy to prevent external modification + result := make([]mcp.Tool, len(pc.Tools)) + copy(result, pc.Tools) + return result +} + +// Close closes the connection to the remote MCP server +func (pc *ProxiedClient) Close() error { + pc.mutex.Lock() + defer pc.mutex.Unlock() + + if pc.Client != nil { + if err := pc.Client.Close(); err != nil { + return fmt.Errorf("failed to close MCP client: %w", err) + } + } + + return nil +} diff --git a/proxied_handler.go b/proxied_handler.go new file mode 100644 index 00000000..06fe14ee --- /dev/null +++ b/proxied_handler.go @@ -0,0 +1,87 @@ +package mcpgrafana + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// ProxiedToolHandler implements the CallToolHandler interface for proxied tools +type ProxiedToolHandler struct { + sessionManager *SessionManager + toolManager *ToolManager + toolName string +} + +// NewProxiedToolHandler creates a new handler for a proxied tool +func NewProxiedToolHandler(sm *SessionManager, tm *ToolManager, toolName string) *ProxiedToolHandler { + return &ProxiedToolHandler{ + sessionManager: sm, + toolManager: tm, + toolName: toolName, + } +} + +// Handle forwards the tool call to the appropriate remote MCP server +func (h *ProxiedToolHandler) Handle(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Check if session is in context + session := server.ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("session not found in context") + } + + // Extract arguments + args, ok := request.Params.Arguments.(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid arguments type") + } + + // Extract required datasourceUid parameter + datasourceUidRaw, ok := args["datasourceUid"] + if !ok { + return nil, fmt.Errorf("datasourceUid parameter is required") + } + datasourceUID, ok := datasourceUidRaw.(string) + if !ok { + return nil, fmt.Errorf("datasourceUid must be a string") + } + + // Parse the tool name to get datasource type and original tool name + // Format: datasourceType_originalToolName (e.g., "tempo_traceql-search") + datasourceType, originalToolName, err := parseProxiedToolName(h.toolName) + if err != nil { + return nil, fmt.Errorf("failed to parse tool name: %w", err) + } + + // Get the proxied client for this datasource + var client *ProxiedClient + + if h.toolManager.serverMode { + // Server mode (stdio): clients stored at manager level + client, err = h.toolManager.GetServerClient(datasourceType, datasourceUID) + } else { + // Session mode (HTTP/SSE): clients stored per-session + client, err = h.sessionManager.GetProxiedClient(ctx, datasourceType, datasourceUID) + if err != nil { + // Fallback to server-level in case of mixed mode + client, err = h.toolManager.GetServerClient(datasourceType, datasourceUID) + } + } + + if err != nil { + return nil, fmt.Errorf("datasource '%s' not found or not accessible. Ensure the datasource exists and you have permission to access it", datasourceUID) + } + + // Remove datasourceUid from args before forwarding to remote server + forwardArgs := make(map[string]any) + for k, v := range args { + if k != "datasourceUid" { + forwardArgs[k] = v + } + } + + // Forward the call to the remote MCP server + return client.CallTool(ctx, originalToolName, forwardArgs) +} diff --git a/proxied_tools.go b/proxied_tools.go new file mode 100644 index 00000000..47b653ca --- /dev/null +++ b/proxied_tools.go @@ -0,0 +1,361 @@ +package mcpgrafana + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "strings" + "sync" + + "github.com/go-openapi/runtime" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// MCPDatasourceConfig defines configuration for a datasource type that supports MCP +type MCPDatasourceConfig struct { + Type string + EndpointPath string // e.g., "/api/mcp" +} + +// mcpEnabledDatasources is a registry of datasource types that support MCP +var mcpEnabledDatasources = map[string]MCPDatasourceConfig{ + "tempo": {Type: "tempo", EndpointPath: "/api/mcp"}, + // Future: add other datasource types here +} + +// DiscoveredDatasource represents a datasource that supports MCP +type DiscoveredDatasource struct { + UID string + Name string + Type string + MCPURL string // The MCP endpoint URL +} + +// discoverMCPDatasources discovers datasources that support MCP +// Returns a list of datasources with MCP endpoints +func discoverMCPDatasources(ctx context.Context) ([]DiscoveredDatasource, error) { + gc := GrafanaClientFromContext(ctx) + if gc == nil { + return nil, fmt.Errorf("grafana client not found in context") + } + + var discovered []DiscoveredDatasource + + // List all datasources + resp, err := gc.Datasources.GetDataSources() + if err != nil { + return nil, fmt.Errorf("failed to list datasources: %w", err) + } + + // Get the Grafana base URL from context + config := GrafanaConfigFromContext(ctx) + if config.URL == "" { + return nil, fmt.Errorf("grafana url not found in context") + } + grafanaBaseURL := config.URL + + // Filter for datasources that support MCP + for _, ds := range resp.Payload { + // Check if this datasource type supports MCP + dsConfig, supported := mcpEnabledDatasources[ds.Type] + if !supported { + continue + } + + // Check if the datasource instance has MCP enabled + // We use a DELETE request to probe the MCP endpoint since: + // - GET would start an event stream and hang + // - POST doesn't work with the Grafana OpenAPI client + // - DELETE returns 200 if MCP is enabled, 404 if not + _, err := gc.Datasources.DatasourceProxyDELETEByUIDcalls(ds.UID, strings.TrimPrefix(dsConfig.EndpointPath, "/")) + if err == nil { + // Something strange happened - the server should never return a 202 for this really. Skip. + continue + } + if apiErr, ok := err.(*runtime.APIError); !ok || (ok && !apiErr.IsCode(http.StatusOK)) { + // Not a 200 response, MCP not enabled + continue + } + + // Build the MCP endpoint URL using Grafana's datasource proxy API + // Format: /api/datasources/proxy/uid/ + mcpURL := fmt.Sprintf("%s/api/datasources/proxy/uid/%s%s", grafanaBaseURL, ds.UID, dsConfig.EndpointPath) + + discovered = append(discovered, DiscoveredDatasource{ + UID: ds.UID, + Name: ds.Name, + Type: ds.Type, + MCPURL: mcpURL, + }) + } + + slog.DebugContext(ctx, "discovered MCP datasources", "count", len(discovered)) + return discovered, nil +} + +// addDatasourceUidParameter adds a required datasourceUid parameter to a tool's input schema +func addDatasourceUidParameter(tool mcp.Tool, datasourceType string) mcp.Tool { + modifiedTool := tool + // Prefix tool name with datasource type (e.g., "tempo_traceql-search") + modifiedTool.Name = datasourceType + "_" + tool.Name + + // Add datasourceUid to the input schema + if modifiedTool.InputSchema.Properties == nil { + modifiedTool.InputSchema.Properties = make(map[string]any) + } + + modifiedTool.InputSchema.Properties["datasourceUid"] = map[string]any{ + "type": "string", + "description": "UID of the " + datasourceType + " datasource to query", + } + + // Add to required fields + modifiedTool.InputSchema.Required = append(modifiedTool.InputSchema.Required, "datasourceUid") + + return modifiedTool +} + +// parseProxiedToolName extracts datasource type and original tool name from a proxied tool name +// Format: _ +// Returns: datasourceType, originalToolName, error +func parseProxiedToolName(toolName string) (string, string, error) { + parts := strings.SplitN(toolName, "_", 2) + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid proxied tool name format: %s", toolName) + } + return parts[0], parts[1], nil +} + +// ToolManager manages proxied tools (either per-session or server-wide) +type ToolManager struct { + sm *SessionManager + server *server.MCPServer + + // Whether to enable proxied tools. + enableProxiedTools bool + + // For stdio transport: store clients at manager level (single-tenant). + // These will be unused for HTTP/SSE transports. + serverMode bool // true if using server-wide tools (stdio), false for per-session (HTTP/SSE) + serverClients map[string]*ProxiedClient + clientsMutex sync.RWMutex +} + +// NewToolManager creates a new ToolManager +func NewToolManager(sm *SessionManager, mcpServer *server.MCPServer, opts ...toolManagerOption) *ToolManager { + tm := &ToolManager{ + sm: sm, + server: mcpServer, + serverClients: make(map[string]*ProxiedClient), + } + for _, opt := range opts { + opt(tm) + } + return tm +} + +type toolManagerOption func(*ToolManager) + +// WithProxiedTools sets whether proxied tools are enabled +func WithProxiedTools(enabled bool) toolManagerOption { + return func(tm *ToolManager) { + tm.enableProxiedTools = enabled + } +} + +// InitializeAndRegisterServerTools discovers datasources and registers tools on the server (for stdio transport) +// This should be called once at server startup for single-tenant stdio servers +func (tm *ToolManager) InitializeAndRegisterServerTools(ctx context.Context) error { + if !tm.enableProxiedTools { + return nil + } + + // Mark as server mode (stdio transport) + tm.serverMode = true + + // Discover datasources with MCP support + discovered, err := discoverMCPDatasources(ctx) + if err != nil { + return fmt.Errorf("failed to discover MCP datasources: %w", err) + } + + if len(discovered) == 0 { + slog.Info("no MCP datasources discovered") + return nil + } + + // Connect to each datasource and store in manager + tm.clientsMutex.Lock() + for _, ds := range discovered { + client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL) + if err != nil { + slog.Error("failed to create proxied client", "datasource", ds.UID, "error", err) + continue + } + key := ds.Type + "_" + ds.UID + tm.serverClients[key] = client + } + clientCount := len(tm.serverClients) + tm.clientsMutex.Unlock() + + if clientCount == 0 { + slog.Warn("no proxied clients created") + return nil + } + + slog.Info("connected to proxied MCP servers", "datasources", clientCount) + + // Collect and register all unique tools + tm.clientsMutex.RLock() + toolMap := make(map[string]mcp.Tool) + for _, client := range tm.serverClients { + for _, tool := range client.ListTools() { + toolName := client.DatasourceType + "_" + tool.Name + if _, exists := toolMap[toolName]; !exists { + modifiedTool := addDatasourceUidParameter(tool, client.DatasourceType) + toolMap[toolName] = modifiedTool + } + } + } + tm.clientsMutex.RUnlock() + + // Register tools on the server (not per-session) + for toolName, tool := range toolMap { + handler := NewProxiedToolHandler(tm.sm, tm, toolName) + tm.server.AddTool(tool, handler.Handle) + } + + slog.Info("registered proxied tools on server", "tools", len(toolMap)) + return nil +} + +// InitializeAndRegisterProxiedTools discovers datasources, creates clients, and registers tools per-session +// This should be called in OnBeforeListTools and OnBeforeCallTool hooks for HTTP/SSE transports +func (tm *ToolManager) InitializeAndRegisterProxiedTools(ctx context.Context, session server.ClientSession) { + if !tm.enableProxiedTools { + return + } + + sessionID := session.SessionID() + state, exists := tm.sm.GetSession(sessionID) + if !exists { + // Session exists in server context but not in our SessionManager yet + tm.sm.CreateSession(ctx, session) + state, exists = tm.sm.GetSession(sessionID) + if !exists { + slog.Error("failed to create session in SessionManager", "sessionID", sessionID) + return + } + } + + // Step 1: Discover and connect (guaranteed to run exactly once per session) + state.initOnce.Do(func() { + // Discover datasources with MCP support + discovered, err := discoverMCPDatasources(ctx) + if err != nil { + slog.Error("failed to discover MCP datasources", "error", err) + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.mutex.Unlock() + return + } + + state.mutex.Lock() + // For each discovered datasource, create a proxied client + for _, ds := range discovered { + client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL) + if err != nil { + slog.Error("failed to create proxied client", "datasource", ds.UID, "error", err) + continue + } + + // Store the client + key := ds.Type + "_" + ds.UID + state.proxiedClients[key] = client + } + state.proxiedToolsInitialized = true + state.mutex.Unlock() + + slog.Info("connected to proxied MCP servers", "session", sessionID, "datasources", len(state.proxiedClients)) + }) + + // Step 2: Register tools with the MCP server + state.mutex.Lock() + defer state.mutex.Unlock() + + // Check if tools already registered + if len(state.proxiedTools) > 0 { + return + } + + // Check if we have any clients (discovery should have happened above) + if len(state.proxiedClients) == 0 { + return + } + + // First pass: collect all unique tools and track which datasources support them + toolMap := make(map[string]mcp.Tool) // unique tools by name + + for key, client := range state.proxiedClients { + remoteTools := client.ListTools() + + for _, tool := range remoteTools { + // Tool name format: datasourceType_originalToolName (e.g., "tempo_traceql-search") + toolName := client.DatasourceType + "_" + tool.Name + + // Store the tool if we haven't seen it yet + if _, exists := toolMap[toolName]; !exists { + // Add datasourceUid parameter to the tool + modifiedTool := addDatasourceUidParameter(tool, client.DatasourceType) + toolMap[toolName] = modifiedTool + } + + // Track which datasources support this tool + state.toolToDatasources[toolName] = append(state.toolToDatasources[toolName], key) + } + } + + // Second pass: register all unique tools at once (reduces listChanged notifications) + var serverTools []server.ServerTool + for toolName, tool := range toolMap { + handler := NewProxiedToolHandler(tm.sm, tm, toolName) + serverTools = append(serverTools, server.ServerTool{ + Tool: tool, + Handler: handler.Handle, + }) + state.proxiedTools = append(state.proxiedTools, tool) + } + + if err := tm.server.AddSessionTools(sessionID, serverTools...); err != nil { + slog.Warn("failed to add session tools", "session", sessionID, "error", err) + } else { + slog.Info("registered proxied tools", "session", sessionID, "tools", len(state.proxiedTools)) + } +} + +// GetServerClient retrieves a proxied client from server-level storage (for stdio transport) +func (tm *ToolManager) GetServerClient(datasourceType, datasourceUID string) (*ProxiedClient, error) { + tm.clientsMutex.RLock() + defer tm.clientsMutex.RUnlock() + + key := datasourceType + "_" + datasourceUID + client, exists := tm.serverClients[key] + if !exists { + // List available datasources to help with debugging + var availableUIDs []string + for _, c := range tm.serverClients { + if c.DatasourceType == datasourceType { + availableUIDs = append(availableUIDs, c.DatasourceUID) + } + } + + if len(availableUIDs) > 0 { + return nil, fmt.Errorf("datasource '%s' not found. Available %s datasources: %v", datasourceUID, datasourceType, availableUIDs) + } + return nil, fmt.Errorf("datasource '%s' not found. No %s datasources with MCP support are configured", datasourceUID, datasourceType) + } + + return client, nil +} diff --git a/proxied_tools_test.go b/proxied_tools_test.go new file mode 100644 index 00000000..9dcdc3fd --- /dev/null +++ b/proxied_tools_test.go @@ -0,0 +1,465 @@ +package mcpgrafana + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" +) + +func TestSessionStateRaceConditions(t *testing.T) { + t.Run("concurrent initialization with sync.Once is safe", func(t *testing.T) { + state := newSessionState() + + var initCounter int32 + var wg sync.WaitGroup + + // Launch 100 goroutines that all try to initialize at once + const numGoroutines = 100 + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + state.initOnce.Do(func() { + // Simulate initialization work + atomic.AddInt32(&initCounter, 1) + time.Sleep(10 * time.Millisecond) // Simulate some work + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.mutex.Unlock() + }) + }() + } + + wg.Wait() + + // Verify initialization happened exactly once + assert.Equal(t, int32(1), atomic.LoadInt32(&initCounter), + "Initialization should run exactly once despite 100 concurrent calls") + assert.True(t, state.proxiedToolsInitialized) + }) + + t.Run("concurrent reads and writes with mutex protection", func(t *testing.T) { + state := newSessionState() + var wg sync.WaitGroup + + // Writer goroutines + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + state.mutex.Lock() + key := "tempo_" + string(rune('a'+id)) + state.proxiedClients[key] = &ProxiedClient{ + DatasourceUID: key, + DatasourceName: "Test " + key, + DatasourceType: "tempo", + } + state.mutex.Unlock() + }(i) + } + + // Reader goroutines + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + state.mutex.RLock() + _ = len(state.proxiedClients) + state.mutex.RUnlock() + }() + } + + wg.Wait() + + // Verify all writes succeeded + state.mutex.RLock() + count := len(state.proxiedClients) + state.mutex.RUnlock() + + assert.Equal(t, 10, count, "All 10 clients should be stored") + }) + + t.Run("concurrent tool registration is safe", func(t *testing.T) { + state := newSessionState() + var wg sync.WaitGroup + + // Multiple goroutines trying to register tools + const numGoroutines = 50 + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + state.mutex.Lock() + toolName := "tempo_tool-" + string(rune('a'+id%26)) + if state.toolToDatasources[toolName] == nil { + state.toolToDatasources[toolName] = []string{} + } + state.toolToDatasources[toolName] = append( + state.toolToDatasources[toolName], + "datasource_"+string(rune('a'+id%26)), + ) + state.mutex.Unlock() + }(i) + } + + wg.Wait() + + // Verify the tool mappings exist + state.mutex.RLock() + defer state.mutex.RUnlock() + assert.Greater(t, len(state.toolToDatasources), 0, "Should have tool mappings") + }) +} + +func TestSessionManagerConcurrency(t *testing.T) { + t.Run("concurrent session creation is safe", func(t *testing.T) { + sm := NewSessionManager() + var wg sync.WaitGroup + + // Create many sessions concurrently + const numSessions = 100 + wg.Add(numSessions) + + for i := 0; i < numSessions; i++ { + go func(id int) { + defer wg.Done() + sessionID := "session-" + string(rune('a'+id%26)) + "-" + string(rune('0'+id/26)) + mockSession := &mockClientSession{id: sessionID} + sm.CreateSession(context.Background(), mockSession) + }(i) + } + + wg.Wait() + + // Verify all sessions were created + sm.mutex.RLock() + count := len(sm.sessions) + sm.mutex.RUnlock() + + assert.Equal(t, numSessions, count, "All sessions should be created") + }) + + t.Run("concurrent get and remove is safe", func(t *testing.T) { + sm := NewSessionManager() + + // Pre-populate sessions + for i := 0; i < 50; i++ { + sessionID := "session-" + string(rune('a'+i%26)) + mockSession := &mockClientSession{id: sessionID} + sm.CreateSession(context.Background(), mockSession) + } + + var wg sync.WaitGroup + + // Readers + for i := 0; i < 50; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + sessionID := "session-" + string(rune('a'+id%26)) + _, _ = sm.GetSession(sessionID) + }(i) + } + + // Writers (removers) + for i := 0; i < 25; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + sessionID := "session-" + string(rune('a'+id%26)) + mockSession := &mockClientSession{id: sessionID} + sm.RemoveSession(context.Background(), mockSession) + }(i) + } + + wg.Wait() + + // Test passed if no race conditions occurred + }) +} + +func TestInitOncePattern(t *testing.T) { + t.Run("verify sync.Once guarantees single execution", func(t *testing.T) { + var once sync.Once + var counter int32 + var wg sync.WaitGroup + + // Simulate what happens in InitializeAndRegisterProxiedTools + initFunc := func() { + atomic.AddInt32(&counter, 1) + // Simulate expensive initialization + time.Sleep(50 * time.Millisecond) + } + + // Launch many concurrent calls + for i := 0; i < 1000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + once.Do(initFunc) + }() + } + + wg.Wait() + + assert.Equal(t, int32(1), atomic.LoadInt32(&counter), + "sync.Once should guarantee function runs exactly once") + }) + + t.Run("sync.Once with different functions only runs first", func(t *testing.T) { + var once sync.Once + var result string + var mu sync.Mutex + + once.Do(func() { + mu.Lock() + result = "first" + mu.Unlock() + }) + + once.Do(func() { + mu.Lock() + result = "second" + mu.Unlock() + }) + + mu.Lock() + finalResult := result + mu.Unlock() + + assert.Equal(t, "first", finalResult, "Only first function should execute") + }) +} + +func TestProxiedToolsInitializationFlow(t *testing.T) { + t.Run("initialization state transitions are correct", func(t *testing.T) { + state := newSessionState() + + // Initial state + assert.False(t, state.proxiedToolsInitialized) + assert.Empty(t, state.proxiedClients) + assert.Empty(t, state.proxiedTools) + + // Simulate initialization + state.initOnce.Do(func() { + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.proxiedClients["tempo_test"] = &ProxiedClient{ + DatasourceUID: "test", + DatasourceName: "Test", + DatasourceType: "tempo", + } + state.mutex.Unlock() + }) + + // Verify state after initialization + state.mutex.RLock() + initialized := state.proxiedToolsInitialized + clientCount := len(state.proxiedClients) + state.mutex.RUnlock() + + assert.True(t, initialized) + assert.Equal(t, 1, clientCount) + }) + + t.Run("multiple sessions maintain separate state", func(t *testing.T) { + sm := NewSessionManager() + + // Create two sessions + session1 := &mockClientSession{id: "session-1"} + session2 := &mockClientSession{id: "session-2"} + + sm.CreateSession(context.Background(), session1) + sm.CreateSession(context.Background(), session2) + + state1, _ := sm.GetSession("session-1") + state2, _ := sm.GetSession("session-2") + + // Initialize only session1 + state1.initOnce.Do(func() { + state1.mutex.Lock() + state1.proxiedToolsInitialized = true + state1.mutex.Unlock() + }) + + // Verify states are independent + assert.True(t, state1.proxiedToolsInitialized) + assert.False(t, state2.proxiedToolsInitialized) + assert.NotSame(t, state1, state2) + }) +} + +func TestRaceConditionDemonstration(t *testing.T) { + t.Run("old pattern WITHOUT sync.Once would have race condition", func(t *testing.T) { + // This test demonstrates what WOULD happen with the old mutex-check pattern + state := newSessionState() + + var discoveryCallCount int32 + var wg sync.WaitGroup + + // Simulate the OLD pattern (mutex check, unlock, then do work) + oldPatternInitialize := func() { + state.mutex.Lock() + // Check if already initialized + if state.proxiedToolsInitialized { + state.mutex.Unlock() + return + } + alreadyDiscovered := state.proxiedToolsInitialized + state.mutex.Unlock() // āŒ OLD PATTERN: Unlock before expensive work + + if !alreadyDiscovered { + // Simulate discovery work that should only happen once + atomic.AddInt32(&discoveryCallCount, 1) + time.Sleep(10 * time.Millisecond) // Simulate expensive operation + + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.mutex.Unlock() + } + } + + // Launch concurrent initializations + const numGoroutines = 10 + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + oldPatternInitialize() + }() + } + wg.Wait() + + // With the old pattern, multiple goroutines can get past the check + // and call discovery multiple times + count := atomic.LoadInt32(&discoveryCallCount) + if count > 1 { + t.Logf("OLD PATTERN: Discovery called %d times (race condition!)", count) + } + // We can't assert > 1 reliably because timing matters, but this demonstrates the problem + }) + + t.Run("new pattern WITH sync.Once prevents race condition", func(t *testing.T) { + // This test demonstrates the NEW pattern with sync.Once + state := newSessionState() + + var discoveryCallCount int32 + var wg sync.WaitGroup + + // NEW pattern: sync.Once guarantees single execution + newPatternInitialize := func() { + state.initOnce.Do(func() { + // Simulate discovery work that should only happen once + atomic.AddInt32(&discoveryCallCount, 1) + time.Sleep(10 * time.Millisecond) // Simulate expensive operation + + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.mutex.Unlock() + }) + } + + // Launch concurrent initializations + const numGoroutines = 10 + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + newPatternInitialize() + }() + } + wg.Wait() + + // With sync.Once, discovery is guaranteed to run exactly once + count := atomic.LoadInt32(&discoveryCallCount) + assert.Equal(t, int32(1), count, "NEW PATTERN: Discovery must be called exactly once") + }) +} + +func TestRaceDetector(t *testing.T) { + // This test is primarily valuable when run with -race flag + t.Run("stress test with race detector", func(t *testing.T) { + + sm := NewSessionManager() + var wg sync.WaitGroup + + // Create a mix of operations happening concurrently + for i := 0; i < 20; i++ { + sessionID := "stress-session-" + string(rune('a'+i%10)) + + // Create session + wg.Add(1) + go func(sid string) { + defer wg.Done() + mockSession := &mockClientSession{id: sid} + sm.CreateSession(context.Background(), mockSession) + }(sessionID) + + // Initialize session state + wg.Add(1) + go func(sid string) { + defer wg.Done() + time.Sleep(time.Millisecond) // Let creation happen first + state, exists := sm.GetSession(sid) + if exists { + state.initOnce.Do(func() { + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.mutex.Unlock() + }) + } + }(sessionID) + + // Read session state + wg.Add(1) + go func(sid string) { + defer wg.Done() + time.Sleep(2 * time.Millisecond) + state, exists := sm.GetSession(sid) + if exists { + state.mutex.RLock() + _ = state.proxiedToolsInitialized + state.mutex.RUnlock() + } + }(sessionID) + } + + wg.Wait() + + // If we get here without race detector warnings, we're good + t.Log("Stress test completed without race conditions") + }) +} + +// mockClientSession implements server.ClientSession for testing +type mockClientSession struct { + id string + notifChannel chan mcp.JSONRPCNotification + isInitialized bool +} + +func (m *mockClientSession) SessionID() string { + return m.id +} + +func (m *mockClientSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + if m.notifChannel == nil { + m.notifChannel = make(chan mcp.JSONRPCNotification, 10) + } + return m.notifChannel +} + +func (m *mockClientSession) Initialize() { + m.isInitialized = true +} + +func (m *mockClientSession) Initialized() bool { + return m.isInitialized +} diff --git a/session.go b/session.go new file mode 100644 index 00000000..d1005e37 --- /dev/null +++ b/session.go @@ -0,0 +1,115 @@ +package mcpgrafana + +import ( + "context" + "fmt" + "log/slog" + "sync" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// SessionState holds the state for a single client session +type SessionState struct { + // Proxied tools state + initOnce sync.Once + proxiedToolsInitialized bool + proxiedTools []mcp.Tool + proxiedClients map[string]*ProxiedClient // key: datasourceType_datasourceUID + toolToDatasources map[string][]string // key: toolName, value: list of datasource keys that support it + mutex sync.RWMutex +} + +func newSessionState() *SessionState { + return &SessionState{ + proxiedClients: make(map[string]*ProxiedClient), + toolToDatasources: make(map[string][]string), + } +} + +// SessionManager manages client sessions and their state +type SessionManager struct { + sessions map[string]*SessionState + mutex sync.RWMutex +} + +func NewSessionManager() *SessionManager { + return &SessionManager{ + sessions: make(map[string]*SessionState), + } +} + +func (sm *SessionManager) CreateSession(ctx context.Context, session server.ClientSession) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + sessionID := session.SessionID() + if _, exists := sm.sessions[sessionID]; !exists { + sm.sessions[sessionID] = newSessionState() + } +} + +func (sm *SessionManager) GetSession(sessionID string) (*SessionState, bool) { + sm.mutex.RLock() + defer sm.mutex.RUnlock() + + session, exists := sm.sessions[sessionID] + return session, exists +} + +func (sm *SessionManager) RemoveSession(ctx context.Context, session server.ClientSession) { + sm.mutex.Lock() + sessionID := session.SessionID() + state, exists := sm.sessions[sessionID] + delete(sm.sessions, sessionID) + sm.mutex.Unlock() + + if !exists { + return + } + + // Clean up proxied clients outside of the main lock + state.mutex.Lock() + defer state.mutex.Unlock() + + for key, client := range state.proxiedClients { + if err := client.Close(); err != nil { + slog.Error("failed to close proxied client", "key", key, "error", err) + } + } +} + +// GetProxiedClient retrieves a proxied client for the given datasource +func (sm *SessionManager) GetProxiedClient(ctx context.Context, datasourceType, datasourceUID string) (*ProxiedClient, error) { + session := server.ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("session not found in context") + } + + state, exists := sm.GetSession(session.SessionID()) + if !exists { + return nil, fmt.Errorf("session not found") + } + + state.mutex.RLock() + defer state.mutex.RUnlock() + + key := datasourceType + "_" + datasourceUID + client, exists := state.proxiedClients[key] + if !exists { + // List available datasources to help with debugging + var availableUIDs []string + for _, c := range state.proxiedClients { + if c.DatasourceType == datasourceType { + availableUIDs = append(availableUIDs, c.DatasourceUID) + } + } + if len(availableUIDs) > 0 { + return nil, fmt.Errorf("datasource '%s' not found. Available %s datasources: %v", datasourceUID, datasourceType, availableUIDs) + } + return nil, fmt.Errorf("datasource '%s' not found. No %s datasources with MCP support are configured", datasourceUID, datasourceType) + } + + return client, nil +} diff --git a/session_test.go b/session_test.go new file mode 100644 index 00000000..040e28a8 --- /dev/null +++ b/session_test.go @@ -0,0 +1,475 @@ +//go:build integration + +// Integration tests for proxied MCP tools functionality. +// Requires docker-compose to be running with Grafana and Tempo instances. +// Run with: go test -tags integration -v ./... + +package mcpgrafana + +import ( + "context" + "fmt" + "net/url" + "os" + "strings" + "sync" + "testing" + + "github.com/go-openapi/strfmt" + grafana_client "github.com/grafana/grafana-openapi-client-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newProxiedToolsTestContext creates a test context with Grafana client and config +func newProxiedToolsTestContext(t *testing.T) context.Context { + cfg := grafana_client.DefaultTransportConfig() + cfg.Host = "localhost:3000" + cfg.Schemes = []string{"http"} + + // Extract transport config from env vars, and set it on the context. + if u, ok := os.LookupEnv("GRAFANA_URL"); ok { + parsedURL, err := url.Parse(u) + require.NoError(t, err, "invalid GRAFANA_URL") + cfg.Host = parsedURL.Host + // The Grafana client will always prefer HTTPS even if the URL is HTTP, + // so we need to limit the schemes to HTTP if the URL is HTTP. + if parsedURL.Scheme == "http" { + cfg.Schemes = []string{"http"} + } + } + + // Check for the new service account token environment variable first + if apiKey := os.Getenv("GRAFANA_SERVICE_ACCOUNT_TOKEN"); apiKey != "" { + cfg.APIKey = apiKey + } else if apiKey := os.Getenv("GRAFANA_API_KEY"); apiKey != "" { + // Fall back to the deprecated API key environment variable + cfg.APIKey = apiKey + } else { + cfg.BasicAuth = url.UserPassword("admin", "admin") + } + + grafanaClient := grafana_client.NewHTTPClientWithConfig(strfmt.Default, cfg) + + grafanaCfg := GrafanaConfig{ + Debug: true, + URL: "http://localhost:3000", + APIKey: cfg.APIKey, + BasicAuth: cfg.BasicAuth, + } + + ctx := WithGrafanaConfig(context.Background(), grafanaCfg) + return WithGrafanaClient(ctx, grafanaClient) +} + +func TestDiscoverMCPDatasources(t *testing.T) { + ctx := newProxiedToolsTestContext(t) + + t.Run("discovers tempo datasources", func(t *testing.T) { + discovered, err := discoverMCPDatasources(ctx) + require.NoError(t, err) + + // Should find two Tempo datasources from docker-compose + assert.GreaterOrEqual(t, len(discovered), 2, "Should discover at least 2 Tempo datasources") + + // Check that we found the expected datasources + uids := make([]string, len(discovered)) + for i, ds := range discovered { + uids[i] = ds.UID + assert.Equal(t, "tempo", ds.Type, "All discovered datasources should be tempo type") + assert.NotEmpty(t, ds.Name, "Datasource should have a name") + assert.NotEmpty(t, ds.MCPURL, "Datasource should have MCP URL") + + // Verify URL format + expectedURLPattern := fmt.Sprintf("http://localhost:3000/api/datasources/proxy/uid/%s/api/mcp", ds.UID) + assert.Equal(t, expectedURLPattern, ds.MCPURL, "MCP URL should follow proxy pattern") + } + + // Should contain our expected UIDs + assert.Contains(t, uids, "tempo", "Should discover 'tempo' datasource") + assert.Contains(t, uids, "tempo-secondary", "Should discover 'tempo-secondary' datasource") + }) + + t.Run("returns error when grafana client not in context", func(t *testing.T) { + emptyCtx := context.Background() + discovered, err := discoverMCPDatasources(emptyCtx) + assert.Error(t, err) + assert.Nil(t, discovered) + assert.Contains(t, err.Error(), "grafana client not found in context") + }) + + t.Run("returns error when auth is missing", func(t *testing.T) { + // Context with client but no auth credentials + cfg := grafana_client.DefaultTransportConfig() + cfg.Host = "localhost:3000" + cfg.Schemes = []string{"http"} + grafanaClient := grafana_client.NewHTTPClientWithConfig(strfmt.Default, cfg) + + grafanaCfg := GrafanaConfig{ + URL: "http://localhost:3000", + // No APIKey or BasicAuth set + } + ctx := WithGrafanaConfig(context.Background(), grafanaCfg) + ctx = WithGrafanaClient(ctx, grafanaClient) + + discovered, err := discoverMCPDatasources(ctx) + assert.Error(t, err) + assert.Nil(t, discovered) + assert.Contains(t, err.Error(), "Unauthorized") + }) +} + +func TestToolNamespacing(t *testing.T) { + t.Run("parse proxied tool name", func(t *testing.T) { + datasourceType, toolName, err := parseProxiedToolName("tempo_traceql-search") + require.NoError(t, err) + assert.Equal(t, "tempo", datasourceType) + assert.Equal(t, "traceql-search", toolName) + }) + + t.Run("parse proxied tool name with multiple underscores", func(t *testing.T) { + datasourceType, toolName, err := parseProxiedToolName("tempo_get-attribute-values") + require.NoError(t, err) + assert.Equal(t, "tempo", datasourceType) + assert.Equal(t, "get-attribute-values", toolName) + }) + + t.Run("parse proxied tool name with invalid format", func(t *testing.T) { + _, _, err := parseProxiedToolName("invalid") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid proxied tool name format") + }) + + t.Run("add datasourceUid parameter to tool", func(t *testing.T) { + originalTool := mcp.Tool{ + Name: "query_traces", + Description: "Query traces from Tempo", + InputSchema: mcp.ToolInputSchema{ + Properties: map[string]any{ + "query": map[string]any{ + "type": "string", + }, + }, + Required: []string{"query"}, + }, + } + + modifiedTool := addDatasourceUidParameter(originalTool, "tempo") + + assert.Equal(t, "tempo_query_traces", modifiedTool.Name) + assert.Equal(t, "Query traces from Tempo", modifiedTool.Description) + assert.NotNil(t, modifiedTool.InputSchema.Properties["datasourceUid"]) + assert.Contains(t, modifiedTool.InputSchema.Required, "datasourceUid") + assert.Contains(t, modifiedTool.InputSchema.Required, "query") + }) + + t.Run("add datasourceUid parameter with empty description", func(t *testing.T) { + originalTool := mcp.Tool{ + Name: "test_tool", + Description: "", + InputSchema: mcp.ToolInputSchema{ + Properties: make(map[string]any), + }, + } + + modifiedTool := addDatasourceUidParameter(originalTool, "tempo") + + assert.Equal(t, "tempo_test_tool", modifiedTool.Name) + assert.Equal(t, "", modifiedTool.Description, "Should not modify empty description") + assert.NotNil(t, modifiedTool.InputSchema.Properties["datasourceUid"]) + }) +} + +func TestSessionStateLifecycle(t *testing.T) { + t.Run("create and get session", func(t *testing.T) { + sm := NewSessionManager() + + // Create mock session + mockSession := &mockClientSession{id: "test-session-123"} + + sm.CreateSession(context.Background(), mockSession) + + state, exists := sm.GetSession("test-session-123") + assert.True(t, exists) + assert.NotNil(t, state) + assert.NotNil(t, state.proxiedClients) + assert.False(t, state.proxiedToolsInitialized) + }) + + t.Run("remove session cleans up clients", func(t *testing.T) { + sm := NewSessionManager() + + mockSession := &mockClientSession{id: "test-session-456"} + sm.CreateSession(context.Background(), mockSession) + + state, _ := sm.GetSession("test-session-456") + + // Add a mock proxied client + mockClient := &ProxiedClient{ + DatasourceUID: "test-uid", + DatasourceName: "Test Datasource", + DatasourceType: "tempo", + } + state.proxiedClients["tempo_test-uid"] = mockClient + + // Remove session + sm.RemoveSession(context.Background(), mockSession) + + // Session should be gone + _, exists := sm.GetSession("test-session-456") + assert.False(t, exists) + }) + + t.Run("get non-existent session", func(t *testing.T) { + sm := NewSessionManager() + + state, exists := sm.GetSession("non-existent") + assert.False(t, exists) + assert.Nil(t, state) + }) +} + +func TestConcurrentInitializationRaceCondition(t *testing.T) { + t.Run("concurrent initialization calls should be safe", func(t *testing.T) { + sm := NewSessionManager() + mockSession := &mockClientSession{id: "race-test-session"} + sm.CreateSession(context.Background(), mockSession) + + state, exists := sm.GetSession("race-test-session") + require.True(t, exists) + + // Track how many times the initialization logic runs + var initCount int + var initCountMutex sync.Mutex + + // Create a custom initOnce to track calls + state.initOnce = sync.Once{} + + // Simulate the initialization work that should run exactly once + initWork := func() { + initCountMutex.Lock() + initCount++ + initCountMutex.Unlock() + // Simulate some work + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.proxiedClients["tempo_test"] = &ProxiedClient{ + DatasourceUID: "test", + DatasourceName: "Test", + DatasourceType: "tempo", + } + state.mutex.Unlock() + } + + // Launch multiple goroutines that all try to initialize concurrently + const numGoroutines = 10 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + // This should be the pattern used in InitializeAndRegisterProxiedTools + state.initOnce.Do(initWork) + }() + } + + wg.Wait() + + // Verify initialization ran exactly once + assert.Equal(t, 1, initCount, "Initialization should run exactly once despite concurrent calls") + assert.True(t, state.proxiedToolsInitialized, "State should be initialized") + assert.Len(t, state.proxiedClients, 1, "Should have exactly one client") + }) + + t.Run("sync.Once prevents double initialization", func(t *testing.T) { + sm := NewSessionManager() + mockSession := &mockClientSession{id: "double-init-test"} + sm.CreateSession(context.Background(), mockSession) + + state, _ := sm.GetSession("double-init-test") + + callCount := 0 + + // First call + state.initOnce.Do(func() { + callCount++ + }) + + // Second call should not execute + state.initOnce.Do(func() { + callCount++ + }) + + // Third call should also not execute + state.initOnce.Do(func() { + callCount++ + }) + + assert.Equal(t, 1, callCount, "sync.Once should ensure function runs exactly once") + }) +} + +func TestProxiedClientLifecycle(t *testing.T) { + ctx := newProxiedToolsTestContext(t) + + t.Run("list tools returns copy", func(t *testing.T) { + pc := &ProxiedClient{ + DatasourceUID: "test-uid", + DatasourceName: "Test", + DatasourceType: "tempo", + Tools: []mcp.Tool{ + {Name: "tool1", Description: "First tool"}, + {Name: "tool2", Description: "Second tool"}, + }, + } + + tools1 := pc.ListTools() + tools2 := pc.ListTools() + + // Should return same content + assert.Equal(t, tools1, tools2) + + // But different slice instances (copy) + assert.NotSame(t, &tools1[0], &tools2[0]) + }) + + t.Run("call tool validates tool exists", func(t *testing.T) { + pc := &ProxiedClient{ + DatasourceUID: "test-uid", + DatasourceName: "Test", + DatasourceType: "tempo", + Tools: []mcp.Tool{ + {Name: "valid_tool", Description: "Valid tool"}, + }, + } + + // Call non-existent tool + result, err := pc.CallTool(ctx, "non_existent_tool", map[string]any{}) + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "not found in remote MCP server") + }) +} + +func TestEndToEndProxiedToolsFlow(t *testing.T) { + ctx := newProxiedToolsTestContext(t) + + t.Run("full flow from discovery to tool call", func(t *testing.T) { + // Step 1: Discover MCP datasources + discovered, err := discoverMCPDatasources(ctx) + require.NoError(t, err) + require.GreaterOrEqual(t, len(discovered), 1, "Should discover at least one Tempo datasource") + + // Use the first discovered datasource + ds := discovered[0] + t.Logf("Testing with datasource: %s (UID: %s, URL: %s)", ds.Name, ds.UID, ds.MCPURL) + + // Step 2: Create a proxied client connection + client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL) + if err != nil { + t.Skipf("Skipping end-to-end test: Tempo MCP endpoint not available: %v", err) + return + } + defer func() { + _ = client.Close() + }() + + // Step 3: Verify we got tools from the remote server + tools := client.ListTools() + require.Greater(t, len(tools), 0, "Should have at least one tool from Tempo MCP server") + t.Logf("Discovered %d tools from Tempo MCP server", len(tools)) + + // Log the available tools + for _, tool := range tools { + t.Logf(" - Tool: %s - %s", tool.Name, tool.Description) + } + + // Step 4: Test tool modification with datasourceUid parameter + firstTool := tools[0] + modifiedTool := addDatasourceUidParameter(firstTool, ds.Type) + + expectedName := ds.Type + "_" + firstTool.Name + assert.Equal(t, expectedName, modifiedTool.Name, "Modified tool should have prefixed name") + assert.Contains(t, modifiedTool.InputSchema.Required, "datasourceUid", "Modified tool should require datasourceUid") + + // Step 5: Test session integration + sm := NewSessionManager() + mockSession := &mockClientSession{id: "e2e-test-session"} + sm.CreateSession(ctx, mockSession) + + state, exists := sm.GetSession("e2e-test-session") + require.True(t, exists) + + // Store the proxied client in session state + key := ds.Type + "_" + ds.UID + state.proxiedClients[key] = client + + // Step 6: Verify client is stored correctly in session + retrievedClient, exists := state.proxiedClients[key] + require.True(t, exists, "Client should be stored in session state") + assert.Equal(t, client, retrievedClient, "Should retrieve the same client from session") + + // Step 7: Test ProxiedToolHandler flow + handler := NewProxiedToolHandler(sm, nil, modifiedTool.Name) + assert.NotNil(t, handler) + + // Note: We can't actually call the tool without knowing what arguments it expects + // and without the context having the proper session, but we've validated the setup + t.Logf("Successfully validated end-to-end proxied tools flow") + }) + + t.Run("multiple datasources in single session", func(t *testing.T) { + discovered, err := discoverMCPDatasources(ctx) + require.NoError(t, err) + + if len(discovered) < 2 { + t.Skip("Need at least 2 Tempo datasources for this test") + } + + sm := NewSessionManager() + mockSession := &mockClientSession{id: "multi-ds-test-session"} + sm.CreateSession(ctx, mockSession) + + state, _ := sm.GetSession("multi-ds-test-session") + + // Try to connect to multiple datasources + connectedCount := 0 + for i, ds := range discovered { + if i >= 2 { + break // Test with first 2 datasources + } + + client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL) + if err != nil { + t.Logf("Could not connect to datasource %s: %v", ds.UID, err) + continue + } + defer func() { + _ = client.Close() + }() + + key := ds.Type + "_" + ds.UID + state.proxiedClients[key] = client + connectedCount++ + + t.Logf("Connected to datasource %s with %d tools", ds.UID, len(client.Tools)) + } + + if connectedCount == 0 { + t.Skip("Could not connect to any Tempo datasources") + } + + // Verify each client is stored correctly + for key, client := range state.proxiedClients { + parts := strings.Split(key, "_") + require.Len(t, parts, 2, "Key should have format type_uid") + assert.NotNil(t, client, "Client should not be nil") + assert.Equal(t, parts[0], client.DatasourceType, "Client type should match key") + assert.Equal(t, parts[1], client.DatasourceUID, "Client UID should match key") + } + + t.Logf("Successfully managed %d datasources in single session", connectedCount) + }) +} diff --git a/testdata/provisioning/datasources/datasources.yaml b/testdata/provisioning/datasources/datasources.yaml index 81a47c23..687772b3 100644 --- a/testdata/provisioning/datasources/datasources.yaml +++ b/testdata/provisioning/datasources/datasources.yaml @@ -27,3 +27,17 @@ datasources: access: proxy url: http://pyroscope:4040 isDefault: false + - name: Tempo + id: 4 + uid: tempo + type: tempo + access: proxy + url: http://tempo:3200 + isDefault: false + - name: Tempo Secondary + id: 5 + uid: tempo-secondary + type: tempo + access: proxy + url: http://tempo2:3201 + isDefault: false diff --git a/testdata/tempo-config-2.yaml b/testdata/tempo-config-2.yaml new file mode 100644 index 00000000..6706e340 --- /dev/null +++ b/testdata/tempo-config-2.yaml @@ -0,0 +1,29 @@ +server: + http_listen_port: 3201 + log_level: debug + +query_frontend: + mcp_server: + enabled: true + +distributor: + receivers: + otlp: + protocols: + http: + grpc: + +ingester: + max_block_duration: 5m + +compactor: + compaction: + block_retention: 1h + +storage: + trace: + backend: local + local: + path: /tmp/tempo2/blocks + wal: + path: /tmp/tempo2/wal diff --git a/testdata/tempo-config.yaml b/testdata/tempo-config.yaml new file mode 100644 index 00000000..dcdef2d7 --- /dev/null +++ b/testdata/tempo-config.yaml @@ -0,0 +1,29 @@ +server: + http_listen_port: 3200 + log_level: debug + +query_frontend: + mcp_server: + enabled: true + +distributor: + receivers: + otlp: + protocols: + http: + grpc: + +ingester: + max_block_duration: 5m + +compactor: + compaction: + block_retention: 1h + +storage: + trace: + backend: local + local: + path: /tmp/tempo/blocks + wal: + path: /tmp/tempo/wal diff --git a/tests/tempo_test.py b/tests/tempo_test.py new file mode 100644 index 00000000..521ec0bb --- /dev/null +++ b/tests/tempo_test.py @@ -0,0 +1,249 @@ +from mcp import ClientSession +import pytest +from langevals import expect +from langevals_langevals.llm_boolean import ( + CustomLLMBooleanEvaluator, + CustomLLMBooleanSettings, +) +from litellm import Message, acompletion +from mcp import ClientSession + +from conftest import models +from utils import ( + get_converted_tools, + llm_tool_call_sequence, +) + +pytestmark = pytest.mark.anyio + + +class TestTempoProxiedToolsBasic: + """Test Tempo proxied MCP tools functionality. + + These tests verify that Tempo datasources with MCP support are discovered + per-session and their tools are registered with a datasourceUid parameter + for multi-datasource support. + + Requires: + - Docker compose services running (includes 2 Tempo instances) + - GRAFANA_USERNAME and GRAFANA_PASSWORD environment variables + - MCP server running + """ + + @pytest.mark.anyio + async def test_tempo_tools_discovered_and_registered( + self, mcp_client: ClientSession + ): + """Test that Tempo tools are discovered and registered with datasourceUid parameter.""" + + # List all tools + list_response = await mcp_client.list_tools() + all_tool_names = [tool.name for tool in list_response.tools] + + # Find tempo-prefixed tools (should preserve hyphens from original tool names) + tempo_tools = [name for name in all_tool_names if name.startswith("tempo_")] + + # Expected tools from Tempo MCP server + expected_tempo_tools = [ + "tempo_traceql-search", + "tempo_traceql-metrics-instant", + "tempo_traceql-metrics-range", + "tempo_get-trace", + "tempo_get-attribute-names", + "tempo_get-attribute-values", + "tempo_docs-traceql", + ] + + assert len(tempo_tools) == len(expected_tempo_tools), ( + f"Expected {len(expected_tempo_tools)} unique tempo tools, found {len(tempo_tools)}: {tempo_tools}" + ) + + for expected_tool in expected_tempo_tools: + assert expected_tool in tempo_tools, ( + f"Tool {expected_tool} should be available" + ) + + @pytest.mark.anyio + async def test_tempo_tools_have_datasourceUid_parameter(self, mcp_client): + """Test that all tempo tools have a required datasourceUid parameter.""" + + list_response = await mcp_client.list_tools() + tempo_tools = [ + tool for tool in list_response.tools if tool.name.startswith("tempo_") + ] + + assert len(tempo_tools) > 0, "Should have at least one tempo tool" + + for tool in tempo_tools: + # Verify the tool has input schema + assert hasattr(tool, "inputSchema"), ( + f"Tool {tool.name} should have inputSchema" + ) + assert isinstance(tool.inputSchema, dict), ( + f"Tool {tool.name} inputSchema should be a dict" + ) + + # Verify datasourceUid parameter exists (camelCase) + properties = tool.inputSchema.get("properties", {}) + assert "datasourceUid" in properties, ( + f"Tool {tool.name} should have datasourceUid parameter (camelCase)" + ) + + # Verify it's required + required = tool.inputSchema.get("required", []) + assert "datasourceUid" in required, ( + f"Tool {tool.name} should require datasourceUid parameter" + ) + + # Verify parameter has proper description + datasource_uid_prop = properties["datasourceUid"] + assert "type" in datasource_uid_prop, ( + f"datasourceUid should have type defined" + ) + assert datasource_uid_prop["type"] == "string", ( + f"datasourceUid should be type string" + ) + + @pytest.mark.anyio + async def test_tempo_tool_call_with_valid_datasource(self, mcp_client): + """Test calling a tempo tool with a valid datasourceUid.""" + + # Call docs-traceql which should return documentation (doesn't require data) + try: + call_response = await mcp_client.call_tool( + "tempo_docs-traceql", + arguments={"datasourceUid": "tempo", "name": "basic"}, + ) + + # Verify we got a response + assert call_response.content, "Tool should return content" + + # Should have text content (documentation) + response_text = call_response.content[0].text + assert len(response_text) > 0, "Response should have content" + assert "traceql" in response_text.lower(), ( + "Response should contain TraceQL documentation" + ) + print(response_text) + + except Exception as e: + # If this fails, it might be because Tempo doesn't have data yet + # but at least verify the error isn't about missing datasourceUid + error_msg = str(e).lower() + assert "datasourceuid" not in error_msg, ( + f"Should not fail due to datasourceUid parameter: {e}" + ) + print(error_msg) + + @pytest.mark.anyio + async def test_tempo_tool_call_missing_datasourceUid(self, mcp_client): + """Test that calling a tempo tool without datasourceUid fails appropriately.""" + + with pytest.raises(Exception) as exc_info: + await mcp_client.call_tool( + "tempo_docs-traceql", + arguments={"name": "basic"}, # Missing datasourceUid + ) + + error_msg = str(exc_info.value).lower() + assert "datasourceuid" in error_msg and "required" in error_msg, ( + f"Should require datasourceUid parameter: {exc_info.value}" + ) + + @pytest.mark.anyio + async def test_tempo_tool_call_invalid_datasourceUid(self, mcp_client): + """Test that calling a tempo tool with invalid datasourceUid returns helpful error.""" + + with pytest.raises(Exception) as exc_info: + await mcp_client.call_tool( + "tempo_docs-traceql", + arguments={"datasourceUid": "nonexistent-tempo", "name": "basic"}, + ) + + error_msg = str(exc_info.value).lower() + # Should mention that datasource wasn't found + assert "not found" in error_msg or "not accessible" in error_msg, ( + f"Should indicate datasource not found: {exc_info.value}" + ) + + # Should mention available datasources to help user + assert "tempo" in error_msg or "available" in error_msg, ( + f"Error should be helpful and mention available datasources: {exc_info.value}" + ) + + @pytest.mark.anyio + async def test_tempo_tool_works_with_multiple_datasources(self, mcp_client): + """Test that the same tool works with different datasources via datasourceUid.""" + + # Both tempo and tempo-secondary should be available in our test environment + datasources = ["tempo", "tempo-secondary"] + + for datasource_uid in datasources: + try: + # Call the same tool with different datasources + call_response = await mcp_client.call_tool( + "tempo_get-attribute-names", + arguments={"datasourceUid": datasource_uid}, + ) + + # Verify we got a response + assert call_response.content, ( + f"Tool should return content for datasource {datasource_uid}" + ) + + # Response should be valid JSON or text + response_text = call_response.content[0].text + assert len(response_text) > 0, ( + f"Response should have content for datasource {datasource_uid}" + ) + + except Exception as e: + # If this fails, it's acceptable if Tempo doesn't have trace data yet + # But verify it's not a routing/config error + error_msg = str(e).lower() + assert ( + "not found" not in error_msg or datasource_uid not in error_msg + ), f"Datasource {datasource_uid} should be accessible: {e}" + + +class TestTempoProxiedToolsWithLLM: + """LLM integration tests for Tempo proxied tools.""" + + @pytest.mark.parametrize("model", models) + @pytest.mark.flaky(max_runs=3) + async def test_llm_can_list_trace_attributes( + self, model: str, mcp_client: ClientSession + ): + """Test that an LLM can list available trace attributes from Tempo.""" + tools = await get_converted_tools(mcp_client) + prompt = ( + "Use the tempo tools to get a list of all available trace attribute names " + "from the datasource with UID 'tempo'. I want to know what attributes " + "I can use in my TraceQL queries." + ) + + messages = [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content=prompt), + ] + + # LLM should call tempo_get-attribute-names with datasourceUid + messages = await llm_tool_call_sequence( + model, + messages, + tools, + mcp_client, + "tempo_get-attribute-names", + {"datasourceUid": "tempo"}, + ) + + # Final LLM response should mention attributes + response = await acompletion(model=model, messages=messages, tools=tools) + content = response.choices[0].message.content + + attributes_checker = CustomLLMBooleanEvaluator( + settings=CustomLLMBooleanSettings( + prompt="Does the response list or describe trace attributes that are available for querying?", + ) + ) + expect(input=prompt, output=content).to_pass(attributes_checker) diff --git a/tests/utils.py b/tests/utils.py index 39c9927d..0e8a3fe1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ from litellm import acompletion, Choices, Message from mcp.types import TextContent, Tool + async def assert_and_handle_tool_call( response: ModelResponse, mcp_client, @@ -15,9 +16,40 @@ async def assert_and_handle_tool_call( assert isinstance(c, Choices) tool_calls.extend(c.message.tool_calls or []) messages.append(c.message) - assert len(tool_calls) == 1 + + # Better error message if wrong number of tool calls + if len(tool_calls) != 1: + actual_calls = [tc.function.name for tc in tool_calls] if tool_calls else [] + assert len(tool_calls) == 1, ( + f"\nāŒ Expected exactly 1 tool call, got {len(tool_calls)}\n" + f"Expected tool: {expected_tool}\n" + f"Actual tools called: {actual_calls}\n" + f"LLM response: {response.choices[0].message.content if response.choices else 'N/A'}" + ) + for tool_call in tool_calls: - assert tool_call.function.name == expected_tool + actual_tool = tool_call.function.name + if actual_tool != expected_tool: + # Parse arguments to understand what LLM was trying to do + try: + actual_args = ( + json.loads(tool_call.function.arguments) + if tool_call.function.arguments + else {} + ) + except: + actual_args = tool_call.function.arguments + + assert False, ( + f"\nāŒ LLM called wrong tool!\n" + f"Expected: {expected_tool}\n" + f"Got: {actual_tool}\n" + f"With args: {json.dumps(actual_args, indent=2)}\n" + f"\nšŸ’” Debugging tips:\n" + f" - Check if the prompt clearly indicates which tool to use\n" + f" - Verify the expected tool exists in the available tools\n" + f" - Consider if the tool description is clear enough\n" + ) arguments = ( {} if len(tool_call.function.arguments) == 0 @@ -25,11 +57,26 @@ async def assert_and_handle_tool_call( ) if expected_args: for key, value in expected_args.items(): - assert key in arguments - assert arguments[key] == value + if key not in arguments: + assert False, ( + f"\nāŒ Missing expected parameter '{key}'\n" + f"Expected args: {json.dumps(expected_args, indent=2)}\n" + f"Actual args: {json.dumps(arguments, indent=2)}\n" + ) + if arguments[key] != value: + assert False, ( + f"\nāŒ Wrong value for parameter '{key}'\n" + f"Expected: {value}\n" + f"Got: {arguments[key]}\n" + f"Full args: {json.dumps(arguments, indent=2)}\n" + ) result = await mcp_client.call_tool(tool_call.function.name, arguments) - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) + assert len(result.content) == 1, ( + f"Expected one result for tool {tool_call.function.name}, got {len(result.content)}" + ) + assert isinstance(result.content[0], TextContent), ( + f"Expected TextContent for tool {tool_call.function.name}, got {type(result.content[0])}" + ) messages.append( Message( role="tool", tool_call_id=tool_call.id, content=result.content[0].text @@ -37,6 +84,7 @@ async def assert_and_handle_tool_call( ) return messages + def convert_tool(tool: Tool) -> dict: return { "type": "function", @@ -50,15 +98,27 @@ def convert_tool(tool: Tool) -> dict: }, } + async def llm_tool_call_sequence( model, messages, tools, mcp_client, tool_name, tool_args=None ): + print(f"\nšŸ¤– Calling LLM ({model}) and expecting tool: {tool_name}") + print(f"šŸ“ Last message: {messages[-1].get('content', messages[-1])[:200]}...") + response = await acompletion( model=model, messages=messages, tools=tools, ) assert isinstance(response, ModelResponse) + + # Print what tool was actually called for debugging + if response.choices and response.choices[0].message.tool_calls: + actual_tool = response.choices[0].message.tool_calls[0].function.name + print(f"āœ… LLM called: {actual_tool}") + if actual_tool != tool_name: + print(f"āš ļø WARNING: Expected {tool_name} but got {actual_tool}") + messages.extend( await assert_and_handle_tool_call( response, mcp_client, tool_name, tool_args or {} @@ -66,12 +126,15 @@ async def llm_tool_call_sequence( ) return messages + async def get_converted_tools(mcp_client): tools = await mcp_client.list_tools() return [convert_tool(t) for t in tools.tools] -async def flexible_tool_call(model, messages, tools, mcp_client, expected_tool_name, required_params=None): +async def flexible_tool_call( + model, messages, tools, mcp_client, expected_tool_name, required_params=None +): """ Make a flexible tool call that only checks essential parameters. Returns updated messages list. @@ -90,11 +153,17 @@ async def flexible_tool_call(model, messages, tools, mcp_client, expected_tool_n response = await acompletion(model=model, messages=messages, tools=tools) # Check that a tool call was made - assert response.choices[0].message.tool_calls is not None, f"Expected tool call for {expected_tool_name}" - assert len(response.choices[0].message.tool_calls) >= 1, f"Expected at least one tool call for {expected_tool_name}" + assert response.choices[0].message.tool_calls is not None, ( + f"Expected tool call for {expected_tool_name}" + ) + assert len(response.choices[0].message.tool_calls) >= 1, ( + f"Expected at least one tool call for {expected_tool_name}" + ) tool_call = response.choices[0].message.tool_calls[0] - assert tool_call.function.name == expected_tool_name, f"Expected {expected_tool_name} tool, got {tool_call.function.name}" + assert tool_call.function.name == expected_tool_name, ( + f"Expected {expected_tool_name} tool, got {tool_call.function.name}" + ) arguments = json.loads(tool_call.function.arguments) @@ -103,7 +172,9 @@ async def flexible_tool_call(model, messages, tools, mcp_client, expected_tool_n for key, expected_value in required_params.items(): assert key in arguments, f"Expected parameter '{key}' in tool arguments" if expected_value is not None: - assert arguments[key] == expected_value, f"Expected {key}='{expected_value}', got {key}='{arguments.get(key)}'" + assert arguments[key] == expected_value, ( + f"Expected {key}='{expected_value}', got {key}='{arguments.get(key)}'" + ) # Call the tool to verify it works result = await mcp_client.call_tool(tool_call.function.name, arguments) @@ -112,6 +183,8 @@ async def flexible_tool_call(model, messages, tools, mcp_client, expected_tool_n # Add both the tool call and result to message history messages.append(response.choices[0].message) - messages.append(Message(role="tool", tool_call_id=tool_call.id, content=result.content[0].text)) - - return messages \ No newline at end of file + messages.append( + Message(role="tool", tool_call_id=tool_call.id, content=result.content[0].text) + ) + + return messages diff --git a/tools/datasources_test.go b/tools/datasources_test.go index d58d0722..f4fdb041 100644 --- a/tools/datasources_test.go +++ b/tools/datasources_test.go @@ -68,8 +68,8 @@ func TestDatasourcesTools(t *testing.T) { ctx := newTestContext() result, err := listDatasources(ctx, ListDatasourcesParams{}) require.NoError(t, err) - // Four datasources are provisioned in the test environment (Prometheus, Loki, and Pyroscope). - assert.Len(t, result, 4) + // Six datasources are provisioned in the test environment (Prometheus, Prometheus Demo, Loki, Pyroscope, Tempo, and Tempo Secondary). + assert.Len(t, result, 6) }) t.Run("list datasources for type", func(t *testing.T) {