Skip to content

Commit

Permalink
Local adapter clean user input before file access (#1037)
Browse files Browse the repository at this point in the history
  • Loading branch information
nopcoder authored Dec 30, 2020
1 parent 96f9f87 commit 60c9a64
Showing 1 changed file with 33 additions and 9 deletions.
42 changes: 33 additions & 9 deletions block/local/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var (
ErrPathNotValid = errors.New("path provided is not a valid directory")
ErrPathNotWritable = errors.New("path provided is not writable")
ErrInventoryNotSupported = errors.New("inventory feature not implemented for local storage adapter")
ErrInvalidUploadIDFormat = errors.New("invalid upload id format")
)

func (l *Adapter) WithContext(ctx context.Context) block.Adapter {
Expand All @@ -51,6 +52,7 @@ func WithTranslator(t block.UploadIDTranslator) func(a *Adapter) {
}

func NewAdapter(path string, opts ...func(a *Adapter)) (*Adapter, error) {
path = filepath.Clean(path)
stt, err := os.Stat(path)
if err != nil {
return nil, err
Expand All @@ -62,7 +64,8 @@ func NewAdapter(path string, opts ...func(a *Adapter)) (*Adapter, error) {
return nil, ErrPathNotWritable
}
adapter := &Adapter{
path: path, ctx: context.Background(),
path: path,
ctx: context.Background(),
uploadIDTranslator: &block.NoOpTranslator{},
}
for _, opt := range opts {
Expand All @@ -86,7 +89,8 @@ func (l *Adapter) getPath(identifier block.ObjectPointer) (string, error) {
if err != nil {
return "", err
}
return path.Join(l.path, obj.StorageNamespace, obj.Key), nil
p := path.Join(l.path, obj.StorageNamespace, obj.Key)
return p, nil
}

// maybeMkdir runs f(path), but if f fails due to file-not-found MkdirAll's its dir and then
Expand All @@ -96,7 +100,8 @@ func maybeMkdir(path string, f func(p string) (*os.File, error)) (*os.File, erro
if !errors.Is(err, os.ErrNotExist) {
return ret, err
}
if err = os.MkdirAll(filepath.Dir(path), 0777); err != nil {
d := filepath.Dir(filepath.Clean(path))
if err = os.MkdirAll(d, 0750); err != nil {
return nil, err
}
return f(path)
Expand All @@ -107,6 +112,7 @@ func (l *Adapter) Put(obj block.ObjectPointer, _ int64, reader io.Reader, _ bloc
if err != nil {
return err
}
p = filepath.Clean(p)
f, err := maybeMkdir(p, os.Create)
if err != nil {
return err
Expand All @@ -123,6 +129,7 @@ func (l *Adapter) Remove(obj block.ObjectPointer) error {
if err != nil {
return err
}
p = filepath.Clean(p)
return os.Remove(p)
}

Expand All @@ -131,7 +138,7 @@ func (l *Adapter) Copy(sourceObj, destinationObj block.ObjectPointer) error {
if err != nil {
return err
}
sourceFile, err := os.Open(source)
sourceFile, err := os.Open(filepath.Clean(source))
defer func() {
_ = sourceFile.Close()
}()
Expand All @@ -158,7 +165,7 @@ func (l *Adapter) Get(obj block.ObjectPointer, _ int64) (reader io.ReadCloser, e
if err != nil {
return nil, err
}
f, err := os.OpenFile(p, os.O_RDONLY, 0755)
f, err := os.OpenFile(filepath.Clean(p), os.O_RDONLY, 0600)
if err != nil {
return nil, err
}
Expand All @@ -170,7 +177,7 @@ func (l *Adapter) GetRange(obj block.ObjectPointer, start int64, end int64) (io.
if err != nil {
return nil, err
}
f, err := os.Open(p)
f, err := os.Open(filepath.Clean(p))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -217,7 +224,7 @@ func (l *Adapter) CreateMultiPartUpload(obj block.ObjectPointer, _ *http.Request
return "", err
}
fullDir := path.Dir(fullPath)
err = os.MkdirAll(fullDir, 0755)
err = os.MkdirAll(fullDir, 0750)
if err != nil {
return "", err
}
Expand All @@ -229,14 +236,20 @@ func (l *Adapter) CreateMultiPartUpload(obj block.ObjectPointer, _ *http.Request
}

func (l *Adapter) UploadPart(obj block.ObjectPointer, _ int64, reader io.Reader, uploadID string, partNumber int64) (string, error) {
if err := isValidUploadID(uploadID); err != nil {
return "", err
}
md5Read := block.NewHashingReader(reader, block.HashFunctionMD5)
fName := uploadID + fmt.Sprintf("-%05d", (partNumber))
fName := uploadID + fmt.Sprintf("-%05d", partNumber)
err := l.Put(block.ObjectPointer{StorageNamespace: obj.StorageNamespace, Identifier: fName}, -1, md5Read, block.PutOpts{})
etag := "\"" + hex.EncodeToString(md5Read.Md5.Sum(nil)) + "\""
return etag, err
}

func (l *Adapter) AbortMultiPartUpload(obj block.ObjectPointer, uploadID string) error {
if err := isValidUploadID(uploadID); err != nil {
return err
}
files, err := l.getPartFiles(uploadID, obj)
if err != nil {
return err
Expand All @@ -246,6 +259,9 @@ func (l *Adapter) AbortMultiPartUpload(obj block.ObjectPointer, uploadID string)
}

func (l *Adapter) CompleteMultiPartUpload(obj block.ObjectPointer, uploadID string, multipartList *block.MultipartUploadCompletion) (*string, int64, error) {
if err := isValidUploadID(uploadID); err != nil {
return nil, -1, err
}
etag := computeETag(multipartList.Part) + "-" + strconv.Itoa(len(multipartList.Part))
partFiles, err := l.getPartFiles(uploadID, obj)
if err != nil {
Expand Down Expand Up @@ -289,7 +305,7 @@ func (l *Adapter) unitePartFiles(identifier block.ObjectPointer, files []string)
}()
var readers = []io.Reader{}
for _, name := range files {
f, err := os.Open(name)
f, err := os.Open(filepath.Clean(name))
if err != nil {
return 0, fmt.Errorf("open file %s: %w", name, err)
}
Expand Down Expand Up @@ -338,3 +354,11 @@ func (l *Adapter) GenerateInventory(_ context.Context, _ logging.Logger, _ strin
func (l *Adapter) BlockstoreType() string {
return BlockstoreType
}

func isValidUploadID(uploadID string) error {
_, err := hex.DecodeString(uploadID)
if err != nil {
return fmt.Errorf("%w: %s", ErrInvalidUploadIDFormat, err)
}
return nil
}

0 comments on commit 60c9a64

Please sign in to comment.