Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,12 @@ func parseRate(r *http.Response) Rate {
return rate
}

type requestContext uint8

const (
bypassRateLimitCheck requestContext = iota
)

// BareDo sends an API request and lets you handle the api response. If an error
// or API Error occurs, the error will contain more information. Otherwise you
// are supposed to read and close the response's Body. If rate limit is exceeded
Expand All @@ -538,12 +544,14 @@ func (c *Client) BareDo(ctx context.Context, req *http.Request) (*Response, erro

rateLimitCategory := category(req.URL.Path)

// If we've hit rate limit, don't make further requests before Reset time.
if err := c.checkRateLimitBeforeDo(req, rateLimitCategory); err != nil {
return &Response{
Response: err.Response,
Rate: err.Rate,
}, err
if bypass := ctx.Value(bypassRateLimitCheck); bypass == nil {
// If we've hit rate limit, don't make further requests before Reset time.
if err := c.checkRateLimitBeforeDo(req, rateLimitCategory); err != nil {
return &Response{
Response: err.Response,
Rate: err.Rate,
}, err
}
}

resp, err := c.client.Do(req)
Expand Down Expand Up @@ -1022,6 +1030,9 @@ func (c *Client) RateLimits(ctx context.Context) (*RateLimits, *Response, error)
response := new(struct {
Resources *RateLimits `json:"resources"`
})

// This resource is not subject to rate limits.
ctx = context.WithValue(ctx, bypassRateLimitCheck, true)
resp, err := c.Do(ctx, req, response)
if err != nil {
return nil, resp, err
Expand Down
49 changes: 49 additions & 0 deletions github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ func testNewRequestAndDoFailure(t *testing.T, methodName string, client *Client,
client.BaseURL.Path = "/api-v3/"
client.rateLimits[0].Reset.Time = time.Now().Add(10 * time.Minute)
resp, err = f()
if bypass := resp.Request.Context().Value(bypassRateLimitCheck); bypass != nil {
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned that we are not checking the value of err here... might there be a case where we get an error that will be ignored here?

Copy link
Contributor Author

@sa-spag sa-spag Jun 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper explicitly sets the client up to trigger errors, and at this stage a rate limit error. As is, the test would fail if we check for err != nil for methodName == "RateLimits" as the request would result in a HTTP 404, since no handler has been created. We can set one up on TestRateLimits_coverage but I think it makes the test less straightforward. Client.RateLimits rate limit check bypass behavior is already covered by the newly introduced TestRateLimits_overQuota.

}
if want := http.StatusForbidden; resp == nil || resp.Response.StatusCode != want {
if resp != nil {
t.Errorf("rate.Reset.Time > now %v resp = %#v, want StatusCode=%v", methodName, resp.Response, want)
Expand Down Expand Up @@ -1711,6 +1714,52 @@ func TestRateLimits_coverage(t *testing.T) {
})
}

func TestRateLimits_overQuota(t *testing.T) {
client, mux, _, teardown := setup()
defer teardown()

client.rateLimits[coreCategory] = Rate{
Limit: 1,
Remaining: 0,
Reset: Timestamp{time.Now().Add(time.Hour).Local()},
}
mux.HandleFunc("/rate_limit", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{"resources":{
"core": {"limit":2,"remaining":1,"reset":1372700873},
"search": {"limit":3,"remaining":2,"reset":1372700874}
}}`)
})

ctx := context.Background()
rate, _, err := client.RateLimits(ctx)
if err != nil {
t.Errorf("RateLimits returned error: %v", err)
}

want := &RateLimits{
Core: &Rate{
Limit: 2,
Remaining: 1,
Reset: Timestamp{time.Date(2013, time.July, 1, 17, 47, 53, 0, time.UTC).Local()},
},
Search: &Rate{
Limit: 3,
Remaining: 2,
Reset: Timestamp{time.Date(2013, time.July, 1, 17, 47, 54, 0, time.UTC).Local()},
},
}
if !cmp.Equal(rate, want) {
t.Errorf("RateLimits returned %+v, want %+v", rate, want)
}

if got, want := client.rateLimits[coreCategory], *want.Core; got != want {
t.Errorf("client.rateLimits[coreCategory] is %+v, want %+v", got, want)
}
if got, want := client.rateLimits[searchCategory], *want.Search; got != want {
t.Errorf("client.rateLimits[searchCategory] is %+v, want %+v", got, want)
}
}

func TestSetCredentialsAsHeaders(t *testing.T) {
req := new(http.Request)
id, secret := "id", "secret"
Expand Down