Skip to content

Commit

Permalink
Prevent arbitrary file overwrite via path traversal (Zip Slip attack)
Browse files Browse the repository at this point in the history
  • Loading branch information
giuliocomi authored and petemoore committed Sep 28, 2020
1 parent 45ba413 commit 8217ed3
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 0 deletions.
5 changes: 5 additions & 0 deletions archiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ type ExtensionChecker interface {
CheckExt(name string) error
}

// FilenameChecker validates filenames to prevent path traversal attacks
type FilenameChecker interface {
CheckPath(to, filename string) error
}

// Unarchiver is a type that can extract archive files
// into a folder.
type Unarchiver interface {
Expand Down
20 changes: 20 additions & 0 deletions rar.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ func (*Rar) CheckExt(filename string) error {
return nil
}

// CheckPath ensures that the filename has not been crafted to perform path traversal attacks
func (*Rar) CheckPath(to, filename string) error {
to, _ = filepath.Abs(to) //explicit the destination folder to prevent that 'string.HasPrefix' check can be 'bypassed' when no destination folder is supplied in input
dest := filepath.Join(to, filename)
//prevent path traversal attacks
if !strings.HasPrefix(dest, to) {
return fmt.Errorf("illegal file path: %s", filename)
}
return nil
}

// Unarchive unpacks the .rar file at source to destination.
// Destination will be treated as a folder name. It supports
// multi-volume archives.
Expand Down Expand Up @@ -145,10 +156,18 @@ func (r *Rar) unrarNext(to string) error {
if err != nil {
return err // don't wrap error; calling loop must break on io.EOF
}
defer f.Close()

header, ok := f.Header.(*rardecode.FileHeader)
if !ok {
return fmt.Errorf("expected header to be *rardecode.FileHeader but was %T", f.Header)
}

errPath := r.CheckPath(to, header.Name)
if errPath != nil {
return fmt.Errorf("checking path traversal attempt: %v", errPath)
}

return r.unrarFile(f, filepath.Join(to, header.Name))
}

Expand Down Expand Up @@ -404,6 +423,7 @@ var (
_ = Extractor(new(Rar))
_ = Matcher(new(Rar))
_ = ExtensionChecker(new(Rar))
_ = FilenameChecker(new(Rar))
_ = os.FileInfo(rarFileInfo{})
)

Expand Down
19 changes: 19 additions & 0 deletions tar.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ func (*Tar) CheckExt(filename string) error {
return nil
}

// CheckPath ensures that the filename has not been crafted to perform path traversal attacks
func (*Tar) CheckPath(to, filename string) error {
to, _ = filepath.Abs(to) //explicit the destination folder to prevent that 'string.HasPrefix' check can be 'bypassed' when no destination folder is supplied in input
dest := filepath.Join(to, filename)
//prevent path traversal attacks
if !strings.HasPrefix(dest, to) {
return fmt.Errorf("illegal file path: %s", filename)
}
return nil
}

// Archive creates a tarball file at destination containing
// the files listed in sources. The destination must end with
// ".tar". File paths can be those of regular files or
Expand Down Expand Up @@ -211,10 +222,17 @@ func (t *Tar) untarNext(destination string) error {
if err != nil {
return err // don't wrap error; calling loop must break on io.EOF
}
defer f.Close()

header, ok := f.Header.(*tar.Header)
if !ok {
return fmt.Errorf("expected header to be *tar.Header but was %T", f.Header)
}

errPath := t.CheckPath(destination, header.Name)
if errPath != nil {
return fmt.Errorf("checking path traversal attempt: %v", errPath)
}
return t.untarFile(f, destination, header)
}

Expand Down Expand Up @@ -614,6 +632,7 @@ var (
_ = Extractor(new(Tar))
_ = Matcher(new(Tar))
_ = ExtensionChecker(new(Tar))
_ = FilenameChecker(new(Tar))
)

// DefaultTar is a default instance that is conveniently ready to use.
Expand Down
17 changes: 17 additions & 0 deletions zip.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ func registerDecompressor(zr *zip.Reader) {
})
}

// CheckPath ensures the file extension matches the format.
func (*Zip) CheckPath(to, filename string) error {
to, _ = filepath.Abs(to) //explicit the destination folder to prevent that 'string.HasPrefix' check can be 'bypassed' when no destination folder is supplied in input
dest := filepath.Join(to, filename)
//prevent path traversal attacks
if !strings.HasPrefix(dest, to) {
return fmt.Errorf("illegal file path: %s", filename)
}
return nil
}

// Archive creates a .zip file at destination containing
// the files listed in sources. The destination must end
// with ".zip". File paths can be those of regular files
Expand Down Expand Up @@ -231,6 +242,11 @@ func (z *Zip) extractNext(to string) error {
return err // don't wrap error; calling loop must break on io.EOF
}
defer f.Close()

errPath := z.CheckPath(to, f.Header.(zip.FileHeader).Name)
if errPath != nil {
return fmt.Errorf("checking path traversal attempt: %v", errPath)
}
return z.extractFile(f, to)
}

Expand Down Expand Up @@ -629,6 +645,7 @@ var (
_ = Extractor(new(Zip))
_ = Matcher(new(Zip))
_ = ExtensionChecker(new(Zip))
_ = FilenameChecker(new(Zip))
)

// compressedFormats is a (non-exhaustive) set of lowercased
Expand Down

0 comments on commit 8217ed3

Please sign in to comment.