diff --git a/internal/transport/receiver.go b/internal/transport/receiver.go index cef3a29f..f67ac51f 100644 --- a/internal/transport/receiver.go +++ b/internal/transport/receiver.go @@ -117,7 +117,10 @@ func (c *Client) NewReceive(code string, pathname chan string, progress func(int return err } - err = zip.Extract(util.NewProgressReaderAt(tmp, progress, contents.Max), n, path) + err = zip.ExtractSafe( + util.NewProgressReaderAt(tmp, progress, contents.Max), + n, path, msg.UncompressedBytes64, msg.FileCount, + ) if err != nil { fyne.LogError("Error on unzipping contents", err) return err diff --git a/zip/zip.go b/zip/zip.go index 17073b4d..56a46eaa 100644 --- a/zip/zip.go +++ b/zip/zip.go @@ -12,8 +12,48 @@ import ( "github.com/klauspost/compress/zip" ) -// ErrorDangerousFilename indicates that a dangerous filename was found. -var ErrorDangerousFilename = errors.New("dangerous filename detected") +var ( + // ErrorDangerousFilename indicates that a dangerous filename was found. + ErrorDangerousFilename = errors.New("dangerous filename detected") + + // ErrorSizeMismatch indicates that the uncompressed size was unexpected. + ErrorSizeMismatch = errors.New("mismatch between offered and actual size") + + // ErrorFileCountMismatch indicates that the file count was unexpected. + ErrorFileCountMismatch = errors.New("mismatch between offered and actual file count") +) + +// ExtractSafe works like Extract() but verifies that the uncompressed size and file count are as expected. +// This can only be used if you know the file count and uncompressed size before extracting. +func ExtractSafe(source io.ReaderAt, length int64, target string, uncompressedBytes int64, files int) error { + reader, err := zip.NewReader(source, length) + if err != nil { + fyne.LogError("Could not create zip reader", err) + return err + } + + // Check that the file count is as expected. + if files < len(reader.File) { + return ErrorFileCountMismatch + } + + // Check that the extracted size is as expected. + actualUncompressedSize := uint64(0) + for _, f := range reader.File { + actualUncompressedSize += f.FileHeader.UncompressedSize64 + } + if uncompressedBytes < int64(actualUncompressedSize) { + return ErrorSizeMismatch + } + + for _, file := range reader.File { + if err := extractFile(file, target); err != nil { + return err + } + } + + return nil +} // Extract takes a reader and the length and then extracts it to the target. // The target should be the path to a folder where the extracted files can be placed.