diff --git a/internal/logger/markdown_logger.go b/internal/logger/markdown_logger.go index 9831c6ce..36db1786 100644 --- a/internal/logger/markdown_logger.go +++ b/internal/logger/markdown_logger.go @@ -149,60 +149,38 @@ func (ml *MarkdownLogger) Log(level LogLevel, category, format string, args ...i // Global logging functions that also write to markdown logger -// LogInfoMd logs to both regular and markdown loggers -func LogInfoMd(category, format string, args ...interface{}) { +// logWithMarkdown is a helper that logs to both regular and markdown loggers +func logWithMarkdown(level LogLevel, regularLogFunc func(string, string, ...interface{}), category, format string, args ...interface{}) { // Log to regular logger - LogInfo(category, format, args...) + regularLogFunc(category, format, args...) // Log to markdown logger globalMarkdownMu.RLock() defer globalMarkdownMu.RUnlock() if globalMarkdownLogger != nil { - globalMarkdownLogger.Log(LogLevelInfo, category, format, args...) + globalMarkdownLogger.Log(level, category, format, args...) } } +// LogInfoMd logs to both regular and markdown loggers +func LogInfoMd(category, format string, args ...interface{}) { + logWithMarkdown(LogLevelInfo, LogInfo, category, format, args...) +} + // LogWarnMd logs to both regular and markdown loggers func LogWarnMd(category, format string, args ...interface{}) { - // Log to regular logger - LogWarn(category, format, args...) - - // Log to markdown logger - globalMarkdownMu.RLock() - defer globalMarkdownMu.RUnlock() - - if globalMarkdownLogger != nil { - globalMarkdownLogger.Log(LogLevelWarn, category, format, args...) - } + logWithMarkdown(LogLevelWarn, LogWarn, category, format, args...) } // LogErrorMd logs to both regular and markdown loggers func LogErrorMd(category, format string, args ...interface{}) { - // Log to regular logger - LogError(category, format, args...) - - // Log to markdown logger - globalMarkdownMu.RLock() - defer globalMarkdownMu.RUnlock() - - if globalMarkdownLogger != nil { - globalMarkdownLogger.Log(LogLevelError, category, format, args...) - } + logWithMarkdown(LogLevelError, LogError, category, format, args...) } // LogDebugMd logs to both regular and markdown loggers func LogDebugMd(category, format string, args ...interface{}) { - // Log to regular logger - LogDebug(category, format, args...) - - // Log to markdown logger - globalMarkdownMu.RLock() - defer globalMarkdownMu.RUnlock() - - if globalMarkdownLogger != nil { - globalMarkdownLogger.Log(LogLevelDebug, category, format, args...) - } + logWithMarkdown(LogLevelDebug, LogDebug, category, format, args...) } // CloseMarkdownLogger closes the global markdown logger diff --git a/internal/logger/rpc_logger.go b/internal/logger/rpc_logger.go index 0b77ad6a..0081a9ac 100644 --- a/internal/logger/rpc_logger.go +++ b/internal/logger/rpc_logger.go @@ -243,51 +243,15 @@ func formatRPCMessageMarkdown(info *RPCMessageInfo) string { return message } -// LogRPCRequest logs an RPC request message to text, markdown, and JSONL logs -func LogRPCRequest(direction RPCMessageDirection, serverID, method string, payload []byte) { +// logRPCMessageToAll is a helper that logs RPC messages to text, markdown, and JSONL logs +func logRPCMessageToAll(direction RPCMessageDirection, messageType RPCMessageType, serverID, method string, payload []byte, err error) { // Create info for text log (with larger payload preview) infoText := &RPCMessageInfo{ Direction: direction, - MessageType: RPCMessageRequest, - ServerID: serverID, - Method: method, - PayloadSize: len(payload), - Payload: truncateAndSanitize(string(payload), MaxPayloadPreviewLengthText), - } - - // Log to text file - LogDebug("rpc", "%s", formatRPCMessage(infoText)) - - // Create info for markdown log (with shorter payload preview) - infoMarkdown := &RPCMessageInfo{ - Direction: direction, - MessageType: RPCMessageRequest, + MessageType: messageType, ServerID: serverID, Method: method, PayloadSize: len(payload), - Payload: truncateAndSanitize(string(payload), MaxPayloadPreviewLengthMarkdown), - } - - // Log to markdown file - globalMarkdownMu.RLock() - defer globalMarkdownMu.RUnlock() - - if globalMarkdownLogger != nil { - globalMarkdownLogger.Log(LogLevelDebug, "rpc", "%s", formatRPCMessageMarkdown(infoMarkdown)) - } - - // Log to JSONL file (full payload, sanitized) - LogRPCMessageJSONL(direction, RPCMessageRequest, serverID, method, payload, nil) -} - -// LogRPCResponse logs an RPC response message to text, markdown, and JSONL logs -func LogRPCResponse(direction RPCMessageDirection, serverID string, payload []byte, err error) { - // Create info for text log (with larger payload preview) - infoText := &RPCMessageInfo{ - Direction: direction, - MessageType: RPCMessageResponse, - ServerID: serverID, - PayloadSize: len(payload), Payload: truncateAndSanitize(string(payload), MaxPayloadPreviewLengthText), } @@ -301,8 +265,9 @@ func LogRPCResponse(direction RPCMessageDirection, serverID string, payload []by // Create info for markdown log (with shorter payload preview) infoMarkdown := &RPCMessageInfo{ Direction: direction, - MessageType: RPCMessageResponse, + MessageType: messageType, ServerID: serverID, + Method: method, PayloadSize: len(payload), Payload: truncateAndSanitize(string(payload), MaxPayloadPreviewLengthMarkdown), } @@ -320,7 +285,17 @@ func LogRPCResponse(direction RPCMessageDirection, serverID string, payload []by } // Log to JSONL file (full payload, sanitized) - LogRPCMessageJSONL(direction, RPCMessageResponse, serverID, "", payload, err) + LogRPCMessageJSONL(direction, messageType, serverID, method, payload, err) +} + +// LogRPCRequest logs an RPC request message to text, markdown, and JSONL logs +func LogRPCRequest(direction RPCMessageDirection, serverID, method string, payload []byte) { + logRPCMessageToAll(direction, RPCMessageRequest, serverID, method, payload, nil) +} + +// LogRPCResponse logs an RPC response message to text, markdown, and JSONL logs +func LogRPCResponse(direction RPCMessageDirection, serverID string, payload []byte, err error) { + logRPCMessageToAll(direction, RPCMessageResponse, serverID, "", payload, err) } // LogRPCMessage logs a generic RPC message with custom info diff --git a/internal/server/unified.go b/internal/server/unified.go index 474b6553..51e33101 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -235,15 +235,10 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error { // Create the handler function handler := func(ctx context.Context, req *sdk.CallToolRequest, args interface{}) (*sdk.CallToolResult, interface{}, error) { // Extract arguments from the request params (not the args parameter which is SDK internal state) - var toolArgs map[string]interface{} - if req.Params.Arguments != nil { - if err := json.Unmarshal(req.Params.Arguments, &toolArgs); err != nil { - logger.LogError("client", "Failed to unmarshal tool arguments, tool=%s, error=%v", toolNameCopy, err) - return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to parse arguments: %w", err) - } - } else { - // No arguments provided, use empty map - toolArgs = make(map[string]interface{}) + toolArgs, err := parseToolArguments(req) + if err != nil { + logger.LogError("client", "Failed to unmarshal tool arguments, tool=%s, error=%v", toolNameCopy, err) + return &sdk.CallToolResult{IsError: true}, nil, err } // Log the MCP tool call request @@ -315,14 +310,10 @@ func (us *UnifiedServer) registerSysTools() error { // Create sys_init handler sysInitHandler := func(ctx context.Context, req *sdk.CallToolRequest, args interface{}) (*sdk.CallToolResult, interface{}, error) { // Extract arguments from the request params - var toolArgs map[string]interface{} - if req.Params.Arguments != nil { - if err := json.Unmarshal(req.Params.Arguments, &toolArgs); err != nil { - logger.LogError("client", "Failed to unmarshal sys_init arguments, error=%v", err) - return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to parse arguments: %w", err) - } - } else { - toolArgs = make(map[string]interface{}) + toolArgs, err := parseToolArguments(req) + if err != nil { + logger.LogError("client", "Failed to unmarshal sys_init arguments, error=%v", err) + return &sdk.CallToolResult{IsError: true}, nil, err } // Extract token from args @@ -692,6 +683,21 @@ func (us *UnifiedServer) Run(transport sdk.Transport) error { return us.server.Run(us.ctx, transport) } +// parseToolArguments extracts and unmarshals tool arguments from a CallToolRequest +// Returns the parsed arguments as a map, or an error if parsing fails +func parseToolArguments(req *sdk.CallToolRequest) (map[string]interface{}, error) { + var toolArgs map[string]interface{} + if req.Params.Arguments != nil { + if err := json.Unmarshal(req.Params.Arguments, &toolArgs); err != nil { + return nil, fmt.Errorf("failed to parse arguments: %w", err) + } + } else { + // No arguments provided, use empty map + toolArgs = make(map[string]interface{}) + } + return toolArgs, nil +} + // getSessionID extracts the MCP session ID from the context func (us *UnifiedServer) getSessionID(ctx context.Context) string { if sessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && sessionID != "" {