Skip to content

Commit

Permalink
feat: enhance commit retrieval with branch & tag prefix support
Browse files Browse the repository at this point in the history
  • Loading branch information
marcsanmi committed Aug 26, 2024
1 parent d657e80 commit 0f4cf2c
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 2 deletions.
34 changes: 32 additions & 2 deletions pkg/querier/vcs/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ type Service struct {
httpClient *http.Client
}

type gitHubCommitGetter interface {
GetCommit(context.Context, string, string, string) (*vcsv1.GetCommitResponse, error)
}

func New(logger log.Logger, reg prometheus.Registerer) *Service {
httpClient := client.InstrumentedHTTPClient(logger, reg)

Expand Down Expand Up @@ -189,13 +193,39 @@ func (q *Service) GetCommit(ctx context.Context, req *connect.Request[vcsv1.GetC
return nil, err
}

commit, err := ghClient.GetCommit(ctx, gitURL.GetOwnerName(), gitURL.GetRepoName(), req.Msg.Ref)
owner := gitURL.GetOwnerName()
repo := gitURL.GetRepoName()
ref := req.Msg.GetRef()

commit, err := q.tryGetCommit(ctx, ghClient, owner, repo, ref)
if err != nil {
return nil, err
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("failed to get commit: %v", err))
}

return connect.NewResponse(commit), nil
}

// tryGetCommit attempts to retrieve a commit using different ref formats (commit hash, branch, tag).
// It tries each format in order and returns the first successful result.
func (q *Service) tryGetCommit(ctx context.Context, client gitHubCommitGetter, owner, repo, ref string) (*vcsv1.GetCommitResponse, error) {
refFormats := []string{
ref, // Try as a commit hash
"heads/" + ref, // Try as a branch
"tags/" + ref, // Try as a tag
}

for _, format := range refFormats {
commit, err := client.GetCommit(ctx, owner, repo, format)
if err == nil {
return commit, nil
}

q.logger.Log("err", err, "msg", "Failed to get commit", "ref", format)
}

return nil, fmt.Errorf("no commit found for %s/%s", owner, repo)
}

func rejectExpiredToken(token *oauth2.Token) error {
if time.Now().After(token.Expiry) {
return connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("token is expired"))
Expand Down
100 changes: 100 additions & 0 deletions pkg/querier/vcs/service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package vcs

import (
"context"
"testing"

"github.com/go-kit/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

vcsv1 "github.com/grafana/pyroscope/api/gen/proto/go/vcs/v1"
)

type gitHubCommitGetterMock struct {
mock.Mock
}

func (m *gitHubCommitGetterMock) GetCommit(ctx context.Context, owner, repo, ref string) (*vcsv1.GetCommitResponse, error) {
args := m.Called(ctx, owner, repo, ref)
if args.Get(0) == nil {
return nil, args.Error(1)
}

return args.Get(0).(*vcsv1.GetCommitResponse), args.Error(1)
}

func TestTryGetCommit(t *testing.T) {
svc := Service{logger: log.NewNopLogger()}

tests := []struct {
name string
setupMock func(*gitHubCommitGetterMock)
ref string
wantCommit *vcsv1.GetCommitResponse
wantErr bool
}{
{
name: "Direct commit hash",
setupMock: func(m *gitHubCommitGetterMock) {
m.On("GetCommit", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(&vcsv1.GetCommitResponse{Sha: "abc123"}, nil)
},
ref: "",
wantCommit: &vcsv1.GetCommitResponse{Sha: "abc123"},
wantErr: false,
},
{
name: "Branch reference with 'heads/' prefix",
setupMock: func(m *gitHubCommitGetterMock) {
m.On("GetCommit", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(nil, assert.AnError).Times(1)
m.On("GetCommit", mock.Anything, mock.Anything, mock.Anything, "heads/main").
Return(&vcsv1.GetCommitResponse{Sha: "def456"}, nil).Times(1)
},
ref: "main",
wantCommit: &vcsv1.GetCommitResponse{Sha: "def456"},
wantErr: false,
},
{
name: "Tag reference with 'tags/' prefix",
setupMock: func(m *gitHubCommitGetterMock) {
m.On("GetCommit", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(nil, assert.AnError).Times(2)
m.On("GetCommit", mock.Anything, mock.Anything, mock.Anything, "tags/v1").
Return(&vcsv1.GetCommitResponse{Sha: "def456"}, nil).Times(1)
},
ref: "v1",
wantCommit: &vcsv1.GetCommitResponse{Sha: "def456"},
wantErr: false,
},
{
name: "Nonexistent reference",
setupMock: func(m *gitHubCommitGetterMock) {
m.On("GetCommit", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(nil, assert.AnError).Times(3)
},
ref: "nonexistent",
wantCommit: nil,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockGetter := new(gitHubCommitGetterMock)
tt.setupMock(mockGetter)

gotCommit, err := svc.tryGetCommit(context.Background(), mockGetter, "owner", "repo", tt.ref)

if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}

assert.Equal(t, tt.wantCommit, gotCommit)
mockGetter.AssertExpectations(t)
})
}
}

0 comments on commit 0f4cf2c

Please sign in to comment.