diff --git a/oci/client/client.go b/oci/client/client.go index b3cd257a..c816855a 100644 --- a/oci/client/client.go +++ b/oci/client/client.go @@ -21,13 +21,15 @@ import ( "github.com/google/go-containerregistry/pkg/crane" "github.com/google/go-containerregistry/pkg/v1/remote" +"github.com/hashicorp/go-retryablehttp" "github.com/fluxcd/pkg/oci" ) // Client holds the options for accessing remote OCI registries. type Client struct { - options []crane.Option + options []crane.Option + httpClient *retryablehttp.Client } // NewClient returns an OCI client configured with the given crane options. diff --git a/oci/client/pull.go b/oci/client/pull.go index 633bf4b3..b49892a7 100644 --- a/oci/client/pull.go +++ b/oci/client/pull.go @@ -22,13 +22,28 @@ import ( "context" "fmt" "io" + "net/http" + "net/url" "os" + "github.com/fluxcd/pkg/tar" + "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/crane" "github.com/google/go-containerregistry/pkg/name" - gcrv1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/hashicorp/go-retryablehttp" - "github.com/fluxcd/pkg/tar" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" + "github.com/google/go-containerregistry/pkg/v1/remote/transport" + "golang.org/x/sync/errgroup" +) + +const ( + // thresholdForConcurrentPull is the maximum size of a layer to be extracted in one go. + // If the layer is larger than this, it will be downloaded in chunks. + thresholdForConcurrentPull = 100 * 1024 * 1024 // 100MB + // maxConcurrentPulls is the maximum number of concurrent downloads. + maxConcurrentPulls = 10 ) var ( @@ -39,8 +54,12 @@ var ( // PullOptions contains options for pulling a layer. type PullOptions struct { - layerIndex int - layerType LayerType + layerIndex int + layerType LayerType + transport http.RoundTripper + auth authn.Authenticator + keychain authn.Keychain + concurrency int } // PullOption is a function for configuring PullOptions. @@ -60,22 +79,53 @@ func WithPullLayerIndex(i int) PullOption { } } +func WithTransport(t http.RoundTripper) PullOption { + return func(o *PullOptions) { + o.transport = t + } +} + +func WithConcurrency(c int) PullOption { + return func(o *PullOptions) { + o.concurrency = c + } +} + // Pull downloads an artifact from an OCI repository and extracts the content. // It untar or copies the content to the given outPath depending on the layerType. // If no layer type is given, it tries to determine the right type by checking compressed content of the layer. -func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOption) (*Metadata, error) { +func (c *Client) Pull(ctx context.Context, urlString, outPath string, opts ...PullOption) (*Metadata, error) { o := &PullOptions{ layerIndex: 0, } + o.keychain = authn.DefaultKeychain for _, opt := range opts { opt(o) } - ref, err := name.ParseReference(url) + + if o.concurrency == 0 || o.concurrency > maxConcurrentPulls { + o.concurrency = maxConcurrentPulls + } + + if o.transport == nil { + transport := remote.DefaultTransport.(*http.Transport).Clone() + o.transport = transport + } + + ref, err := name.ParseReference(urlString) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } - img, err := crane.Pull(url, c.optionsWithContext(ctx)...) + if c.httpClient == nil { + h, err := makeHttpClient(ctx, ref.Context(), *o) + if err != nil { + return nil, err + } + c.httpClient = h + } + + img, err := crane.Pull(urlString, c.optionsWithContext(ctx)...) if err != nil { return nil, err } @@ -91,7 +141,7 @@ func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOpti } meta := MetadataFromAnnotations(manifest.Annotations) - meta.URL = url + meta.URL = urlString meta.Digest = ref.Context().Digest(digest.String()).String() layers, err := img.Layers() @@ -107,6 +157,34 @@ func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOpti return nil, fmt.Errorf("index '%d' out of bound for '%d' layers in artifact", o.layerIndex, len(layers)) } + size, err := layers[o.layerIndex].Size() + if err != nil { + return nil, fmt.Errorf("failed to get layer size: %w", err) + } + + if size > thresholdForConcurrentPull { + digest, err := layers[o.layerIndex].Digest() + if err != nil { + return nil, fmt.Errorf("parsing digest failed: %w", err) + } + u := url.URL{ + Scheme: ref.Context().Scheme(), + Host: ref.Context().RegistryStr(), + Path: fmt.Sprintf("/v2/%s/blobs/%s", ref.Context().RepositoryStr(), digest.String()), + } + ok, err := c.IsRangeRequestEnabled(ctx, u) + if err != nil { + return nil, fmt.Errorf("failed to check range request support: %w", err) + } + if ok { + err = c.concurrentExtractLayer(ctx, u, layers[o.layerIndex], outPath, digest, size, o.concurrency) + if err != nil { + return nil, err + } + return meta, nil + } + } + err = extractLayer(layers[o.layerIndex], outPath, o.layerType) if err != nil { return nil, err @@ -114,8 +192,98 @@ func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOpti return meta, nil } +// TO DO: handle authentication handle using keychain for authentication +func (c *Client) IsRangeRequestEnabled(ctx context.Context, u url.URL) (bool, error) { + req, err := retryablehttp.NewRequest(http.MethodHead, u.String(), nil) + if err != nil { + return false, err + } + + resp, err := c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return false, err + } + + if err := transport.CheckError(resp, http.StatusOK); err != nil { + return false, err + } + + if rangeUnit := resp.Header.Get("Accept-Ranges"); rangeUnit == "bytes" { + return true, nil + } + for k, v := range resp.Header { + fmt.Printf("Header: %s, Value: %s\n", k, v) + } + return false, nil +} + +func (c *Client) concurrentExtractLayer(ctx context.Context, u url.URL, layer v1.Layer, path string, digest v1.Hash, size int64, concurrency int) error { + chunkSize := size / int64(concurrency) + chunks := make([][]byte, concurrency+1) + diff := size % int64(concurrency) + + g, ctx := errgroup.WithContext(ctx) + for i := 0; i < concurrency; i++ { + i := i + g.Go(func() (err error) { + start, end := int64(i)*chunkSize, int64(i+1)*chunkSize + if i == concurrency-1 { + end += diff + } + req, err := retryablehttp.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return fmt.Errorf("failed to create a new request: %w", err) + } + req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end-1)) + resp, err := c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return fmt.Errorf("failed to download archive: %w", err) + } + defer resp.Body.Close() + + if err := transport.CheckError(resp, http.StatusPartialContent); err != nil { + return fmt.Errorf("failed to download archive from %s (status: %s)", u.String(), resp.Status) + } + + c, err := io.ReadAll(io.LimitReader(resp.Body, end-start)) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + chunks[i] = c + return nil + }) + } + err := g.Wait() + if err != nil { + return err + } + + content := bufio.NewReader(bytes.NewReader(bytes.Join(chunks, nil))) + d, s, err := v1.SHA256(content) + if err != nil { + return err + } + if d != digest { + return fmt.Errorf("digest mismatch: expected %s, got %s", digest, d) + } + if s != size { + return fmt.Errorf("size mismatch: expected %d, got %d", size, size) + } + + f, err := os.Create(path) + if err != nil { + return err + } + + _, err = io.Copy(f, content) + if err != nil { + return fmt.Errorf("error copying layer content: %s", err) + } + return nil +} + // extractLayer extracts the Layer to the path -func extractLayer(layer gcrv1.Layer, path string, layerType LayerType) error { +func extractLayer(layer v1.Layer, path string, layerType LayerType) error { var blob io.Reader blob, err := layer.Compressed() if err != nil { @@ -173,3 +341,40 @@ func isGzipBlob(buf *bufio.Reader) (bool, error) { } return bytes.Equal(b, gzipMagicHeader), nil } + +type resource interface { + Scheme() string + RegistryStr() string + Scope(string) string + + authn.Resource +} + +func makeHttpClient(ctx context.Context, target resource, o PullOptions) (*retryablehttp.Client, error) { + auth := o.auth + if o.keychain != nil { + kauth, err := o.keychain.Resolve(target) + if err != nil { + return nil, err + } + auth = kauth + } + + reg, ok := target.(name.Registry) + if !ok { + repo, ok := target.(name.Repository) + if !ok { + return nil, fmt.Errorf("unexpected resource: %T", target) + } + reg = repo.Registry + } + + tr, err := transport.NewWithContext(ctx, reg, auth, o.transport, []string{target.Scope(transport.PullScope)}) + if err != nil { + return nil, err + } + + h := retryablehttp.NewClient() + h.HTTPClient = &http.Client{Transport: tr} + return h, nil +} diff --git a/oci/client/pull_test.go b/oci/client/pull_test.go index 86795284..b68dd15a 100644 --- a/oci/client/pull_test.go +++ b/oci/client/pull_test.go @@ -41,6 +41,7 @@ func Test_PullAnyTarball(t *testing.T) { repo := "test-no-annotations" + randStringRunes(5) dst := fmt.Sprintf("%s/%s:%s", dockerReg, repo, tag) + fmt.Println("Pulling from:", dst) artifact := filepath.Join(t.TempDir(), "artifact.tgz") g.Expect(build(artifact, testDir, nil)).To(Succeed()) @@ -82,3 +83,23 @@ func Test_PullAnyTarball(t *testing.T) { g.Expect(extractTo + "/" + entry).To(Or(BeAnExistingFile(), BeADirectory())) } } + +func Test_PullLargeTarball(t *testing.T) { + g := NewWithT(t) + ctx := context.Background() + c := NewClient(DefaultOptions()) + dst := "vnp505/zephyr-7b-alpha:alpha" + extractTo := filepath.Join(t.TempDir(), "artifact") + m, err := c.Pull(ctx, dst, extractTo, WithPullLayerIndex(19)) + fmt.Println("Pulled from:", dst) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(m).ToNot(BeNil()) + g.Expect(m.Annotations).To(BeEmpty()) + g.Expect(m.Created).To(BeEmpty()) + g.Expect(m.Revision).To(BeEmpty()) + g.Expect(m.Source).To(BeEmpty()) + g.Expect(m.URL).To(Equal(dst)) + g.Expect(m.Digest).ToNot(BeEmpty()) + g.Expect(extractTo).ToNot(BeEmpty()) +} diff --git a/oci/client/push_pull_test.go b/oci/client/push_pull_test.go index 3c68b253..9d02f101 100644 --- a/oci/client/push_pull_test.go +++ b/oci/client/push_pull_test.go @@ -305,6 +305,7 @@ func Test_Push_Pull(t *testing.T) { g.Expect(err).ToNot(HaveOccurred()) fileInfo, err := os.Stat(tt.sourcePath) + g.Expect(err).ToNot(HaveOccurred()) // if a directory was pushed, then the created file should be a gzipped archive if fileInfo.IsDir() { bufReader := bufio.NewReader(bytes.NewReader(got)) diff --git a/oci/go.mod b/oci/go.mod index e992b575..50681ead 100644 --- a/oci/go.mod +++ b/oci/go.mod @@ -21,9 +21,11 @@ require ( github.com/fluxcd/pkg/tar v0.4.0 github.com/fluxcd/pkg/version v0.2.2 github.com/google/go-containerregistry v0.18.0 + github.com/hashicorp/go-retryablehttp v0.7.5 github.com/onsi/gomega v1.31.1 github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/sirupsen/logrus v1.9.3 + golang.org/x/sync v0.6.0 sigs.k8s.io/controller-runtime v0.16.3 ) @@ -80,6 +82,7 @@ require ( github.com/gorilla/handlers v1.5.1 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/golang-lru/arc/v2 v2.0.5 // indirect github.com/hashicorp/golang-lru/v2 v2.0.5 // indirect github.com/imdario/mergo v0.3.15 // indirect @@ -130,7 +133,6 @@ require ( golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect golang.org/x/net v0.20.0 // indirect golang.org/x/oauth2 v0.16.0 // indirect - golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.16.0 // indirect golang.org/x/term v0.16.0 // indirect golang.org/x/text v0.14.0 // indirect diff --git a/oci/go.sum b/oci/go.sum index 87ecfa5d..45aeee99 100644 --- a/oci/go.sum +++ b/oci/go.sum @@ -155,6 +155,12 @@ github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= +github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= +github.com/hashicorp/go-retryablehttp v0.7.5 h1:bJj+Pj19UZMIweq/iie+1u5YCdGrnxCT9yvm0e+Nd5M= +github.com/hashicorp/go-retryablehttp v0.7.5/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8= github.com/hashicorp/golang-lru/arc/v2 v2.0.5 h1:l2zaLDubNhW4XO3LnliVj0GXO3+/CGNJAg1dcN2Fpfw= github.com/hashicorp/golang-lru/arc/v2 v2.0.5/go.mod h1:ny6zBSQZi2JxIeYcv7kt2sH2PXJtirBN7RDhRpxPkxU= github.com/hashicorp/golang-lru/v2 v2.0.5 h1:wW7h1TG88eUIJ2i69gaE3uNVtEPIagzhGvHgwfx2Vm4=