Skip to content

Commit

Permalink
http/fetch: Optimise memory consumption
Browse files Browse the repository at this point in the history
The previous `http/fetch` logic would load into memory the tar file,
causing large files to increase the likelihood of concurrent
reconciliations to cause OOM.

The Fetch func downloads a file, and then hashs it content and if the
checksum matches, it then goes ahead and extract its contents. The
`resp.Body` is not a `io.SeekReader`, which means that to avoid loading
the full size of the file into memory, we need to save it into a temporary
file, and then load the file to the subsequent operations. With this approach
the memory consumption per operation was reduced from 23mb to 2.1mb:
```
Benchmark_Fetch-16      5  227630480 ns/op  23003358 B/op  19511 allocs/op
Benchmark_FetchNew-16   5  227570375 ns/op   2106795 B/op  19504 allocs/op
```
The tar size use was 7mb.

Expanding on preventing programming, the download process and subsequent
operations are short-circuited after a Max Download Size is reached. With
a max limit set to 100 bytes, the error message yielded is:

`artifact is 7879239 bytes greater than the max download size of 100 bytes`

Signed-off-by: Paulo Gomes <paulo.gomes@weave.works>
  • Loading branch information
Paulo Gomes committed Oct 14, 2022
1 parent 3868547 commit 30a4aca
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 21 deletions.
66 changes: 51 additions & 15 deletions http/fetch/archive_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ limitations under the License.
package fetch

import (
"bytes"
"crypto/sha256"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"time"

"github.com/hashicorp/go-retryablehttp"
Expand All @@ -35,15 +35,15 @@ import (
// the file server is offline.
type ArchiveFetcher struct {
httpClient *retryablehttp.Client
maxUntarSize int
maxDownloadSize int
hostnameOverwrite string
}

// FileNotFoundError is an error type used to signal 404 HTTP status code responses.
var FileNotFoundError = errors.New("file not found")

// NewArchiveFetcher configures the retryable http client used for fetching archives.
func NewArchiveFetcher(retries, maxUntarSize int, hostnameOverwrite string) *ArchiveFetcher {
func NewArchiveFetcher(retries, maxDownloadSize int, hostnameOverwrite string) *ArchiveFetcher {
httpClient := retryablehttp.NewClient()
httpClient.RetryWaitMin = 5 * time.Second
httpClient.RetryWaitMax = 30 * time.Second
Expand All @@ -52,7 +52,7 @@ func NewArchiveFetcher(retries, maxUntarSize int, hostnameOverwrite string) *Arc

return &ArchiveFetcher{
httpClient: httpClient,
maxUntarSize: maxUntarSize,
maxDownloadSize: maxDownloadSize,
hostnameOverwrite: hostnameOverwrite,
}
}
Expand Down Expand Up @@ -89,34 +89,70 @@ func (r *ArchiveFetcher) Fetch(archiveURL, checksum, dir string) error {
return fmt.Errorf("failed to download archive from %s, status: %s", archiveURL, resp.Status)
}

var buf bytes.Buffer
f, err := os.CreateTemp("", "fetch.*.tmp")
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
defer os.Remove(f.Name())

// Save temporary file, but limit download to the max download size.
if r.maxDownloadSize > 0 {
// Headers can lie, so instead of trusting resp.ContentLength,
// limit the download to the max download size and error in case
// there are still bytes left.
// Note that discarding of remaining bytes in resp.Body is a
// requirement for Go to effectively reuse HTTP connections.
_, err = io.Copy(f, io.LimitReader(resp.Body, int64(r.maxDownloadSize)))
n, _ := io.Copy(io.Discard, resp.Body)
if n > 0 {
return fmt.Errorf("artifact is %d bytes greater than the max download size of %d bytes", n, r.maxDownloadSize)
}
} else {
_, err = io.Copy(f, resp.Body)
}
if err != nil {
return fmt.Errorf("failed to copy temp contents: %w", err)
}

// verify checksum matches origin
if err := r.verifyChecksum(checksum, &buf, resp.Body); err != nil {
// We have just filled the file, to be able to read it from
// the start we must go back to its beginning.
_, err = f.Seek(0, 0)
if err != nil {
return fmt.Errorf("failed to seek back to beginning: %w", err)
}

// Ensure that the checksum of the downloaded file matches the
// known checksum.
if err := r.verifyChecksum(checksum, f); err != nil {
return err
}

// extract
if err = tar.Untar(&buf, dir, tar.WithMaxUntarSize(r.maxUntarSize)); err != nil {
return fmt.Errorf("failed to extract archive, error: %w", err)
// Jump back at the beginning of the file stream again.
_, err = f.Seek(0, 0)
if err != nil {
return fmt.Errorf("failed to seek back to beginning again: %w", err)
}

// Extracts the tar file.
if err = tar.Untar(f, dir, tar.WithMaxUntarSize(-1)); err != nil {
return fmt.Errorf("failed to extract archive (check whether file size exceeds max download size): %w", err)
}

return nil
}

// verifyChecksum computes the checksum of the tarball and returns an error if the computed value
// does not match the artifact advertised checksum.
func (r *ArchiveFetcher) verifyChecksum(checksum string, buf *bytes.Buffer, reader io.Reader) error {
func (r *ArchiveFetcher) verifyChecksum(checksum string, reader io.Reader) error {
hasher := sha256.New()

// compute checksum
mw := io.MultiWriter(hasher, buf)
if _, err := io.Copy(mw, reader); err != nil {
// Computes reader's checksum.
if _, err := io.Copy(hasher, reader); err != nil {
return err
}

if newChecksum := fmt.Sprintf("%x", hasher.Sum(nil)); newChecksum != checksum {
return fmt.Errorf("failed to verify archive: computed checksum '%s' doesn't match provided '%s'",
return fmt.Errorf("failed to verify archive: computed checksum '%s' doesn't match provided '%s' (check whether file size exceeds max download size)",
newChecksum, checksum)
}

Expand Down
20 changes: 14 additions & 6 deletions http/fetch/archive_fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,26 @@ func TestArchiveFetcher_Fetch(t *testing.T) {
g.Expect(err).ToNot(HaveOccurred())

tests := []struct {
name string
url string
checksum string
wantErr bool
wantErrType error
name string
url string
checksum string
maxDownloadSize int
wantErr bool
wantErrType error
}{
{
name: "fetches and verifies the checksum",
url: artifactURL,
checksum: artifactChecksum,
wantErr: false,
},
{
name: "breaches max download size",
url: artifactURL,
checksum: artifactChecksum,
maxDownloadSize: 1,
wantErr: true,
},
{
name: "fails to verify the checksum",
url: artifactURL,
Expand All @@ -76,7 +84,7 @@ func TestArchiveFetcher_Fetch(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
g := NewWithT(t)

fetcher := NewArchiveFetcher(1, -1, "")
fetcher := NewArchiveFetcher(1, tt.maxDownloadSize, "")
err = fetcher.Fetch(tt.url, tt.checksum, tmpDir)

if tt.wantErr {
Expand Down

0 comments on commit 30a4aca

Please sign in to comment.