diff --git a/block/local/adapter.go b/block/local/adapter.go index d3870967e33..5353ceb0852 100644 --- a/block/local/adapter.go +++ b/block/local/adapter.go @@ -70,9 +70,23 @@ func NewAdapter(path string, opts ...func(a *Adapter)) (*Adapter, error) { } return adapter, nil } +func resolveNamespace(obj block.ObjectPointer) (block.QualifiedKey, error) { + qualifiedKey, err := block.ResolveNamespace(obj.StorageNamespace, obj.Identifier) + if err != nil { + return qualifiedKey, err + } + if qualifiedKey.StorageType != block.StorageTypeLocal { + return qualifiedKey, block.ErrInvalidNamespace + } + return qualifiedKey, nil +} -func (l *Adapter) getPath(identifier string) string { - return path.Join(l.path, identifier) +func (l *Adapter) getPath(identifier block.ObjectPointer) (string, error) { + obj, err := resolveNamespace(identifier) + if err != nil { + return "", err + } + return path.Join(l.path, obj.StorageNamespace, obj.Key), nil } // maybeMkdir runs f(path), but if f fails due to file-not-found MkdirAll's its dir and then @@ -89,7 +103,10 @@ func maybeMkdir(path string, f func(p string) (*os.File, error)) (*os.File, erro } func (l *Adapter) Put(obj block.ObjectPointer, _ int64, reader io.Reader, _ block.PutOpts) error { - p := l.getPath(obj.Identifier) + p, err := l.getPath(obj) + if err != nil { + return err + } f, err := maybeMkdir(p, os.Create) if err != nil { return err @@ -102,12 +119,18 @@ func (l *Adapter) Put(obj block.ObjectPointer, _ int64, reader io.Reader, _ bloc } func (l *Adapter) Remove(obj block.ObjectPointer) error { - p := l.getPath(obj.Identifier) + p, err := l.getPath(obj) + if err != nil { + return err + } return os.Remove(p) } func (l *Adapter) Copy(sourceObj, destinationObj block.ObjectPointer) error { - source := l.getPath(sourceObj.Identifier) + source, err := l.getPath(sourceObj) + if err != nil { + return err + } sourceFile, err := os.Open(source) defer func() { _ = sourceFile.Close() @@ -115,7 +138,10 @@ func (l *Adapter) Copy(sourceObj, destinationObj block.ObjectPointer) error { if err != nil { return err } - dest := l.getPath(destinationObj.Identifier) + dest, err := l.getPath(destinationObj) + if err != nil { + return err + } destinationFile, err := maybeMkdir(dest, os.Create) if err != nil { return err @@ -128,7 +154,10 @@ func (l *Adapter) Copy(sourceObj, destinationObj block.ObjectPointer) error { } func (l *Adapter) Get(obj block.ObjectPointer, _ int64) (reader io.ReadCloser, err error) { - p := l.getPath(obj.Identifier) + p, err := l.getPath(obj) + if err != nil { + return nil, err + } f, err := os.OpenFile(p, os.O_RDONLY, 0755) if err != nil { return nil, err @@ -137,7 +166,10 @@ func (l *Adapter) Get(obj block.ObjectPointer, _ int64) (reader io.ReadCloser, e } func (l *Adapter) GetRange(obj block.ObjectPointer, start int64, end int64) (io.ReadCloser, error) { - p := l.getPath(obj.Identifier) + p, err := l.getPath(obj) + if err != nil { + return nil, err + } f, err := os.Open(p) if err != nil { return nil, err @@ -152,8 +184,11 @@ func (l *Adapter) GetRange(obj block.ObjectPointer, start int64, end int64) (io. } func (l *Adapter) GetProperties(obj block.ObjectPointer) (block.Properties, error) { - p := l.getPath(obj.Identifier) - _, err := os.Stat(p) + p, err := l.getPath(obj) + if err != nil { + return block.Properties{}, err + } + _, err = os.Stat(p) if err != nil { return block.Properties{}, err } @@ -175,11 +210,14 @@ func isDirectoryWritable(pth string) bool { return true } -func (l *Adapter) CreateMultiPartUpload(obj block.ObjectPointer, r *http.Request, opts block.CreateMultiPartUploadOpts) (string, error) { +func (l *Adapter) CreateMultiPartUpload(obj block.ObjectPointer, _ *http.Request, _ block.CreateMultiPartUploadOpts) (string, error) { if strings.Contains(obj.Identifier, "/") { - fullPath := l.getPath(obj.Identifier) + fullPath, err := l.getPath(obj) + if err != nil { + return "", err + } fullDir := path.Dir(fullPath) - err := os.MkdirAll(fullDir, 0755) + err = os.MkdirAll(fullDir, 0755) if err != nil { return "", err } @@ -190,16 +228,16 @@ func (l *Adapter) CreateMultiPartUpload(obj block.ObjectPointer, r *http.Request return uploadID, nil } -func (l *Adapter) UploadPart(obj block.ObjectPointer, sizeBytes int64, reader io.Reader, uploadID string, partNumber int64) (string, error) { +func (l *Adapter) UploadPart(obj block.ObjectPointer, _ int64, reader io.Reader, uploadID string, partNumber int64) (string, error) { md5Read := block.NewHashingReader(reader, block.HashFunctionMD5) fName := uploadID + fmt.Sprintf("-%05d", (partNumber)) - err := l.Put(block.ObjectPointer{StorageNamespace: "", Identifier: fName}, -1, md5Read, block.PutOpts{}) + err := l.Put(block.ObjectPointer{StorageNamespace: obj.StorageNamespace, Identifier: fName}, -1, md5Read, block.PutOpts{}) etag := "\"" + hex.EncodeToString(md5Read.Md5.Sum(nil)) + "\"" return etag, err } func (l *Adapter) AbortMultiPartUpload(obj block.ObjectPointer, uploadID string) error { - files, err := l.getPartFiles(uploadID) + files, err := l.getPartFiles(uploadID, obj) if err != nil { return err } @@ -209,11 +247,11 @@ func (l *Adapter) AbortMultiPartUpload(obj block.ObjectPointer, uploadID string) func (l *Adapter) CompleteMultiPartUpload(obj block.ObjectPointer, uploadID string, multipartList *block.MultipartUploadCompletion) (*string, int64, error) { etag := computeETag(multipartList.Part) + "-" + strconv.Itoa(len(multipartList.Part)) - partFiles, err := l.getPartFiles(uploadID) + partFiles, err := l.getPartFiles(uploadID, obj) if err != nil { return nil, -1, fmt.Errorf("part files not found for %s: %w", uploadID, err) } - size, err := l.unitePartFiles(obj.Identifier, partFiles) + size, err := l.unitePartFiles(obj, partFiles) if err != nil { return nil, -1, fmt.Errorf("multipart upload unite for %s: %w", uploadID, err) } @@ -237,8 +275,11 @@ func computeETag(parts []*s3.CompletedPart) string { return csm } -func (l *Adapter) unitePartFiles(identifier string, files []string) (int64, error) { - p := l.getPath(identifier) +func (l *Adapter) unitePartFiles(identifier block.ObjectPointer, files []string) (int64, error) { + p, err := l.getPath(identifier) + if err != nil { + return 0, err + } unitedFile, err := os.Create(p) if err != nil { return 0, fmt.Errorf("create path %s: %w", p, err) @@ -268,8 +309,16 @@ func (l *Adapter) removePartFiles(files []string) { } } -func (l *Adapter) getPartFiles(uploadID string) ([]string, error) { - globPathPattern := l.getPath(uploadID) + "-*" +func (l *Adapter) getPartFiles(uploadID string, obj block.ObjectPointer) ([]string, error) { + newObj := block.ObjectPointer{ + StorageNamespace: obj.StorageNamespace, + Identifier: uploadID, + } + globPathPattern, err := l.getPath(newObj) + if err != nil { + return nil, err + } + globPathPattern += "*" names, err := filepath.Glob(globPathPattern) if err != nil { return nil, err diff --git a/block/local/adapter_test.go b/block/local/adapter_test.go index 8e9af32a25b..d868b6b09c6 100644 --- a/block/local/adapter_test.go +++ b/block/local/adapter_test.go @@ -1,6 +1,8 @@ package local_test import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" "io/ioutil" "os" "strings" @@ -15,7 +17,7 @@ func makeAdapter(t *testing.T) (*local.Adapter, func()) { t.Helper() dir, err := ioutil.TempDir("", "testing-local-adapter-*") testutil.MustDo(t, "TempDir", err) - os.MkdirAll(dir, 0700) + testutil.MustDo(t, "NewAdapter", os.MkdirAll(dir, 0700)) a, err := local.NewAdapter(dir) testutil.MustDo(t, "NewAdapter", err) @@ -27,7 +29,7 @@ func makeAdapter(t *testing.T) (*local.Adapter, func()) { } func makePointer(path string) block.ObjectPointer { - return block.ObjectPointer{Identifier: path} + return block.ObjectPointer{Identifier: path, StorageNamespace: "local://test/"} } func TestLocalPutGet(t *testing.T) { @@ -58,6 +60,49 @@ func TestLocalPutGet(t *testing.T) { } } +func TestLocalMultipartUpload(t *testing.T) { + a, cleanup := makeAdapter(t) + defer cleanup() + + cases := []struct { + name string + path string + partData []string + }{ + {"simple", "abc", []string{"one ", "two ", "three"}}, + {"nested", "foo/bar", []string{"one ", "two ", "three"}}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + pointer := makePointer(c.path) + uploadID, err := a.CreateMultiPartUpload(pointer, nil, block.CreateMultiPartUploadOpts{}) + testutil.MustDo(t, "CreateMultiPartUpload", err) + parts := make([]*s3.CompletedPart, 0) + for partNumber, content := range c.partData { + cs, err := a.UploadPart(pointer, 0, strings.NewReader(content), uploadID, int64(partNumber)) + testutil.MustDo(t, "UploadPart", err) + parts = append(parts, &s3.CompletedPart{ + ETag: aws.String(cs), + PartNumber: aws.Int64(int64(partNumber)), + }) + } + _, _, err = a.CompleteMultiPartUpload(pointer, uploadID, &block.MultipartUploadCompletion{ + Part: parts, + }) + testutil.MustDo(t, "CompleteMultiPartUpload", err) + reader, err := a.Get(pointer, 0) + testutil.MustDo(t, "Get", err) + got, err := ioutil.ReadAll(reader) + testutil.MustDo(t, "ReadAll", err) + expected := strings.Join(c.partData, "") + if string(got) != expected { + t.Errorf("expected to read \"%s\" as written, got \"%s\"", expected, string(got)) + } + }) + } +} + func TestLocalCopy(t *testing.T) { a, cleanup := makeAdapter(t) defer cleanup()