Skip to content

Commit

Permalink
Modify GetBranch to handle redirects (#1901)
Browse files Browse the repository at this point in the history
Fixes #1895.
  • Loading branch information
n1lesh authored Jul 8, 2021
1 parent 72cc2f6 commit c9fec82
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 15 deletions.
10 changes: 10 additions & 0 deletions github/actions_artifacts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package github

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -294,6 +295,15 @@ func TestActionsSerivice_DownloadArtifact(t *testing.T) {
_, _, err = client.Actions.DownloadArtifact(ctx, "\n", "\n", -1, true)
return err
})

// Add custom round tripper
client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return nil, errors.New("failed to download artifact")
})
testBadOptions(t, methodName, func() (err error) {
_, _, err = client.Actions.DownloadArtifact(ctx, "o", "r", 1, true)
return err
})
}

func TestActionsService_DownloadArtifact_invalidOwner(t *testing.T) {
Expand Down
10 changes: 10 additions & 0 deletions github/actions_workflow_jobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package github

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -154,6 +155,15 @@ func TestActionsService_GetWorkflowJobLogs(t *testing.T) {
_, _, err = client.Actions.GetWorkflowJobLogs(ctx, "\n", "\n", 399444496, true)
return err
})

// Add custom round tripper
client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return nil, errors.New("failed to get workflow logs")
})
testBadOptions(t, methodName, func() (err error) {
_, _, err = client.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, true)
return err
})
}

func TestActionsService_GetWorkflowJobLogs_StatusMovedPermanently_dontFollowRedirects(t *testing.T) {
Expand Down
7 changes: 7 additions & 0 deletions github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1964,3 +1964,10 @@ func TestBareDo_returnsOpenBody(t *testing.T) {
t.Fatalf("resp.Body.Close() returned error: %v", err)
}
}

// roundTripperFunc creates a mock RoundTripper (transport)
type roundTripperFunc func(*http.Request) (*http.Response, error)

func (fn roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return fn(r)
}
40 changes: 35 additions & 5 deletions github/repos.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
)

Expand Down Expand Up @@ -935,20 +936,49 @@ func (s *RepositoriesService) ListBranches(ctx context.Context, owner string, re
// GetBranch gets the specified branch for a repository.
//
// GitHub API docs: https://docs.github.com/en/free-pro-team@latest/rest/reference/repos/#get-a-branch
func (s *RepositoriesService) GetBranch(ctx context.Context, owner, repo, branch string) (*Branch, *Response, error) {
func (s *RepositoriesService) GetBranch(ctx context.Context, owner, repo, branch string, followRedirects bool) (*Branch, *Response, error) {
u := fmt.Sprintf("repos/%v/%v/branches/%v", owner, repo, branch)
req, err := s.client.NewRequest("GET", u, nil)

resp, err := s.getBranchFromURL(ctx, u, followRedirects)
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, newResponse(resp), fmt.Errorf("unexpected status code: %s", resp.Status)
}

b := new(Branch)
resp, err := s.client.Do(ctx, req, b)
err = json.NewDecoder(resp.Body).Decode(b)
return b, newResponse(resp), err
}

func (s *RepositoriesService) getBranchFromURL(ctx context.Context, u string, followRedirects bool) (*http.Response, error) {
req, err := s.client.NewRequest("GET", u, nil)
if err != nil {
return nil, resp, err
return nil, err
}

var resp *http.Response
// Use http.DefaultTransport if no custom Transport is configured
req = withContext(ctx, req)
if s.client.client.Transport == nil {
resp, err = http.DefaultTransport.RoundTrip(req)
} else {
resp, err = s.client.client.Transport.RoundTrip(req)
}
if err != nil {
return nil, err
}

return b, resp, nil
// If redirect response is returned, follow it
if followRedirects && resp.StatusCode == http.StatusMovedPermanently {
resp.Body.Close()
u = resp.Header.Get("Location")
resp, err = s.getBranchFromURL(ctx, u, false)
}
return resp, err
}

// GetBranchProtection gets the protection of a given branch.
Expand Down
10 changes: 10 additions & 0 deletions github/repos_contents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package github

import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -690,6 +691,15 @@ func TestRepositoriesService_GetArchiveLink(t *testing.T) {
_, _, err = client.Repositories.GetArchiveLink(ctx, "\n", "\n", Tarball, &RepositoryContentGetOptions{}, true)
return err
})

// Add custom round tripper
client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return nil, errors.New("failed to get archive link")
})
testBadOptions(t, methodName, func() (err error) {
_, _, err = client.Repositories.GetArchiveLink(ctx, "o", "r", Tarball, &RepositoryContentGetOptions{}, true)
return err
})
}

func TestRepositoriesService_GetArchiveLink_StatusMovedPermanently_dontFollowRedirects(t *testing.T) {
Expand Down
76 changes: 68 additions & 8 deletions github/repos_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ package github
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"testing"

Expand Down Expand Up @@ -886,7 +888,7 @@ func TestRepositoriesService_GetBranch(t *testing.T) {
})

ctx := context.Background()
branch, _, err := client.Repositories.GetBranch(ctx, "o", "r", "b")
branch, _, err := client.Repositories.GetBranch(ctx, "o", "r", "b", false)
if err != nil {
t.Errorf("Repositories.GetBranch returned error: %v", err)
}
Expand All @@ -908,16 +910,74 @@ func TestRepositoriesService_GetBranch(t *testing.T) {

const methodName = "GetBranch"
testBadOptions(t, methodName, func() (err error) {
_, _, err = client.Repositories.GetBranch(ctx, "\n", "\n", "\n")
_, _, err = client.Repositories.GetBranch(ctx, "\n", "\n", "\n", false)
return err
})
}

testNewRequestAndDoFailure(t, methodName, client, func() (*Response, error) {
got, resp, err := client.Repositories.GetBranch(ctx, "o", "r", "b")
if got != nil {
t.Errorf("testNewRequestAndDoFailure %v = %#v, want nil", methodName, got)
}
return resp, err
func TestRepositoriesService_GetBranch_StatusMovedPermanently_followRedirects(t *testing.T) {
client, mux, serverURL, teardown := setup()
defer teardown()

mux.HandleFunc("/repos/o/r/branches/b", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
redirectURL, _ := url.Parse(serverURL + baseURLPath + "/repos/o/r/branches/br")
http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently)
})
mux.HandleFunc("/repos/o/r/branches/br", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
fmt.Fprint(w, `{"name":"n", "commit":{"sha":"s","commit":{"message":"m"}}, "protected":true}`)
})
ctx := context.Background()
branch, resp, err := client.Repositories.GetBranch(ctx, "o", "r", "b", true)
if err != nil {
t.Errorf("Repositories.GetBranch returned error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("Repositories.GetBranch returned status: %d, want %d", resp.StatusCode, http.StatusOK)
}

want := &Branch{
Name: String("n"),
Commit: &RepositoryCommit{
SHA: String("s"),
Commit: &Commit{
Message: String("m"),
},
},
Protected: Bool(true),
}
if !cmp.Equal(branch, want) {
t.Errorf("Repositories.GetBranch returned %+v, want %+v", branch, want)
}
}

func TestRepositoriesService_GetBranch_notFound(t *testing.T) {
client, mux, _, teardown := setup()
defer teardown()

mux.HandleFunc("/repos/o/r/branches/b", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
http.Error(w, "branch not found", http.StatusNotFound)
})
ctx := context.Background()
_, resp, err := client.Repositories.GetBranch(ctx, "o", "r", "b", true)
if err == nil {
t.Error("Repositories.GetBranch returned error: nil")
}
if resp.StatusCode != http.StatusNotFound {
t.Errorf("Repositories.GetBranch returned status: %d, want %d", resp.StatusCode, http.StatusNotFound)
}

// Add custom round tripper
client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return nil, errors.New("failed to get branch")
})

const methodName = "GetBranch"
testBadOptions(t, methodName, func() (err error) {
_, _, err = client.Repositories.GetBranch(ctx, "o", "r", "b", true)
return err
})
}

Expand Down
4 changes: 2 additions & 2 deletions test/integration/repos_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestRepositories_BranchesTags(t *testing.T) {
t.Fatalf("Repositories.ListBranches('git', 'git') returned no branches")
}

_, _, err = client.Repositories.GetBranch(context.Background(), "git", "git", *branches[0].Name)
_, _, err = client.Repositories.GetBranch(context.Background(), "git", "git", *branches[0].Name, false)
if err != nil {
t.Fatalf("Repositories.GetBranch() returned error: %v", err)
}
Expand Down Expand Up @@ -102,7 +102,7 @@ func TestRepositories_EditBranches(t *testing.T) {
t.Fatalf("createRandomTestRepository returned error: %v", err)
}

branch, _, err := client.Repositories.GetBranch(context.Background(), *repo.Owner.Login, *repo.Name, "master")
branch, _, err := client.Repositories.GetBranch(context.Background(), *repo.Owner.Login, *repo.Name, "master", false)
if err != nil {
t.Fatalf("Repositories.GetBranch() returned error: %v", err)
}
Expand Down

0 comments on commit c9fec82

Please sign in to comment.