diff --git a/pkg/github/actions.go b/pkg/github/actions.go index 8c7b08a85..1c4d8613d 100644 --- a/pkg/github/actions.go +++ b/pkg/github/actions.go @@ -2,7 +2,6 @@ package github import ( "context" - "encoding/json" "fmt" "io" "net/http" @@ -29,29 +28,15 @@ func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc) Title: t("TOOL_LIST_WORKFLOWS_USER_TITLE", "List workflows"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description(DescriptionRepositoryOwner), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description(DescriptionRepositoryName), - ), - mcp.WithNumber("per_page", - mcp.Description("The number of results per page (max 100)"), - ), - mcp.WithNumber("page", - mcp.Description("The page number of the results to fetch"), - ), + withOwnerParam(), + withRepoParam(), + withPerPageParam(), + withPageParam(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := RequiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := RequiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + owner, repo, errResult := parseOwnerRepo(request) + if errResult != nil { + return errResult, nil } // Get optional pagination parameters @@ -66,7 +51,7 @@ func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc) client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } // Set up list options @@ -81,12 +66,7 @@ func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc) } defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(workflows) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(workflows) } } @@ -208,7 +188,7 @@ func ListWorkflowRuns(getClient GetClientFn, t translations.TranslationHelperFun client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } // Set up list options @@ -229,12 +209,7 @@ func ListWorkflowRuns(getClient GetClientFn, t translations.TranslationHelperFun } defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(workflowRuns) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(workflowRuns) } } @@ -294,7 +269,7 @@ func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (t client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } event := github.CreateWorkflowDispatchEventRequest{ @@ -328,12 +303,7 @@ func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (t "status_code": resp.StatusCode, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } } @@ -375,7 +345,7 @@ func GetWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } workflowRun, resp, err := client.Actions.GetWorkflowRunByID(ctx, owner, repo, runID) @@ -384,12 +354,7 @@ func GetWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) } defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(workflowRun) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(workflowRun) } } @@ -431,7 +396,7 @@ func GetWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelperF client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } // Get the download URL for the logs @@ -450,12 +415,7 @@ func GetWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelperF "optimization_tip": "Use: get_job_logs with parameters {run_id: " + fmt.Sprintf("%d", runID) + ", failed_only: true} for more efficient failed job debugging", } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } } @@ -523,7 +483,7 @@ func ListWorkflowJobs(getClient GetClientFn, t translations.TranslationHelperFun client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } // Set up list options @@ -547,12 +507,7 @@ func ListWorkflowJobs(getClient GetClientFn, t translations.TranslationHelperFun "optimization_tip": "For debugging failed jobs, consider using get_job_logs with failed_only=true and run_id=" + fmt.Sprintf("%d", runID) + " to get logs directly without needing to list jobs first", } - r, err := json.Marshal(response) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(response) } } @@ -627,7 +582,7 @@ func GetJobLogs(getClient GetClientFn, t translations.TranslationHelperFunc) (to client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } // Validate parameters @@ -676,8 +631,7 @@ func handleFailedJobLogs(ctx context.Context, client *github.Client, owner, repo "total_jobs": len(jobs.Jobs), "failed_jobs": 0, } - r, _ := json.Marshal(result) - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } // Collect logs for all failed jobs @@ -707,12 +661,7 @@ func handleFailedJobLogs(ctx context.Context, client *github.Client, owner, repo "return_format": map[string]bool{"content": returnContent, "urls": !returnContent}, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } // handleSingleJobLogs gets logs for a single job @@ -722,12 +671,7 @@ func handleSingleJobLogs(ctx context.Context, client *github.Client, owner, repo return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get job logs", resp, err), nil } - r, err := json.Marshal(jobResult) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(jobResult) } // getJobLogData retrieves log data for a single job, either as URL or content @@ -851,7 +795,7 @@ func RerunWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFun client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } resp, err := client.Actions.RerunWorkflowByID(ctx, owner, repo, runID) @@ -867,12 +811,7 @@ func RerunWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFun "status_code": resp.StatusCode, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } } @@ -914,7 +853,7 @@ func RerunFailedJobs(getClient GetClientFn, t translations.TranslationHelperFunc client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } resp, err := client.Actions.RerunFailedJobsByID(ctx, owner, repo, runID) @@ -930,12 +869,7 @@ func RerunFailedJobs(getClient GetClientFn, t translations.TranslationHelperFunc "status_code": resp.StatusCode, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } } @@ -977,7 +911,7 @@ func CancelWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFu client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } resp, err := client.Actions.CancelWorkflowRunByID(ctx, owner, repo, runID) @@ -993,12 +927,7 @@ func CancelWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFu "status_code": resp.StatusCode, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } } @@ -1056,7 +985,7 @@ func ListWorkflowRunArtifacts(getClient GetClientFn, t translations.TranslationH client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } // Set up list options @@ -1071,12 +1000,7 @@ func ListWorkflowRunArtifacts(getClient GetClientFn, t translations.TranslationH } defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(artifacts) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(artifacts) } } @@ -1118,7 +1042,7 @@ func DownloadWorkflowRunArtifact(getClient GetClientFn, t translations.Translati client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } // Get the download URL for the artifact @@ -1136,12 +1060,7 @@ func DownloadWorkflowRunArtifact(getClient GetClientFn, t translations.Translati "artifact_id": artifactID, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } } @@ -1184,7 +1103,7 @@ func DeleteWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelp client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } resp, err := client.Actions.DeleteWorkflowRunLogs(ctx, owner, repo, runID) @@ -1200,12 +1119,7 @@ func DeleteWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelp "status_code": resp.StatusCode, } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } } @@ -1247,7 +1161,7 @@ func GetWorkflowRunUsage(getClient GetClientFn, t translations.TranslationHelper client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } usage, resp, err := client.Actions.GetWorkflowRunUsageByID(ctx, owner, repo, runID) @@ -1256,11 +1170,6 @@ func GetWorkflowRunUsage(getClient GetClientFn, t translations.TranslationHelper } defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(usage) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(usage) } } diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 3b07692c0..1b0eb5bf0 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -2,12 +2,8 @@ package github import ( "context" - "encoding/json" "fmt" - "io" - "net/http" - ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -50,33 +46,19 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get alert", - resp, - err, - ), nil - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + if errResult, hasError := handleAPIResponse(ctx, resp, err, "failed to get alert"); hasError { + if errResult != nil { + return errResult, nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil - } - - r, err := json.Marshal(alert) - if err != nil { - return nil, fmt.Errorf("failed to marshal alert: %w", err) + return nil, fmt.Errorf(ErrReadResponseBody, err) } - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(alert) } } @@ -139,31 +121,17 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName}) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list alerts", - resp, - err, - ), nil - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + if errResult, hasError := handleAPIResponse(ctx, resp, err, "failed to list alerts"); hasError { + if errResult != nil { + return errResult, nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil - } - - r, err := json.Marshal(alerts) - if err != nil { - return nil, fmt.Errorf("failed to marshal alerts: %w", err) + return nil, fmt.Errorf(ErrReadResponseBody, err) } - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(alerts) } } diff --git a/pkg/github/errors_constants.go b/pkg/github/errors_constants.go new file mode 100644 index 000000000..530edc9d6 --- /dev/null +++ b/pkg/github/errors_constants.go @@ -0,0 +1,8 @@ +package github + +const ( + ErrMarshalResponse = "failed to marshal response: %w" + ErrGetGitHubClient = "failed to get GitHub client: %w" + ErrReadResponseBody = "failed to read response body: %w" + RepoURIPrefix = "repo://" +) diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 6121786d2..79afae7a5 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -2,7 +2,6 @@ package github import ( "context" - "encoding/json" "fmt" "io" "net/http" @@ -54,7 +53,7 @@ func GetIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) if err != nil { @@ -70,12 +69,7 @@ func GetIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool return mcp.NewToolResultError(fmt.Sprintf("failed to get issue: %s", string(body))), nil } - r, err := json.Marshal(issue) - if err != nil { - return nil, fmt.Errorf("failed to marshal issue: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(issue) } } @@ -128,7 +122,7 @@ func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } createdComment, resp, err := client.Issues.CreateComment(ctx, owner, repo, issueNumber, comment) if err != nil { @@ -144,12 +138,7 @@ func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc return mcp.NewToolResultError(fmt.Sprintf("failed to create comment: %s", string(body))), nil } - r, err := json.Marshal(createdComment) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(createdComment) } } @@ -295,7 +284,7 @@ func CreateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (t client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } issue, resp, err := client.Issues.Create(ctx, owner, repo, issueRequest) if err != nil { @@ -311,12 +300,7 @@ func CreateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (t return mcp.NewToolResultError(fmt.Sprintf("failed to create issue: %s", string(body))), nil } - r, err := json.Marshal(issue) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(issue) } } @@ -417,7 +401,7 @@ func ListIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (to client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } issues, resp, err := client.Issues.ListByRepo(ctx, owner, repo, opts) if err != nil { @@ -433,12 +417,7 @@ func ListIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (to return mcp.NewToolResultError(fmt.Sprintf("failed to list issues: %s", string(body))), nil } - r, err := json.Marshal(issues) - if err != nil { - return nil, fmt.Errorf("failed to marshal issues: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(issues) } } @@ -563,7 +542,7 @@ func UpdateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (t client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } updatedIssue, resp, err := client.Issues.Edit(ctx, owner, repo, issueNumber, issueRequest) if err != nil { @@ -579,12 +558,7 @@ func UpdateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (t return mcp.NewToolResultError(fmt.Sprintf("failed to update issue: %s", string(body))), nil } - r, err := json.Marshal(updatedIssue) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(updatedIssue) } } @@ -646,7 +620,7 @@ func GetIssueComments(getClient GetClientFn, t translations.TranslationHelperFun client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } comments, resp, err := client.Issues.ListComments(ctx, owner, repo, issueNumber, opts) if err != nil { @@ -662,12 +636,7 @@ func GetIssueComments(getClient GetClientFn, t translations.TranslationHelperFun return mcp.NewToolResultError(fmt.Sprintf("failed to get issue comments: %s", string(body))), nil } - r, err := json.Marshal(comments) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(comments) } } @@ -744,7 +713,7 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio client, err := getGQLClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } // Firstly, we try to find the copilot bot in the suggested actors for the repository. diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index b6b6bfd79..79ded0685 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -51,7 +51,7 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } filter, err := OptionalParam[string](request, "filter") @@ -138,7 +138,7 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu // Marshal response to JSON r, err := json.Marshal(notifications) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil @@ -162,7 +162,7 @@ func DismissNotification(getclient GetClientFn, t translations.TranslationHelper func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { client, err := getclient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } threadID, err := RequiredParam[string](request, "threadID") @@ -233,7 +233,7 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } lastReadAt, err := OptionalParam[string](request, "lastReadAt") @@ -307,7 +307,7 @@ func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHel func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } notificationID, err := RequiredParam[string](request, "notificationID") @@ -335,7 +335,7 @@ func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHel r, err := json.Marshal(thread) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil @@ -370,7 +370,7 @@ func ManageNotificationSubscription(getClient GetClientFn, t translations.Transl func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } notificationID, err := RequiredParam[string](request, "notificationID") @@ -422,7 +422,7 @@ func ManageNotificationSubscription(getClient GetClientFn, t translations.Transl r, err := json.Marshal(result) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil } @@ -459,7 +459,7 @@ func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translati func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } owner, err := RequiredParam[string](request, "owner") diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index bad822b13..96a949cbc 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -2,7 +2,6 @@ package github import ( "context" - "encoding/json" "fmt" "io" "net/http" @@ -25,27 +24,17 @@ func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) Title: t("TOOL_GET_PULL_REQUEST_USER_TITLE", "Get pull request details"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := RequiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := RequiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + owner, repo, errResult := parseOwnerRepo(request) + if errResult != nil { + return errResult, nil } pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { @@ -54,32 +43,18 @@ func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get pull request", - resp, - err, - ), nil - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + if errResult, hasError := handleAPIResponse(ctx, resp, err, "failed to get pull request"); hasError { + if errResult != nil { + return errResult, nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request: %s", string(body))), nil - } - - r, err := json.Marshal(pr) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrReadResponseBody, err) } - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(pr) } } @@ -91,14 +66,8 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu Title: t("TOOL_CREATE_PULL_REQUEST_USER_TITLE", "Open new pull request"), ReadOnlyHint: ToBoolPtr(false), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithString("title", mcp.Required(), mcp.Description("PR title"), @@ -122,13 +91,9 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := RequiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := RequiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + owner, repo, errResult := parseOwnerRepo(request) + if errResult != nil { + return errResult, nil } title, err := RequiredParam[string](request, "title") if err != nil { @@ -173,7 +138,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) if err != nil { @@ -193,12 +158,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(body))), nil } - r, err := json.Marshal(pr) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(pr) } } @@ -210,14 +170,8 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu Title: t("TOOL_UPDATE_PULL_REQUEST_USER_TITLE", "Edit pull request"), ReadOnlyHint: ToBoolPtr(false), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number to update"), @@ -240,13 +194,9 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := RequiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := RequiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + owner, repo, errResult := parseOwnerRepo(request) + if errResult != nil { + return errResult, nil } pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { @@ -298,7 +248,7 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) if err != nil { @@ -318,12 +268,7 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil } - r, err := json.Marshal(pr) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(pr) } } @@ -335,14 +280,8 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun Title: t("TOOL_LIST_PULL_REQUESTS_USER_TITLE", "List pull requests"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithString("state", mcp.Description("Filter by state"), mcp.Enum("open", "closed", "all"), @@ -364,13 +303,9 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun WithPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := RequiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := RequiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + owner, repo, errResult := parseOwnerRepo(request) + if errResult != nil { + return errResult, nil } state, err := OptionalParam[string](request, "state") if err != nil { @@ -411,7 +346,7 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) if err != nil { @@ -431,12 +366,7 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun return mcp.NewToolResultError(fmt.Sprintf("failed to list pull requests: %s", string(body))), nil } - r, err := json.Marshal(prs) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(prs) } } @@ -448,14 +378,8 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun Title: t("TOOL_MERGE_PULL_REQUEST_USER_TITLE", "Merge pull request"), ReadOnlyHint: ToBoolPtr(false), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), @@ -472,13 +396,9 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := RequiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := RequiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + owner, repo, errResult := parseOwnerRepo(request) + if errResult != nil { + return errResult, nil } pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { @@ -504,7 +424,7 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } result, resp, err := client.PullRequests.Merge(ctx, owner, repo, pullNumber, commitMessage, options) if err != nil { @@ -524,12 +444,7 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun return mcp.NewToolResultError(fmt.Sprintf("failed to merge pull request: %s", string(body))), nil } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } } @@ -586,14 +501,8 @@ func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelper Title: t("TOOL_GET_PULL_REQUEST_FILES_USER_TITLE", "Get pull request files"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), @@ -601,13 +510,9 @@ func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelper WithPagination(), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := RequiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := RequiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + owner, repo, errResult := parseOwnerRepo(request) + if errResult != nil { + return errResult, nil } pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { @@ -620,7 +525,7 @@ func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelper client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } opts := &github.ListOptions{ PerPage: pagination.perPage, @@ -644,12 +549,7 @@ func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelper return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request files: %s", string(body))), nil } - r, err := json.Marshal(files) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(files) } } @@ -661,27 +561,17 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe Title: t("TOOL_GET_PULL_REQUEST_STATUS_USER_TITLE", "Get pull request status checks"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := RequiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := RequiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + owner, repo, errResult := parseOwnerRepo(request) + if errResult != nil { + return errResult, nil } pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { @@ -690,7 +580,7 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe // First get the PR to find the head SHA client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { @@ -729,12 +619,7 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe return mcp.NewToolResultError(fmt.Sprintf("failed to get combined status: %s", string(body))), nil } - r, err := json.Marshal(status) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(status) } } @@ -746,14 +631,8 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe Title: t("TOOL_UPDATE_PULL_REQUEST_BRANCH_USER_TITLE", "Update pull request branch"), ReadOnlyHint: ToBoolPtr(false), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), @@ -763,13 +642,9 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := RequiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := RequiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + owner, repo, errResult := parseOwnerRepo(request) + if errResult != nil { + return errResult, nil } pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { @@ -786,7 +661,7 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } result, resp, err := client.PullRequests.UpdateBranch(ctx, owner, repo, pullNumber, opts) if err != nil { @@ -811,12 +686,7 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request branch: %s", string(body))), nil } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } } @@ -828,27 +698,17 @@ func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHel Title: t("TOOL_GET_PULL_REQUEST_COMMENTS_USER_TITLE", "Get pull request comments"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := RequiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := RequiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + owner, repo, errResult := parseOwnerRepo(request) + if errResult != nil { + return errResult, nil } pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { @@ -863,7 +723,7 @@ func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHel client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } comments, resp, err := client.PullRequests.ListComments(ctx, owner, repo, pullNumber, opts) if err != nil { @@ -883,12 +743,7 @@ func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHel return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request comments: %s", string(body))), nil } - r, err := json.Marshal(comments) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(comments) } } @@ -900,27 +755,17 @@ func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelp Title: t("TOOL_GET_PULL_REQUEST_REVIEWS_USER_TITLE", "Get pull request reviews"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := RequiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := RequiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + owner, repo, errResult := parseOwnerRepo(request) + if errResult != nil { + return errResult, nil } pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { @@ -929,7 +774,7 @@ func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelp client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { @@ -949,12 +794,7 @@ func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelp return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request reviews: %s", string(body))), nil } - r, err := json.Marshal(reviews) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(reviews) } } @@ -968,14 +808,8 @@ func CreateAndSubmitPullRequestReview(getGQLClient GetGQLClientFn, t translation // Either we need the PR GQL Id directly, or we need owner, repo and PR number to look it up. // Since our other Pull Request tools are working with the REST Client, will handle the lookup // internally for now. - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), @@ -1071,14 +905,8 @@ func CreatePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. // Either we need the PR GQL Id directly, or we need owner, repo and PR number to look it up. // Since our other Pull Request tools are working with the REST Client, will handle the lookup // internally for now. - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), @@ -1169,14 +997,8 @@ func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t t // mcp.Required(), // mcp.Description("The ID of the pull request review to add a comment to"), // ), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), @@ -1330,14 +1152,8 @@ func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. // add a new tool to get that ID for clients that aren't in the same context as the original pending review // creation. So for now, we'll just accept the owner, repo and pull number and assume this is submitting // the latest review from a user, since only one can be active at a time. - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), @@ -1464,14 +1280,8 @@ func DeletePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. // add a new tool to get that ID for clients that aren't in the same context as the original pending review // creation. So for now, we'll just accept the owner, repo and pull number and assume this is deleting // the latest pending review from a user, since only one can be active at a time. - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), @@ -1579,14 +1389,8 @@ func GetPullRequestDiff(getClient GetClientFn, t translations.TranslationHelperF Title: t("TOOL_GET_PULL_REQUEST_DIFF_USER_TITLE", "Get pull request diff"), ReadOnlyHint: ToBoolPtr(true), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), @@ -1647,14 +1451,8 @@ func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelpe Title: t("TOOL_REQUEST_COPILOT_REVIEW_USER_TITLE", "Request Copilot review"), ReadOnlyHint: ToBoolPtr(false), }), - mcp.WithString("owner", - mcp.Required(), - mcp.Description("Repository owner"), - ), - mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), - ), + withOwnerParam(), + withRepoParam(), mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 29f776a05..44c794106 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -65,7 +65,7 @@ func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (too client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) if err != nil { @@ -87,7 +87,7 @@ func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (too r, err := json.Marshal(commit) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil @@ -155,7 +155,7 @@ func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (t client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) if err != nil { @@ -177,7 +177,7 @@ func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (t r, err := json.Marshal(commits) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil @@ -225,7 +225,7 @@ func ListBranches(getClient GetClientFn, t translations.TranslationHelperFunc) ( client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } branches, resp, err := client.Repositories.ListBranches(ctx, owner, repo, opts) @@ -248,7 +248,7 @@ func ListBranches(getClient GetClientFn, t translations.TranslationHelperFunc) ( r, err := json.Marshal(branches) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil @@ -339,7 +339,7 @@ func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperF // Create or update the file client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts) if err != nil { @@ -361,7 +361,7 @@ func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperF r, err := json.Marshal(fileContent) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil @@ -417,7 +417,7 @@ func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFun client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } createdRepo, resp, err := client.Repositories.Create(ctx, "", repo) if err != nil { @@ -439,13 +439,94 @@ func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFun r, err := json.Marshal(createdRepo) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil } } +func resolvePRReference(ctx context.Context, getClient GetClientFn, owner, repo, ref string) (string, error) { + if !strings.HasPrefix(ref, "refs/pull/") { + return "", nil + } + + prNumber := strings.TrimSuffix(strings.TrimPrefix(ref, "refs/pull/"), "/head") + if len(prNumber) == 0 { + return "", nil + } + + githubClient, err := getClient(ctx) + if err != nil { + return "", fmt.Errorf(ErrGetGitHubClient, err) + } + + prNum, err := strconv.Atoi(prNumber) + if err != nil { + return "", fmt.Errorf("invalid pull request number: %w", err) + } + + pr, _, err := githubClient.PullRequests.Get(ctx, owner, repo, prNum) + if err != nil { + return "", fmt.Errorf("failed to get pull request: %w", err) + } + + return pr.GetHead().GetSHA(), nil +} + +func buildResourceURI(owner, repo, path, sha, ref string) (string, error) { + switch { + case sha != "": + return url.JoinPath(RepoURIPrefix, owner, repo, "sha", sha, "contents", path) + case ref != "": + return url.JoinPath(RepoURIPrefix, owner, repo, ref, "contents", path) + default: + return url.JoinPath(RepoURIPrefix, owner, repo, "contents", path) + } +} + +func tryGetRawContent(ctx context.Context, getRawClient raw.GetRawClientFn, owner, repo, path string, rawOpts *raw.ContentOpts) (*mcp.CallToolResult, bool, error) { + rawClient, err := getRawClient(ctx) + if err != nil { + return mcp.NewToolResultError("failed to get GitHub raw content client"), true, nil + } + + resp, err := rawClient.GetRawContent(ctx, owner, repo, path, rawOpts) + if err != nil { + return mcp.NewToolResultError("failed to get raw repository content"), true, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, false, nil + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return mcp.NewToolResultError("failed to read response body"), true, nil + } + + contentType := resp.Header.Get("Content-Type") + resourceURI, err := buildResourceURI(owner, repo, path, rawOpts.SHA, rawOpts.Ref) + if err != nil { + return nil, true, err + } + + if strings.HasPrefix(contentType, "application") || strings.HasPrefix(contentType, "text") { + return mcp.NewToolResultResource("successfully downloaded text file", mcp.TextResourceContents{ + URI: resourceURI, + Text: string(body), + MIMEType: contentType, + }), true, nil + } + + return mcp.NewToolResultResource("successfully downloaded binary file", mcp.BlobResourceContents{ + URI: resourceURI, + Blob: base64.StdEncoding.EncodeToString(body), + MIMEType: contentType, + }), true, nil +} + // GetFileContents creates a tool to get the contents of a file or directory from a GitHub repository. func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_file_contents", @@ -495,88 +576,27 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t return mcp.NewToolResultError(err.Error()), nil } - rawOpts := &raw.ContentOpts{} - - if strings.HasPrefix(ref, "refs/pull/") { - prNumber := strings.TrimSuffix(strings.TrimPrefix(ref, "refs/pull/"), "/head") - if len(prNumber) > 0 { - // fetch the PR from the API to get the latest commit and use SHA - githubClient, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - prNum, err := strconv.Atoi(prNumber) - if err != nil { - return nil, fmt.Errorf("invalid pull request number: %w", err) - } - pr, _, err := githubClient.PullRequests.Get(ctx, owner, repo, prNum) - if err != nil { - return nil, fmt.Errorf("failed to get pull request: %w", err) - } - sha = pr.GetHead().GetSHA() - ref = "" - } + resolvedSHA, err := resolvePRReference(ctx, getClient, owner, repo, ref) + if err != nil { + return nil, err + } + if resolvedSHA != "" { + sha = resolvedSHA + ref = "" } - rawOpts.SHA = sha - rawOpts.Ref = ref + rawOpts := &raw.ContentOpts{ + SHA: sha, + Ref: ref, + } - // If the path is (most likely) not to be a directory, we will first try to get the raw content from the GitHub raw content API. if path != "" && !strings.HasSuffix(path, "/") { - - rawClient, err := getRawClient(ctx) + result, found, err := tryGetRawContent(ctx, getRawClient, owner, repo, path, rawOpts) if err != nil { - return mcp.NewToolResultError("failed to get GitHub raw content client"), nil + return nil, err } - resp, err := rawClient.GetRawContent(ctx, owner, repo, path, rawOpts) - if err != nil { - return mcp.NewToolResultError("failed to get raw repository content"), nil - } - defer func() { - _ = resp.Body.Close() - }() - - if resp.StatusCode == http.StatusOK { - // If the raw content is found, return it directly - body, err := io.ReadAll(resp.Body) - if err != nil { - return mcp.NewToolResultError("failed to read response body"), nil - } - contentType := resp.Header.Get("Content-Type") - - var resourceURI string - switch { - case sha != "": - resourceURI, err = url.JoinPath("repo://", owner, repo, "sha", sha, "contents", path) - if err != nil { - return nil, fmt.Errorf("failed to create resource URI: %w", err) - } - case ref != "": - resourceURI, err = url.JoinPath("repo://", owner, repo, ref, "contents", path) - if err != nil { - return nil, fmt.Errorf("failed to create resource URI: %w", err) - } - default: - resourceURI, err = url.JoinPath("repo://", owner, repo, "contents", path) - if err != nil { - return nil, fmt.Errorf("failed to create resource URI: %w", err) - } - } - - if strings.HasPrefix(contentType, "application") || strings.HasPrefix(contentType, "text") { - return mcp.NewToolResultResource("successfully downloaded text file", mcp.TextResourceContents{ - URI: resourceURI, - Text: string(body), - MIMEType: contentType, - }), nil - } - - return mcp.NewToolResultResource("successfully downloaded binary file", mcp.BlobResourceContents{ - URI: resourceURI, - Blob: base64.StdEncoding.EncodeToString(body), - MIMEType: contentType, - }), nil - + if found { + return result, nil } } @@ -682,7 +702,7 @@ func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) r, err := json.Marshal(forkedRepo) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil @@ -694,6 +714,98 @@ func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) // This is because REST file deletion endpoint (and client.Repositories.DeleteFile) don't add commit signing to the deletion commit, // unlike how the endpoint backing the create_or_update_files tool does. This appears to be a quirk of the API. // The approach implemented here gets automatic commit signing when used with either the github-actions user or as an app, +func getBranchReference(ctx context.Context, client *github.Client, owner, repo, branch string) (*github.Reference, error) { + ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) + if err != nil { + return nil, fmt.Errorf("failed to get branch reference: %w", err) + } + defer func() { _ = resp.Body.Close() }() + return ref, nil +} + +func getCommitWithValidation(ctx context.Context, client *github.Client, owner, repo, sha string) (*github.Commit, *github.Response, error) { + commit, resp, err := client.Git.GetCommit(ctx, owner, repo, sha) + if err != nil { + return nil, resp, err + } + + if resp.StatusCode != http.StatusOK { + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, resp, fmt.Errorf(ErrReadResponseBody, readErr) + } + return nil, resp, fmt.Errorf("failed to get commit: %s", string(body)) + } + + return commit, resp, nil +} + +func createTreeForDeletion(ctx context.Context, client *github.Client, owner, repo, path, baseTreeSHA string) (*github.Tree, *github.Response, error) { + treeEntries := []*github.TreeEntry{ + { + Path: github.Ptr(path), + Mode: github.Ptr("100644"), + Type: github.Ptr("blob"), + SHA: nil, + }, + } + + newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, baseTreeSHA, treeEntries) + if err != nil { + return nil, resp, err + } + + if resp.StatusCode != http.StatusCreated { + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, resp, fmt.Errorf(ErrReadResponseBody, readErr) + } + return nil, resp, fmt.Errorf("failed to create tree: %s", string(body)) + } + + return newTree, resp, nil +} + +func createCommitWithValidation(ctx context.Context, client *github.Client, owner, repo, message string, tree *github.Tree, parentSHA string) (*github.Commit, *github.Response, error) { + commit := &github.Commit{ + Message: github.Ptr(message), + Tree: tree, + Parents: []*github.Commit{{SHA: github.Ptr(parentSHA)}}, + } + + newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) + if err != nil { + return nil, resp, err + } + + if resp.StatusCode != http.StatusCreated { + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, resp, fmt.Errorf(ErrReadResponseBody, readErr) + } + return nil, resp, fmt.Errorf("failed to create commit: %s", string(body)) + } + + return newCommit, resp, nil +} + +func updateReferenceWithValidation(ctx context.Context, client *github.Client, owner, repo string, ref *github.Reference) (*github.Response, error) { + _, resp, err := client.Git.UpdateRef(ctx, owner, repo, ref, false) + if err != nil { + return resp, err + } + + if resp.StatusCode != http.StatusOK { + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return resp, fmt.Errorf(ErrReadResponseBody, readErr) + } + return resp, fmt.Errorf("failed to update reference: %s", string(body)) + } + + return resp, nil +} + // both of which suit an LLM well. func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("delete_file", @@ -748,108 +860,39 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } - // Get the reference for the branch - ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) + ref, err := getBranchReference(ctx, client, owner, repo, branch) if err != nil { - return nil, fmt.Errorf("failed to get branch reference: %w", err) + return nil, err } - defer func() { _ = resp.Body.Close() }() - // Get the commit object that the branch points to - baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) + baseCommit, resp, err := getCommitWithValidation(ctx, client, owner, repo, *ref.Object.SHA) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get base commit", - resp, - err, - ), nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get base commit", resp, err), nil } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return mcp.NewToolResultError(fmt.Sprintf("failed to get commit: %s", string(body))), nil - } - - // Create a tree entry for the file deletion by setting SHA to nil - treeEntries := []*github.TreeEntry{ - { - Path: github.Ptr(path), - Mode: github.Ptr("100644"), // Regular file mode - Type: github.Ptr("blob"), - SHA: nil, // Setting SHA to nil deletes the file - }, - } - - // Create a new tree with the deletion - newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, treeEntries) + newTree, resp, err := createTreeForDeletion(ctx, client, owner, repo, path, *baseCommit.Tree.SHA) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create tree", - resp, - err, - ), nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create tree", resp, err), nil } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return mcp.NewToolResultError(fmt.Sprintf("failed to create tree: %s", string(body))), nil - } - - // Create a new commit with the new tree - commit := &github.Commit{ - Message: github.Ptr(message), - Tree: newTree, - Parents: []*github.Commit{{SHA: baseCommit.SHA}}, - } - newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) + newCommit, resp, err := createCommitWithValidation(ctx, client, owner, repo, message, newTree, *baseCommit.SHA) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create commit", - resp, - err, - ), nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create commit", resp, err), nil } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return mcp.NewToolResultError(fmt.Sprintf("failed to create commit: %s", string(body))), nil - } - - // Update the branch reference to point to the new commit ref.Object.SHA = newCommit.SHA - _, resp, err = client.Git.UpdateRef(ctx, owner, repo, ref, false) + resp, err = updateReferenceWithValidation(ctx, client, owner, repo, ref) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update reference", - resp, - err, - ), nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to update reference", resp, err), nil } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return mcp.NewToolResultError(fmt.Sprintf("failed to update reference: %s", string(body))), nil - } - // Create a response similar to what the DeleteFile API would return response := map[string]interface{}{ "commit": newCommit, @@ -858,7 +901,7 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to r, err := json.Marshal(response) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil @@ -959,7 +1002,7 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) ( r, err := json.Marshal(createdRef) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil @@ -1131,7 +1174,7 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too r, err := json.Marshal(updatedRef) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil @@ -1200,7 +1243,7 @@ func ListTags(getClient GetClientFn, t translations.TranslationHelperFunc) (tool r, err := json.Marshal(tags) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil @@ -1287,7 +1330,7 @@ func GetTag(getClient GetClientFn, t translations.TranslationHelperFunc) (tool m r, err := json.Marshal(tagObj) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrMarshalResponse, err) } return mcp.NewToolResultText(string(r)), nil diff --git a/pkg/github/repository_resource.go b/pkg/github/repository_resource.go index a454db630..3a2be9f5a 100644 --- a/pkg/github/repository_resource.go +++ b/pkg/github/repository_resource.go @@ -113,7 +113,7 @@ func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.G // fetch the PR from the API to get the latest commit and use SHA githubClient, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } prNum, err := strconv.Atoi(prNumber[0]) if err != nil { diff --git a/pkg/github/search.go b/pkg/github/search.go index 5106b84d8..9371f5c6b 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -2,11 +2,8 @@ package github import ( "context" - "encoding/json" "fmt" - "io" - ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -46,32 +43,18 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } result, resp, err := client.Search.Repositories(ctx, query, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search repositories with query '%s'", query), - resp, - err, - ), nil - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + if errResult, hasError := handleAPIResponse(ctx, resp, err, fmt.Sprintf("failed to search repositories with query '%s'", query)); hasError { + if errResult != nil { + return errResult, nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to search repositories: %s", string(body))), nil - } - - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return nil, fmt.Errorf(ErrReadResponseBody, err) } - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } } @@ -125,33 +108,19 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (to client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } result, resp, err := client.Search.Code(ctx, query, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search code with query '%s'", query), - resp, - err, - ), nil - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + if errResult, hasError := handleAPIResponse(ctx, resp, err, fmt.Sprintf("failed to search code with query '%s'", query)); hasError { + if errResult != nil { + return errResult, nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to search code: %s", string(body))), nil + return nil, fmt.Errorf(ErrReadResponseBody, err) } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(result) } } @@ -198,26 +167,17 @@ func userOrOrgHandler(accountType string, getClient GetClientFn) server.ToolHand client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } searchQuery := "type:" + accountType + " " + query result, resp, err := client.Search.Users(ctx, searchQuery, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search %ss with query '%s'", accountType, query), - resp, - err, - ), nil - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + if errResult, hasError := handleAPIResponse(ctx, resp, err, fmt.Sprintf("failed to search %ss with query '%s'", accountType, query)); hasError { + if errResult != nil { + return errResult, nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to search %ss: %s", accountType, string(body))), nil + return nil, fmt.Errorf(ErrReadResponseBody, err) } minimalUsers := make([]MinimalUser, 0, len(result.Users)) @@ -249,11 +209,7 @@ func userOrOrgHandler(accountType string, getClient GetClientFn) server.ToolHand minimalResp.IncompleteResults = *result.IncompleteResults } - r, err := json.Marshal(minimalResp) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(minimalResp) } } diff --git a/pkg/github/search_utils.go b/pkg/github/search_utils.go index 6642dad8f..0ec8b99dc 100644 --- a/pkg/github/search_utils.go +++ b/pkg/github/search_utils.go @@ -63,7 +63,7 @@ func searchHandler( client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("%s: failed to get GitHub client: %w", errorPrefix, err) + return nil, fmt.Errorf("%s: "+ErrGetGitHubClient, errorPrefix, err) } result, resp, err := client.Search.Issues(ctx, query, opts) if err != nil { @@ -74,14 +74,14 @@ func searchHandler( if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("%s: failed to read response body: %w", errorPrefix, err) + return nil, fmt.Errorf("%s: "+ErrReadResponseBody, errorPrefix, err) } return mcp.NewToolResultError(fmt.Sprintf("%s: %s", errorPrefix, string(body))), nil } r, err := json.Marshal(result) if err != nil { - return nil, fmt.Errorf("%s: failed to marshal response: %w", errorPrefix, err) + return nil, fmt.Errorf("%s: "+ErrMarshalResponse, errorPrefix, err) } return mcp.NewToolResultText(string(r)), nil diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index bea6df2ae..e51b65b34 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -2,12 +2,8 @@ package github import ( "context" - "encoding/json" "fmt" - "io" - "net/http" - ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -51,33 +47,19 @@ func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHel client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } alert, resp, err := client.SecretScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get alert with number '%d'", alertNumber), - resp, - err, - ), nil - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + if errResult, hasError := handleAPIResponse(ctx, resp, err, fmt.Sprintf("failed to get alert with number '%d'", alertNumber)); hasError { + if errResult != nil { + return errResult, nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil - } - - r, err := json.Marshal(alert) - if err != nil { - return nil, fmt.Errorf("failed to marshal alert: %w", err) + return nil, fmt.Errorf(ErrReadResponseBody, err) } - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(alert) } } @@ -133,31 +115,17 @@ func ListSecretScanningAlerts(getClient GetClientFn, t translations.TranslationH client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return nil, fmt.Errorf(ErrGetGitHubClient, err) } alerts, resp, err := client.SecretScanning.ListAlertsForRepo(ctx, owner, repo, &github.SecretScanningAlertListOptions{State: state, SecretType: secretType, Resolution: resolution}) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), - resp, - err, - ), nil - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + if errResult, hasError := handleAPIResponse(ctx, resp, err, fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo)); hasError { + if errResult != nil { + return errResult, nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil - } - - r, err := json.Marshal(alerts) - if err != nil { - return nil, fmt.Errorf("failed to marshal alerts: %w", err) + return nil, fmt.Errorf(ErrReadResponseBody, err) } - return mcp.NewToolResultText(string(r)), nil + return marshalAndReturn(alerts) } } diff --git a/pkg/github/server.go b/pkg/github/server.go index 85d078f1b..a88f05ac2 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -1,17 +1,83 @@ package github import ( + "context" "encoding/json" "errors" "fmt" + "io" + "net/http" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) +const ( + DescriptionPerPage = "Number of results per page" + DescriptionPage = "Page number for pagination" +) + // NewServer creates a new GitHub MCP server with the specified GH client and logger. +func handleAPIResponse(ctx context.Context, resp *github.Response, err error, operation string) (*mcp.CallToolResult, bool) { + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, operation, resp, err), true + } + if resp == nil { + return nil, false + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, true + } + return mcp.NewToolResultError(fmt.Sprintf("%s: %s", operation, string(body))), true + } + return nil, false +} + +func marshalAndReturn(result interface{}) (*mcp.CallToolResult, error) { + r, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf(ErrMarshalResponse, err) + } + return mcp.NewToolResultText(string(r)), nil +} + +func withOwnerParam() mcp.ToolOption { + return mcp.WithString("owner", mcp.Required(), mcp.Description(DescriptionRepositoryOwner)) +} + +func withRepoParam() mcp.ToolOption { + return mcp.WithString("repo", mcp.Required(), mcp.Description(DescriptionRepositoryName)) +} + +func withPerPageParam() mcp.ToolOption { + return mcp.WithNumber("per_page", mcp.Description(DescriptionPerPage)) +} + +func withPageParam() mcp.ToolOption { + return mcp.WithNumber("page", mcp.Description(DescriptionPage)) +} + +func parseOwnerRepo(request mcp.CallToolRequest) (owner, repo string, result *mcp.CallToolResult) { + owner, err := RequiredParam[string](request, "owner") + if err != nil { + return "", "", mcp.NewToolResultError(err.Error()) + } + + repo, err = RequiredParam[string](request, "repo") + if err != nil { + return "", "", mcp.NewToolResultError(err.Error()) + } + + return owner, repo, nil +} + func NewServer(version string, opts ...server.ServerOption) *server.MCPServer { // Add default options defaultOpts := []server.ServerOption{