diff --git a/go.mod b/go.mod index 4fa673cc96..82af8bbd5c 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/go-ozzo/ozzo-validation v3.6.0+incompatible github.com/go-playground/validator/v10 v10.22.1 github.com/go-test/deep v1.1.1 + github.com/gofri/go-github-ratelimit v1.1.0 github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/go-github/v65 v65.0.0 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 diff --git a/go.sum b/go.sum index 8613b6dde2..140153e53f 100644 --- a/go.sum +++ b/go.sum @@ -163,6 +163,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4 github.com/go-test/deep v1.0.4/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/gofri/go-github-ratelimit v1.1.0 h1:ijQ2bcv5pjZXNil5FiwglCg8wc9s8EgjTmNkqjw8nuk= +github.com/gofri/go-github-ratelimit v1.1.0/go.mod h1:OnCi5gV+hAG/LMR7llGhU7yHt44se9sYgKPnafoL7RY= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo= diff --git a/server/events/vcs/github_client.go b/server/events/vcs/github_client.go index d9c3d90541..8fdf84f74f 100644 --- a/server/events/vcs/github_client.go +++ b/server/events/vcs/github_client.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "github.com/gofri/go-github-ratelimit/github_ratelimit" "github.com/google/go-github/v65/github" "github.com/pkg/errors" "github.com/runatlantis/atlantis/server/events/command" @@ -124,15 +125,20 @@ func NewGithubClient(hostname string, credentials GithubCredentials, config Gith return nil, errors.Wrap(err, "error initializing github authentication transport") } + transportWithRateLimit, err := github_ratelimit.NewRateLimitWaiterClient(transport.Transport, github_ratelimit.WithTotalSleepLimit(time.Minute, nil)) + if err != nil { + return nil, errors.Wrap(err, "error initializing github rate limit transport") + } + var graphqlURL string var client *github.Client if hostname == "github.com" { - client = github.NewClient(transport) + client = github.NewClient(transportWithRateLimit) graphqlURL = "https://api.github.com/graphql" } else { apiURL := resolveGithubAPIURL(hostname) // TODO: Deprecated: Use NewClient(httpClient).WithEnterpriseURLs(baseURL, uploadURL) instead - client, err = github.NewEnterpriseClient(apiURL.String(), apiURL.String(), transport) //nolint:staticcheck + client, err = github.NewEnterpriseClient(apiURL.String(), apiURL.String(), transportWithRateLimit) //nolint:staticcheck if err != nil { return nil, err } @@ -140,7 +146,7 @@ func NewGithubClient(hostname string, credentials GithubCredentials, config Gith } // Use the client from shurcooL's githubv4 library for queries. - v4Client := githubv4.NewEnterpriseClient(graphqlURL, transport) + v4Client := githubv4.NewEnterpriseClient(graphqlURL, transportWithRateLimit) user, err := credentials.GetUser() logger.Debug("GH User: %s", user) diff --git a/server/events/vcs/github_client_test.go b/server/events/vcs/github_client_test.go index 81ec7ee7a4..637cbae207 100644 --- a/server/events/vcs/github_client_test.go +++ b/server/events/vcs/github_client_test.go @@ -11,6 +11,7 @@ import ( "os" "strings" "testing" + "time" "github.com/runatlantis/atlantis/server/events/command" "github.com/runatlantis/atlantis/server/events/models" @@ -1707,3 +1708,54 @@ func TestGithubClient_GetPullLabels_EmptyResponse(t *testing.T) { Ok(t, err) Equals(t, 0, len(labels)) } + +func TestGithubClient_SecondaryRateLimitHandling_CreateComment(t *testing.T) { + logger := logging.NewNoopLogger(t) + calls := 0 + maxCalls := 2 + + testServer := httptest.NewTLSServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/api/v3/repos/owner/repo/issues/1/comments" { + t.Errorf("Unexpected request: %s %s", r.Method, r.URL.Path) + w.WriteHeader(http.StatusNotFound) + return + } + + if calls < maxCalls { + // Secondary rate limiting, x-ratelimit-remaining must be > 0 + w.Header().Set("x-ratelimit-remaining", "1") + w.Header().Set("x-ratelimit-reset", fmt.Sprintf("%d", time.Now().Unix()+1)) + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(`{"message": "You have exceeded a secondary rate limit"}`)) // nolint: errcheck + } else { + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"id": 1, "body": "Test comment"}`)) // nolint: errcheck + } + calls++ + }), + ) + + testServerURL, err := url.Parse(testServer.URL) + Ok(t, err) + + client, err := vcs.NewGithubClient(testServerURL.Host, &vcs.GithubUserCredentials{"user", "pass"}, vcs.GithubConfig{}, 0, logger) + Ok(t, err) + defer disableSSLVerification()() + + // Simulate creating a comment + repo := models.Repo{ + FullName: "owner/repo", + Owner: "owner", + Name: "repo", + } + pullNum := 1 + comment := "Test comment" + + err = client.CreateComment(logger, repo, pullNum, comment, "") + Ok(t, err) + + // Verify that the number of calls is greater than maxCalls, indicating that retries occurred + Assert(t, calls > maxCalls, "Expected more than %d calls due to rate limiting, but got %d", maxCalls, calls) + +}