diff --git a/pkg/github/actions.go b/pkg/github/actions.go index 8c7b08a85..611674125 100644 --- a/pkg/github/actions.go +++ b/pkg/github/actions.go @@ -19,8 +19,17 @@ import ( const ( DescriptionRepositoryOwner = "Repository owner" DescriptionRepositoryName = "Repository name" + errFailedToMarshalResponse = "failed to marshal response: %w" ) +func marshalToToolResult(data interface{}) (*mcp.CallToolResult, error) { + r, err := json.Marshal(data) + if err != nil { + return nil, fmt.Errorf(errFailedToMarshalResponse, err) + } + return mcp.NewToolResultText(string(r)), nil +} + // ListWorkflows creates a tool to list workflows in a repository func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_workflows", @@ -81,12 +90,7 @@ func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc) } defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(workflows) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(workflows) } } @@ -229,12 +233,7 @@ func ListWorkflowRuns(getClient GetClientFn, t translations.TranslationHelperFun } defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(workflowRuns) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(workflowRuns) } } @@ -328,12 +327,7 @@ func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (t "status_code": resp.StatusCode, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(result) } } @@ -384,12 +378,7 @@ func GetWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) } defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(workflowRun) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(workflowRun) } } @@ -450,12 +439,7 @@ func GetWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelperF "optimization_tip": "Use: get_job_logs with parameters {run_id: " + fmt.Sprintf("%d", runID) + ", failed_only: true} for more efficient failed job debugging", } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(result) } } @@ -547,12 +531,7 @@ func ListWorkflowJobs(getClient GetClientFn, t translations.TranslationHelperFun "optimization_tip": "For debugging failed jobs, consider using get_job_logs with failed_only=true and run_id=" + fmt.Sprintf("%d", runID) + " to get logs directly without needing to list jobs first", } - r, err := json.Marshal(response) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(response) } } @@ -707,12 +686,7 @@ func handleFailedJobLogs(ctx context.Context, client *github.Client, owner, repo "return_format": map[string]bool{"content": returnContent, "urls": !returnContent}, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(result) } // handleSingleJobLogs gets logs for a single job @@ -722,12 +696,7 @@ func handleSingleJobLogs(ctx context.Context, client *github.Client, owner, repo return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get job logs", resp, err), nil } - r, err := json.Marshal(jobResult) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(jobResult) } // getJobLogData retrieves log data for a single job, either as URL or content @@ -867,12 +836,7 @@ func RerunWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFun "status_code": resp.StatusCode, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(result) } } @@ -930,12 +894,7 @@ func RerunFailedJobs(getClient GetClientFn, t translations.TranslationHelperFunc "status_code": resp.StatusCode, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(result) } } @@ -993,12 +952,7 @@ func CancelWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFu "status_code": resp.StatusCode, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(result) } } @@ -1071,12 +1025,7 @@ func ListWorkflowRunArtifacts(getClient GetClientFn, t translations.TranslationH } defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(artifacts) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(artifacts) } } @@ -1136,12 +1085,7 @@ func DownloadWorkflowRunArtifact(getClient GetClientFn, t translations.Translati "artifact_id": artifactID, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(result) } } @@ -1200,12 +1144,7 @@ func DeleteWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelp "status_code": resp.StatusCode, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(result) } } @@ -1256,11 +1195,6 @@ func GetWorkflowRunUsage(getClient GetClientFn, t translations.TranslationHelper } defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(usage) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalToToolResult(usage) } } diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index b6b6bfd79..69bf4d8f3 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "strconv" + "strings" "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" @@ -22,6 +23,29 @@ const ( FilterOnlyParticipating = "only_participating" ) +func parseThreadID(threadID string) (int64, error) { + threadIDInt, err := strconv.ParseInt(threadID, 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid threadID format: %v", err) + } + return threadIDInt, nil +} + +func markNotificationByState(ctx context.Context, client *github.Client, threadID string, state string) (*github.Response, error) { + switch state { + case "done": + threadIDInt, err := parseThreadID(threadID) + if err != nil { + return nil, err + } + return client.Activity.MarkThreadDone(ctx, threadIDInt) + case "read": + return client.Activity.MarkThreadRead(ctx, threadID) + default: + return nil, fmt.Errorf("invalid state; must be one of: read, done") + } +} + // ListNotifications creates a tool to list notifications for the current user. func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_notifications", @@ -175,23 +199,14 @@ func DismissNotification(getclient GetClientFn, t translations.TranslationHelper return mcp.NewToolResultError(err.Error()), nil } - var resp *github.Response - switch state { - case "done": - // for some inexplicable reason, the API seems to have threadID as int64 and string depending on the endpoint - var threadIDInt int64 - threadIDInt, err = strconv.ParseInt(threadID, 10, 64) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("invalid threadID format: %v", err)), nil - } - resp, err = client.Activity.MarkThreadDone(ctx, threadIDInt) - case "read": - resp, err = client.Activity.MarkThreadRead(ctx, threadID) - default: - return mcp.NewToolResultError("Invalid state. Must be one of: read, done."), nil - } - + resp, err := markNotificationByState(ctx, client, threadID, state) if err != nil { + if strings.Contains(err.Error(), "invalid threadID format") { + return mcp.NewToolResultError(err.Error()), nil + } + if strings.Contains(err.Error(), "invalid state") { + return mcp.NewToolResultError(err.Error()), nil + } return ghErrors.NewGitHubAPIErrorResponse(ctx, fmt.Sprintf("failed to mark notification as %s", state), resp, diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go index a83df3ed8..679c50770 100644 --- a/pkg/github/notifications_test.go +++ b/pkg/github/notifications_test.go @@ -559,7 +559,7 @@ func Test_DismissNotification(t *testing.T) { case tc.name == "invalid threadID format": assert.Contains(t, text, "invalid threadID format") case tc.name == "invalid state value": - assert.Contains(t, text, "Invalid state. Must be one of: read, done.") + assert.Contains(t, text, "invalid state; must be one of: read, done") default: // fallback for other errors assert.Contains(t, text, "error")