Skip to content

Commit

Permalink
Merge branch 'main' into renovate/osv-scanner-minor
Browse files Browse the repository at this point in the history
  • Loading branch information
cuixq authored Dec 23, 2024
2 parents d2d07fa + bd2b403 commit dd2f961
Show file tree
Hide file tree
Showing 5 changed files with 779 additions and 0 deletions.
27 changes: 27 additions & 0 deletions internal/osvdev/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package osvdev

import "github.com/google/osv-scanner/internal/version"

type ClientConfig struct {
MaxConcurrentRequests int
MaxConcurrentBatchRequests int

MaxRetryAttempts int
JitterMultiplier float64
BackoffDurationExponential float64
BackoffDurationMultiplier float64
UserAgent string
}

// DefaultConfig make a default client config
func DefaultConfig() ClientConfig {
return ClientConfig{
MaxRetryAttempts: 4,
JitterMultiplier: 2,
BackoffDurationExponential: 2,
BackoffDurationMultiplier: 1,
UserAgent: "osv-scanner/" + version.OSVVersion,
MaxConcurrentRequests: 1000,
MaxConcurrentBatchRequests: 10,
}
}
73 changes: 73 additions & 0 deletions internal/osvdev/models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package osvdev

import "github.com/google/osv-scanner/pkg/models"

// Package represents a package identifier for OSV.
type Package struct {
PURL string `json:"purl,omitempty"`
Name string `json:"name,omitempty"`
Ecosystem string `json:"ecosystem,omitempty"`
}

// Query represents a query to OSV.
type Query struct {
Commit string `json:"commit,omitempty"`
Package Package `json:"package,omitempty"`
Version string `json:"version,omitempty"`
}

// BatchedQuery represents a batched query to OSV.
type BatchedQuery struct {
Queries []*Query `json:"queries"`
}

// MinimalVulnerability represents an unhydrated vulnerability entry from OSV.
type MinimalVulnerability struct {
ID string `json:"id"`
}

// Response represents a full response from OSV.
type Response struct {
Vulns []models.Vulnerability `json:"vulns"`
}

// MinimalResponse represents an unhydrated response from OSV.
type MinimalResponse struct {
Vulns []MinimalVulnerability `json:"vulns"`
}

// BatchedResponse represents an unhydrated batched response from OSV.
type BatchedResponse struct {
Results []MinimalResponse `json:"results"`
}

// HydratedBatchedResponse represents a hydrated batched response from OSV.
type HydratedBatchedResponse struct {
Results []Response `json:"results"`
}

// DetermineVersionHash holds the per file hash and path information for determineversion.
type DetermineVersionHash struct {
Path string `json:"path"`
Hash []byte `json:"hash"`
}

// DetermineVersionResponse is the response from the determineversions endpoint
type DetermineVersionResponse struct {
Matches []struct {
Score float64 `json:"score"`
RepoInfo struct {
Type string `json:"type"`
Address string `json:"address"`
Tag string `json:"tag"`
Version string `json:"version"`
Commit string `json:"commit"`
} `json:"repo_info"`
} `json:"matches"`
}

// DetermineVersionsRequest is the request format to the determineversions endpoint
type DetermineVersionsRequest struct {
Name string `json:"name"`
FileHashes []DetermineVersionHash `json:"file_hashes"`
}
280 changes: 280 additions & 0 deletions internal/osvdev/osvdev.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
package osvdev

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math"
"math/rand/v2"
"net/http"
"time"

"github.com/google/osv-scanner/pkg/models"
"golang.org/x/sync/errgroup"
)

const (
QueryBatchEndpoint = "/v1/querybatch"
QueryEndpoint = "/v1/query"
GetEndpoint = "/v1/vulns"

// DetermineVersionEndpoint is the URL for posting determineversion queries to OSV.
DetermineVersionEndpoint = "/v1experimental/determineversion"

// MaxQueriesPerQueryBatchRequest is a limit set in osv.dev's API, so is not configurable
MaxQueriesPerQueryBatchRequest = 1000
)

type OSVClient struct {
HTTPClient *http.Client
Config ClientConfig
BaseHostURL string
}

// DefaultClient() creates a new OSVClient with default settings
func DefaultClient() *OSVClient {
return &OSVClient{
HTTPClient: http.DefaultClient,
Config: DefaultConfig(),
BaseHostURL: "https://api.osv.dev",
}
}

// GetVulnsByID is an interface to this endpoint: https://google.github.io/osv.dev/get-v1-vulns/
func (c *OSVClient) GetVulnsByID(ctx context.Context, id string) (*models.Vulnerability, error) {
resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.BaseHostURL+GetEndpoint+"/"+id, nil)
if err != nil {
return nil, err
}
if c.Config.UserAgent != "" {
req.Header.Set("User-Agent", c.Config.UserAgent)
}

return client.Do(req)
})

if err != nil {
return nil, err
}

defer resp.Body.Close()

var vuln models.Vulnerability
decoder := json.NewDecoder(resp.Body)
err = decoder.Decode(&vuln)
if err != nil {
return nil, err
}

return &vuln, nil
}

// QueryBatch is an interface to this endpoint: https://google.github.io/osv.dev/post-v1-querybatch/
func (c *OSVClient) QueryBatch(ctx context.Context, queries []*Query) (*BatchedResponse, error) {
// API has a limit of how many queries are in one batch
queryChunks := chunkBy(queries, MaxQueriesPerQueryBatchRequest)
totalOsvRespBatched := make([][]MinimalResponse, len(queryChunks))

g, ctx := errgroup.WithContext(ctx)
g.SetLimit(c.Config.MaxConcurrentBatchRequests)
for batchIndex, queries := range queryChunks {
requestBytes, err := json.Marshal(BatchedQuery{Queries: queries})
if err != nil {
return nil, err
}

g.Go(func() error {
// exit early if another hydration request has already failed
// results are thrown away later, so avoid needless work
if ctx.Err() != nil {
return nil
}

resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) {
// Make sure request buffer is inside retry, if outside
// http request would finish the buffer, and retried requests would be empty
requestBuf := bytes.NewBuffer(requestBytes)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseHostURL+QueryBatchEndpoint, requestBuf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if c.Config.UserAgent != "" {
req.Header.Set("User-Agent", c.Config.UserAgent)
}

return client.Do(req)
})
if err != nil {
return err
}
defer resp.Body.Close()

var osvResp BatchedResponse
decoder := json.NewDecoder(resp.Body)
err = decoder.Decode(&osvResp)
if err != nil {
return err
}

// Store batch results in the corresponding index to maintain original query order.
totalOsvRespBatched[batchIndex] = osvResp.Results

return nil
})
}

if err := g.Wait(); err != nil {
return nil, err
}

totalOsvResp := BatchedResponse{
Results: make([]MinimalResponse, 0, len(queries)),
}
for _, results := range totalOsvRespBatched {
totalOsvResp.Results = append(totalOsvResp.Results, results...)
}

return &totalOsvResp, nil
}

// Query is an interface to this endpoint: https://google.github.io/osv.dev/post-v1-query/
func (c *OSVClient) Query(ctx context.Context, query *Query) (*Response, error) {
requestBytes, err := json.Marshal(query)
if err != nil {
return nil, err
}

resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) {
requestBuf := bytes.NewBuffer(requestBytes)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseHostURL+QueryEndpoint, requestBuf)
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
if c.Config.UserAgent != "" {
req.Header.Set("User-Agent", c.Config.UserAgent)
}

return client.Do(req)
})

if err != nil {
return nil, err
}
defer resp.Body.Close()

var osvResp Response
decoder := json.NewDecoder(resp.Body)
err = decoder.Decode(&osvResp)
if err != nil {
return nil, err
}

return &osvResp, nil
}

// ExperimentalDetermineVersion
func (c *OSVClient) ExperimentalDetermineVersion(ctx context.Context, query *DetermineVersionsRequest) (*DetermineVersionResponse, error) {
requestBytes, err := json.Marshal(query)
if err != nil {
return nil, err
}

resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) {
// Make sure request buffer is inside retry, if outside
// http request would finish the buffer, and retried requests would be empty
requestBuf := bytes.NewBuffer(requestBytes)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseHostURL+DetermineVersionEndpoint, requestBuf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if c.Config.UserAgent != "" {
req.Header.Set("User-Agent", c.Config.UserAgent)
}

return client.Do(req)
})

if err != nil {
return nil, err
}
defer resp.Body.Close()

var result DetermineVersionResponse
decoder := json.NewDecoder(resp.Body)
if err := decoder.Decode(&result); err != nil {
return nil, err
}

return &result, nil
}

// makeRetryRequest will return an error on both network errors, and if the response is not 200
func (c *OSVClient) makeRetryRequest(action func(client *http.Client) (*http.Response, error)) (*http.Response, error) {
var resp *http.Response
var err error
var lastErr error

for i := range c.Config.MaxRetryAttempts {
// rand is initialized with a random number (since go1.20), and is also safe to use concurrently
// we do not need to use a cryptographically secure random jitter, this is just to spread out the retry requests
// #nosec G404
jitterAmount := (rand.Float64() * float64(c.Config.JitterMultiplier) * float64(i))
time.Sleep(
time.Duration(math.Pow(float64(i), c.Config.BackoffDurationExponential)*c.Config.BackoffDurationMultiplier*1000)*time.Millisecond +
time.Duration(jitterAmount*1000)*time.Millisecond)

resp, err = action(c.HTTPClient)

// The network request itself failed, did not even get a response
if err != nil {
lastErr = fmt.Errorf("attempt %d: request failed: %w", i+1, err)
continue
}

// Everything is fine
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return resp, nil
}

errBody, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = fmt.Errorf("attempt %d: failed to read response: %w", i+1, err)
continue
}

// Special case for too many requests, it should try again after a delay.
if resp.StatusCode == http.StatusTooManyRequests {
lastErr = fmt.Errorf("attempt %d: too many requests: status=%q body=%s", i+1, resp.Status, errBody)
continue
}

// Otherwise any other 400 error should be fatal, as the request we are sending is incorrect
// Retrying won't make a difference
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
return nil, fmt.Errorf("client error: status=%q body=%s", resp.Status, errBody)
}

// Most likely a 500 >= error
lastErr = fmt.Errorf("server error: status=%q body=%s", resp.Status, errBody)
}

return nil, fmt.Errorf("max retries exceeded: %w", lastErr)
}

// From: https://stackoverflow.com/a/72408490
func chunkBy[T any](items []T, chunkSize int) [][]T {
chunks := make([][]T, 0, (len(items)/chunkSize)+1)
for chunkSize < len(items) {
items, chunks = items[chunkSize:], append(chunks, items[0:chunkSize:chunkSize])
}

return append(chunks, items)
}
Loading

0 comments on commit dd2f961

Please sign in to comment.