From c9120fc23d2060a8f0fc9905265bae4dee108ccf Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Fri, 10 Oct 2025 10:10:19 +0100 Subject: [PATCH 1/4] feat: add support for proxying to Tempo datasources' MCP servers This commit adds support for 'proxied' datasource tools. These are tools that are hosted on a remote MCP server (e.g., a Tempo datasource). Before the first MCP method is called, provided a session is available (i.e. not during an 'initialize' request), the MCP server will iterate over all suitable available datasources in the configured Grafana instance and attempt to discover and register tools from them, by calling the remote MCP server's 'list_tools' method then using dynamic session-based tools to register the discovered tools. Tools are named with the format: _. For example, a Tempo traceql-search tool would be named: tempo_traceql-search. Each added tool also has a 'datasourceUid' parameter added to its input schema, which is used to identify the datasource to query. The `--disable-proxied` flag can be used to disable this feature. Fixes #222. Supersedes #226. --- Makefile | 6 +- cmd/mcp-grafana/main.go | 90 +++- docker-compose.yaml | 16 + go.mod | 2 +- mcpgrafana.go | 3 +- proxied_client.go | 143 ++++++ proxied_handler.go | 72 +++ proxied_tools.go | 273 ++++++++++++ session.go | 114 +++++ session_test.go | 419 ++++++++++++++++++ .../provisioning/datasources/datasources.yaml | 14 + testdata/tempo-config-2.yaml | 29 ++ testdata/tempo-config.yaml | 29 ++ tests/tempo_test.py | 191 ++++++++ tools/datasources_test.go | 4 +- 15 files changed, 1375 insertions(+), 30 deletions(-) create mode 100644 proxied_client.go create mode 100644 proxied_handler.go create mode 100644 proxied_tools.go create mode 100644 session.go create mode 100644 session_test.go create mode 100644 testdata/tempo-config-2.yaml create mode 100644 testdata/tempo-config.yaml create mode 100644 tests/tempo_test.py diff --git a/Makefile b/Makefile index 7f50cc35..4d774714 100644 --- a/Makefile +++ b/Makefile @@ -53,15 +53,15 @@ test-python-e2e: ## Run Python E2E tests (requires docker-compose services and S .PHONY: run run: ## Run the MCP server in stdio mode. - go run ./cmd/mcp-grafana + GRAFANA_USERNAME=admin GRAFANA_PASSWORD=admin go run ./cmd/mcp-grafana .PHONY: run-sse run-sse: ## Run the MCP server in SSE mode. - go run ./cmd/mcp-grafana --transport sse --log-level debug --debug + GRAFANA_USERNAME=admin GRAFANA_PASSWORD=admin go run ./cmd/mcp-grafana --transport sse --log-level debug --debug PHONY: run-streamable-http run-streamable-http: ## Run the MCP server in StreamableHTTP mode. - go run ./cmd/mcp-grafana --transport streamable-http --log-level debug --debug + GRAFANA_USERNAME=admin GRAFANA_PASSWORD=admin go run ./cmd/mcp-grafana --transport streamable-http --log-level debug --debug .PHONY: run-test-services run-test-services: ## Run the docker-compose services required for the unit and integration tests. diff --git a/cmd/mcp-grafana/main.go b/cmd/mcp-grafana/main.go index 51a09c61..c31ef2ab 100644 --- a/cmd/mcp-grafana/main.go +++ b/cmd/mcp-grafana/main.go @@ -14,6 +14,7 @@ import ( "syscall" "time" + "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" mcpgrafana "github.com/grafana/mcp-grafana" @@ -40,7 +41,7 @@ type disabledTools struct { search, datasource, incident, prometheus, loki, alerting, dashboard, folder, oncall, asserts, sift, admin, - pyroscope, navigation bool + pyroscope, navigation, proxied bool } // Configuration for the Grafana client. @@ -56,8 +57,7 @@ type grafanaConfig struct { } func (dt *disabledTools) addFlags() { - flag.StringVar(&dt.enabledTools, "enabled-tools", "search,datasource,incident,prometheus,loki,alerting,dashboard,folder,oncall,asserts,sift,admin,pyroscope,navigation", "A comma separated list of tools enabled for this server. Can be overwritten entirely or by disabling specific components, e.g. --disable-search.") - + flag.StringVar(&dt.enabledTools, "enabled-tools", "search,datasource,incident,prometheus,loki,alerting,dashboard,folder,oncall,asserts,sift,admin,pyroscope,navigation,proxied", "A comma separated list of tools enabled for this server. Can be overwritten entirely or by disabling specific components, e.g. --disable-search.") flag.BoolVar(&dt.search, "disable-search", false, "Disable search tools") flag.BoolVar(&dt.datasource, "disable-datasource", false, "Disable datasource tools") flag.BoolVar(&dt.incident, "disable-incident", false, "Disable incident tools") @@ -72,6 +72,7 @@ func (dt *disabledTools) addFlags() { flag.BoolVar(&dt.admin, "disable-admin", false, "Disable admin tools") flag.BoolVar(&dt.pyroscope, "disable-pyroscope", false, "Disable pyroscope tools") flag.BoolVar(&dt.navigation, "disable-navigation", false, "Disable navigation tools") + flag.BoolVar(&dt.proxied, "disable-proxied", false, "Disable proxied tools (tools from external MCP servers)") } func (gc *grafanaConfig) addFlags() { @@ -103,21 +104,66 @@ func (dt *disabledTools) addTools(s *server.MCPServer) { } func newServer(dt disabledTools) *server.MCPServer { - s := server.NewMCPServer("mcp-grafana", mcpgrafana.Version(), server.WithInstructions(` - This server provides access to your Grafana instance and the surrounding ecosystem. - - Available Capabilities: - - Dashboards: Search, retrieve, update, and create dashboards. Extract panel queries and datasource information. - - Datasources: List and fetch details for datasources. - - Prometheus & Loki: Run PromQL and LogQL queries, retrieve metric/log metadata, and explore label names/values. - - Incidents: Search, create, update, and resolve incidents in Grafana Incident. - - Sift Investigations: Start and manage Sift investigations, analyze logs/traces, find error patterns, and detect slow requests. - - Alerting: List and fetch alert rules and notification contact points. - - OnCall: View and manage on-call schedules, shifts, teams, and users. - - Admin: List teams and perform administrative tasks. - - Pyroscope: Profile applications and fetch profiling data. - - Navigation: Generate deeplink URLs for Grafana resources like dashboards, panels, and Explore queries. - `)) + sm := mcpgrafana.NewSessionManager() + + // Declare variable for SessionToolManager that will be initialized after server creation + var stm *mcpgrafana.SessionToolManager + + // Create hooks + hooks := &server.Hooks{ + OnRegisterSession: []server.OnRegisterSessionHookFunc{sm.CreateSession}, + OnUnregisterSession: []server.OnUnregisterSessionHookFunc{sm.RemoveSession}, + } + + // Add proxied tools hooks if enabled + if !dt.proxied { + // OnBeforeListTools: Discover, connect, and register tools + hooks.OnBeforeListTools = []server.OnBeforeListToolsFunc{ + func(ctx context.Context, id any, request *mcp.ListToolsRequest) { + if stm != nil { + if session := server.ClientSessionFromContext(ctx); session != nil { + stm.InitializeAndRegisterProxiedTools(ctx, session) + } + } + }, + } + + // OnBeforeCallTool: Fallback in case client calls tool without listing first + hooks.OnBeforeCallTool = []server.OnBeforeCallToolFunc{ + func(ctx context.Context, id any, request *mcp.CallToolRequest) { + if stm != nil { + if session := server.ClientSessionFromContext(ctx); session != nil { + stm.InitializeAndRegisterProxiedTools(ctx, session) + } + } + }, + } + } + s := server.NewMCPServer("mcp-grafana", mcpgrafana.Version(), + server.WithInstructions(` +This server provides access to your Grafana instance and the surrounding ecosystem. + +Available Capabilities: +- Dashboards: Search, retrieve, update, and create dashboards. Extract panel queries and datasource information. +- Datasources: List and fetch details for datasources. +- Prometheus & Loki: Run PromQL and LogQL queries, retrieve metric/log metadata, and explore label names/values. +- Incidents: Search, create, update, and resolve incidents in Grafana Incident. +- Sift Investigations: Start and manage Sift investigations, analyze logs/traces, find error patterns, and detect slow requests. +- Alerting: List and fetch alert rules and notification contact points. +- OnCall: View and manage on-call schedules, shifts, teams, and users. +- Admin: List teams and perform administrative tasks. +- Pyroscope: Profile applications and fetch profiling data. +- Navigation: Generate deeplink URLs for Grafana resources like dashboards, panels, and Explore queries. +- Proxied Tools: Access tools from external MCP servers (like Tempo) through dynamic discovery. + +Note that some of these capabilities may be disabled. Do not try to use features that are not available via tools. +`), + server.WithHooks(hooks), + ) + + // Initialize SessionToolManager now that server is created + stm = mcpgrafana.NewSessionToolManager(sm, s, mcpgrafana.WithProxiedTools(!dt.proxied)) + dt.addTools(s) return s } @@ -162,6 +208,7 @@ func runHTTPServer(ctx context.Context, srv httpServer, addr, transportName stri if err := srv.Shutdown(shutdownCtx); err != nil { return fmt.Errorf("shutdown error: %v", err) } + slog.Debug("Shutdown called, waiting for connections to close...") // Wait for server to finish select { @@ -227,7 +274,7 @@ func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt case "streamable-http": opts := []server.StreamableHTTPOption{ server.WithHTTPContextFunc(mcpgrafana.ComposedHTTPContextFunc(gc)), - server.WithStateLess(true), + server.WithStateLess(dt.proxied), // Stateful when proxied tools enabled (requires sessions) server.WithEndpointPath(endpointPath), } if tls.certFile != "" || tls.keyFile != "" { @@ -238,10 +285,7 @@ func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt "version", mcpgrafana.Version(), "address", addr, "endpointPath", endpointPath) return runHTTPServer(ctx, srv, addr, "StreamableHTTP") default: - return fmt.Errorf( - "invalid transport type: %s. Must be 'stdio', 'sse' or 'streamable-http'", - transport, - ) + return fmt.Errorf("invalid transport type: %s. Must be 'stdio', 'sse' or 'streamable-http'", transport) } } diff --git a/docker-compose.yaml b/docker-compose.yaml index 9fa8b79d..4965e7c1 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -43,3 +43,19 @@ services: image: grafana/pyroscope:1.13.4 ports: - 4040:4040 + + tempo: + image: grafana/tempo:2.9.0-rc.0 + command: ["-config.file=/etc/tempo/tempo-config.yaml"] + volumes: + - ./testdata/tempo-config.yaml:/etc/tempo/tempo-config.yaml + ports: + - "3200:3200" # tempo + + tempo2: + image: grafana/tempo:2.9.0-rc.0 + command: ["-config.file=/etc/tempo/tempo-config.yaml"] + volumes: + - ./testdata/tempo-config-2.yaml:/etc/tempo/tempo-config.yaml + ports: + - "3201:3201" # tempo instance 2 diff --git a/go.mod b/go.mod index 8c8c872d..e65fbb9c 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( connectrpc.com/connect v1.19.0 github.com/PaesslerAG/gval v1.2.4 github.com/PaesslerAG/jsonpath v0.1.1 + github.com/cenkalti/backoff/v5 v5.0.3 github.com/go-openapi/runtime v0.29.0 github.com/go-openapi/strfmt v0.24.0 github.com/google/uuid v1.6.0 @@ -31,7 +32,6 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cheekybits/genny v1.0.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect diff --git a/mcpgrafana.go b/mcpgrafana.go index 58f31cfa..82eee3e9 100644 --- a/mcpgrafana.go +++ b/mcpgrafana.go @@ -435,7 +435,7 @@ var ExtractGrafanaClientFromEnv server.StdioContextFunc = func(ctx context.Conte } auth := userAndPassFromEnv() grafanaClient := NewGrafanaClient(ctx, grafanaURL, apiKey, auth) - return context.WithValue(ctx, grafanaClientKey{}, grafanaClient) + return WithGrafanaClient(ctx, grafanaClient) } // ExtractGrafanaClientFromHeaders is a HTTPContextFunc that creates and injects a Grafana client into the context. @@ -443,6 +443,7 @@ var ExtractGrafanaClientFromEnv server.StdioContextFunc = func(ctx context.Conte var ExtractGrafanaClientFromHeaders httpContextFunc = func(ctx context.Context, req *http.Request) context.Context { // Extract transport config from request headers, and set it on the context. u, apiKey, basicAuth := extractKeyGrafanaInfoFromReq(req) + slog.Debug("Creating Grafana client", "url", u, "api_key_set", apiKey != "", "basic_auth_set", basicAuth != nil) grafanaClient := NewGrafanaClient(ctx, u, apiKey, basicAuth) return WithGrafanaClient(ctx, grafanaClient) diff --git a/proxied_client.go b/proxied_client.go new file mode 100644 index 00000000..45571b16 --- /dev/null +++ b/proxied_client.go @@ -0,0 +1,143 @@ +package mcpgrafana + +import ( + "context" + "encoding/base64" + "fmt" + "log/slog" + "sync" + + mcp_client "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// ProxiedClient represents a connection to a remote MCP server (e.g., Tempo datasource) +type ProxiedClient struct { + DatasourceUID string + DatasourceName string + DatasourceType string + Client *mcp_client.Client + Tools []mcp.Tool + mutex sync.RWMutex +} + +// NewProxiedClient creates a new connection to a remote MCP server +func NewProxiedClient(ctx context.Context, datasourceUID, datasourceName, datasourceType, mcpEndpoint string) (*ProxiedClient, error) { + // Get Grafana config for authentication + config := GrafanaConfigFromContext(ctx) + + // Build headers for authentication + headers := make(map[string]string) + if config.APIKey != "" { + headers["Authorization"] = "Bearer " + config.APIKey + } else if config.BasicAuth != nil { + auth := config.BasicAuth.String() + headers["Authorization"] = "Basic " + base64.StdEncoding.EncodeToString([]byte(auth)) + } + + // Create HTTP transport with authentication headers + slog.DebugContext(ctx, "connecting to MCP server", "datasource", datasourceUID, "url", mcpEndpoint) + httpTransport, err := transport.NewStreamableHTTP( + mcpEndpoint, + transport.WithHTTPHeaders(headers), + ) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP transport: %w", err) + } + + // Create MCP client + mcpClient := mcp_client.NewClient(httpTransport) + + // Initialize the connection + initReq := mcp.InitializeRequest{} + initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initReq.Params.ClientInfo = mcp.Implementation{ + Name: "mcp-grafana-proxy", + Version: Version(), + } + + _, err = mcpClient.Initialize(ctx, initReq) + if err != nil { + _ = mcpClient.Close() + return nil, fmt.Errorf("failed to initialize MCP client: %w", err) + } + + // List available tools from the remote server + listReq := mcp.ListToolsRequest{} + toolsResult, err := mcpClient.ListTools(ctx, listReq) + if err != nil { + _ = mcpClient.Close() + return nil, fmt.Errorf("failed to list tools from remote MCP server: %w", err) + } + + slog.DebugContext(ctx, "connected to proxied MCP server", + "datasource", datasourceUID, + "type", datasourceType, + "tools", len(toolsResult.Tools)) + + return &ProxiedClient{ + DatasourceUID: datasourceUID, + DatasourceName: datasourceName, + DatasourceType: datasourceType, + Client: mcpClient, + Tools: toolsResult.Tools, + }, nil +} + +// CallTool forwards a tool call to the remote MCP server +func (pc *ProxiedClient) CallTool(ctx context.Context, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + pc.mutex.RLock() + defer pc.mutex.RUnlock() + + // Validate the tool exists + var toolExists bool + for _, tool := range pc.Tools { + if tool.Name == toolName { + toolExists = true + break + } + } + if !toolExists { + return nil, fmt.Errorf("tool %s not found in remote MCP server", toolName) + } + + // Create the call tool request + req := mcp.CallToolRequest{} + req.Params.Name = toolName + req.Params.Arguments = arguments + + // Forward the call to the remote server + result, err := pc.Client.CallTool(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to call tool on remote MCP server: %w", err) + } + + return result, nil +} + +// ListTools returns the tools available from this remote server +// Note: This method doesn't take a context parameter as the tools are cached locally +func (pc *ProxiedClient) ListTools() []mcp.Tool { + pc.mutex.RLock() + defer pc.mutex.RUnlock() + + // Return a copy to prevent external modification + result := make([]mcp.Tool, len(pc.Tools)) + copy(result, pc.Tools) + return result +} + +// Close closes the connection to the remote MCP server +func (pc *ProxiedClient) Close() error { + pc.mutex.Lock() + defer pc.mutex.Unlock() + + if pc.Client != nil { + if err := pc.Client.Close(); err != nil { + return fmt.Errorf("failed to close MCP client: %w", err) + } + } + + return nil +} diff --git a/proxied_handler.go b/proxied_handler.go new file mode 100644 index 00000000..77a5ee39 --- /dev/null +++ b/proxied_handler.go @@ -0,0 +1,72 @@ +package mcpgrafana + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// ProxiedToolHandler implements the CallToolHandler interface for proxied tools +type ProxiedToolHandler struct { + sessionManager *SessionManager + toolName string +} + +// NewProxiedToolHandler creates a new handler for a proxied tool +func NewProxiedToolHandler(sm *SessionManager, toolName string) *ProxiedToolHandler { + return &ProxiedToolHandler{ + sessionManager: sm, + toolName: toolName, + } +} + +// Handle forwards the tool call to the appropriate remote MCP server +func (h *ProxiedToolHandler) Handle(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Check if session is in context + session := server.ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("session not found in context") + } + + // Extract arguments + args, ok := request.Params.Arguments.(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid arguments type") + } + + // Extract required datasourceUid parameter + datasourceUidRaw, ok := args["datasourceUid"] + if !ok { + return nil, fmt.Errorf("datasourceUid parameter is required") + } + datasourceUID, ok := datasourceUidRaw.(string) + if !ok { + return nil, fmt.Errorf("datasourceUid must be a string") + } + + // Parse the tool name to get datasource type and original tool name + // Format: datasourceType_originalToolName (e.g., "tempo_traceql-search") + datasourceType, originalToolName, err := parseProxiedToolName(h.toolName) + if err != nil { + return nil, fmt.Errorf("failed to parse tool name: %w", err) + } + + // Get the proxied client for this datasource + client, err := h.sessionManager.GetProxiedClient(ctx, datasourceType, datasourceUID) + if err != nil { + return nil, fmt.Errorf("datasource '%s' not found or not accessible. Ensure the datasource exists and you have permission to access it", datasourceUID) + } + + // Remove datasourceUid from args before forwarding to remote server + forwardArgs := make(map[string]any) + for k, v := range args { + if k != "datasourceUid" { + forwardArgs[k] = v + } + } + + // Forward the call to the remote MCP server + return client.CallTool(ctx, originalToolName, forwardArgs) +} diff --git a/proxied_tools.go b/proxied_tools.go new file mode 100644 index 00000000..7ea7a936 --- /dev/null +++ b/proxied_tools.go @@ -0,0 +1,273 @@ +package mcpgrafana + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "strings" + + "github.com/go-openapi/runtime" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// MCPDatasourceConfig defines configuration for a datasource type that supports MCP +type MCPDatasourceConfig struct { + Type string + EndpointPath string // e.g., "/api/mcp" +} + +// mcpEnabledDatasources is a registry of datasource types that support MCP +var mcpEnabledDatasources = map[string]MCPDatasourceConfig{ + "tempo": {Type: "tempo", EndpointPath: "/api/mcp"}, + // Future: add other datasource types here +} + +// DiscoveredDatasource represents a datasource that supports MCP +type DiscoveredDatasource struct { + UID string + Name string + Type string + MCPURL string // The MCP endpoint URL +} + +// discoverMCPDatasources discovers datasources that support MCP +// Returns a list of datasources with MCP endpoints +func discoverMCPDatasources(ctx context.Context) ([]DiscoveredDatasource, error) { + gc := GrafanaClientFromContext(ctx) + if gc == nil { + return nil, fmt.Errorf("grafana client not found in context") + } + + var discovered []DiscoveredDatasource + + // List all datasources + resp, err := gc.Datasources.GetDataSources() + if err != nil { + return nil, fmt.Errorf("failed to list datasources: %w", err) + } + + // Get the Grafana base URL from context + config := GrafanaConfigFromContext(ctx) + if config.URL == "" { + return nil, fmt.Errorf("grafana url not found in context") + } + grafanaBaseURL := config.URL + + // Filter for datasources that support MCP + for _, ds := range resp.Payload { + // Check if this datasource type supports MCP + dsConfig, supported := mcpEnabledDatasources[ds.Type] + if !supported { + continue + } + + // Check if the datasource instance has MCP enabled + // We use a DELETE request to probe the MCP endpoint since: + // - GET would start an event stream and hang + // - POST doesn't work with the Grafana OpenAPI client + // - DELETE returns 200 if MCP is enabled, 404 if not + _, err := gc.Datasources.DatasourceProxyDELETEByUIDcalls(ds.UID, strings.TrimPrefix(dsConfig.EndpointPath, "/")) + if err == nil { + // Something strange happened - the server should never return a 202 for this really. Skip. + continue + } + if apiErr, ok := err.(*runtime.APIError); !ok || (ok && !apiErr.IsCode(http.StatusOK)) { + // Not a 200 response, MCP not enabled + continue + } + + // Build the MCP endpoint URL using Grafana's datasource proxy API + // Format: /api/datasources/proxy/uid/ + mcpURL := fmt.Sprintf("%s/api/datasources/proxy/uid/%s%s", grafanaBaseURL, ds.UID, dsConfig.EndpointPath) + + discovered = append(discovered, DiscoveredDatasource{ + UID: ds.UID, + Name: ds.Name, + Type: ds.Type, + MCPURL: mcpURL, + }) + } + + slog.DebugContext(ctx, "discovered MCP datasources", "count", len(discovered)) + return discovered, nil +} + +// addDatasourceUidParameter adds a required datasourceUid parameter to a tool's input schema +func addDatasourceUidParameter(tool mcp.Tool, datasourceType string) mcp.Tool { + modifiedTool := tool + // Prefix tool name with datasource type (e.g., "tempo_traceql-search") + modifiedTool.Name = datasourceType + "_" + tool.Name + + // Add datasourceUid to the input schema + if modifiedTool.InputSchema.Properties == nil { + modifiedTool.InputSchema.Properties = make(map[string]any) + } + + modifiedTool.InputSchema.Properties["datasourceUid"] = map[string]any{ + "type": "string", + "description": "UID of the " + datasourceType + " datasource to query", + } + + // Add to required fields + modifiedTool.InputSchema.Required = append(modifiedTool.InputSchema.Required, "datasourceUid") + + return modifiedTool +} + +// parseProxiedToolName extracts datasource type and original tool name from a proxied tool name +// Format: _ +// Returns: datasourceType, originalToolName, error +func parseProxiedToolName(toolName string) (string, string, error) { + parts := strings.SplitN(toolName, "_", 2) + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid proxied tool name format: %s", toolName) + } + return parts[0], parts[1], nil +} + +// SessionToolManager manages session-specific proxied tools +type SessionToolManager struct { + sm *SessionManager + server *server.MCPServer + + // Whether to enable proxied tools. + enableProxiedTools bool +} + +// NewSessionToolManager creates a new SessionToolManager +func NewSessionToolManager(sm *SessionManager, mcpServer *server.MCPServer, opts ...sessionToolManagerOption) *SessionToolManager { + stm := &SessionToolManager{ + sm: sm, + server: mcpServer, + } + for _, opt := range opts { + opt(stm) + } + return stm +} + +type sessionToolManagerOption func(*SessionToolManager) + +// WithProxiedTools sets whether proxied tools are enabled +func WithProxiedTools(enabled bool) sessionToolManagerOption { + return func(stm *SessionToolManager) { + stm.enableProxiedTools = enabled + } +} + +// InitializeAndRegisterProxiedTools discovers datasources, creates clients, and registers tools +// This should be called in OnBeforeListTools and OnBeforeCallTool hooks +func (stm *SessionToolManager) InitializeAndRegisterProxiedTools(ctx context.Context, session server.ClientSession) { + if !stm.enableProxiedTools { + return + } + + sessionID := session.SessionID() + state, exists := stm.sm.GetSession(sessionID) + if !exists { + // Session exists in server context but not in our SessionManager yet + stm.sm.CreateSession(ctx, session) + state, exists = stm.sm.GetSession(sessionID) + if !exists { + slog.Error("failed to create session in SessionManager", "sessionID", sessionID) + return + } + } + + state.mutex.Lock() + // Check if already initialized and registered + if state.proxiedToolsInitialized && len(state.proxiedTools) > 0 { + state.mutex.Unlock() + return + } + + // If already initialized but not registered, skip discovery + alreadyDiscovered := state.proxiedToolsInitialized + state.mutex.Unlock() + + // Step 1: Discover and connect (if not already done) + if !alreadyDiscovered { + // Discover datasources with MCP support + discovered, err := discoverMCPDatasources(ctx) + if err != nil { + slog.Error("failed to discover MCP datasources", "error", err) + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.mutex.Unlock() + return + } + + state.mutex.Lock() + // For each discovered datasource, create a proxied client + for _, ds := range discovered { + client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL) + if err != nil { + slog.Error("failed to create proxied client", "datasource", ds.UID, "error", err) + continue + } + + // Store the client + key := ds.Type + "_" + ds.UID + state.proxiedClients[key] = client + } + state.proxiedToolsInitialized = true + state.mutex.Unlock() + + slog.Info("connected to proxied MCP servers", "session", sessionID, "datasources", len(state.proxiedClients)) + } + + // Step 2: Register tools with the MCP server + state.mutex.Lock() + defer state.mutex.Unlock() + + // Check if tools already registered + if len(state.proxiedTools) > 0 { + return + } + + // Check if we have any clients (discovery should have happened above) + if len(state.proxiedClients) == 0 { + return + } + + // First pass: collect all unique tools and track which datasources support them + toolMap := make(map[string]mcp.Tool) // unique tools by name + + for key, client := range state.proxiedClients { + remoteTools := client.ListTools() + + for _, tool := range remoteTools { + // Tool name format: datasourceType_originalToolName (e.g., "tempo_traceql-search") + toolName := client.DatasourceType + "_" + tool.Name + + // Store the tool if we haven't seen it yet + if _, exists := toolMap[toolName]; !exists { + // Add datasourceUid parameter to the tool + modifiedTool := addDatasourceUidParameter(tool, client.DatasourceType) + toolMap[toolName] = modifiedTool + } + + // Track which datasources support this tool + state.toolToDatasources[toolName] = append(state.toolToDatasources[toolName], key) + } + } + + // Second pass: register all unique tools at once (reduces listChanged notifications) + var serverTools []server.ServerTool + for toolName, tool := range toolMap { + handler := NewProxiedToolHandler(stm.sm, toolName) + serverTools = append(serverTools, server.ServerTool{ + Tool: tool, + Handler: handler.Handle, + }) + state.proxiedTools = append(state.proxiedTools, tool) + } + + if err := stm.server.AddSessionTools(sessionID, serverTools...); err != nil { + slog.Warn("failed to add session tools", "session", sessionID, "error", err) + } else { + slog.Info("registered proxied tools", "session", sessionID, "tools", len(state.proxiedTools)) + } +} diff --git a/session.go b/session.go new file mode 100644 index 00000000..e1acfa67 --- /dev/null +++ b/session.go @@ -0,0 +1,114 @@ +package mcpgrafana + +import ( + "context" + "fmt" + "log/slog" + "sync" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// SessionState holds the state for a single client session +type SessionState struct { + // Proxied tools state + proxiedToolsInitialized bool + proxiedTools []mcp.Tool + proxiedClients map[string]*ProxiedClient // key: datasourceType_datasourceUID + toolToDatasources map[string][]string // key: toolName, value: list of datasource keys that support it + mutex sync.RWMutex +} + +func newSessionState() *SessionState { + return &SessionState{ + proxiedClients: make(map[string]*ProxiedClient), + toolToDatasources: make(map[string][]string), + } +} + +// SessionManager manages client sessions and their state +type SessionManager struct { + sessions map[string]*SessionState + mutex sync.RWMutex +} + +func NewSessionManager() *SessionManager { + return &SessionManager{ + sessions: make(map[string]*SessionState), + } +} + +func (sm *SessionManager) CreateSession(ctx context.Context, session server.ClientSession) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + sessionID := session.SessionID() + if _, exists := sm.sessions[sessionID]; !exists { + sm.sessions[sessionID] = newSessionState() + } +} + +func (sm *SessionManager) GetSession(sessionID string) (*SessionState, bool) { + sm.mutex.RLock() + defer sm.mutex.RUnlock() + + session, exists := sm.sessions[sessionID] + return session, exists +} + +func (sm *SessionManager) RemoveSession(ctx context.Context, session server.ClientSession) { + sm.mutex.Lock() + sessionID := session.SessionID() + state, exists := sm.sessions[sessionID] + delete(sm.sessions, sessionID) + sm.mutex.Unlock() + + if !exists { + return + } + + // Clean up proxied clients outside of the main lock + state.mutex.Lock() + defer state.mutex.Unlock() + + for key, client := range state.proxiedClients { + if err := client.Close(); err != nil { + slog.Error("failed to close proxied client", "key", key, "error", err) + } + } +} + +// GetProxiedClient retrieves a proxied client for the given datasource +func (sm *SessionManager) GetProxiedClient(ctx context.Context, datasourceType, datasourceUID string) (*ProxiedClient, error) { + session := server.ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("session not found in context") + } + + state, exists := sm.GetSession(session.SessionID()) + if !exists { + return nil, fmt.Errorf("session not found") + } + + state.mutex.RLock() + defer state.mutex.RUnlock() + + key := datasourceType + "_" + datasourceUID + client, exists := state.proxiedClients[key] + if !exists { + // List available datasources to help with debugging + var availableUIDs []string + for _, c := range state.proxiedClients { + if c.DatasourceType == datasourceType { + availableUIDs = append(availableUIDs, c.DatasourceUID) + } + } + if len(availableUIDs) > 0 { + return nil, fmt.Errorf("datasource '%s' not found. Available %s datasources: %v", datasourceUID, datasourceType, availableUIDs) + } + return nil, fmt.Errorf("datasource '%s' not found. No %s datasources with MCP support are configured", datasourceUID, datasourceType) + } + + return client, nil +} diff --git a/session_test.go b/session_test.go new file mode 100644 index 00000000..6625bb93 --- /dev/null +++ b/session_test.go @@ -0,0 +1,419 @@ +//go:build integration + +// Integration tests for proxied MCP tools functionality. +// Requires docker-compose to be running with Grafana and Tempo instances. +// Run with: go test -tags integration -v ./... + +package mcpgrafana + +import ( + "context" + "fmt" + "net/url" + "os" + "strings" + "testing" + + "github.com/go-openapi/strfmt" + grafana_client "github.com/grafana/grafana-openapi-client-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newProxiedToolsTestContext creates a test context with Grafana client and config +func newProxiedToolsTestContext(t *testing.T) context.Context { + cfg := grafana_client.DefaultTransportConfig() + cfg.Host = "localhost:3000" + cfg.Schemes = []string{"http"} + + // Extract transport config from env vars, and set it on the context. + if u, ok := os.LookupEnv("GRAFANA_URL"); ok { + parsedURL, err := url.Parse(u) + require.NoError(t, err, "invalid GRAFANA_URL") + cfg.Host = parsedURL.Host + // The Grafana client will always prefer HTTPS even if the URL is HTTP, + // so we need to limit the schemes to HTTP if the URL is HTTP. + if parsedURL.Scheme == "http" { + cfg.Schemes = []string{"http"} + } + } + + // Check for the new service account token environment variable first + if apiKey := os.Getenv("GRAFANA_SERVICE_ACCOUNT_TOKEN"); apiKey != "" { + cfg.APIKey = apiKey + } else if apiKey := os.Getenv("GRAFANA_API_KEY"); apiKey != "" { + // Fall back to the deprecated API key environment variable + cfg.APIKey = apiKey + } else { + cfg.BasicAuth = url.UserPassword("admin", "admin") + } + + grafanaClient := grafana_client.NewHTTPClientWithConfig(strfmt.Default, cfg) + + grafanaCfg := GrafanaConfig{ + Debug: true, + URL: "http://localhost:3000", + APIKey: cfg.APIKey, + BasicAuth: cfg.BasicAuth, + } + + ctx := WithGrafanaConfig(context.Background(), grafanaCfg) + return WithGrafanaClient(ctx, grafanaClient) +} + +func TestDiscoverMCPDatasources(t *testing.T) { + ctx := newProxiedToolsTestContext(t) + + t.Run("discovers tempo datasources", func(t *testing.T) { + discovered, err := discoverMCPDatasources(ctx) + require.NoError(t, err) + + // Should find two Tempo datasources from docker-compose + assert.GreaterOrEqual(t, len(discovered), 2, "Should discover at least 2 Tempo datasources") + + // Check that we found the expected datasources + uids := make([]string, len(discovered)) + for i, ds := range discovered { + uids[i] = ds.UID + assert.Equal(t, "tempo", ds.Type, "All discovered datasources should be tempo type") + assert.NotEmpty(t, ds.Name, "Datasource should have a name") + assert.NotEmpty(t, ds.MCPURL, "Datasource should have MCP URL") + + // Verify URL format + expectedURLPattern := fmt.Sprintf("http://localhost:3000/api/datasources/proxy/uid/%s/api/mcp", ds.UID) + assert.Equal(t, expectedURLPattern, ds.MCPURL, "MCP URL should follow proxy pattern") + } + + // Should contain our expected UIDs + assert.Contains(t, uids, "tempo", "Should discover 'tempo' datasource") + assert.Contains(t, uids, "tempo-secondary", "Should discover 'tempo-secondary' datasource") + }) + + t.Run("returns error when grafana client not in context", func(t *testing.T) { + emptyCtx := context.Background() + discovered, err := discoverMCPDatasources(emptyCtx) + assert.Error(t, err) + assert.Nil(t, discovered) + assert.Contains(t, err.Error(), "grafana client not found in context") + }) + + t.Run("returns error when auth is missing", func(t *testing.T) { + // Context with client but no auth credentials + cfg := grafana_client.DefaultTransportConfig() + cfg.Host = "localhost:3000" + cfg.Schemes = []string{"http"} + grafanaClient := grafana_client.NewHTTPClientWithConfig(strfmt.Default, cfg) + + grafanaCfg := GrafanaConfig{ + URL: "http://localhost:3000", + // No APIKey or BasicAuth set + } + ctx := WithGrafanaConfig(context.Background(), grafanaCfg) + ctx = WithGrafanaClient(ctx, grafanaClient) + + discovered, err := discoverMCPDatasources(ctx) + assert.Error(t, err) + assert.Nil(t, discovered) + assert.Contains(t, err.Error(), "Unauthorized") + }) +} + +func TestToolNamespacing(t *testing.T) { + t.Run("parse proxied tool name", func(t *testing.T) { + datasourceType, toolName, err := parseProxiedToolName("tempo_traceql-search") + require.NoError(t, err) + assert.Equal(t, "tempo", datasourceType) + assert.Equal(t, "traceql-search", toolName) + }) + + t.Run("parse proxied tool name with multiple underscores", func(t *testing.T) { + datasourceType, toolName, err := parseProxiedToolName("tempo_get-attribute-values") + require.NoError(t, err) + assert.Equal(t, "tempo", datasourceType) + assert.Equal(t, "get-attribute-values", toolName) + }) + + t.Run("parse proxied tool name with invalid format", func(t *testing.T) { + _, _, err := parseProxiedToolName("invalid") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid proxied tool name format") + }) + + t.Run("add datasourceUid parameter to tool", func(t *testing.T) { + originalTool := mcp.Tool{ + Name: "query_traces", + Description: "Query traces from Tempo", + InputSchema: mcp.ToolInputSchema{ + Properties: map[string]any{ + "query": map[string]any{ + "type": "string", + }, + }, + Required: []string{"query"}, + }, + } + + modifiedTool := addDatasourceUidParameter(originalTool, "tempo") + + assert.Equal(t, "tempo_query_traces", modifiedTool.Name) + assert.Equal(t, "Query traces from Tempo", modifiedTool.Description) + assert.NotNil(t, modifiedTool.InputSchema.Properties["datasourceUid"]) + assert.Contains(t, modifiedTool.InputSchema.Required, "datasourceUid") + assert.Contains(t, modifiedTool.InputSchema.Required, "query") + }) + + t.Run("add datasourceUid parameter with empty description", func(t *testing.T) { + originalTool := mcp.Tool{ + Name: "test_tool", + Description: "", + InputSchema: mcp.ToolInputSchema{ + Properties: make(map[string]any), + }, + } + + modifiedTool := addDatasourceUidParameter(originalTool, "tempo") + + assert.Equal(t, "tempo_test_tool", modifiedTool.Name) + assert.Equal(t, "", modifiedTool.Description, "Should not modify empty description") + assert.NotNil(t, modifiedTool.InputSchema.Properties["datasourceUid"]) + }) +} + +func TestSessionStateLifecycle(t *testing.T) { + t.Run("create and get session", func(t *testing.T) { + sm := NewSessionManager() + + // Create mock session + mockSession := &mockClientSession{id: "test-session-123"} + + sm.CreateSession(context.Background(), mockSession) + + state, exists := sm.GetSession("test-session-123") + assert.True(t, exists) + assert.NotNil(t, state) + assert.NotNil(t, state.proxiedClients) + assert.False(t, state.proxiedToolsInitialized) + }) + + t.Run("remove session cleans up clients", func(t *testing.T) { + sm := NewSessionManager() + + mockSession := &mockClientSession{id: "test-session-456"} + sm.CreateSession(context.Background(), mockSession) + + state, _ := sm.GetSession("test-session-456") + + // Add a mock proxied client + mockClient := &ProxiedClient{ + DatasourceUID: "test-uid", + DatasourceName: "Test Datasource", + DatasourceType: "tempo", + } + state.proxiedClients["tempo_test-uid"] = mockClient + + // Remove session + sm.RemoveSession(context.Background(), mockSession) + + // Session should be gone + _, exists := sm.GetSession("test-session-456") + assert.False(t, exists) + }) + + t.Run("get non-existent session", func(t *testing.T) { + sm := NewSessionManager() + + state, exists := sm.GetSession("non-existent") + assert.False(t, exists) + assert.Nil(t, state) + }) +} + +func TestProxiedClientLifecycle(t *testing.T) { + ctx := newProxiedToolsTestContext(t) + + t.Run("list tools returns copy", func(t *testing.T) { + pc := &ProxiedClient{ + DatasourceUID: "test-uid", + DatasourceName: "Test", + DatasourceType: "tempo", + Tools: []mcp.Tool{ + {Name: "tool1", Description: "First tool"}, + {Name: "tool2", Description: "Second tool"}, + }, + } + + tools1 := pc.ListTools() + tools2 := pc.ListTools() + + // Should return same content + assert.Equal(t, tools1, tools2) + + // But different slice instances (copy) + assert.NotSame(t, &tools1[0], &tools2[0]) + }) + + t.Run("call tool validates tool exists", func(t *testing.T) { + pc := &ProxiedClient{ + DatasourceUID: "test-uid", + DatasourceName: "Test", + DatasourceType: "tempo", + Tools: []mcp.Tool{ + {Name: "valid_tool", Description: "Valid tool"}, + }, + } + + // Call non-existent tool + result, err := pc.CallTool(ctx, "non_existent_tool", map[string]any{}) + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "not found in remote MCP server") + }) +} + +func TestEndToEndProxiedToolsFlow(t *testing.T) { + ctx := newProxiedToolsTestContext(t) + + t.Run("full flow from discovery to tool call", func(t *testing.T) { + // Step 1: Discover MCP datasources + discovered, err := discoverMCPDatasources(ctx) + require.NoError(t, err) + require.GreaterOrEqual(t, len(discovered), 1, "Should discover at least one Tempo datasource") + + // Use the first discovered datasource + ds := discovered[0] + t.Logf("Testing with datasource: %s (UID: %s, URL: %s)", ds.Name, ds.UID, ds.MCPURL) + + // Step 2: Create a proxied client connection + client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL) + if err != nil { + t.Skipf("Skipping end-to-end test: Tempo MCP endpoint not available: %v", err) + return + } + defer func() { + _ = client.Close() + }() + + // Step 3: Verify we got tools from the remote server + tools := client.ListTools() + require.Greater(t, len(tools), 0, "Should have at least one tool from Tempo MCP server") + t.Logf("Discovered %d tools from Tempo MCP server", len(tools)) + + // Log the available tools + for _, tool := range tools { + t.Logf(" - Tool: %s - %s", tool.Name, tool.Description) + } + + // Step 4: Test tool modification with datasourceUid parameter + firstTool := tools[0] + modifiedTool := addDatasourceUidParameter(firstTool, ds.Type) + + expectedName := ds.Type + "_" + firstTool.Name + assert.Equal(t, expectedName, modifiedTool.Name, "Modified tool should have prefixed name") + assert.Contains(t, modifiedTool.InputSchema.Required, "datasourceUid", "Modified tool should require datasourceUid") + + // Step 5: Test session integration + sm := NewSessionManager() + mockSession := &mockClientSession{id: "e2e-test-session"} + sm.CreateSession(ctx, mockSession) + + state, exists := sm.GetSession("e2e-test-session") + require.True(t, exists) + + // Store the proxied client in session state + key := ds.Type + "_" + ds.UID + state.proxiedClients[key] = client + + // Step 6: Verify client is stored correctly in session + retrievedClient, exists := state.proxiedClients[key] + require.True(t, exists, "Client should be stored in session state") + assert.Equal(t, client, retrievedClient, "Should retrieve the same client from session") + + // Step 7: Test ProxiedToolHandler flow + handler := NewProxiedToolHandler(sm, modifiedTool.Name) + assert.NotNil(t, handler) + + // Note: We can't actually call the tool without knowing what arguments it expects + // and without the context having the proper session, but we've validated the setup + t.Logf("Successfully validated end-to-end proxied tools flow") + }) + + t.Run("multiple datasources in single session", func(t *testing.T) { + discovered, err := discoverMCPDatasources(ctx) + require.NoError(t, err) + + if len(discovered) < 2 { + t.Skip("Need at least 2 Tempo datasources for this test") + } + + sm := NewSessionManager() + mockSession := &mockClientSession{id: "multi-ds-test-session"} + sm.CreateSession(ctx, mockSession) + + state, _ := sm.GetSession("multi-ds-test-session") + + // Try to connect to multiple datasources + connectedCount := 0 + for i, ds := range discovered { + if i >= 2 { + break // Test with first 2 datasources + } + + client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL) + if err != nil { + t.Logf("Could not connect to datasource %s: %v", ds.UID, err) + continue + } + defer func() { + _ = client.Close() + }() + + key := ds.Type + "_" + ds.UID + state.proxiedClients[key] = client + connectedCount++ + + t.Logf("Connected to datasource %s with %d tools", ds.UID, len(client.Tools)) + } + + if connectedCount == 0 { + t.Skip("Could not connect to any Tempo datasources") + } + + // Verify each client is stored correctly + for key, client := range state.proxiedClients { + parts := strings.Split(key, "_") + require.Len(t, parts, 2, "Key should have format type_uid") + assert.NotNil(t, client, "Client should not be nil") + assert.Equal(t, parts[0], client.DatasourceType, "Client type should match key") + assert.Equal(t, parts[1], client.DatasourceUID, "Client UID should match key") + } + + t.Logf("Successfully managed %d datasources in single session", connectedCount) + }) +} + +// mockClientSession implements server.ClientSession for testing +type mockClientSession struct { + id string + notifChannel chan mcp.JSONRPCNotification + isInitialized bool +} + +func (m *mockClientSession) SessionID() string { + return m.id +} + +func (m *mockClientSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + if m.notifChannel == nil { + m.notifChannel = make(chan mcp.JSONRPCNotification, 10) + } + return m.notifChannel +} + +func (m *mockClientSession) Initialize() { + m.isInitialized = true +} + +func (m *mockClientSession) Initialized() bool { + return m.isInitialized +} diff --git a/testdata/provisioning/datasources/datasources.yaml b/testdata/provisioning/datasources/datasources.yaml index 81a47c23..687772b3 100644 --- a/testdata/provisioning/datasources/datasources.yaml +++ b/testdata/provisioning/datasources/datasources.yaml @@ -27,3 +27,17 @@ datasources: access: proxy url: http://pyroscope:4040 isDefault: false + - name: Tempo + id: 4 + uid: tempo + type: tempo + access: proxy + url: http://tempo:3200 + isDefault: false + - name: Tempo Secondary + id: 5 + uid: tempo-secondary + type: tempo + access: proxy + url: http://tempo2:3201 + isDefault: false diff --git a/testdata/tempo-config-2.yaml b/testdata/tempo-config-2.yaml new file mode 100644 index 00000000..6706e340 --- /dev/null +++ b/testdata/tempo-config-2.yaml @@ -0,0 +1,29 @@ +server: + http_listen_port: 3201 + log_level: debug + +query_frontend: + mcp_server: + enabled: true + +distributor: + receivers: + otlp: + protocols: + http: + grpc: + +ingester: + max_block_duration: 5m + +compactor: + compaction: + block_retention: 1h + +storage: + trace: + backend: local + local: + path: /tmp/tempo2/blocks + wal: + path: /tmp/tempo2/wal diff --git a/testdata/tempo-config.yaml b/testdata/tempo-config.yaml new file mode 100644 index 00000000..dcdef2d7 --- /dev/null +++ b/testdata/tempo-config.yaml @@ -0,0 +1,29 @@ +server: + http_listen_port: 3200 + log_level: debug + +query_frontend: + mcp_server: + enabled: true + +distributor: + receivers: + otlp: + protocols: + http: + grpc: + +ingester: + max_block_duration: 5m + +compactor: + compaction: + block_retention: 1h + +storage: + trace: + backend: local + local: + path: /tmp/tempo/blocks + wal: + path: /tmp/tempo/wal diff --git a/tests/tempo_test.py b/tests/tempo_test.py new file mode 100644 index 00000000..00895096 --- /dev/null +++ b/tests/tempo_test.py @@ -0,0 +1,191 @@ +from mcp import ClientSession +import pytest + + +class TestTempoProxiedTools: + """Test Tempo proxied MCP tools functionality. + + These tests verify that Tempo datasources with MCP support are discovered + per-session and their tools are registered with a datasourceUid parameter + for multi-datasource support. + + Requires: + - Docker compose services running (includes 2 Tempo instances) + - GRAFANA_USERNAME and GRAFANA_PASSWORD environment variables + - MCP server running + """ + + @pytest.mark.anyio + async def test_tempo_tools_discovered_and_registered( + self, mcp_client: ClientSession + ): + """Test that Tempo tools are discovered and registered with datasourceUid parameter.""" + + # List all tools + list_response = await mcp_client.list_tools() + all_tool_names = [tool.name for tool in list_response.tools] + + # Find tempo-prefixed tools (should preserve hyphens from original tool names) + tempo_tools = [name for name in all_tool_names if name.startswith("tempo_")] + + # Expected tools from Tempo MCP server + expected_tempo_tools = [ + "tempo_traceql-search", + "tempo_traceql-metrics-instant", + "tempo_traceql-metrics-range", + "tempo_get-trace", + "tempo_get-attribute-names", + "tempo_get-attribute-values", + "tempo_docs-traceql", + ] + + assert len(tempo_tools) == len(expected_tempo_tools), ( + f"Expected {len(expected_tempo_tools)} unique tempo tools, found {len(tempo_tools)}: {tempo_tools}" + ) + + for expected_tool in expected_tempo_tools: + assert expected_tool in tempo_tools, ( + f"Tool {expected_tool} should be available" + ) + + @pytest.mark.anyio + async def test_tempo_tools_have_datasourceUid_parameter(self, mcp_client): + """Test that all tempo tools have a required datasourceUid parameter.""" + + list_response = await mcp_client.list_tools() + tempo_tools = [ + tool for tool in list_response.tools if tool.name.startswith("tempo_") + ] + + assert len(tempo_tools) > 0, "Should have at least one tempo tool" + + for tool in tempo_tools: + # Verify the tool has input schema + assert hasattr(tool, "inputSchema"), ( + f"Tool {tool.name} should have inputSchema" + ) + assert isinstance(tool.inputSchema, dict), ( + f"Tool {tool.name} inputSchema should be a dict" + ) + + # Verify datasourceUid parameter exists (camelCase) + properties = tool.inputSchema.get("properties", {}) + assert "datasourceUid" in properties, ( + f"Tool {tool.name} should have datasourceUid parameter (camelCase)" + ) + + # Verify it's required + required = tool.inputSchema.get("required", []) + assert "datasourceUid" in required, ( + f"Tool {tool.name} should require datasourceUid parameter" + ) + + # Verify parameter has proper description + datasource_uid_prop = properties["datasourceUid"] + assert "type" in datasource_uid_prop, ( + f"datasourceUid should have type defined" + ) + assert datasource_uid_prop["type"] == "string", ( + f"datasourceUid should be type string" + ) + + @pytest.mark.anyio + async def test_tempo_tool_call_with_valid_datasource(self, mcp_client): + """Test calling a tempo tool with a valid datasourceUid.""" + + # Call docs-traceql which should return documentation (doesn't require data) + try: + call_response = await mcp_client.call_tool( + "tempo_docs-traceql", + arguments={"datasourceUid": "tempo", "name": "basic"}, + ) + + # Verify we got a response + assert call_response.content, "Tool should return content" + + # Should have text content (documentation) + response_text = call_response.content[0].text + assert len(response_text) > 0, "Response should have content" + assert "traceql" in response_text.lower(), ( + "Response should contain TraceQL documentation" + ) + print(response_text) + + except Exception as e: + # If this fails, it might be because Tempo doesn't have data yet + # but at least verify the error isn't about missing datasourceUid + error_msg = str(e).lower() + assert "datasourceuid" not in error_msg, ( + f"Should not fail due to datasourceUid parameter: {e}" + ) + print(error_msg) + + @pytest.mark.anyio + async def test_tempo_tool_call_missing_datasourceUid(self, mcp_client): + """Test that calling a tempo tool without datasourceUid fails appropriately.""" + + with pytest.raises(Exception) as exc_info: + await mcp_client.call_tool( + "tempo_docs-traceql", + arguments={"name": "basic"}, # Missing datasourceUid + ) + + error_msg = str(exc_info.value).lower() + assert "datasourceuid" in error_msg and "required" in error_msg, ( + f"Should require datasourceUid parameter: {exc_info.value}" + ) + + @pytest.mark.anyio + async def test_tempo_tool_call_invalid_datasourceUid(self, mcp_client): + """Test that calling a tempo tool with invalid datasourceUid returns helpful error.""" + + with pytest.raises(Exception) as exc_info: + await mcp_client.call_tool( + "tempo_docs-traceql", + arguments={"datasourceUid": "nonexistent-tempo", "name": "basic"}, + ) + + error_msg = str(exc_info.value).lower() + # Should mention that datasource wasn't found + assert "not found" in error_msg or "not accessible" in error_msg, ( + f"Should indicate datasource not found: {exc_info.value}" + ) + + # Should mention available datasources to help user + assert "tempo" in error_msg or "available" in error_msg, ( + f"Error should be helpful and mention available datasources: {exc_info.value}" + ) + + @pytest.mark.anyio + async def test_tempo_tool_works_with_multiple_datasources(self, mcp_client): + """Test that the same tool works with different datasources via datasourceUid.""" + + # Both tempo and tempo-secondary should be available in our test environment + datasources = ["tempo", "tempo-secondary"] + + for datasource_uid in datasources: + try: + # Call the same tool with different datasources + call_response = await mcp_client.call_tool( + "tempo_get-attribute-names", + arguments={"datasourceUid": datasource_uid}, + ) + + # Verify we got a response + assert call_response.content, ( + f"Tool should return content for datasource {datasource_uid}" + ) + + # Response should be valid JSON or text + response_text = call_response.content[0].text + assert len(response_text) > 0, ( + f"Response should have content for datasource {datasource_uid}" + ) + + except Exception as e: + # If this fails, it's acceptable if Tempo doesn't have trace data yet + # But verify it's not a routing/config error + error_msg = str(e).lower() + assert ( + "not found" not in error_msg or datasource_uid not in error_msg + ), f"Datasource {datasource_uid} should be accessible: {e}" diff --git a/tools/datasources_test.go b/tools/datasources_test.go index d58d0722..f4fdb041 100644 --- a/tools/datasources_test.go +++ b/tools/datasources_test.go @@ -68,8 +68,8 @@ func TestDatasourcesTools(t *testing.T) { ctx := newTestContext() result, err := listDatasources(ctx, ListDatasourcesParams{}) require.NoError(t, err) - // Four datasources are provisioned in the test environment (Prometheus, Loki, and Pyroscope). - assert.Len(t, result, 4) + // Six datasources are provisioned in the test environment (Prometheus, Prometheus Demo, Loki, Pyroscope, Tempo, and Tempo Secondary). + assert.Len(t, result, 6) }) t.Run("list datasources for type", func(t *testing.T) { From b6fa3a2eca72c246b7731bc25f2780b4d0510567 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Fri, 10 Oct 2025 13:14:02 +0100 Subject: [PATCH 2/4] Improve e2e tests --- tests/tempo_test.py | 61 +++++++++++++++++++++++++- tests/utils.py | 101 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 147 insertions(+), 15 deletions(-) diff --git a/tests/tempo_test.py b/tests/tempo_test.py index 00895096..dd6c9927 100644 --- a/tests/tempo_test.py +++ b/tests/tempo_test.py @@ -1,8 +1,23 @@ from mcp import ClientSession import pytest +from langevals import expect +from langevals_langevals.llm_boolean import ( + CustomLLMBooleanEvaluator, + CustomLLMBooleanSettings, +) +from litellm import Message, acompletion +from mcp import ClientSession + +from conftest import models +from utils import ( + get_converted_tools, + llm_tool_call_sequence, +) +pytestmark = pytest.mark.anyio -class TestTempoProxiedTools: + +class TestTempoProxiedToolsBasic: """Test Tempo proxied MCP tools functionality. These tests verify that Tempo datasources with MCP support are discovered @@ -189,3 +204,47 @@ async def test_tempo_tool_works_with_multiple_datasources(self, mcp_client): assert ( "not found" not in error_msg or datasource_uid not in error_msg ), f"Datasource {datasource_uid} should be accessible: {e}" + + +class TestTempoProxiedToolsWithLLM: + """LLM integration tests for Tempo proxied tools.""" + + @pytest.mark.parametrize("model", models) + @pytest.mark.flaky(max_runs=3) + async def test_llm_can_list_trace_attributes( + self, model: str, mcp_client: ClientSession + ): + """Test that an LLM can list available trace attributes from Tempo.""" + tools = await get_converted_tools(mcp_client) + # prompt = ( + # "Use the tempo tools to get a list of all available trace attribute names " + # "from the datasource with UID 'tempo'. I want to know what attributes " + # "I can use in my TraceQL queries." + # ) + prompt = "what trace attributes are available in the tempo datasource?" + + messages = [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content=prompt), + ] + + # LLM should call tempo_get-attribute-names with datasourceUid + messages = await llm_tool_call_sequence( + model, + messages, + tools, + mcp_client, + "tempo_get-attribute-names", + {"datasourceUid": "tempo"}, + ) + + # Final LLM response should mention attributes + response = await acompletion(model=model, messages=messages, tools=tools) + content = response.choices[0].message.content + + attributes_checker = CustomLLMBooleanEvaluator( + settings=CustomLLMBooleanSettings( + prompt="Does the response list or describe trace attributes that are available for querying?", + ) + ) + expect(input=prompt, output=content).to_pass(attributes_checker) diff --git a/tests/utils.py b/tests/utils.py index 39c9927d..0e8a3fe1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ from litellm import acompletion, Choices, Message from mcp.types import TextContent, Tool + async def assert_and_handle_tool_call( response: ModelResponse, mcp_client, @@ -15,9 +16,40 @@ async def assert_and_handle_tool_call( assert isinstance(c, Choices) tool_calls.extend(c.message.tool_calls or []) messages.append(c.message) - assert len(tool_calls) == 1 + + # Better error message if wrong number of tool calls + if len(tool_calls) != 1: + actual_calls = [tc.function.name for tc in tool_calls] if tool_calls else [] + assert len(tool_calls) == 1, ( + f"\nāŒ Expected exactly 1 tool call, got {len(tool_calls)}\n" + f"Expected tool: {expected_tool}\n" + f"Actual tools called: {actual_calls}\n" + f"LLM response: {response.choices[0].message.content if response.choices else 'N/A'}" + ) + for tool_call in tool_calls: - assert tool_call.function.name == expected_tool + actual_tool = tool_call.function.name + if actual_tool != expected_tool: + # Parse arguments to understand what LLM was trying to do + try: + actual_args = ( + json.loads(tool_call.function.arguments) + if tool_call.function.arguments + else {} + ) + except: + actual_args = tool_call.function.arguments + + assert False, ( + f"\nāŒ LLM called wrong tool!\n" + f"Expected: {expected_tool}\n" + f"Got: {actual_tool}\n" + f"With args: {json.dumps(actual_args, indent=2)}\n" + f"\nšŸ’” Debugging tips:\n" + f" - Check if the prompt clearly indicates which tool to use\n" + f" - Verify the expected tool exists in the available tools\n" + f" - Consider if the tool description is clear enough\n" + ) arguments = ( {} if len(tool_call.function.arguments) == 0 @@ -25,11 +57,26 @@ async def assert_and_handle_tool_call( ) if expected_args: for key, value in expected_args.items(): - assert key in arguments - assert arguments[key] == value + if key not in arguments: + assert False, ( + f"\nāŒ Missing expected parameter '{key}'\n" + f"Expected args: {json.dumps(expected_args, indent=2)}\n" + f"Actual args: {json.dumps(arguments, indent=2)}\n" + ) + if arguments[key] != value: + assert False, ( + f"\nāŒ Wrong value for parameter '{key}'\n" + f"Expected: {value}\n" + f"Got: {arguments[key]}\n" + f"Full args: {json.dumps(arguments, indent=2)}\n" + ) result = await mcp_client.call_tool(tool_call.function.name, arguments) - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) + assert len(result.content) == 1, ( + f"Expected one result for tool {tool_call.function.name}, got {len(result.content)}" + ) + assert isinstance(result.content[0], TextContent), ( + f"Expected TextContent for tool {tool_call.function.name}, got {type(result.content[0])}" + ) messages.append( Message( role="tool", tool_call_id=tool_call.id, content=result.content[0].text @@ -37,6 +84,7 @@ async def assert_and_handle_tool_call( ) return messages + def convert_tool(tool: Tool) -> dict: return { "type": "function", @@ -50,15 +98,27 @@ def convert_tool(tool: Tool) -> dict: }, } + async def llm_tool_call_sequence( model, messages, tools, mcp_client, tool_name, tool_args=None ): + print(f"\nšŸ¤– Calling LLM ({model}) and expecting tool: {tool_name}") + print(f"šŸ“ Last message: {messages[-1].get('content', messages[-1])[:200]}...") + response = await acompletion( model=model, messages=messages, tools=tools, ) assert isinstance(response, ModelResponse) + + # Print what tool was actually called for debugging + if response.choices and response.choices[0].message.tool_calls: + actual_tool = response.choices[0].message.tool_calls[0].function.name + print(f"āœ… LLM called: {actual_tool}") + if actual_tool != tool_name: + print(f"āš ļø WARNING: Expected {tool_name} but got {actual_tool}") + messages.extend( await assert_and_handle_tool_call( response, mcp_client, tool_name, tool_args or {} @@ -66,12 +126,15 @@ async def llm_tool_call_sequence( ) return messages + async def get_converted_tools(mcp_client): tools = await mcp_client.list_tools() return [convert_tool(t) for t in tools.tools] -async def flexible_tool_call(model, messages, tools, mcp_client, expected_tool_name, required_params=None): +async def flexible_tool_call( + model, messages, tools, mcp_client, expected_tool_name, required_params=None +): """ Make a flexible tool call that only checks essential parameters. Returns updated messages list. @@ -90,11 +153,17 @@ async def flexible_tool_call(model, messages, tools, mcp_client, expected_tool_n response = await acompletion(model=model, messages=messages, tools=tools) # Check that a tool call was made - assert response.choices[0].message.tool_calls is not None, f"Expected tool call for {expected_tool_name}" - assert len(response.choices[0].message.tool_calls) >= 1, f"Expected at least one tool call for {expected_tool_name}" + assert response.choices[0].message.tool_calls is not None, ( + f"Expected tool call for {expected_tool_name}" + ) + assert len(response.choices[0].message.tool_calls) >= 1, ( + f"Expected at least one tool call for {expected_tool_name}" + ) tool_call = response.choices[0].message.tool_calls[0] - assert tool_call.function.name == expected_tool_name, f"Expected {expected_tool_name} tool, got {tool_call.function.name}" + assert tool_call.function.name == expected_tool_name, ( + f"Expected {expected_tool_name} tool, got {tool_call.function.name}" + ) arguments = json.loads(tool_call.function.arguments) @@ -103,7 +172,9 @@ async def flexible_tool_call(model, messages, tools, mcp_client, expected_tool_n for key, expected_value in required_params.items(): assert key in arguments, f"Expected parameter '{key}' in tool arguments" if expected_value is not None: - assert arguments[key] == expected_value, f"Expected {key}='{expected_value}', got {key}='{arguments.get(key)}'" + assert arguments[key] == expected_value, ( + f"Expected {key}='{expected_value}', got {key}='{arguments.get(key)}'" + ) # Call the tool to verify it works result = await mcp_client.call_tool(tool_call.function.name, arguments) @@ -112,6 +183,8 @@ async def flexible_tool_call(model, messages, tools, mcp_client, expected_tool_n # Add both the tool call and result to message history messages.append(response.choices[0].message) - messages.append(Message(role="tool", tool_call_id=tool_call.id, content=result.content[0].text)) - - return messages \ No newline at end of file + messages.append( + Message(role="tool", tool_call_id=tool_call.id, content=result.content[0].text) + ) + + return messages From a2eac79a4fc5f3e89424e9a429b5866710979ba2 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Mon, 13 Oct 2025 08:43:51 +0100 Subject: [PATCH 3/4] Don't try to use per-session tools for stdio The stdio transport is single-session by default so 'per-session' methods don't work. Instead, just register the tools on the server at startup in this mode. --- cmd/mcp-grafana/main.go | 32 ++++++--- proxied_handler.go | 19 +++++- proxied_tools.go | 143 +++++++++++++++++++++++++++++++++------- session_test.go | 2 +- tests/tempo_test.py | 11 ++-- 5 files changed, 166 insertions(+), 41 deletions(-) diff --git a/cmd/mcp-grafana/main.go b/cmd/mcp-grafana/main.go index c31ef2ab..0a3deff5 100644 --- a/cmd/mcp-grafana/main.go +++ b/cmd/mcp-grafana/main.go @@ -103,11 +103,11 @@ func (dt *disabledTools) addTools(s *server.MCPServer) { maybeAddTools(s, tools.AddNavigationTools, enabledTools, dt.navigation, "navigation") } -func newServer(dt disabledTools) *server.MCPServer { +func newServer(transport string, dt disabledTools) (*server.MCPServer, *mcpgrafana.ToolManager) { sm := mcpgrafana.NewSessionManager() - // Declare variable for SessionToolManager that will be initialized after server creation - var stm *mcpgrafana.SessionToolManager + // Declare variable for ToolManager that will be initialized after server creation + var stm *mcpgrafana.ToolManager // Create hooks hooks := &server.Hooks{ @@ -115,8 +115,10 @@ func newServer(dt disabledTools) *server.MCPServer { OnUnregisterSession: []server.OnUnregisterSessionHookFunc{sm.RemoveSession}, } - // Add proxied tools hooks if enabled - if !dt.proxied { + // Add proxied tools hooks if enabled and we're not running in stdio mode. + // (stdio mode is handled by InitializeAndRegisterServerTools; per-session tools + // are not supported). + if transport != "stdio" && !dt.proxied { // OnBeforeListTools: Discover, connect, and register tools hooks.OnBeforeListTools = []server.OnBeforeListToolsFunc{ func(ctx context.Context, id any, request *mcp.ListToolsRequest) { @@ -161,11 +163,11 @@ Note that some of these capabilities may be disabled. Do not try to use features server.WithHooks(hooks), ) - // Initialize SessionToolManager now that server is created - stm = mcpgrafana.NewSessionToolManager(sm, s, mcpgrafana.WithProxiedTools(!dt.proxied)) + // Initialize ToolManager now that server is created + stm = mcpgrafana.NewToolManager(sm, s, mcpgrafana.WithProxiedTools(!dt.proxied)) dt.addTools(s) - return s + return s, stm } type tlsConfig struct { @@ -227,7 +229,7 @@ func runHTTPServer(ctx context.Context, srv httpServer, addr, transportName stri func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt disabledTools, gc mcpgrafana.GrafanaConfig, tls tlsConfig) error { slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel}))) - s := newServer(dt) + s, tm := newServer(transport, dt) // Create a context that will be cancelled on shutdown ctx, cancel := context.WithCancel(context.Background()) @@ -254,7 +256,17 @@ func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt switch transport { case "stdio": srv := server.NewStdioServer(s) - srv.SetContextFunc(mcpgrafana.ComposedStdioContextFunc(gc)) + cf := mcpgrafana.ComposedStdioContextFunc(gc) + srv.SetContextFunc(cf) + + // For stdio (single-tenant), initialize proxied tools on the server directly + if !dt.proxied { + stdioCtx := cf(ctx) + if err := tm.InitializeAndRegisterServerTools(stdioCtx); err != nil { + slog.Error("failed to initialize proxied tools for stdio", "error", err) + } + } + slog.Info("Starting Grafana MCP server using stdio transport", "version", mcpgrafana.Version()) err := srv.Listen(ctx, os.Stdin, os.Stdout) diff --git a/proxied_handler.go b/proxied_handler.go index 77a5ee39..06fe14ee 100644 --- a/proxied_handler.go +++ b/proxied_handler.go @@ -11,13 +11,15 @@ import ( // ProxiedToolHandler implements the CallToolHandler interface for proxied tools type ProxiedToolHandler struct { sessionManager *SessionManager + toolManager *ToolManager toolName string } // NewProxiedToolHandler creates a new handler for a proxied tool -func NewProxiedToolHandler(sm *SessionManager, toolName string) *ProxiedToolHandler { +func NewProxiedToolHandler(sm *SessionManager, tm *ToolManager, toolName string) *ProxiedToolHandler { return &ProxiedToolHandler{ sessionManager: sm, + toolManager: tm, toolName: toolName, } } @@ -54,7 +56,20 @@ func (h *ProxiedToolHandler) Handle(ctx context.Context, request mcp.CallToolReq } // Get the proxied client for this datasource - client, err := h.sessionManager.GetProxiedClient(ctx, datasourceType, datasourceUID) + var client *ProxiedClient + + if h.toolManager.serverMode { + // Server mode (stdio): clients stored at manager level + client, err = h.toolManager.GetServerClient(datasourceType, datasourceUID) + } else { + // Session mode (HTTP/SSE): clients stored per-session + client, err = h.sessionManager.GetProxiedClient(ctx, datasourceType, datasourceUID) + if err != nil { + // Fallback to server-level in case of mixed mode + client, err = h.toolManager.GetServerClient(datasourceType, datasourceUID) + } + } + if err != nil { return nil, fmt.Errorf("datasource '%s' not found or not accessible. Ensure the datasource exists and you have permission to access it", datasourceUID) } diff --git a/proxied_tools.go b/proxied_tools.go index 7ea7a936..158a1abb 100644 --- a/proxied_tools.go +++ b/proxied_tools.go @@ -6,6 +6,7 @@ import ( "log/slog" "net/http" "strings" + "sync" "github.com/go-openapi/runtime" "github.com/mark3labs/mcp-go/mcp" @@ -127,49 +128,122 @@ func parseProxiedToolName(toolName string) (string, string, error) { return parts[0], parts[1], nil } -// SessionToolManager manages session-specific proxied tools -type SessionToolManager struct { +// ToolManager manages proxied tools (either per-session or server-wide) +type ToolManager struct { sm *SessionManager server *server.MCPServer // Whether to enable proxied tools. enableProxiedTools bool + + // For stdio transport: store clients at manager level (single-tenant). + // These will be unused for HTTP/SSE transports. + serverMode bool // true if using server-wide tools (stdio), false for per-session (HTTP/SSE) + serverClients map[string]*ProxiedClient + clientsMutex sync.RWMutex } -// NewSessionToolManager creates a new SessionToolManager -func NewSessionToolManager(sm *SessionManager, mcpServer *server.MCPServer, opts ...sessionToolManagerOption) *SessionToolManager { - stm := &SessionToolManager{ - sm: sm, - server: mcpServer, +// NewToolManager creates a new ToolManager +func NewToolManager(sm *SessionManager, mcpServer *server.MCPServer, opts ...toolManagerOption) *ToolManager { + tm := &ToolManager{ + sm: sm, + server: mcpServer, + serverClients: make(map[string]*ProxiedClient), } for _, opt := range opts { - opt(stm) + opt(tm) } - return stm + return tm } -type sessionToolManagerOption func(*SessionToolManager) +type toolManagerOption func(*ToolManager) // WithProxiedTools sets whether proxied tools are enabled -func WithProxiedTools(enabled bool) sessionToolManagerOption { - return func(stm *SessionToolManager) { - stm.enableProxiedTools = enabled +func WithProxiedTools(enabled bool) toolManagerOption { + return func(tm *ToolManager) { + tm.enableProxiedTools = enabled + } +} + +// InitializeAndRegisterServerTools discovers datasources and registers tools on the server (for stdio transport) +// This should be called once at server startup for single-tenant stdio servers +func (tm *ToolManager) InitializeAndRegisterServerTools(ctx context.Context) error { + if !tm.enableProxiedTools { + return nil + } + + // Mark as server mode (stdio transport) + tm.serverMode = true + + // Discover datasources with MCP support + discovered, err := discoverMCPDatasources(ctx) + if err != nil { + return fmt.Errorf("failed to discover MCP datasources: %w", err) + } + + if len(discovered) == 0 { + slog.Info("no MCP datasources discovered") + return nil + } + + // Connect to each datasource and store in manager + tm.clientsMutex.Lock() + for _, ds := range discovered { + client, err := NewProxiedClient(ctx, ds.UID, ds.Name, ds.Type, ds.MCPURL) + if err != nil { + slog.Error("failed to create proxied client", "datasource", ds.UID, "error", err) + continue + } + key := ds.Type + "_" + ds.UID + tm.serverClients[key] = client + } + clientCount := len(tm.serverClients) + tm.clientsMutex.Unlock() + + if clientCount == 0 { + slog.Warn("no proxied clients created") + return nil + } + + slog.Info("connected to proxied MCP servers", "datasources", clientCount) + + // Collect and register all unique tools + tm.clientsMutex.RLock() + toolMap := make(map[string]mcp.Tool) + for _, client := range tm.serverClients { + for _, tool := range client.ListTools() { + toolName := client.DatasourceType + "_" + tool.Name + if _, exists := toolMap[toolName]; !exists { + modifiedTool := addDatasourceUidParameter(tool, client.DatasourceType) + toolMap[toolName] = modifiedTool + } + } } + tm.clientsMutex.RUnlock() + + // Register tools on the server (not per-session) + for toolName, tool := range toolMap { + handler := NewProxiedToolHandler(tm.sm, tm, toolName) + tm.server.AddTool(tool, handler.Handle) + } + + slog.Info("registered proxied tools on server", "tools", len(toolMap)) + return nil } -// InitializeAndRegisterProxiedTools discovers datasources, creates clients, and registers tools -// This should be called in OnBeforeListTools and OnBeforeCallTool hooks -func (stm *SessionToolManager) InitializeAndRegisterProxiedTools(ctx context.Context, session server.ClientSession) { - if !stm.enableProxiedTools { +// InitializeAndRegisterProxiedTools discovers datasources, creates clients, and registers tools per-session +// This should be called in OnBeforeListTools and OnBeforeCallTool hooks for HTTP/SSE transports +func (tm *ToolManager) InitializeAndRegisterProxiedTools(ctx context.Context, session server.ClientSession) { + if !tm.enableProxiedTools { return } sessionID := session.SessionID() - state, exists := stm.sm.GetSession(sessionID) + state, exists := tm.sm.GetSession(sessionID) if !exists { // Session exists in server context but not in our SessionManager yet - stm.sm.CreateSession(ctx, session) - state, exists = stm.sm.GetSession(sessionID) + tm.sm.CreateSession(ctx, session) + state, exists = tm.sm.GetSession(sessionID) if !exists { slog.Error("failed to create session in SessionManager", "sessionID", sessionID) return @@ -257,7 +331,7 @@ func (stm *SessionToolManager) InitializeAndRegisterProxiedTools(ctx context.Con // Second pass: register all unique tools at once (reduces listChanged notifications) var serverTools []server.ServerTool for toolName, tool := range toolMap { - handler := NewProxiedToolHandler(stm.sm, toolName) + handler := NewProxiedToolHandler(tm.sm, tm, toolName) serverTools = append(serverTools, server.ServerTool{ Tool: tool, Handler: handler.Handle, @@ -265,9 +339,34 @@ func (stm *SessionToolManager) InitializeAndRegisterProxiedTools(ctx context.Con state.proxiedTools = append(state.proxiedTools, tool) } - if err := stm.server.AddSessionTools(sessionID, serverTools...); err != nil { + if err := tm.server.AddSessionTools(sessionID, serverTools...); err != nil { slog.Warn("failed to add session tools", "session", sessionID, "error", err) } else { slog.Info("registered proxied tools", "session", sessionID, "tools", len(state.proxiedTools)) } } + +// GetServerClient retrieves a proxied client from server-level storage (for stdio transport) +func (tm *ToolManager) GetServerClient(datasourceType, datasourceUID string) (*ProxiedClient, error) { + tm.clientsMutex.RLock() + defer tm.clientsMutex.RUnlock() + + key := datasourceType + "_" + datasourceUID + client, exists := tm.serverClients[key] + if !exists { + // List available datasources to help with debugging + var availableUIDs []string + for _, c := range tm.serverClients { + if c.DatasourceType == datasourceType { + availableUIDs = append(availableUIDs, c.DatasourceUID) + } + } + + if len(availableUIDs) > 0 { + return nil, fmt.Errorf("datasource '%s' not found. Available %s datasources: %v", datasourceUID, datasourceType, availableUIDs) + } + return nil, fmt.Errorf("datasource '%s' not found. No %s datasources with MCP support are configured", datasourceUID, datasourceType) + } + + return client, nil +} diff --git a/session_test.go b/session_test.go index 6625bb93..3b0b8e4d 100644 --- a/session_test.go +++ b/session_test.go @@ -330,7 +330,7 @@ func TestEndToEndProxiedToolsFlow(t *testing.T) { assert.Equal(t, client, retrievedClient, "Should retrieve the same client from session") // Step 7: Test ProxiedToolHandler flow - handler := NewProxiedToolHandler(sm, modifiedTool.Name) + handler := NewProxiedToolHandler(sm, nil, modifiedTool.Name) assert.NotNil(t, handler) // Note: We can't actually call the tool without knowing what arguments it expects diff --git a/tests/tempo_test.py b/tests/tempo_test.py index dd6c9927..521ec0bb 100644 --- a/tests/tempo_test.py +++ b/tests/tempo_test.py @@ -216,12 +216,11 @@ async def test_llm_can_list_trace_attributes( ): """Test that an LLM can list available trace attributes from Tempo.""" tools = await get_converted_tools(mcp_client) - # prompt = ( - # "Use the tempo tools to get a list of all available trace attribute names " - # "from the datasource with UID 'tempo'. I want to know what attributes " - # "I can use in my TraceQL queries." - # ) - prompt = "what trace attributes are available in the tempo datasource?" + prompt = ( + "Use the tempo tools to get a list of all available trace attribute names " + "from the datasource with UID 'tempo'. I want to know what attributes " + "I can use in my TraceQL queries." + ) messages = [ Message(role="system", content="You are a helpful assistant."), From 95d36b872fb21a1fa4dedb6d4f2266f13d20fd6e Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Mon, 20 Oct 2025 08:14:17 +0100 Subject: [PATCH 4/4] Use sync.Once to handle proxied tool initialization --- proxied_tools.go | 17 +- proxied_tools_test.go | 465 ++++++++++++++++++++++++++++++++++++++++++ session.go | 1 + session_test.go | 108 +++++++--- 4 files changed, 551 insertions(+), 40 deletions(-) create mode 100644 proxied_tools_test.go diff --git a/proxied_tools.go b/proxied_tools.go index 158a1abb..47b653ca 100644 --- a/proxied_tools.go +++ b/proxied_tools.go @@ -250,19 +250,8 @@ func (tm *ToolManager) InitializeAndRegisterProxiedTools(ctx context.Context, se } } - state.mutex.Lock() - // Check if already initialized and registered - if state.proxiedToolsInitialized && len(state.proxiedTools) > 0 { - state.mutex.Unlock() - return - } - - // If already initialized but not registered, skip discovery - alreadyDiscovered := state.proxiedToolsInitialized - state.mutex.Unlock() - - // Step 1: Discover and connect (if not already done) - if !alreadyDiscovered { + // Step 1: Discover and connect (guaranteed to run exactly once per session) + state.initOnce.Do(func() { // Discover datasources with MCP support discovered, err := discoverMCPDatasources(ctx) if err != nil { @@ -290,7 +279,7 @@ func (tm *ToolManager) InitializeAndRegisterProxiedTools(ctx context.Context, se state.mutex.Unlock() slog.Info("connected to proxied MCP servers", "session", sessionID, "datasources", len(state.proxiedClients)) - } + }) // Step 2: Register tools with the MCP server state.mutex.Lock() diff --git a/proxied_tools_test.go b/proxied_tools_test.go new file mode 100644 index 00000000..9dcdc3fd --- /dev/null +++ b/proxied_tools_test.go @@ -0,0 +1,465 @@ +package mcpgrafana + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" +) + +func TestSessionStateRaceConditions(t *testing.T) { + t.Run("concurrent initialization with sync.Once is safe", func(t *testing.T) { + state := newSessionState() + + var initCounter int32 + var wg sync.WaitGroup + + // Launch 100 goroutines that all try to initialize at once + const numGoroutines = 100 + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + state.initOnce.Do(func() { + // Simulate initialization work + atomic.AddInt32(&initCounter, 1) + time.Sleep(10 * time.Millisecond) // Simulate some work + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.mutex.Unlock() + }) + }() + } + + wg.Wait() + + // Verify initialization happened exactly once + assert.Equal(t, int32(1), atomic.LoadInt32(&initCounter), + "Initialization should run exactly once despite 100 concurrent calls") + assert.True(t, state.proxiedToolsInitialized) + }) + + t.Run("concurrent reads and writes with mutex protection", func(t *testing.T) { + state := newSessionState() + var wg sync.WaitGroup + + // Writer goroutines + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + state.mutex.Lock() + key := "tempo_" + string(rune('a'+id)) + state.proxiedClients[key] = &ProxiedClient{ + DatasourceUID: key, + DatasourceName: "Test " + key, + DatasourceType: "tempo", + } + state.mutex.Unlock() + }(i) + } + + // Reader goroutines + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + state.mutex.RLock() + _ = len(state.proxiedClients) + state.mutex.RUnlock() + }() + } + + wg.Wait() + + // Verify all writes succeeded + state.mutex.RLock() + count := len(state.proxiedClients) + state.mutex.RUnlock() + + assert.Equal(t, 10, count, "All 10 clients should be stored") + }) + + t.Run("concurrent tool registration is safe", func(t *testing.T) { + state := newSessionState() + var wg sync.WaitGroup + + // Multiple goroutines trying to register tools + const numGoroutines = 50 + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + state.mutex.Lock() + toolName := "tempo_tool-" + string(rune('a'+id%26)) + if state.toolToDatasources[toolName] == nil { + state.toolToDatasources[toolName] = []string{} + } + state.toolToDatasources[toolName] = append( + state.toolToDatasources[toolName], + "datasource_"+string(rune('a'+id%26)), + ) + state.mutex.Unlock() + }(i) + } + + wg.Wait() + + // Verify the tool mappings exist + state.mutex.RLock() + defer state.mutex.RUnlock() + assert.Greater(t, len(state.toolToDatasources), 0, "Should have tool mappings") + }) +} + +func TestSessionManagerConcurrency(t *testing.T) { + t.Run("concurrent session creation is safe", func(t *testing.T) { + sm := NewSessionManager() + var wg sync.WaitGroup + + // Create many sessions concurrently + const numSessions = 100 + wg.Add(numSessions) + + for i := 0; i < numSessions; i++ { + go func(id int) { + defer wg.Done() + sessionID := "session-" + string(rune('a'+id%26)) + "-" + string(rune('0'+id/26)) + mockSession := &mockClientSession{id: sessionID} + sm.CreateSession(context.Background(), mockSession) + }(i) + } + + wg.Wait() + + // Verify all sessions were created + sm.mutex.RLock() + count := len(sm.sessions) + sm.mutex.RUnlock() + + assert.Equal(t, numSessions, count, "All sessions should be created") + }) + + t.Run("concurrent get and remove is safe", func(t *testing.T) { + sm := NewSessionManager() + + // Pre-populate sessions + for i := 0; i < 50; i++ { + sessionID := "session-" + string(rune('a'+i%26)) + mockSession := &mockClientSession{id: sessionID} + sm.CreateSession(context.Background(), mockSession) + } + + var wg sync.WaitGroup + + // Readers + for i := 0; i < 50; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + sessionID := "session-" + string(rune('a'+id%26)) + _, _ = sm.GetSession(sessionID) + }(i) + } + + // Writers (removers) + for i := 0; i < 25; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + sessionID := "session-" + string(rune('a'+id%26)) + mockSession := &mockClientSession{id: sessionID} + sm.RemoveSession(context.Background(), mockSession) + }(i) + } + + wg.Wait() + + // Test passed if no race conditions occurred + }) +} + +func TestInitOncePattern(t *testing.T) { + t.Run("verify sync.Once guarantees single execution", func(t *testing.T) { + var once sync.Once + var counter int32 + var wg sync.WaitGroup + + // Simulate what happens in InitializeAndRegisterProxiedTools + initFunc := func() { + atomic.AddInt32(&counter, 1) + // Simulate expensive initialization + time.Sleep(50 * time.Millisecond) + } + + // Launch many concurrent calls + for i := 0; i < 1000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + once.Do(initFunc) + }() + } + + wg.Wait() + + assert.Equal(t, int32(1), atomic.LoadInt32(&counter), + "sync.Once should guarantee function runs exactly once") + }) + + t.Run("sync.Once with different functions only runs first", func(t *testing.T) { + var once sync.Once + var result string + var mu sync.Mutex + + once.Do(func() { + mu.Lock() + result = "first" + mu.Unlock() + }) + + once.Do(func() { + mu.Lock() + result = "second" + mu.Unlock() + }) + + mu.Lock() + finalResult := result + mu.Unlock() + + assert.Equal(t, "first", finalResult, "Only first function should execute") + }) +} + +func TestProxiedToolsInitializationFlow(t *testing.T) { + t.Run("initialization state transitions are correct", func(t *testing.T) { + state := newSessionState() + + // Initial state + assert.False(t, state.proxiedToolsInitialized) + assert.Empty(t, state.proxiedClients) + assert.Empty(t, state.proxiedTools) + + // Simulate initialization + state.initOnce.Do(func() { + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.proxiedClients["tempo_test"] = &ProxiedClient{ + DatasourceUID: "test", + DatasourceName: "Test", + DatasourceType: "tempo", + } + state.mutex.Unlock() + }) + + // Verify state after initialization + state.mutex.RLock() + initialized := state.proxiedToolsInitialized + clientCount := len(state.proxiedClients) + state.mutex.RUnlock() + + assert.True(t, initialized) + assert.Equal(t, 1, clientCount) + }) + + t.Run("multiple sessions maintain separate state", func(t *testing.T) { + sm := NewSessionManager() + + // Create two sessions + session1 := &mockClientSession{id: "session-1"} + session2 := &mockClientSession{id: "session-2"} + + sm.CreateSession(context.Background(), session1) + sm.CreateSession(context.Background(), session2) + + state1, _ := sm.GetSession("session-1") + state2, _ := sm.GetSession("session-2") + + // Initialize only session1 + state1.initOnce.Do(func() { + state1.mutex.Lock() + state1.proxiedToolsInitialized = true + state1.mutex.Unlock() + }) + + // Verify states are independent + assert.True(t, state1.proxiedToolsInitialized) + assert.False(t, state2.proxiedToolsInitialized) + assert.NotSame(t, state1, state2) + }) +} + +func TestRaceConditionDemonstration(t *testing.T) { + t.Run("old pattern WITHOUT sync.Once would have race condition", func(t *testing.T) { + // This test demonstrates what WOULD happen with the old mutex-check pattern + state := newSessionState() + + var discoveryCallCount int32 + var wg sync.WaitGroup + + // Simulate the OLD pattern (mutex check, unlock, then do work) + oldPatternInitialize := func() { + state.mutex.Lock() + // Check if already initialized + if state.proxiedToolsInitialized { + state.mutex.Unlock() + return + } + alreadyDiscovered := state.proxiedToolsInitialized + state.mutex.Unlock() // āŒ OLD PATTERN: Unlock before expensive work + + if !alreadyDiscovered { + // Simulate discovery work that should only happen once + atomic.AddInt32(&discoveryCallCount, 1) + time.Sleep(10 * time.Millisecond) // Simulate expensive operation + + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.mutex.Unlock() + } + } + + // Launch concurrent initializations + const numGoroutines = 10 + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + oldPatternInitialize() + }() + } + wg.Wait() + + // With the old pattern, multiple goroutines can get past the check + // and call discovery multiple times + count := atomic.LoadInt32(&discoveryCallCount) + if count > 1 { + t.Logf("OLD PATTERN: Discovery called %d times (race condition!)", count) + } + // We can't assert > 1 reliably because timing matters, but this demonstrates the problem + }) + + t.Run("new pattern WITH sync.Once prevents race condition", func(t *testing.T) { + // This test demonstrates the NEW pattern with sync.Once + state := newSessionState() + + var discoveryCallCount int32 + var wg sync.WaitGroup + + // NEW pattern: sync.Once guarantees single execution + newPatternInitialize := func() { + state.initOnce.Do(func() { + // Simulate discovery work that should only happen once + atomic.AddInt32(&discoveryCallCount, 1) + time.Sleep(10 * time.Millisecond) // Simulate expensive operation + + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.mutex.Unlock() + }) + } + + // Launch concurrent initializations + const numGoroutines = 10 + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + newPatternInitialize() + }() + } + wg.Wait() + + // With sync.Once, discovery is guaranteed to run exactly once + count := atomic.LoadInt32(&discoveryCallCount) + assert.Equal(t, int32(1), count, "NEW PATTERN: Discovery must be called exactly once") + }) +} + +func TestRaceDetector(t *testing.T) { + // This test is primarily valuable when run with -race flag + t.Run("stress test with race detector", func(t *testing.T) { + + sm := NewSessionManager() + var wg sync.WaitGroup + + // Create a mix of operations happening concurrently + for i := 0; i < 20; i++ { + sessionID := "stress-session-" + string(rune('a'+i%10)) + + // Create session + wg.Add(1) + go func(sid string) { + defer wg.Done() + mockSession := &mockClientSession{id: sid} + sm.CreateSession(context.Background(), mockSession) + }(sessionID) + + // Initialize session state + wg.Add(1) + go func(sid string) { + defer wg.Done() + time.Sleep(time.Millisecond) // Let creation happen first + state, exists := sm.GetSession(sid) + if exists { + state.initOnce.Do(func() { + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.mutex.Unlock() + }) + } + }(sessionID) + + // Read session state + wg.Add(1) + go func(sid string) { + defer wg.Done() + time.Sleep(2 * time.Millisecond) + state, exists := sm.GetSession(sid) + if exists { + state.mutex.RLock() + _ = state.proxiedToolsInitialized + state.mutex.RUnlock() + } + }(sessionID) + } + + wg.Wait() + + // If we get here without race detector warnings, we're good + t.Log("Stress test completed without race conditions") + }) +} + +// mockClientSession implements server.ClientSession for testing +type mockClientSession struct { + id string + notifChannel chan mcp.JSONRPCNotification + isInitialized bool +} + +func (m *mockClientSession) SessionID() string { + return m.id +} + +func (m *mockClientSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + if m.notifChannel == nil { + m.notifChannel = make(chan mcp.JSONRPCNotification, 10) + } + return m.notifChannel +} + +func (m *mockClientSession) Initialize() { + m.isInitialized = true +} + +func (m *mockClientSession) Initialized() bool { + return m.isInitialized +} diff --git a/session.go b/session.go index e1acfa67..d1005e37 100644 --- a/session.go +++ b/session.go @@ -13,6 +13,7 @@ import ( // SessionState holds the state for a single client session type SessionState struct { // Proxied tools state + initOnce sync.Once proxiedToolsInitialized bool proxiedTools []mcp.Tool proxiedClients map[string]*ProxiedClient // key: datasourceType_datasourceUID diff --git a/session_test.go b/session_test.go index 3b0b8e4d..040e28a8 100644 --- a/session_test.go +++ b/session_test.go @@ -12,6 +12,7 @@ import ( "net/url" "os" "strings" + "sync" "testing" "github.com/go-openapi/strfmt" @@ -229,6 +230,87 @@ func TestSessionStateLifecycle(t *testing.T) { }) } +func TestConcurrentInitializationRaceCondition(t *testing.T) { + t.Run("concurrent initialization calls should be safe", func(t *testing.T) { + sm := NewSessionManager() + mockSession := &mockClientSession{id: "race-test-session"} + sm.CreateSession(context.Background(), mockSession) + + state, exists := sm.GetSession("race-test-session") + require.True(t, exists) + + // Track how many times the initialization logic runs + var initCount int + var initCountMutex sync.Mutex + + // Create a custom initOnce to track calls + state.initOnce = sync.Once{} + + // Simulate the initialization work that should run exactly once + initWork := func() { + initCountMutex.Lock() + initCount++ + initCountMutex.Unlock() + // Simulate some work + state.mutex.Lock() + state.proxiedToolsInitialized = true + state.proxiedClients["tempo_test"] = &ProxiedClient{ + DatasourceUID: "test", + DatasourceName: "Test", + DatasourceType: "tempo", + } + state.mutex.Unlock() + } + + // Launch multiple goroutines that all try to initialize concurrently + const numGoroutines = 10 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + // This should be the pattern used in InitializeAndRegisterProxiedTools + state.initOnce.Do(initWork) + }() + } + + wg.Wait() + + // Verify initialization ran exactly once + assert.Equal(t, 1, initCount, "Initialization should run exactly once despite concurrent calls") + assert.True(t, state.proxiedToolsInitialized, "State should be initialized") + assert.Len(t, state.proxiedClients, 1, "Should have exactly one client") + }) + + t.Run("sync.Once prevents double initialization", func(t *testing.T) { + sm := NewSessionManager() + mockSession := &mockClientSession{id: "double-init-test"} + sm.CreateSession(context.Background(), mockSession) + + state, _ := sm.GetSession("double-init-test") + + callCount := 0 + + // First call + state.initOnce.Do(func() { + callCount++ + }) + + // Second call should not execute + state.initOnce.Do(func() { + callCount++ + }) + + // Third call should also not execute + state.initOnce.Do(func() { + callCount++ + }) + + assert.Equal(t, 1, callCount, "sync.Once should ensure function runs exactly once") + }) +} + func TestProxiedClientLifecycle(t *testing.T) { ctx := newProxiedToolsTestContext(t) @@ -391,29 +473,3 @@ func TestEndToEndProxiedToolsFlow(t *testing.T) { t.Logf("Successfully managed %d datasources in single session", connectedCount) }) } - -// mockClientSession implements server.ClientSession for testing -type mockClientSession struct { - id string - notifChannel chan mcp.JSONRPCNotification - isInitialized bool -} - -func (m *mockClientSession) SessionID() string { - return m.id -} - -func (m *mockClientSession) NotificationChannel() chan<- mcp.JSONRPCNotification { - if m.notifChannel == nil { - m.notifChannel = make(chan mcp.JSONRPCNotification, 10) - } - return m.notifChannel -} - -func (m *mockClientSession) Initialize() { - m.isInitialized = true -} - -func (m *mockClientSession) Initialized() bool { - return m.isInitialized -}