Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 82 additions & 73 deletions pkg/github/repositories.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ import (
"github.com/mark3labs/mcp-go/server"
)

const (
errFailedToGetGitHubClient = "failed to get GitHub client: %w"
)

func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("get_commit",
mcp.WithDescription(t("TOOL_GET_COMMITS_DESCRIPTION", "Get details for a commit from a GitHub repository")),
Expand Down Expand Up @@ -65,7 +69,7 @@ func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (too

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
return nil, fmt.Errorf(errFailedToGetGitHubClient, err)
}
commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts)
if err != nil {
Expand Down Expand Up @@ -94,6 +98,72 @@ func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (too
}
}

func listCommitsHandler(ctx context.Context, request mcp.CallToolRequest, getClient GetClientFn) (*mcp.CallToolResult, error) {
owner, err := RequiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
repo, err := RequiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
sha, err := OptionalParam[string](request, "sha")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
author, err := OptionalParam[string](request, "author")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
pagination, err := OptionalPaginationParams(request)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

// Set default perPage to 30 if not provided
perPage := pagination.perPage
if perPage == 0 {
perPage = 30
}
opts := &github.CommitsListOptions{
SHA: sha,
Author: author,
ListOptions: github.ListOptions{
Page: pagination.page,
PerPage: perPage,
},
}

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf(errFailedToGetGitHubClient, err)
}
commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts)
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx,
fmt.Sprintf("failed to list commits: %s", sha),
resp,
err,
), nil
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != 200 {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to list commits: %s", string(body))), nil
}

r, err := json.Marshal(commits)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
}

// ListCommits creates a tool to get commits of a branch in a repository.
func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("list_commits",
Expand All @@ -119,68 +189,7 @@ func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (t
WithPagination(),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
owner, err := RequiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
repo, err := RequiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
sha, err := OptionalParam[string](request, "sha")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
author, err := OptionalParam[string](request, "author")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
pagination, err := OptionalPaginationParams(request)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
// Set default perPage to 30 if not provided
perPage := pagination.perPage
if perPage == 0 {
perPage = 30
}
opts := &github.CommitsListOptions{
SHA: sha,
Author: author,
ListOptions: github.ListOptions{
Page: pagination.page,
PerPage: perPage,
},
}

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}
commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts)
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx,
fmt.Sprintf("failed to list commits: %s", sha),
resp,
err,
), nil
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != 200 {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to list commits: %s", string(body))), nil
}

r, err := json.Marshal(commits)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
return listCommitsHandler(ctx, request, getClient)
}
}

Expand Down Expand Up @@ -225,7 +234,7 @@ func ListBranches(getClient GetClientFn, t translations.TranslationHelperFunc) (

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
return nil, fmt.Errorf(errFailedToGetGitHubClient, err)
}

branches, resp, err := client.Repositories.ListBranches(ctx, owner, repo, opts)
Expand Down Expand Up @@ -339,7 +348,7 @@ func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperF
// Create or update the file
client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
return nil, fmt.Errorf(errFailedToGetGitHubClient, err)
}
fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts)
if err != nil {
Expand Down Expand Up @@ -417,7 +426,7 @@ func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFun

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
return nil, fmt.Errorf(errFailedToGetGitHubClient, err)
}
createdRepo, resp, err := client.Repositories.Create(ctx, "", repo)
if err != nil {
Expand Down Expand Up @@ -503,7 +512,7 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t
// fetch the PR from the API to get the latest commit and use SHA
githubClient, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
return nil, fmt.Errorf(errFailedToGetGitHubClient, err)
}
prNum, err := strconv.Atoi(prNumber)
if err != nil {
Expand Down Expand Up @@ -655,7 +664,7 @@ func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc)

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
return nil, fmt.Errorf(errFailedToGetGitHubClient, err)
}
forkedRepo, resp, err := client.Repositories.CreateFork(ctx, owner, repo, opts)
if err != nil {
Expand Down Expand Up @@ -748,7 +757,7 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
return nil, fmt.Errorf(errFailedToGetGitHubClient, err)
}

// Get the reference for the branch
Expand Down Expand Up @@ -909,7 +918,7 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
return nil, fmt.Errorf(errFailedToGetGitHubClient, err)
}

// Get the source branch SHA
Expand Down Expand Up @@ -1037,7 +1046,7 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
return nil, fmt.Errorf(errFailedToGetGitHubClient, err)
}

// Get the reference for the branch
Expand Down Expand Up @@ -1177,7 +1186,7 @@ func ListTags(getClient GetClientFn, t translations.TranslationHelperFunc) (tool

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
return nil, fmt.Errorf(errFailedToGetGitHubClient, err)
}

tags, resp, err := client.Repositories.ListTags(ctx, owner, repo, opts)
Expand Down Expand Up @@ -1244,7 +1253,7 @@ func GetTag(getClient GetClientFn, t translations.TranslationHelperFunc) (tool m

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
return nil, fmt.Errorf(errFailedToGetGitHubClient, err)
}

// First get the tag reference
Expand Down