From 43b3496f0001cec231c80af1f9a9b3417d04e8d4 Mon Sep 17 00:00:00 2001 From: Matthew Penner Date: Tue, 7 Mar 2023 15:19:09 -0700 Subject: [PATCH] server(filesystem): fix archiver path matching Closes https://github.com/pterodactyl/panel/issues/4630 --- server/filesystem/archive.go | 19 ++++- server/filesystem/archive_test.go | 125 ++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 3 deletions(-) create mode 100644 server/filesystem/archive_test.go diff --git a/server/filesystem/archive.go b/server/filesystem/archive.go index 968a2800..0e95224c 100644 --- a/server/filesystem/archive.go +++ b/server/filesystem/archive.go @@ -3,6 +3,7 @@ package filesystem import ( "archive/tar" "context" + "fmt" "io" "io/fs" "os" @@ -66,6 +67,8 @@ type Archive struct { // Files specifies the files to archive, this takes priority over the Ignore option, if // unspecified, all files in the BasePath will be archived unless Ignore is set. + // + // All items in Files must be absolute within BasePath. Files []string // Progress wraps the writer of the archive to pass through the progress tracker. @@ -97,6 +100,14 @@ func (a *Archive) Create(ctx context.Context, dst string) error { // Stream . func (a *Archive) Stream(ctx context.Context, w io.Writer) error { + for _, f := range a.Files { + if strings.HasPrefix(f, a.BasePath) { + continue + } + + return fmt.Errorf("archive: all entries in Files must be absolute and within BasePath: %s\n", f) + } + // Choose which compression level to use based on the compression_level configuration option var compressionLevel int switch config.Get().System.Backups.CompressionLevel { @@ -190,9 +201,11 @@ func (a *Archive) callback(tw *TarProgress, opts ...func(path string, relative s func (a *Archive) withFilesCallback(tw *TarProgress) func(path string, de *godirwalk.Dirent) error { return a.callback(tw, func(p string, rp string) error { for _, f := range a.Files { - // If the given doesn't match, or doesn't have the same prefix continue - // to the next item in the loop. - if p != f && !strings.HasPrefix(strings.TrimSuffix(p, "/")+"/", f) { + // Allow exact file matches, otherwise check if file is within a parent directory. + // + // The slashes are added in the prefix checks to prevent partial name matches from being + // included in the archive. + if f != p && !strings.HasPrefix(strings.TrimSuffix(p, "/")+"/", strings.TrimSuffix(f, "/")+"/") { continue } diff --git a/server/filesystem/archive_test.go b/server/filesystem/archive_test.go new file mode 100644 index 00000000..91bbfe72 --- /dev/null +++ b/server/filesystem/archive_test.go @@ -0,0 +1,125 @@ +package filesystem + +import ( + "context" + iofs "io/fs" + "os" + "path/filepath" + "strings" + "testing" + + . "github.com/franela/goblin" + "github.com/mholt/archiver/v4" +) + +func TestArchive_Stream(t *testing.T) { + g := Goblin(t) + fs, rfs := NewFs() + + g.Describe("Archive", func() { + g.AfterEach(func() { + // Reset the filesystem after each run. + rfs.reset() + }) + + g.It("throws an error when passed invalid file paths", func() { + a := &Archive{ + BasePath: fs.Path(), + Files: []string{ + // To use the archiver properly, this needs to be filepath.Join(BasePath, "yeet") + // However, this test tests that we actually validate that behavior. + "yeet", + }, + } + + g.Assert(a.Create(context.Background(), "")).IsNotNil() + }) + + g.It("creates archive with intended files", func() { + g.Assert(fs.CreateDirectory("test", "/")).IsNil() + g.Assert(fs.CreateDirectory("test2", "/")).IsNil() + + err := fs.Writefile("test/file.txt", strings.NewReader("hello, world!\n")) + g.Assert(err).IsNil() + + err = fs.Writefile("test2/file.txt", strings.NewReader("hello, world!\n")) + g.Assert(err).IsNil() + + err = fs.Writefile("test_file.txt", strings.NewReader("hello, world!\n")) + g.Assert(err).IsNil() + + err = fs.Writefile("test_file.txt.old", strings.NewReader("hello, world!\n")) + g.Assert(err).IsNil() + + a := &Archive{ + BasePath: fs.Path(), + Files: []string{ + filepath.Join(fs.Path(), "test"), + filepath.Join(fs.Path(), "test_file.txt"), + }, + } + + // Create the archive + archivePath := filepath.Join(rfs.root, "archive.tar.gz") + g.Assert(a.Create(context.Background(), archivePath)).IsNil() + + // Ensure the archive exists. + _, err = os.Stat(archivePath) + g.Assert(err).IsNil() + + // Open the archive. + genericFs, err := archiver.FileSystem(archivePath) + g.Assert(err).IsNil() + + // Assert that we are opening an archive. + afs, ok := genericFs.(archiver.ArchiveFS) + g.Assert(ok).IsTrue() + + // Get the names of the files recursively from the archive. + files, err := getFiles(afs, ".") + g.Assert(err).IsNil() + + // Ensure the files in the archive match what we are expecting. + g.Assert(files).Equal([]string{ + "test_file.txt", + "test/file.txt", + }) + }) + }) +} + +func getFiles(f iofs.ReadDirFS, name string) ([]string, error) { + var v []string + + entries, err := f.ReadDir(name) + if err != nil { + return nil, err + } + + for _, e := range entries { + entryName := e.Name() + if name != "." { + entryName = filepath.Join(name, entryName) + } + + if e.IsDir() { + files, err := getFiles(f, entryName) + if err != nil { + return nil, err + } + + if files == nil { + return nil, nil + } + + for _, f := range files { + v = append(v, f) + } + continue + } + + v = append(v, entryName) + } + + return v, nil +}