Skip to content

Commit

Permalink
Unify tests
Browse files Browse the repository at this point in the history
  • Loading branch information
another-rex committed Dec 23, 2024
1 parent 6cca258 commit 2ec6518
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 56 deletions.
10 changes: 7 additions & 3 deletions internal/osvdev/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@ type ClientConfig struct {
MaxConcurrentRequests int
MaxConcurrentBatchRequests int

MaxRetryAttempts int
JitterMultiplier float64
UserAgent string
MaxRetryAttempts int
JitterMultiplier float64
BackoffDurationExponential float64
BackoffDurationMultiplier float64
UserAgent string
}

// Default make a default client config
func DefaultConfig() ClientConfig {
return ClientConfig{
MaxRetryAttempts: 4,
JitterMultiplier: 2,
BackoffDurationExponential: 2,
BackoffDurationMultiplier: 1,
UserAgent: "osv-scanner-v2-" + version.OSVVersion,
MaxConcurrentRequests: 1000,
MaxConcurrentBatchRequests: 10,
Expand Down
31 changes: 17 additions & 14 deletions internal/osvdev/osvdev.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"math/rand/v2"
"net/http"
"time"
Expand All @@ -27,23 +28,23 @@ const (
)

type OSVClient struct {
HttpClient http.Client
HTTPClient *http.Client
Config ClientConfig
BaseHostURL string
}

// DefaultClient() creates a new OSVClient with default settings
func DefaultClient() *OSVClient {
return &OSVClient{
HttpClient: http.Client{},
HTTPClient: http.DefaultClient,
Config: DefaultConfig(),
BaseHostURL: "https://api.osv.dev",
}
}

// GetVulnsByID is an interface to this endpoint: https://google.github.io/osv.dev/get-v1-vulns/
func (c *OSVClient) GetVulnsByID(ctx context.Context, id string) (*models.Vulnerability, error) {
resp, err := c.makeRetryRequest(func() (*http.Response, error) {
resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.BaseHostURL+GetEndpoint+"/"+id, nil)
if err != nil {
return nil, err
Expand All @@ -52,7 +53,7 @@ func (c *OSVClient) GetVulnsByID(ctx context.Context, id string) (*models.Vulner
req.Header.Set("User-Agent", c.Config.UserAgent)
}

return c.HttpClient.Do(req)
return client.Do(req)
})

if err != nil {
Expand Down Expand Up @@ -92,7 +93,7 @@ func (c *OSVClient) QueryBatch(ctx context.Context, queries []*Query) (*BatchedR
return nil
}

resp, err := c.makeRetryRequest(func() (*http.Response, error) {
resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) {
// Make sure request buffer is inside retry, if outside
// http request would finish the buffer, and retried requests would be empty
requestBuf := bytes.NewBuffer(requestBytes)
Expand All @@ -105,7 +106,7 @@ func (c *OSVClient) QueryBatch(ctx context.Context, queries []*Query) (*BatchedR
req.Header.Set("User-Agent", c.Config.UserAgent)
}

return c.HttpClient.Do(req)
return client.Do(req)
})
if err != nil {
return err
Expand Down Expand Up @@ -147,7 +148,7 @@ func (c *OSVClient) Query(ctx context.Context, query *Query) (*Response, error)
return nil, err
}

resp, err := c.makeRetryRequest(func() (*http.Response, error) {
resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) {
requestBuf := bytes.NewBuffer(requestBytes)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseHostURL+QueryEndpoint, requestBuf)
if err != nil {
Expand All @@ -159,7 +160,7 @@ func (c *OSVClient) Query(ctx context.Context, query *Query) (*Response, error)
req.Header.Set("User-Agent", c.Config.UserAgent)
}

return c.HttpClient.Do(req)
return client.Do(req)
})

if err != nil {
Expand All @@ -184,7 +185,7 @@ func (c *OSVClient) ExperimentalDetermineVersion(ctx context.Context, query *Det
return nil, err
}

resp, err := c.makeRetryRequest(func() (*http.Response, error) {
resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) {
// Make sure request buffer is inside retry, if outside
// http request would finish the buffer, and retried requests would be empty
requestBuf := bytes.NewBuffer(requestBytes)
Expand All @@ -197,7 +198,7 @@ func (c *OSVClient) ExperimentalDetermineVersion(ctx context.Context, query *Det
req.Header.Set("User-Agent", c.Config.UserAgent)
}

return http.DefaultClient.Do(req)
return client.Do(req)
})

if err != nil {
Expand All @@ -215,7 +216,7 @@ func (c *OSVClient) ExperimentalDetermineVersion(ctx context.Context, query *Det
}

// makeRetryRequest will return an error on both network errors, and if the response is not 200
func (c *OSVClient) makeRetryRequest(action func() (*http.Response, error)) (*http.Response, error) {
func (c *OSVClient) makeRetryRequest(action func(client *http.Client) (*http.Response, error)) (*http.Response, error) {
var resp *http.Response
var err error
var lastErr error
Expand All @@ -225,9 +226,11 @@ func (c *OSVClient) makeRetryRequest(action func() (*http.Response, error)) (*ht
// we do not need to use a cryptographically secure random jitter, this is just to spread out the retry requests
// #nosec G404
jitterAmount := (rand.Float64() * float64(c.Config.JitterMultiplier) * float64(i))
time.Sleep(time.Duration(i*i)*time.Second + time.Duration(jitterAmount*1000)*time.Millisecond)
time.Sleep(
time.Duration(math.Pow(float64(i), c.Config.BackoffDurationExponential)*c.Config.BackoffDurationMultiplier*1000)*time.Millisecond +
time.Duration(jitterAmount*1000)*time.Millisecond)

resp, err = action()
resp, err = action(c.HTTPClient)

// The network request itself failed, did not even get a response
if err != nil {
Expand Down Expand Up @@ -263,7 +266,7 @@ func (c *OSVClient) makeRetryRequest(action func() (*http.Response, error)) (*ht
lastErr = fmt.Errorf("server error: status=%q body=%s", resp.Status, errBody)
}

return nil, lastErr
return nil, fmt.Errorf("max retries exceeded: %w", lastErr)
}

// From: https://stackoverflow.com/a/72408490
Expand Down
109 changes: 109 additions & 0 deletions internal/osvdev/osvdev_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package osvdev

import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/osv-scalibr/testing/extracttest"
)

func TestMakeRetryRequest(t *testing.T) {
t.Parallel()

tests := []struct {
name string
statusCodes []int
wantErr error
wantAttempts int
}{
{
name: "success on first attempt",
statusCodes: []int{http.StatusOK},
wantAttempts: 1,
},
{
name: "client error no retry",
statusCodes: []int{http.StatusBadRequest},
wantErr: extracttest.ContainsErrStr{
Str: "client error: status=\"400 Bad Request\"",
},
wantAttempts: 1,
},
{
name: "server error then success",
statusCodes: []int{http.StatusInternalServerError, http.StatusOK},
wantAttempts: 2,
},
{
name: "max retries on server error",
statusCodes: []int{http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError},
wantErr: extracttest.ContainsErrStr{
Str: "max retries exceeded",
},
wantAttempts: 4,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

client := DefaultClient()
// Low multipliers to make the test run faster
client.Config.JitterMultiplier = 0
client.Config.BackoffDurationMultiplier = 0
client.Config.MaxRetryAttempts = 4
client.HTTPClient = &http.Client{Timeout: time.Second}

attempts := 0
idx := 0

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
attempts++
status := tt.statusCodes[idx]
if idx < len(tt.statusCodes)-1 {
idx++
}

w.WriteHeader(status)
message := fmt.Sprintf("response-%d", attempts)
_, _ = w.Write([]byte(message))
}))
defer server.Close()

resp, err := client.makeRetryRequest(func(hc *http.Client) (*http.Response, error) {
//nolint:noctx // because this is test code
return hc.Get(server.URL)
})

if attempts != tt.wantAttempts {
t.Errorf("got %d attempts, want %d", attempts, tt.wantAttempts)
}

if diff := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); diff != "" {
t.Fatalf("Unexpected error (-want +got):\n%s", diff)
}

if err != nil {
return
}

defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}

expectedBody := fmt.Sprintf("response-%d", attempts)
if string(body) != expectedBody {
t.Errorf("got body %q, want %q", string(body), expectedBody)
}
})
}
}
Loading

0 comments on commit 2ec6518

Please sign in to comment.