Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 24 additions & 90 deletions pkg/github/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,17 @@ import (
const (
DescriptionRepositoryOwner = "Repository owner"
DescriptionRepositoryName = "Repository name"
errFailedToMarshalResponse = "failed to marshal response: %w"
)

func marshalToToolResult(data interface{}) (*mcp.CallToolResult, error) {
r, err := json.Marshal(data)
if err != nil {
return nil, fmt.Errorf(errFailedToMarshalResponse, err)
}
return mcp.NewToolResultText(string(r)), nil
}

// ListWorkflows creates a tool to list workflows in a repository
func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("list_workflows",
Expand Down Expand Up @@ -81,12 +90,7 @@ func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc)
}
defer func() { _ = resp.Body.Close() }()

r, err := json.Marshal(workflows)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(workflows)
}
}

Expand Down Expand Up @@ -229,12 +233,7 @@ func ListWorkflowRuns(getClient GetClientFn, t translations.TranslationHelperFun
}
defer func() { _ = resp.Body.Close() }()

r, err := json.Marshal(workflowRuns)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(workflowRuns)
}
}

Expand Down Expand Up @@ -328,12 +327,7 @@ func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (t
"status_code": resp.StatusCode,
}

r, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(result)
}
}

Expand Down Expand Up @@ -384,12 +378,7 @@ func GetWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc)
}
defer func() { _ = resp.Body.Close() }()

r, err := json.Marshal(workflowRun)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(workflowRun)
}
}

Expand Down Expand Up @@ -450,12 +439,7 @@ func GetWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelperF
"optimization_tip": "Use: get_job_logs with parameters {run_id: " + fmt.Sprintf("%d", runID) + ", failed_only: true} for more efficient failed job debugging",
}

r, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(result)
}
}

Expand Down Expand Up @@ -547,12 +531,7 @@ func ListWorkflowJobs(getClient GetClientFn, t translations.TranslationHelperFun
"optimization_tip": "For debugging failed jobs, consider using get_job_logs with failed_only=true and run_id=" + fmt.Sprintf("%d", runID) + " to get logs directly without needing to list jobs first",
}

r, err := json.Marshal(response)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(response)
}
}

Expand Down Expand Up @@ -707,12 +686,7 @@ func handleFailedJobLogs(ctx context.Context, client *github.Client, owner, repo
"return_format": map[string]bool{"content": returnContent, "urls": !returnContent},
}

r, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(result)
}

// handleSingleJobLogs gets logs for a single job
Expand All @@ -722,12 +696,7 @@ func handleSingleJobLogs(ctx context.Context, client *github.Client, owner, repo
return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get job logs", resp, err), nil
}

r, err := json.Marshal(jobResult)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(jobResult)
}

// getJobLogData retrieves log data for a single job, either as URL or content
Expand Down Expand Up @@ -867,12 +836,7 @@ func RerunWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFun
"status_code": resp.StatusCode,
}

r, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(result)
}
}

Expand Down Expand Up @@ -930,12 +894,7 @@ func RerunFailedJobs(getClient GetClientFn, t translations.TranslationHelperFunc
"status_code": resp.StatusCode,
}

r, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(result)
}
}

Expand Down Expand Up @@ -993,12 +952,7 @@ func CancelWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFu
"status_code": resp.StatusCode,
}

r, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(result)
}
}

Expand Down Expand Up @@ -1071,12 +1025,7 @@ func ListWorkflowRunArtifacts(getClient GetClientFn, t translations.TranslationH
}
defer func() { _ = resp.Body.Close() }()

r, err := json.Marshal(artifacts)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(artifacts)
}
}

Expand Down Expand Up @@ -1136,12 +1085,7 @@ func DownloadWorkflowRunArtifact(getClient GetClientFn, t translations.Translati
"artifact_id": artifactID,
}

r, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(result)
}
}

Expand Down Expand Up @@ -1200,12 +1144,7 @@ func DeleteWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelp
"status_code": resp.StatusCode,
}

r, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(result)
}
}

Expand Down Expand Up @@ -1256,11 +1195,6 @@ func GetWorkflowRunUsage(getClient GetClientFn, t translations.TranslationHelper
}
defer func() { _ = resp.Body.Close() }()

r, err := json.Marshal(usage)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return marshalToToolResult(usage)
}
}
47 changes: 31 additions & 16 deletions pkg/github/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"strconv"
"strings"
"time"

ghErrors "github.com/github/github-mcp-server/pkg/errors"
Expand All @@ -22,6 +23,29 @@ const (
FilterOnlyParticipating = "only_participating"
)

func parseThreadID(threadID string) (int64, error) {
threadIDInt, err := strconv.ParseInt(threadID, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid threadID format: %v", err)
}
return threadIDInt, nil
}

func markNotificationByState(ctx context.Context, client *github.Client, threadID string, state string) (*github.Response, error) {
switch state {
case "done":
threadIDInt, err := parseThreadID(threadID)
if err != nil {
return nil, err
}
return client.Activity.MarkThreadDone(ctx, threadIDInt)
case "read":
return client.Activity.MarkThreadRead(ctx, threadID)
default:
return nil, fmt.Errorf("invalid state; must be one of: read, done")
}
}

// ListNotifications creates a tool to list notifications for the current user.
func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("list_notifications",
Expand Down Expand Up @@ -175,23 +199,14 @@ func DismissNotification(getclient GetClientFn, t translations.TranslationHelper
return mcp.NewToolResultError(err.Error()), nil
}

var resp *github.Response
switch state {
case "done":
// for some inexplicable reason, the API seems to have threadID as int64 and string depending on the endpoint
var threadIDInt int64
threadIDInt, err = strconv.ParseInt(threadID, 10, 64)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid threadID format: %v", err)), nil
}
resp, err = client.Activity.MarkThreadDone(ctx, threadIDInt)
case "read":
resp, err = client.Activity.MarkThreadRead(ctx, threadID)
default:
return mcp.NewToolResultError("Invalid state. Must be one of: read, done."), nil
}

resp, err := markNotificationByState(ctx, client, threadID, state)
if err != nil {
if strings.Contains(err.Error(), "invalid threadID format") {
return mcp.NewToolResultError(err.Error()), nil
}
if strings.Contains(err.Error(), "invalid state") {
return mcp.NewToolResultError(err.Error()), nil
}
return ghErrors.NewGitHubAPIErrorResponse(ctx,
fmt.Sprintf("failed to mark notification as %s", state),
resp,
Expand Down
2 changes: 1 addition & 1 deletion pkg/github/notifications_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ func Test_DismissNotification(t *testing.T) {
case tc.name == "invalid threadID format":
assert.Contains(t, text, "invalid threadID format")
case tc.name == "invalid state value":
assert.Contains(t, text, "Invalid state. Must be one of: read, done.")
assert.Contains(t, text, "invalid state; must be one of: read, done")
default:
// fallback for other errors
assert.Contains(t, text, "error")
Expand Down