diff --git a/internal/osvdev/config.go b/internal/osvdev/config.go index c501fad13b..f6fee4ec39 100644 --- a/internal/osvdev/config.go +++ b/internal/osvdev/config.go @@ -6,9 +6,11 @@ type ClientConfig struct { MaxConcurrentRequests int MaxConcurrentBatchRequests int - MaxRetryAttempts int - JitterMultiplier float64 - UserAgent string + MaxRetryAttempts int + JitterMultiplier float64 + BackoffDurationExponential float64 + BackoffDurationMultiplier float64 + UserAgent string } // Default make a default client config @@ -16,6 +18,8 @@ func DefaultConfig() ClientConfig { return ClientConfig{ MaxRetryAttempts: 4, JitterMultiplier: 2, + BackoffDurationExponential: 2, + BackoffDurationMultiplier: 1, UserAgent: "osv-scanner-v2-" + version.OSVVersion, MaxConcurrentRequests: 1000, MaxConcurrentBatchRequests: 10, diff --git a/internal/osvdev/osvdev.go b/internal/osvdev/osvdev.go index bf4a81a06d..51efcd472c 100644 --- a/internal/osvdev/osvdev.go +++ b/internal/osvdev/osvdev.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "math" "math/rand/v2" "net/http" "time" @@ -27,7 +28,7 @@ const ( ) type OSVClient struct { - HttpClient http.Client + HTTPClient *http.Client Config ClientConfig BaseHostURL string } @@ -35,7 +36,7 @@ type OSVClient struct { // DefaultClient() creates a new OSVClient with default settings func DefaultClient() *OSVClient { return &OSVClient{ - HttpClient: http.Client{}, + HTTPClient: http.DefaultClient, Config: DefaultConfig(), BaseHostURL: "https://api.osv.dev", } @@ -43,7 +44,7 @@ func DefaultClient() *OSVClient { // GetVulnsByID is an interface to this endpoint: https://google.github.io/osv.dev/get-v1-vulns/ func (c *OSVClient) GetVulnsByID(ctx context.Context, id string) (*models.Vulnerability, error) { - resp, err := c.makeRetryRequest(func() (*http.Response, error) { + resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.BaseHostURL+GetEndpoint+"/"+id, nil) if err != nil { return nil, err @@ -52,7 +53,7 @@ func (c *OSVClient) GetVulnsByID(ctx context.Context, id string) (*models.Vulner req.Header.Set("User-Agent", c.Config.UserAgent) } - return c.HttpClient.Do(req) + return client.Do(req) }) if err != nil { @@ -92,7 +93,7 @@ func (c *OSVClient) QueryBatch(ctx context.Context, queries []*Query) (*BatchedR return nil } - resp, err := c.makeRetryRequest(func() (*http.Response, error) { + resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) { // Make sure request buffer is inside retry, if outside // http request would finish the buffer, and retried requests would be empty requestBuf := bytes.NewBuffer(requestBytes) @@ -105,7 +106,7 @@ func (c *OSVClient) QueryBatch(ctx context.Context, queries []*Query) (*BatchedR req.Header.Set("User-Agent", c.Config.UserAgent) } - return c.HttpClient.Do(req) + return client.Do(req) }) if err != nil { return err @@ -147,7 +148,7 @@ func (c *OSVClient) Query(ctx context.Context, query *Query) (*Response, error) return nil, err } - resp, err := c.makeRetryRequest(func() (*http.Response, error) { + resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) { requestBuf := bytes.NewBuffer(requestBytes) req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseHostURL+QueryEndpoint, requestBuf) if err != nil { @@ -159,7 +160,7 @@ func (c *OSVClient) Query(ctx context.Context, query *Query) (*Response, error) req.Header.Set("User-Agent", c.Config.UserAgent) } - return c.HttpClient.Do(req) + return client.Do(req) }) if err != nil { @@ -184,7 +185,7 @@ func (c *OSVClient) ExperimentalDetermineVersion(ctx context.Context, query *Det return nil, err } - resp, err := c.makeRetryRequest(func() (*http.Response, error) { + resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) { // Make sure request buffer is inside retry, if outside // http request would finish the buffer, and retried requests would be empty requestBuf := bytes.NewBuffer(requestBytes) @@ -197,7 +198,7 @@ func (c *OSVClient) ExperimentalDetermineVersion(ctx context.Context, query *Det req.Header.Set("User-Agent", c.Config.UserAgent) } - return http.DefaultClient.Do(req) + return client.Do(req) }) if err != nil { @@ -215,7 +216,7 @@ func (c *OSVClient) ExperimentalDetermineVersion(ctx context.Context, query *Det } // makeRetryRequest will return an error on both network errors, and if the response is not 200 -func (c *OSVClient) makeRetryRequest(action func() (*http.Response, error)) (*http.Response, error) { +func (c *OSVClient) makeRetryRequest(action func(client *http.Client) (*http.Response, error)) (*http.Response, error) { var resp *http.Response var err error var lastErr error @@ -225,9 +226,11 @@ func (c *OSVClient) makeRetryRequest(action func() (*http.Response, error)) (*ht // we do not need to use a cryptographically secure random jitter, this is just to spread out the retry requests // #nosec G404 jitterAmount := (rand.Float64() * float64(c.Config.JitterMultiplier) * float64(i)) - time.Sleep(time.Duration(i*i)*time.Second + time.Duration(jitterAmount*1000)*time.Millisecond) + time.Sleep( + time.Duration(math.Pow(float64(i), c.Config.BackoffDurationExponential)*c.Config.BackoffDurationMultiplier*1000)*time.Millisecond + + time.Duration(jitterAmount*1000)*time.Millisecond) - resp, err = action() + resp, err = action(c.HTTPClient) // The network request itself failed, did not even get a response if err != nil { @@ -263,7 +266,7 @@ func (c *OSVClient) makeRetryRequest(action func() (*http.Response, error)) (*ht lastErr = fmt.Errorf("server error: status=%q body=%s", resp.Status, errBody) } - return nil, lastErr + return nil, fmt.Errorf("max retries exceeded: %w", lastErr) } // From: https://stackoverflow.com/a/72408490 diff --git a/internal/osvdev/osvdev_internal_test.go b/internal/osvdev/osvdev_internal_test.go new file mode 100644 index 0000000000..d81b379925 --- /dev/null +++ b/internal/osvdev/osvdev_internal_test.go @@ -0,0 +1,109 @@ +package osvdev + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/osv-scalibr/testing/extracttest" +) + +func TestMakeRetryRequest(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCodes []int + wantErr error + wantAttempts int + }{ + { + name: "success on first attempt", + statusCodes: []int{http.StatusOK}, + wantAttempts: 1, + }, + { + name: "client error no retry", + statusCodes: []int{http.StatusBadRequest}, + wantErr: extracttest.ContainsErrStr{ + Str: "client error: status=\"400 Bad Request\"", + }, + wantAttempts: 1, + }, + { + name: "server error then success", + statusCodes: []int{http.StatusInternalServerError, http.StatusOK}, + wantAttempts: 2, + }, + { + name: "max retries on server error", + statusCodes: []int{http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError}, + wantErr: extracttest.ContainsErrStr{ + Str: "max retries exceeded", + }, + wantAttempts: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + client := DefaultClient() + // Low multipliers to make the test run faster + client.Config.JitterMultiplier = 0 + client.Config.BackoffDurationMultiplier = 0 + client.Config.MaxRetryAttempts = 4 + client.HTTPClient = &http.Client{Timeout: time.Second} + + attempts := 0 + idx := 0 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + attempts++ + status := tt.statusCodes[idx] + if idx < len(tt.statusCodes)-1 { + idx++ + } + + w.WriteHeader(status) + message := fmt.Sprintf("response-%d", attempts) + _, _ = w.Write([]byte(message)) + })) + defer server.Close() + + resp, err := client.makeRetryRequest(func(hc *http.Client) (*http.Response, error) { + //nolint:noctx // because this is test code + return hc.Get(server.URL) + }) + + if attempts != tt.wantAttempts { + t.Errorf("got %d attempts, want %d", attempts, tt.wantAttempts) + } + + if diff := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); diff != "" { + t.Fatalf("Unexpected error (-want +got):\n%s", diff) + } + + if err != nil { + return + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + + expectedBody := fmt.Sprintf("response-%d", attempts) + if string(body) != expectedBody { + t.Errorf("got body %q, want %q", string(body), expectedBody) + } + }) + } +} diff --git a/internal/osvdev/osvdev_test.go b/internal/osvdev/osvdev_test.go index d605e1ac3d..8069e23b22 100644 --- a/internal/osvdev/osvdev_test.go +++ b/internal/osvdev/osvdev_test.go @@ -2,10 +2,11 @@ package osvdev_test import ( "context" - "strings" "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/osv-scalibr/testing/extracttest" "github.com/google/osv-scanner/internal/osvdev" "github.com/ossf/osv-schema/bindings/go/osvschema" ) @@ -14,23 +15,27 @@ func TestOSVClient_GetVulnsByID(t *testing.T) { t.Parallel() tests := []struct { - name string - id string - wantErrContains string + name string + id string + wantErr error }{ { name: "Simple ID lookup", id: "GO-2024-3333", }, { - name: "Missing ID lookup", - id: "GO-1000-1000", - wantErrContains: `client error: status="404 Not Found" body={"code":5,"message":"Bug not found."}`, + name: "Missing ID lookup", + id: "GO-1000-1000", + wantErr: extracttest.ContainsErrStr{ + Str: `client error: status="404 Not Found" body={"code":5,"message":"Bug not found."}`, + }, }, { - name: "Invalid ID", - id: "_--_--", - wantErrContains: `client error: status="404 Not Found" body={"code":5,"message":"Bug not found."}`, + name: "Invalid ID", + id: "_--_--", + wantErr: extracttest.ContainsErrStr{ + Str: `client error: status="404 Not Found" body={"code":5,"message":"Bug not found."}`, + }, }, } for _, tt := range tests { @@ -39,13 +44,17 @@ func TestOSVClient_GetVulnsByID(t *testing.T) { c := osvdev.DefaultClient() c.Config.UserAgent = "osv-scanner-api-test" + got, err := c.GetVulnsByID(context.Background(), tt.id) + + if diff := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); diff != "" { + t.Fatalf("Unexpected error (-want +got):\n%s", diff) + } + if err != nil { - if tt.wantErrContains == "" || !strings.Contains(err.Error(), tt.wantErrContains) { - t.Errorf("OSVClient.GetVulnsByID() error = %v, wantErr %q", err, tt.wantErrContains) - } return } + if got.ID != tt.id { t.Errorf("OSVClient.GetVulnsByID() = %v, want %v", got, tt.id) } @@ -54,11 +63,13 @@ func TestOSVClient_GetVulnsByID(t *testing.T) { } func TestOSVClient_QueryBatch(t *testing.T) { + t.Parallel() + tests := []struct { - name string - queries []*osvdev.Query - wantIDs [][]string - wantErrContains string + name string + queries []*osvdev.Query + wantIDs [][]string + wantErr error }{ { name: "multiple queries lookup", @@ -108,8 +119,10 @@ func TestOSVClient_QueryBatch(t *testing.T) { }, }, }, - wantIDs: [][]string{}, - wantErrContains: `client error: status="400 Bad Request" body={"code":3,"message":"Invalid query."}`, + wantIDs: [][]string{}, + wantErr: extracttest.ContainsErrStr{ + Str: `client error: status="400 Bad Request" body={"code":3,"message":"Invalid query."}`, + }, }, } @@ -119,11 +132,14 @@ func TestOSVClient_QueryBatch(t *testing.T) { c := osvdev.DefaultClient() c.Config.UserAgent = "osv-scanner-api-test" + got, err := c.QueryBatch(context.Background(), tt.queries) + + if diff := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); diff != "" { + t.Fatalf("Unexpected error (-want +got):\n%s", diff) + } + if err != nil { - if tt.wantErrContains == "" || !strings.Contains(err.Error(), tt.wantErrContains) { - t.Errorf("OSVClient.GetVulnsByID() error = %v, wantErr %q", err, tt.wantErrContains) - } return } @@ -147,10 +163,10 @@ func TestOSVClient_Query(t *testing.T) { t.Parallel() tests := []struct { - name string - query osvdev.Query - wantIDs []string - wantErrContains string + name string + query osvdev.Query + wantIDs []string + wantErr error }{ { name: "npm Package lookup", @@ -193,8 +209,9 @@ func TestOSVClient_Query(t *testing.T) { Name: "abcd-definitely-does-not-exist", }, }, - wantErrContains: `client error: status="400 Bad Request" body={"code":3,"message":"Invalid query."}`, - }, + wantErr: extracttest.ContainsErrStr{ + Str: `client error: status="400 Bad Request" body={"code":3,"message":"Invalid query."}`, + }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -202,11 +219,14 @@ func TestOSVClient_Query(t *testing.T) { c := osvdev.DefaultClient() c.Config.UserAgent = "osv-scanner-api-test" + got, err := c.Query(context.Background(), &tt.query) + + if diff := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); diff != "" { + t.Fatalf("Unexpected error (-want +got):\n%s", diff) + } + if err != nil { - if tt.wantErrContains == "" || !strings.Contains(err.Error(), tt.wantErrContains) { - t.Errorf("OSVClient.GetVulnsByID() error = %v, wantErr %q", err, tt.wantErrContains) - } return } @@ -226,10 +246,9 @@ func TestOSVClient_ExperimentalDetermineVersion(t *testing.T) { t.Parallel() tests := []struct { - name string - query osvdev.DetermineVersionsRequest - wantPkgs []string - wantErrContains string + name string + query osvdev.DetermineVersionsRequest + wantPkgs []string }{ { name: "Simple non existent package query", @@ -252,12 +271,10 @@ func TestOSVClient_ExperimentalDetermineVersion(t *testing.T) { c := osvdev.DefaultClient() c.Config.UserAgent = "osv-scanner-api-test" + got, err := c.ExperimentalDetermineVersion(context.Background(), &tt.query) if err != nil { - if tt.wantErrContains == "" || !strings.Contains(err.Error(), tt.wantErrContains) { - t.Errorf("OSVClient.GetVulnsByID() error = %v, wantErr %q", err, tt.wantErrContains) - } - return + t.Fatalf("Unexpected error %v", err) } gotPkgInfo := make([]string, 0, len(got.Matches))