Skip to content

Commit

Permalink
refactor(guided remediation): Take PreFetch out of `DependencyClien…
Browse files Browse the repository at this point in the history
…t` interface and prevent repeated datasource network calls (#1224)

What I mentioned in
#1207 (comment)

Make `PreFetch` a standalone function that takes in a client that uses
every `DependencyClient` method call.
Since the underlying datasources tend to use the same request for
multiple methods, I've made a `requestCache` type that uses logic based
on the
[singleflight](https://cs.opensource.google/go/x/sync/+/refs/tags/v0.8.0:singleflight/singleflight.go;l=91)
package to prevent the same requests being made multiple times. I've
simplified it a bit by skipping the bespoke handling of panics /
`runtime.Goexit`.
  • Loading branch information
michaelkedar authored Sep 18, 2024
1 parent c3295de commit 1856add
Show file tree
Hide file tree
Showing 16 changed files with 360 additions and 339 deletions.
2 changes: 1 addition & 1 deletion cmd/osv-scanner/fix/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func doInitialRelock(ctx context.Context, opts osvFixOptions) tea.Msg {
if err != nil {
return doRelockMsg{err: err}
}
opts.Client.PreFetch(ctx, m.Requirements, m.FilePath)
client.PreFetch(opts.Client, ctx, m.Requirements, m.FilePath)

return doRelock(ctx, opts.Client, m, opts.ResolveOpts, opts.MatchVuln)
}
Expand Down
5 changes: 3 additions & 2 deletions cmd/osv-scanner/fix/noninteractive.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"deps.dev/util/resolve"
"github.com/google/osv-scanner/internal/remediation"
"github.com/google/osv-scanner/internal/resolution"
"github.com/google/osv-scanner/internal/resolution/client"
lf "github.com/google/osv-scanner/internal/resolution/lockfile"
"github.com/google/osv-scanner/internal/resolution/manifest"
"github.com/google/osv-scanner/pkg/lockfile"
Expand Down Expand Up @@ -129,7 +130,7 @@ func autoRelock(ctx context.Context, r reporter.Reporter, opts osvFixOptions, ma
return err
}

opts.Client.PreFetch(ctx, manif.Requirements, manif.FilePath)
client.PreFetch(opts.Client, ctx, manif.Requirements, manif.FilePath)
res, err := resolution.Resolve(ctx, opts.Client, manif, opts.ResolveOpts)
if err != nil {
return err
Expand Down Expand Up @@ -294,7 +295,7 @@ func autoOverride(ctx context.Context, r reporter.Reporter, opts osvFixOptions,
return err
}

opts.Client.PreFetch(ctx, manif.Requirements, manif.FilePath)
client.PreFetch(opts.Client, ctx, manif.Requirements, manif.FilePath)
res, err := resolution.Resolve(ctx, opts.Client, manif, opts.ResolveOpts)
if err != nil {
return err
Expand Down
105 changes: 99 additions & 6 deletions internal/resolution/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,121 @@ package client

import (
"context"
"crypto/x509"

pb "deps.dev/api/v3"
"deps.dev/util/resolve"
"github.com/google/osv-scanner/pkg/depsdev"
"github.com/google/osv-scanner/pkg/models"
"github.com/google/osv-scanner/pkg/osv"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

type ResolutionClient struct {
DependencyClient
VulnerabilityClient
}

type VulnerabilityClient interface {
// FindVulns finds the vulnerabilities affecting each of Nodes in the graph.
// The returned Vulnerabilities[i] corresponds to the vulnerabilities in g.Nodes[i].
FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error)
}

type DependencyClient interface {
resolve.Client
// WriteCache writes a manifest-specific resolution cache.
WriteCache(filepath string) error
// LoadCache loads a manifest-specific resolution cache.
LoadCache(filepath string) error
// PreFetch loads cache, then makes and caches likely queries needed for resolving a package with a list of requirements
PreFetch(ctx context.Context, requirements []resolve.RequirementVersion, manifestPath string)
}

type VulnerabilityClient interface {
// FindVulns finds the vulnerabilities affecting each of Nodes in the graph.
// The returned Vulnerabilities[i] corresponds to the vulnerabilities in g.Nodes[i].
FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error)
// PreFetch loads cache, then makes and caches likely queries needed for resolving a package with a list of requirements
func PreFetch(c DependencyClient, ctx context.Context, requirements []resolve.RequirementVersion, manifestPath string) {
// It doesn't matter if loading the cache fails
_ = c.LoadCache(manifestPath)

certPool, err := x509.SystemCertPool()
if err != nil {
return
}
creds := credentials.NewClientTLSFromCert(certPool, "")
dialOpts := []grpc.DialOption{grpc.WithTransportCredentials(creds)}

if osv.RequestUserAgent != "" {
dialOpts = append(dialOpts, grpc.WithUserAgent(osv.RequestUserAgent))
}

conn, err := grpc.NewClient(depsdev.DepsdevAPI, dialOpts...)
if err != nil {
return
}
insights := pb.NewInsightsClient(conn)

// Use the deps.dev client to fetch complete dependency graphs of our direct imports
for _, im := range requirements {
// Get the preferred version of the import requirement
vks, err := c.MatchingVersions(ctx, im.VersionKey)
if err != nil || len(vks) == 0 {
continue
}

vk := vks[len(vks)-1]

// We prefer the exact version for soft requirements.
for _, v := range vks {
if im.Version == v.Version {
vk = v
break
}
}

// Make a request for the precomputed dependency tree
resp, err := insights.GetDependencies(ctx, &pb.GetDependenciesRequest{
VersionKey: &pb.VersionKey{
System: pb.System(vk.System),
Name: vk.Name,
Version: vk.Version,
},
})
if err != nil {
continue
}

// Send off queries to cache the packages in the dependency tree
nodes := resp.GetNodes()
for _, node := range nodes {
pbvk := node.GetVersionKey()
vk := resolve.VersionKey{
PackageKey: resolve.PackageKey{
System: resolve.System(pbvk.GetSystem()),
Name: pbvk.GetName(),
},
Version: pbvk.GetVersion(),
VersionType: resolve.Concrete,
}

// TODO: We might want to limit the number of goroutines this creates.
go c.Requirements(ctx, vk) //nolint:errcheck
go c.Version(ctx, vk) //nolint:errcheck
go c.Versions(ctx, vk.PackageKey) //nolint:errcheck
}

for _, edge := range resp.GetEdges() {
req := edge.GetRequirement()
pbvk := nodes[edge.GetToNode()].GetVersionKey()
vk := resolve.VersionKey{
PackageKey: resolve.PackageKey{
System: resolve.System(pbvk.GetSystem()),
Name: pbvk.GetName(),
},
Version: req,
VersionType: resolve.Requirement,
}
go c.MatchingVersions(ctx, vk) //nolint:errcheck
}
}

// don't bother waiting for goroutines to finish.
}
57 changes: 0 additions & 57 deletions internal/resolution/client/depsdev_client.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package client

import (
"context"
"encoding/gob"
"os"

pb "deps.dev/api/v3"
"deps.dev/util/resolve"
"github.com/google/osv-scanner/internal/resolution/datasource"
)
Expand All @@ -27,61 +25,6 @@ func NewDepsDevClient(addr string) (*DepsDevClient, error) {
return &DepsDevClient{APIClient: *resolve.NewAPIClient(c), c: c}, nil
}

func (d *DepsDevClient) PreFetch(ctx context.Context, requirements []resolve.RequirementVersion, manifestPath string) {
// It doesn't matter if loading the cache fails
_ = d.LoadCache(manifestPath)

// Use the deps.dev client to fetch complete dependency graphs of the direct requirements
for _, im := range requirements {
// Get the preferred version of the import requirement
vks, err := d.MatchingVersions(ctx, im.VersionKey)
if err != nil || len(vks) == 0 {
continue
}

vk := vks[len(vks)-1]
for _, v := range vks {
// We prefer the exact version for soft requirements.
if im.Version == v.Version {
vk = v
break
}
}

// Make a request for the precomputed dependency tree
resp, err := d.c.GetDependencies(ctx, &pb.GetDependenciesRequest{
VersionKey: &pb.VersionKey{
System: pb.System(vk.System),
Name: vk.Name,
Version: vk.Version,
},
})
if err != nil {
continue
}

// Send off queries to cache the packages in the dependency tree
for _, node := range resp.GetNodes() {
pbvk := node.GetVersionKey()

pk := resolve.PackageKey{
System: resolve.System(pbvk.GetSystem()),
Name: pbvk.GetName(),
}
go d.Versions(ctx, pk) //nolint:errcheck

vk := resolve.VersionKey{
PackageKey: pk,
Version: pbvk.GetVersion(),
VersionType: resolve.Concrete,
}
go d.Requirements(ctx, vk) //nolint:errcheck
go d.Version(ctx, vk) //nolint:errcheck
}
}
// Don't bother waiting for these goroutines to finish.
}

func (d *DepsDevClient) WriteCache(path string) error {
f, err := os.Create(path + depsDevCacheExt)
if err != nil {
Expand Down
75 changes: 0 additions & 75 deletions internal/resolution/client/maven_registry_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,18 @@ package client

import (
"context"
"crypto/x509"
"encoding/gob"
"fmt"
"os"
"slices"
"strings"

pb "deps.dev/api/v3"
"deps.dev/util/maven"
"deps.dev/util/resolve"
"deps.dev/util/resolve/version"
"deps.dev/util/semver"
"github.com/google/osv-scanner/internal/resolution/datasource"
mavenutil "github.com/google/osv-scanner/internal/utility/maven"
"github.com/google/osv-scanner/pkg/depsdev"
"github.com/google/osv-scanner/pkg/osv"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

const mavenRegistryCacheExt = ".resolve.maven"
Expand Down Expand Up @@ -176,72 +170,3 @@ func (c *MavenRegistryClient) LoadCache(path string) error {

return gob.NewDecoder(f).Decode(&c.api)
}

func (c *MavenRegistryClient) PreFetch(ctx context.Context, imports []resolve.RequirementVersion, manifestPath string) {
certPool, err := x509.SystemCertPool()
if err != nil {
return
}
creds := credentials.NewClientTLSFromCert(certPool, "")
dialOpts := []grpc.DialOption{grpc.WithTransportCredentials(creds)}
if osv.RequestUserAgent != "" {
dialOpts = append(dialOpts, grpc.WithUserAgent(osv.RequestUserAgent))
}

conn, err := grpc.NewClient(depsdev.DepsdevAPI, dialOpts...)
if err != nil {
return
}
insights := pb.NewInsightsClient(conn)

// It doesn't matter if loading the cache fails
_ = c.LoadCache(manifestPath)

// User the deps.dev client to fetch complete dependency graphs of our direct imports
for _, im := range imports {
// Get the preferred version of the import requirement
vks, err := c.MatchingVersions(ctx, im.VersionKey)
if err != nil || len(vks) == 0 {
continue
}

vk := vks[len(vks)-1]
for _, v := range vks {
// We prefer the exact version for soft requirements.
if im.Version == v.Version {
vk = v
break
}
}

// Make a request for the pre-computed dependency tree
resp, err := insights.GetDependencies(ctx, &pb.GetDependenciesRequest{
VersionKey: &pb.VersionKey{
System: pb.System(vk.System),
Name: vk.Name,
Version: vk.Version,
},
})
if err != nil {
continue
}

// Send off queries to cache the packages in the dependency tree
for _, node := range resp.GetNodes() {
pbvk := node.GetVersionKey()
vk := resolve.VersionKey{
PackageKey: resolve.PackageKey{
System: resolve.System(pbvk.GetSystem()),
Name: pbvk.GetName(),
},
Version: pbvk.GetVersion(),
VersionType: resolve.Concrete,
}
// To cache Metadata.
go c.Versions(ctx, vk.PackageKey) //nolint:errcheck
// To cache Projects.
go c.Requirements(ctx, vk) //nolint:errcheck
}
}
// Don't bother waiting for goroutines to finish
}
43 changes: 0 additions & 43 deletions internal/resolution/client/npm_registry_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,49 +229,6 @@ func isNpmBundle(pk resolve.PackageKey) bool {
return strings.Contains(pk.Name, ">")
}

func (c *NpmRegistryClient) PreFetch(ctx context.Context, imports []resolve.RequirementVersion, manifestPath string) {
// It doesn't matter if loading the cache fails
_ = c.LoadCache(manifestPath)

// Use the deps.dev client to fetch complete dependency graphs of our direct imports
for _, im := range imports {
// Get the preferred version of the import requirement
vks, err := c.MatchingVersions(ctx, im.VersionKey)
if err != nil || len(vks) == 0 {
continue
}

vk := vks[len(vks)-1]

// Make a request for the precomputed dependency tree
resp, err := c.ic.GetDependencies(ctx, &pb.GetDependenciesRequest{
VersionKey: &pb.VersionKey{
System: pb.System(vk.System),
Name: vk.Name,
Version: vk.Version,
},
})
if err != nil {
continue
}

// Send off queries to cache the packages in the dependency tree
for _, node := range resp.GetNodes() {
pbvk := node.GetVersionKey()
vk := resolve.VersionKey{
PackageKey: resolve.PackageKey{
System: resolve.System(pbvk.GetSystem()),
Name: pbvk.GetName(),
},
Version: pbvk.GetVersion(),
VersionType: resolve.Concrete,
}
go c.Requirements(ctx, vk) //nolint:errcheck
}
}
// don't bother waiting for goroutines to finish.
}

func (c *NpmRegistryClient) WriteCache(path string) error {
f, err := os.Create(path + npmRegistryCacheExt)
if err != nil {
Expand Down
Loading

0 comments on commit 1856add

Please sign in to comment.