Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 12 additions & 34 deletions internal/logger/markdown_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,60 +149,38 @@ func (ml *MarkdownLogger) Log(level LogLevel, category, format string, args ...i

// Global logging functions that also write to markdown logger

// LogInfoMd logs to both regular and markdown loggers
func LogInfoMd(category, format string, args ...interface{}) {
// logWithMarkdown is a helper that logs to both regular and markdown loggers
func logWithMarkdown(level LogLevel, regularLogFunc func(string, string, ...interface{}), category, format string, args ...interface{}) {
// Log to regular logger
LogInfo(category, format, args...)
regularLogFunc(category, format, args...)

// Log to markdown logger
globalMarkdownMu.RLock()
defer globalMarkdownMu.RUnlock()

if globalMarkdownLogger != nil {
globalMarkdownLogger.Log(LogLevelInfo, category, format, args...)
globalMarkdownLogger.Log(level, category, format, args...)
}
}

// LogInfoMd logs to both regular and markdown loggers
func LogInfoMd(category, format string, args ...interface{}) {
logWithMarkdown(LogLevelInfo, LogInfo, category, format, args...)
}

// LogWarnMd logs to both regular and markdown loggers
func LogWarnMd(category, format string, args ...interface{}) {
// Log to regular logger
LogWarn(category, format, args...)

// Log to markdown logger
globalMarkdownMu.RLock()
defer globalMarkdownMu.RUnlock()

if globalMarkdownLogger != nil {
globalMarkdownLogger.Log(LogLevelWarn, category, format, args...)
}
logWithMarkdown(LogLevelWarn, LogWarn, category, format, args...)
}

// LogErrorMd logs to both regular and markdown loggers
func LogErrorMd(category, format string, args ...interface{}) {
// Log to regular logger
LogError(category, format, args...)

// Log to markdown logger
globalMarkdownMu.RLock()
defer globalMarkdownMu.RUnlock()

if globalMarkdownLogger != nil {
globalMarkdownLogger.Log(LogLevelError, category, format, args...)
}
logWithMarkdown(LogLevelError, LogError, category, format, args...)
}

// LogDebugMd logs to both regular and markdown loggers
func LogDebugMd(category, format string, args ...interface{}) {
// Log to regular logger
LogDebug(category, format, args...)

// Log to markdown logger
globalMarkdownMu.RLock()
defer globalMarkdownMu.RUnlock()

if globalMarkdownLogger != nil {
globalMarkdownLogger.Log(LogLevelDebug, category, format, args...)
}
logWithMarkdown(LogLevelDebug, LogDebug, category, format, args...)
}

// CloseMarkdownLogger closes the global markdown logger
Expand Down
57 changes: 16 additions & 41 deletions internal/logger/rpc_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,51 +243,15 @@ func formatRPCMessageMarkdown(info *RPCMessageInfo) string {
return message
}

// LogRPCRequest logs an RPC request message to text, markdown, and JSONL logs
func LogRPCRequest(direction RPCMessageDirection, serverID, method string, payload []byte) {
// logRPCMessageToAll is a helper that logs RPC messages to text, markdown, and JSONL logs
func logRPCMessageToAll(direction RPCMessageDirection, messageType RPCMessageType, serverID, method string, payload []byte, err error) {
// Create info for text log (with larger payload preview)
infoText := &RPCMessageInfo{
Direction: direction,
MessageType: RPCMessageRequest,
ServerID: serverID,
Method: method,
PayloadSize: len(payload),
Payload: truncateAndSanitize(string(payload), MaxPayloadPreviewLengthText),
}

// Log to text file
LogDebug("rpc", "%s", formatRPCMessage(infoText))

// Create info for markdown log (with shorter payload preview)
infoMarkdown := &RPCMessageInfo{
Direction: direction,
MessageType: RPCMessageRequest,
MessageType: messageType,
ServerID: serverID,
Method: method,
PayloadSize: len(payload),
Payload: truncateAndSanitize(string(payload), MaxPayloadPreviewLengthMarkdown),
}

// Log to markdown file
globalMarkdownMu.RLock()
defer globalMarkdownMu.RUnlock()

if globalMarkdownLogger != nil {
globalMarkdownLogger.Log(LogLevelDebug, "rpc", "%s", formatRPCMessageMarkdown(infoMarkdown))
}

// Log to JSONL file (full payload, sanitized)
LogRPCMessageJSONL(direction, RPCMessageRequest, serverID, method, payload, nil)
}

// LogRPCResponse logs an RPC response message to text, markdown, and JSONL logs
func LogRPCResponse(direction RPCMessageDirection, serverID string, payload []byte, err error) {
// Create info for text log (with larger payload preview)
infoText := &RPCMessageInfo{
Direction: direction,
MessageType: RPCMessageResponse,
ServerID: serverID,
PayloadSize: len(payload),
Payload: truncateAndSanitize(string(payload), MaxPayloadPreviewLengthText),
}

Expand All @@ -301,8 +265,9 @@ func LogRPCResponse(direction RPCMessageDirection, serverID string, payload []by
// Create info for markdown log (with shorter payload preview)
infoMarkdown := &RPCMessageInfo{
Direction: direction,
MessageType: RPCMessageResponse,
MessageType: messageType,
ServerID: serverID,
Method: method,
PayloadSize: len(payload),
Payload: truncateAndSanitize(string(payload), MaxPayloadPreviewLengthMarkdown),
}
Expand All @@ -320,7 +285,17 @@ func LogRPCResponse(direction RPCMessageDirection, serverID string, payload []by
}

// Log to JSONL file (full payload, sanitized)
LogRPCMessageJSONL(direction, RPCMessageResponse, serverID, "", payload, err)
LogRPCMessageJSONL(direction, messageType, serverID, method, payload, err)
}

// LogRPCRequest logs an RPC request message to text, markdown, and JSONL logs
func LogRPCRequest(direction RPCMessageDirection, serverID, method string, payload []byte) {
logRPCMessageToAll(direction, RPCMessageRequest, serverID, method, payload, nil)
}

// LogRPCResponse logs an RPC response message to text, markdown, and JSONL logs
func LogRPCResponse(direction RPCMessageDirection, serverID string, payload []byte, err error) {
logRPCMessageToAll(direction, RPCMessageResponse, serverID, "", payload, err)
}

// LogRPCMessage logs a generic RPC message with custom info
Expand Down
40 changes: 23 additions & 17 deletions internal/server/unified.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,10 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error {
// Create the handler function
handler := func(ctx context.Context, req *sdk.CallToolRequest, args interface{}) (*sdk.CallToolResult, interface{}, error) {
// Extract arguments from the request params (not the args parameter which is SDK internal state)
var toolArgs map[string]interface{}
if req.Params.Arguments != nil {
if err := json.Unmarshal(req.Params.Arguments, &toolArgs); err != nil {
logger.LogError("client", "Failed to unmarshal tool arguments, tool=%s, error=%v", toolNameCopy, err)
return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to parse arguments: %w", err)
}
} else {
// No arguments provided, use empty map
toolArgs = make(map[string]interface{})
toolArgs, err := parseToolArguments(req)
if err != nil {
logger.LogError("client", "Failed to unmarshal tool arguments, tool=%s, error=%v", toolNameCopy, err)
return &sdk.CallToolResult{IsError: true}, nil, err
}

// Log the MCP tool call request
Expand Down Expand Up @@ -315,14 +310,10 @@ func (us *UnifiedServer) registerSysTools() error {
// Create sys_init handler
sysInitHandler := func(ctx context.Context, req *sdk.CallToolRequest, args interface{}) (*sdk.CallToolResult, interface{}, error) {
// Extract arguments from the request params
var toolArgs map[string]interface{}
if req.Params.Arguments != nil {
if err := json.Unmarshal(req.Params.Arguments, &toolArgs); err != nil {
logger.LogError("client", "Failed to unmarshal sys_init arguments, error=%v", err)
return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to parse arguments: %w", err)
}
} else {
toolArgs = make(map[string]interface{})
toolArgs, err := parseToolArguments(req)
if err != nil {
logger.LogError("client", "Failed to unmarshal sys_init arguments, error=%v", err)
return &sdk.CallToolResult{IsError: true}, nil, err
}

// Extract token from args
Expand Down Expand Up @@ -692,6 +683,21 @@ func (us *UnifiedServer) Run(transport sdk.Transport) error {
return us.server.Run(us.ctx, transport)
}

// parseToolArguments extracts and unmarshals tool arguments from a CallToolRequest
// Returns the parsed arguments as a map, or an error if parsing fails
func parseToolArguments(req *sdk.CallToolRequest) (map[string]interface{}, error) {
var toolArgs map[string]interface{}
if req.Params.Arguments != nil {
if err := json.Unmarshal(req.Params.Arguments, &toolArgs); err != nil {
return nil, fmt.Errorf("failed to parse arguments: %w", err)
}
} else {
// No arguments provided, use empty map
toolArgs = make(map[string]interface{})
}
return toolArgs, nil
}

// getSessionID extracts the MCP session ID from the context
func (us *UnifiedServer) getSessionID(ctx context.Context) string {
if sessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && sessionID != "" {
Expand Down