diff --git a/internal/osvdev/config.go b/internal/osvdev/config.go new file mode 100644 index 0000000000..0a8fd76397 --- /dev/null +++ b/internal/osvdev/config.go @@ -0,0 +1,27 @@ +package osvdev + +import "github.com/google/osv-scanner/internal/version" + +type ClientConfig struct { + MaxConcurrentRequests int + MaxConcurrentBatchRequests int + + MaxRetryAttempts int + JitterMultiplier float64 + BackoffDurationExponential float64 + BackoffDurationMultiplier float64 + UserAgent string +} + +// DefaultConfig make a default client config +func DefaultConfig() ClientConfig { + return ClientConfig{ + MaxRetryAttempts: 4, + JitterMultiplier: 2, + BackoffDurationExponential: 2, + BackoffDurationMultiplier: 1, + UserAgent: "osv-scanner/" + version.OSVVersion, + MaxConcurrentRequests: 1000, + MaxConcurrentBatchRequests: 10, + } +} diff --git a/internal/osvdev/models.go b/internal/osvdev/models.go new file mode 100644 index 0000000000..7f4563cb27 --- /dev/null +++ b/internal/osvdev/models.go @@ -0,0 +1,73 @@ +package osvdev + +import "github.com/google/osv-scanner/pkg/models" + +// Package represents a package identifier for OSV. +type Package struct { + PURL string `json:"purl,omitempty"` + Name string `json:"name,omitempty"` + Ecosystem string `json:"ecosystem,omitempty"` +} + +// Query represents a query to OSV. +type Query struct { + Commit string `json:"commit,omitempty"` + Package Package `json:"package,omitempty"` + Version string `json:"version,omitempty"` +} + +// BatchedQuery represents a batched query to OSV. +type BatchedQuery struct { + Queries []*Query `json:"queries"` +} + +// MinimalVulnerability represents an unhydrated vulnerability entry from OSV. +type MinimalVulnerability struct { + ID string `json:"id"` +} + +// Response represents a full response from OSV. +type Response struct { + Vulns []models.Vulnerability `json:"vulns"` +} + +// MinimalResponse represents an unhydrated response from OSV. +type MinimalResponse struct { + Vulns []MinimalVulnerability `json:"vulns"` +} + +// BatchedResponse represents an unhydrated batched response from OSV. +type BatchedResponse struct { + Results []MinimalResponse `json:"results"` +} + +// HydratedBatchedResponse represents a hydrated batched response from OSV. +type HydratedBatchedResponse struct { + Results []Response `json:"results"` +} + +// DetermineVersionHash holds the per file hash and path information for determineversion. +type DetermineVersionHash struct { + Path string `json:"path"` + Hash []byte `json:"hash"` +} + +// DetermineVersionResponse is the response from the determineversions endpoint +type DetermineVersionResponse struct { + Matches []struct { + Score float64 `json:"score"` + RepoInfo struct { + Type string `json:"type"` + Address string `json:"address"` + Tag string `json:"tag"` + Version string `json:"version"` + Commit string `json:"commit"` + } `json:"repo_info"` + } `json:"matches"` +} + +// DetermineVersionsRequest is the request format to the determineversions endpoint +type DetermineVersionsRequest struct { + Name string `json:"name"` + FileHashes []DetermineVersionHash `json:"file_hashes"` +} diff --git a/internal/osvdev/osvdev.go b/internal/osvdev/osvdev.go new file mode 100644 index 0000000000..51efcd472c --- /dev/null +++ b/internal/osvdev/osvdev.go @@ -0,0 +1,280 @@ +package osvdev + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "math" + "math/rand/v2" + "net/http" + "time" + + "github.com/google/osv-scanner/pkg/models" + "golang.org/x/sync/errgroup" +) + +const ( + QueryBatchEndpoint = "/v1/querybatch" + QueryEndpoint = "/v1/query" + GetEndpoint = "/v1/vulns" + + // DetermineVersionEndpoint is the URL for posting determineversion queries to OSV. + DetermineVersionEndpoint = "/v1experimental/determineversion" + + // MaxQueriesPerQueryBatchRequest is a limit set in osv.dev's API, so is not configurable + MaxQueriesPerQueryBatchRequest = 1000 +) + +type OSVClient struct { + HTTPClient *http.Client + Config ClientConfig + BaseHostURL string +} + +// DefaultClient() creates a new OSVClient with default settings +func DefaultClient() *OSVClient { + return &OSVClient{ + HTTPClient: http.DefaultClient, + Config: DefaultConfig(), + BaseHostURL: "https://api.osv.dev", + } +} + +// GetVulnsByID is an interface to this endpoint: https://google.github.io/osv.dev/get-v1-vulns/ +func (c *OSVClient) GetVulnsByID(ctx context.Context, id string) (*models.Vulnerability, error) { + resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.BaseHostURL+GetEndpoint+"/"+id, nil) + if err != nil { + return nil, err + } + if c.Config.UserAgent != "" { + req.Header.Set("User-Agent", c.Config.UserAgent) + } + + return client.Do(req) + }) + + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + var vuln models.Vulnerability + decoder := json.NewDecoder(resp.Body) + err = decoder.Decode(&vuln) + if err != nil { + return nil, err + } + + return &vuln, nil +} + +// QueryBatch is an interface to this endpoint: https://google.github.io/osv.dev/post-v1-querybatch/ +func (c *OSVClient) QueryBatch(ctx context.Context, queries []*Query) (*BatchedResponse, error) { + // API has a limit of how many queries are in one batch + queryChunks := chunkBy(queries, MaxQueriesPerQueryBatchRequest) + totalOsvRespBatched := make([][]MinimalResponse, len(queryChunks)) + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(c.Config.MaxConcurrentBatchRequests) + for batchIndex, queries := range queryChunks { + requestBytes, err := json.Marshal(BatchedQuery{Queries: queries}) + if err != nil { + return nil, err + } + + g.Go(func() error { + // exit early if another hydration request has already failed + // results are thrown away later, so avoid needless work + if ctx.Err() != nil { + return nil + } + + resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) { + // Make sure request buffer is inside retry, if outside + // http request would finish the buffer, and retried requests would be empty + requestBuf := bytes.NewBuffer(requestBytes) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseHostURL+QueryBatchEndpoint, requestBuf) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + if c.Config.UserAgent != "" { + req.Header.Set("User-Agent", c.Config.UserAgent) + } + + return client.Do(req) + }) + if err != nil { + return err + } + defer resp.Body.Close() + + var osvResp BatchedResponse + decoder := json.NewDecoder(resp.Body) + err = decoder.Decode(&osvResp) + if err != nil { + return err + } + + // Store batch results in the corresponding index to maintain original query order. + totalOsvRespBatched[batchIndex] = osvResp.Results + + return nil + }) + } + + if err := g.Wait(); err != nil { + return nil, err + } + + totalOsvResp := BatchedResponse{ + Results: make([]MinimalResponse, 0, len(queries)), + } + for _, results := range totalOsvRespBatched { + totalOsvResp.Results = append(totalOsvResp.Results, results...) + } + + return &totalOsvResp, nil +} + +// Query is an interface to this endpoint: https://google.github.io/osv.dev/post-v1-query/ +func (c *OSVClient) Query(ctx context.Context, query *Query) (*Response, error) { + requestBytes, err := json.Marshal(query) + if err != nil { + return nil, err + } + + resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) { + requestBuf := bytes.NewBuffer(requestBytes) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseHostURL+QueryEndpoint, requestBuf) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + if c.Config.UserAgent != "" { + req.Header.Set("User-Agent", c.Config.UserAgent) + } + + return client.Do(req) + }) + + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var osvResp Response + decoder := json.NewDecoder(resp.Body) + err = decoder.Decode(&osvResp) + if err != nil { + return nil, err + } + + return &osvResp, nil +} + +// ExperimentalDetermineVersion +func (c *OSVClient) ExperimentalDetermineVersion(ctx context.Context, query *DetermineVersionsRequest) (*DetermineVersionResponse, error) { + requestBytes, err := json.Marshal(query) + if err != nil { + return nil, err + } + + resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) { + // Make sure request buffer is inside retry, if outside + // http request would finish the buffer, and retried requests would be empty + requestBuf := bytes.NewBuffer(requestBytes) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseHostURL+DetermineVersionEndpoint, requestBuf) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + if c.Config.UserAgent != "" { + req.Header.Set("User-Agent", c.Config.UserAgent) + } + + return client.Do(req) + }) + + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var result DetermineVersionResponse + decoder := json.NewDecoder(resp.Body) + if err := decoder.Decode(&result); err != nil { + return nil, err + } + + return &result, nil +} + +// makeRetryRequest will return an error on both network errors, and if the response is not 200 +func (c *OSVClient) makeRetryRequest(action func(client *http.Client) (*http.Response, error)) (*http.Response, error) { + var resp *http.Response + var err error + var lastErr error + + for i := range c.Config.MaxRetryAttempts { + // rand is initialized with a random number (since go1.20), and is also safe to use concurrently + // we do not need to use a cryptographically secure random jitter, this is just to spread out the retry requests + // #nosec G404 + jitterAmount := (rand.Float64() * float64(c.Config.JitterMultiplier) * float64(i)) + time.Sleep( + time.Duration(math.Pow(float64(i), c.Config.BackoffDurationExponential)*c.Config.BackoffDurationMultiplier*1000)*time.Millisecond + + time.Duration(jitterAmount*1000)*time.Millisecond) + + resp, err = action(c.HTTPClient) + + // The network request itself failed, did not even get a response + if err != nil { + lastErr = fmt.Errorf("attempt %d: request failed: %w", i+1, err) + continue + } + + // Everything is fine + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return resp, nil + } + + errBody, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + lastErr = fmt.Errorf("attempt %d: failed to read response: %w", i+1, err) + continue + } + + // Special case for too many requests, it should try again after a delay. + if resp.StatusCode == http.StatusTooManyRequests { + lastErr = fmt.Errorf("attempt %d: too many requests: status=%q body=%s", i+1, resp.Status, errBody) + continue + } + + // Otherwise any other 400 error should be fatal, as the request we are sending is incorrect + // Retrying won't make a difference + if resp.StatusCode >= 400 && resp.StatusCode < 500 { + return nil, fmt.Errorf("client error: status=%q body=%s", resp.Status, errBody) + } + + // Most likely a 500 >= error + lastErr = fmt.Errorf("server error: status=%q body=%s", resp.Status, errBody) + } + + return nil, fmt.Errorf("max retries exceeded: %w", lastErr) +} + +// From: https://stackoverflow.com/a/72408490 +func chunkBy[T any](items []T, chunkSize int) [][]T { + chunks := make([][]T, 0, (len(items)/chunkSize)+1) + for chunkSize < len(items) { + items, chunks = items[chunkSize:], append(chunks, items[0:chunkSize:chunkSize]) + } + + return append(chunks, items) +} diff --git a/internal/osvdev/osvdev_internal_test.go b/internal/osvdev/osvdev_internal_test.go new file mode 100644 index 0000000000..d81b379925 --- /dev/null +++ b/internal/osvdev/osvdev_internal_test.go @@ -0,0 +1,109 @@ +package osvdev + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/osv-scalibr/testing/extracttest" +) + +func TestMakeRetryRequest(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCodes []int + wantErr error + wantAttempts int + }{ + { + name: "success on first attempt", + statusCodes: []int{http.StatusOK}, + wantAttempts: 1, + }, + { + name: "client error no retry", + statusCodes: []int{http.StatusBadRequest}, + wantErr: extracttest.ContainsErrStr{ + Str: "client error: status=\"400 Bad Request\"", + }, + wantAttempts: 1, + }, + { + name: "server error then success", + statusCodes: []int{http.StatusInternalServerError, http.StatusOK}, + wantAttempts: 2, + }, + { + name: "max retries on server error", + statusCodes: []int{http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError}, + wantErr: extracttest.ContainsErrStr{ + Str: "max retries exceeded", + }, + wantAttempts: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + client := DefaultClient() + // Low multipliers to make the test run faster + client.Config.JitterMultiplier = 0 + client.Config.BackoffDurationMultiplier = 0 + client.Config.MaxRetryAttempts = 4 + client.HTTPClient = &http.Client{Timeout: time.Second} + + attempts := 0 + idx := 0 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + attempts++ + status := tt.statusCodes[idx] + if idx < len(tt.statusCodes)-1 { + idx++ + } + + w.WriteHeader(status) + message := fmt.Sprintf("response-%d", attempts) + _, _ = w.Write([]byte(message)) + })) + defer server.Close() + + resp, err := client.makeRetryRequest(func(hc *http.Client) (*http.Response, error) { + //nolint:noctx // because this is test code + return hc.Get(server.URL) + }) + + if attempts != tt.wantAttempts { + t.Errorf("got %d attempts, want %d", attempts, tt.wantAttempts) + } + + if diff := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); diff != "" { + t.Fatalf("Unexpected error (-want +got):\n%s", diff) + } + + if err != nil { + return + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + + expectedBody := fmt.Sprintf("response-%d", attempts) + if string(body) != expectedBody { + t.Errorf("got body %q, want %q", string(body), expectedBody) + } + }) + } +} diff --git a/internal/osvdev/osvdev_test.go b/internal/osvdev/osvdev_test.go new file mode 100644 index 0000000000..8069e23b22 --- /dev/null +++ b/internal/osvdev/osvdev_test.go @@ -0,0 +1,290 @@ +package osvdev_test + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/osv-scalibr/testing/extracttest" + "github.com/google/osv-scanner/internal/osvdev" + "github.com/ossf/osv-schema/bindings/go/osvschema" +) + +func TestOSVClient_GetVulnsByID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + id string + wantErr error + }{ + { + name: "Simple ID lookup", + id: "GO-2024-3333", + }, + { + name: "Missing ID lookup", + id: "GO-1000-1000", + wantErr: extracttest.ContainsErrStr{ + Str: `client error: status="404 Not Found" body={"code":5,"message":"Bug not found."}`, + }, + }, + { + name: "Invalid ID", + id: "_--_--", + wantErr: extracttest.ContainsErrStr{ + Str: `client error: status="404 Not Found" body={"code":5,"message":"Bug not found."}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + c := osvdev.DefaultClient() + c.Config.UserAgent = "osv-scanner-api-test" + + got, err := c.GetVulnsByID(context.Background(), tt.id) + + if diff := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); diff != "" { + t.Fatalf("Unexpected error (-want +got):\n%s", diff) + } + + if err != nil { + return + } + + if got.ID != tt.id { + t.Errorf("OSVClient.GetVulnsByID() = %v, want %v", got, tt.id) + } + }) + } +} + +func TestOSVClient_QueryBatch(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + queries []*osvdev.Query + wantIDs [][]string + wantErr error + }{ + { + name: "multiple queries lookup", + queries: []*osvdev.Query{ + { + Package: osvdev.Package{ + Name: "faker", + Ecosystem: string(osvschema.EcosystemNPM), + }, + Version: "6.6.6", + }, + { + Commit: "60e572dbf7b4ded66b488f54773f66aaf6184321", + }, + { + Package: osvdev.Package{ + Name: "abcd-definitely-does-not-exist", + Ecosystem: string(osvschema.EcosystemNPM), + }, + Version: "1.0.0", + }, + }, + wantIDs: [][]string{ + { // Package Query + "GHSA-5w9c-rv96-fr7g", + }, + { // Commit + "OSV-2023-890", + }, + // non-existent package + {}, + }, + }, + { + name: "multiple queries with invalid", + queries: []*osvdev.Query{ + { + Package: osvdev.Package{ + Name: "faker", + Ecosystem: string(osvschema.EcosystemNPM), + }, + Version: "6.6.6", + }, + { + Package: osvdev.Package{ + Name: "abcd-definitely-does-not-exist", + }, + }, + }, + wantIDs: [][]string{}, + wantErr: extracttest.ContainsErrStr{ + Str: `client error: status="400 Bad Request" body={"code":3,"message":"Invalid query."}`, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + c := osvdev.DefaultClient() + c.Config.UserAgent = "osv-scanner-api-test" + + got, err := c.QueryBatch(context.Background(), tt.queries) + + if diff := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); diff != "" { + t.Fatalf("Unexpected error (-want +got):\n%s", diff) + } + + if err != nil { + return + } + + gotResults := make([][]string, 0, len(got.Results)) + for _, res := range got.Results { + gotVulnIDs := make([]string, 0, len(res.Vulns)) + for _, vuln := range res.Vulns { + gotVulnIDs = append(gotVulnIDs, vuln.ID) + } + gotResults = append(gotResults, gotVulnIDs) + } + + if diff := cmp.Diff(tt.wantIDs, gotResults); diff != "" { + t.Errorf("Unexpected vuln IDs (-want +got):\n%s", diff) + } + }) + } +} + +func TestOSVClient_Query(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + query osvdev.Query + wantIDs []string + wantErr error + }{ + { + name: "npm Package lookup", + query: osvdev.Query{ + Package: osvdev.Package{ + // Use a deleted package as it is less likely new vulns will be published for it + Name: "faker", + Ecosystem: string(osvschema.EcosystemNPM), + }, + Version: "6.6.6", + }, + wantIDs: []string{ + "GHSA-5w9c-rv96-fr7g", + }, + }, + { + name: "commit lookup", + query: osvdev.Query{ + Commit: "60e572dbf7b4ded66b488f54773f66aaf6184321", + }, + wantIDs: []string{ + "OSV-2023-890", + }, + }, + { + name: "unknown package lookup", + query: osvdev.Query{ + Package: osvdev.Package{ + Name: "abcd-definitely-does-not-exist", + Ecosystem: string(osvschema.EcosystemNPM), + }, + Version: "1.0.0", + }, + wantIDs: []string{}, + }, + { + name: "invalid query", + query: osvdev.Query{ + Package: osvdev.Package{ + Name: "abcd-definitely-does-not-exist", + }, + }, + wantErr: extracttest.ContainsErrStr{ + Str: `client error: status="400 Bad Request" body={"code":3,"message":"Invalid query."}`, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + c := osvdev.DefaultClient() + c.Config.UserAgent = "osv-scanner-api-test" + + got, err := c.Query(context.Background(), &tt.query) + + if diff := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); diff != "" { + t.Fatalf("Unexpected error (-want +got):\n%s", diff) + } + + if err != nil { + return + } + + gotVulnIDs := make([]string, 0, len(got.Vulns)) + for _, vuln := range got.Vulns { + gotVulnIDs = append(gotVulnIDs, vuln.ID) + } + + if diff := cmp.Diff(tt.wantIDs, gotVulnIDs); diff != "" { + t.Errorf("Unexpected vuln IDs (-want +got):\n%s", diff) + } + }) + } +} + +func TestOSVClient_ExperimentalDetermineVersion(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + query osvdev.DetermineVersionsRequest + wantPkgs []string + }{ + { + name: "Simple non existent package query", + query: osvdev.DetermineVersionsRequest{ + Name: "test file", + FileHashes: []osvdev.DetermineVersionHash{ + { + Path: "test file/file", + Hash: []byte{}, + }, + }, + }, + wantPkgs: []string{}, + }, + // TODO: Add query for an actual package, this is not added at the moment as it requires too many hashes + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + c := osvdev.DefaultClient() + c.Config.UserAgent = "osv-scanner-api-test" + + got, err := c.ExperimentalDetermineVersion(context.Background(), &tt.query) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + gotPkgInfo := make([]string, 0, len(got.Matches)) + for _, vuln := range got.Matches { + gotPkgInfo = append(gotPkgInfo, vuln.RepoInfo.Address+"@"+vuln.RepoInfo.Version) + } + + if diff := cmp.Diff(tt.wantPkgs, gotPkgInfo); diff != "" { + t.Errorf("Unexpected vuln IDs (-want +got):\n%s", diff) + } + }) + } +}