Skip to content

Commit

Permalink
Merge pull request #2903 from nirs/download-time
Browse files Browse the repository at this point in the history
Fix races during parallel downloads
  • Loading branch information
AkihiroSuda authored Nov 25, 2024
2 parents 3fdfaeb + 5071535 commit f81fa90
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 95 deletions.
159 changes: 92 additions & 67 deletions pkg/downloader/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,31 @@ func Download(ctx context.Context, local, remote string, opts ...Opt) (*Result,
return res, nil
}

shad := cacheDirectoryPath(o.cacheDir, remote)
if err := os.MkdirAll(shad, 0o700); err != nil {
return nil, err
}

var res *Result
err := lockutil.WithDirLock(shad, func() error {
var err error
res, err = getCached(ctx, localPath, remote, o)
if err != nil {
return err
}
if res != nil {
return nil
}
res, err = fetch(ctx, localPath, remote, o)
return err
})
return res, err
}

// getCached tries to copy the file from the cache to local path. Return result,
// nil if the file was copied, nil, nil if the file is not in the cache or the
// cache needs update, or nil, error on fatal error.
func getCached(ctx context.Context, localPath, remote string, o options) (*Result, error) {
shad := cacheDirectoryPath(o.cacheDir, remote)
shadData := filepath.Join(shad, "data")
shadTime := filepath.Join(shad, "time")
Expand All @@ -237,53 +262,62 @@ func Download(ctx context.Context, local, remote string, opts ...Opt) (*Result,
if err != nil {
return nil, err
}
if _, err := os.Stat(shadData); err == nil {
logrus.Debugf("file %q is cached as %q", localPath, shadData)
useCache := true
if _, err := os.Stat(shadDigest); err == nil {
logrus.Debugf("Comparing digest %q with the cached digest file %q, not computing the actual digest of %q",
o.expectedDigest, shadDigest, shadData)
if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil {
return nil, err
}
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, "", ""); err != nil {
if _, err := os.Stat(shadData); err != nil {
return nil, nil
}
ext := path.Ext(remote)
logrus.Debugf("file %q is cached as %q", localPath, shadData)
if _, err := os.Stat(shadDigest); err == nil {
logrus.Debugf("Comparing digest %q with the cached digest file %q, not computing the actual digest of %q",
o.expectedDigest, shadDigest, shadData)
if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil {
return nil, err
}
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, "", ""); err != nil {
return nil, err
}
} else {
if match, lmCached, lmRemote, err := matchLastModified(ctx, shadTime, remote); err != nil {
logrus.WithError(err).Info("Failed to retrieve last-modified for cached digest-less image; using cached image.")
} else if match {
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, o.description, o.expectedDigest); err != nil {
return nil, err
}
} else {
if match, lmCached, lmRemote, err := matchLastModified(ctx, shadTime, remote); err != nil {
logrus.WithError(err).Info("Failed to retrieve last-modified for cached digest-less image; using cached image.")
} else if match {
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, o.description, o.expectedDigest); err != nil {
return nil, err
}
} else {
logrus.Infof("Re-downloading digest-less image: last-modified mismatch (cached: %q, remote: %q)", lmCached, lmRemote)
useCache = false
}
}
if useCache {
res := &Result{
Status: StatusUsedCache,
CachePath: shadData,
LastModified: readTime(shadTime),
ContentType: readFile(shadType),
ValidatedDigest: o.expectedDigest != "",
}
return res, nil
logrus.Infof("Re-downloading digest-less image: last-modified mismatch (cached: %q, remote: %q)", lmCached, lmRemote)
return nil, nil
}
}
if err := os.MkdirAll(shad, 0o700); err != nil {
res := &Result{
Status: StatusUsedCache,
CachePath: shadData,
LastModified: readTime(shadTime),
ContentType: readFile(shadType),
ValidatedDigest: o.expectedDigest != "",
}
return res, nil
}

// fetch downloads remote to the cache and copy the cached file to local path.
func fetch(ctx context.Context, localPath, remote string, o options) (*Result, error) {
shad := cacheDirectoryPath(o.cacheDir, remote)
shadData := filepath.Join(shad, "data")
shadTime := filepath.Join(shad, "time")
shadType := filepath.Join(shad, "type")
shadDigest, err := cacheDigestPath(shad, o.expectedDigest)
if err != nil {
return nil, err
}
ext := path.Ext(remote)
shadURL := filepath.Join(shad, "url")
if err := writeFirst(shadURL, []byte(remote), 0o644); err != nil {
if err := os.WriteFile(shadURL, []byte(remote), 0o644); err != nil {
return nil, err
}
if err := downloadHTTP(ctx, shadData, shadTime, shadType, remote, o.description, o.expectedDigest); err != nil {
return nil, err
}
if shadDigest != "" && o.expectedDigest != "" {
if err := writeFirst(shadDigest, []byte(o.expectedDigest.String()), 0o644); err != nil {
if err := os.WriteFile(shadDigest, []byte(o.expectedDigest.String()), 0o644); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -327,18 +361,33 @@ func Cached(remote string, opts ...Opt) (*Result, error) {
if err != nil {
return nil, err
}

// Checking if data file exists is safe without locking.
if _, err := os.Stat(shadData); err != nil {
return nil, err
}
if _, err := os.Stat(shadDigest); err != nil {
if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil {
return nil, err
}
} else {
if err := validateLocalFileDigest(shadData, o.expectedDigest); err != nil {
return nil, err

// But validating the digest or the data file must take the lock to avoid races
// with parallel downloads.
if err := os.MkdirAll(shad, 0o700); err != nil {
return nil, err
}
err = lockutil.WithDirLock(shad, func() error {
if _, err := os.Stat(shadDigest); err != nil {
if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil {
return err
}
} else {
if err := validateLocalFileDigest(shadData, o.expectedDigest); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}

res := &Result{
Status: StatusUsedCache,
CachePath: shadData,
Expand Down Expand Up @@ -612,13 +661,13 @@ func downloadHTTP(ctx context.Context, localPath, lastModified, contentType, url
}
if lastModified != "" {
lm := resp.Header.Get("Last-Modified")
if err := writeFirst(lastModified, []byte(lm), 0o644); err != nil {
if err := os.WriteFile(lastModified, []byte(lm), 0o644); err != nil {
return err
}
}
if contentType != "" {
ct := resp.Header.Get("Content-Type")
if err := writeFirst(contentType, []byte(ct), 0o644); err != nil {
if err := os.WriteFile(contentType, []byte(ct), 0o644); err != nil {
return err
}
}
Expand Down Expand Up @@ -679,19 +728,7 @@ func downloadHTTP(ctx context.Context, localPath, lastModified, contentType, url
return err
}

// If localPath was created by a parallel download keep it. Replacing it
// while another process is copying it to the destination may fail the
// clonefile syscall. We use a lock to ensure that only one process updates
// data, and when we return data file exists.

return lockutil.WithDirLock(filepath.Dir(localPath), func() error {
if _, err := os.Stat(localPath); err == nil {
return nil
} else if !errors.Is(err, os.ErrNotExist) {
return err
}
return os.Rename(localPathTmp, localPath)
})
return os.Rename(localPathTmp, localPath)
}

var tempfileCount atomic.Uint64
Expand All @@ -706,18 +743,6 @@ func perProcessTempfile(path string) string {
return fmt.Sprintf("%s.tmp.%d.%d", path, os.Getpid(), tempfileCount.Add(1))
}

// writeFirst writes data to path unless path already exists.
func writeFirst(path string, data []byte, perm os.FileMode) error {
return lockutil.WithDirLock(filepath.Dir(path), func() error {
if _, err := os.Stat(path); err == nil {
return nil
} else if !errors.Is(err, os.ErrNotExist) {
return err
}
return os.WriteFile(path, data, perm)
})
}

// CacheEntries returns a map of cache entries.
// The key is the SHA256 of the URL.
// The value is the path to the cache entry.
Expand Down
68 changes: 40 additions & 28 deletions pkg/downloader/downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
"time"
Expand All @@ -31,11 +30,6 @@ type downloadResult struct {
// races quicker. 20 parallel downloads take about 120 milliseconds on M1 Pro.
const parallelDownloads = 20

// When downloading in parallel usually all downloads completed with
// StatusDownload, but some may be delayed and find the data file when they
// start. Can be reproduced locally using 100 parallel downloads.
var parallelStatus = []Status{StatusDownloaded, StatusUsedCache}

func TestDownloadRemote(t *testing.T) {
ts := httptest.NewServer(http.FileServer(http.Dir("testdata")))
t.Cleanup(ts.Close)
Expand Down Expand Up @@ -103,15 +97,10 @@ func TestDownloadRemote(t *testing.T) {
results <- downloadResult{r, err}
}()
}
// We must process all results before cleanup.
for i := 0; i < parallelDownloads; i++ {
result := <-results
if result.err != nil {
t.Errorf("Download failed: %s", result.err)
} else if !slices.Contains(parallelStatus, result.r.Status) {
t.Errorf("Expected download status %s, got %s", parallelStatus, result.r.Status)
}
}
// Only one thread should download, the rest should use the cache.
downloaded, cached := countResults(t, results)
assert.Equal(t, downloaded, 1)
assert.Equal(t, cached, parallelDownloads-1)
})
})
t.Run("caching-only mode", func(t *testing.T) {
Expand Down Expand Up @@ -146,15 +135,10 @@ func TestDownloadRemote(t *testing.T) {
results <- downloadResult{r, err}
}()
}
// We must process all results before cleanup.
for i := 0; i < parallelDownloads; i++ {
result := <-results
if result.err != nil {
t.Errorf("Download failed: %s", result.err)
} else if !slices.Contains(parallelStatus, result.r.Status) {
t.Errorf("Expected download status %s, got %s", parallelStatus, result.r.Status)
}
}
// Only one thread should download, the rest should use the cache.
downloaded, cached := countResults(t, results)
assert.Equal(t, downloaded, 1)
assert.Equal(t, cached, parallelDownloads-1)
})
})
t.Run("cached", func(t *testing.T) {
Expand Down Expand Up @@ -188,6 +172,26 @@ func TestDownloadRemote(t *testing.T) {
})
}

func countResults(t *testing.T, results chan downloadResult) (downloaded, cached int) {
t.Helper()
for i := 0; i < parallelDownloads; i++ {
result := <-results
if result.err != nil {
t.Errorf("Download failed: %s", result.err)
} else {
switch result.r.Status {
case StatusDownloaded:
downloaded++
case StatusUsedCache:
cached++
default:
t.Errorf("Unexpected download status %q", result.r.Status)
}
}
}
return downloaded, cached
}

func TestRedownloadRemote(t *testing.T) {
remoteDir := t.TempDir()
ts := httptest.NewServer(http.FileServer(http.Dir(remoteDir)))
Expand All @@ -203,18 +207,26 @@ func TestRedownloadRemote(t *testing.T) {
assert.NilError(t, os.Chtimes(remoteFile, time.Now(), time.Now().Add(-time.Hour)))
opt := []Opt{cacheOpt}

r, err := Download(context.Background(), filepath.Join(downloadDir, "digest-less1.txt"), ts.URL+"/digest-less.txt", opt...)
// Download on the first call
r, err := Download(context.Background(), filepath.Join(downloadDir, "1"), ts.URL+"/digest-less.txt", opt...)
assert.NilError(t, err)
assert.Equal(t, StatusDownloaded, r.Status)
r, err = Download(context.Background(), filepath.Join(downloadDir, "digest-less2.txt"), ts.URL+"/digest-less.txt", opt...)

// Next download will use the cached download
r, err = Download(context.Background(), filepath.Join(downloadDir, "2"), ts.URL+"/digest-less.txt", opt...)
assert.NilError(t, err)
assert.Equal(t, StatusUsedCache, r.Status)

// modifying remote file will cause redownload
// Modifying remote file will cause redownload
assert.NilError(t, os.Chtimes(remoteFile, time.Now(), time.Now()))
r, err = Download(context.Background(), filepath.Join(downloadDir, "digest-less3.txt"), ts.URL+"/digest-less.txt", opt...)
r, err = Download(context.Background(), filepath.Join(downloadDir, "3"), ts.URL+"/digest-less.txt", opt...)
assert.NilError(t, err)
assert.Equal(t, StatusDownloaded, r.Status)

// Next download will use the cached download
r, err = Download(context.Background(), filepath.Join(downloadDir, "4"), ts.URL+"/digest-less.txt", opt...)
assert.NilError(t, err)
assert.Equal(t, StatusUsedCache, r.Status)
})

t.Run("has-digest", func(t *testing.T) {
Expand Down

0 comments on commit f81fa90

Please sign in to comment.