diff --git a/vacation/symlink_sorting.go b/vacation/symlink_sorting.go new file mode 100644 index 00000000..8c1b61e8 --- /dev/null +++ b/vacation/symlink_sorting.go @@ -0,0 +1,72 @@ +package vacation + +import ( + "fmt" + "path/filepath" + "strings" +) + +type symlink struct { + name string + path string +} + +func sortSymlinks(symlinks []symlink) ([]symlink, error) { + // Create a map of all of the symlink names and where they are pointing to to + // act as a quasi-graph + index := map[string]string{} + for _, s := range symlinks { + index[filepath.Clean(s.path)] = s.name + } + + // Check to see if the link name lies on the path of another symlink in + // the table or if it is another symlink in the table + // + // Example: + // path = dir/file + // a-symlink -> dir + // b-symlink -> a-symlink + // c-symlink -> a-symlink/file + shouldSkipLink := func(linkname, linkpath string) bool { + sln := strings.Split(linkname, "/") + for j := 0; j < len(sln); j++ { + if _, ok := index[linknameFullPath(linkpath, filepath.Join(sln[:j+1]...))]; ok { + return true + } + } + return false + } + + // Iterate over the symlink map for every link that is found this ensures + // that all symlinks that can be created will be created and any that are + // left over are cyclically dependent + var links []symlink + maxIterations := len(index) + for i := 0; i < maxIterations; i++ { + for path, name := range index { + // If there is a match either of the symlink or it is on the path then + // skip the creation of this symlink for now + if shouldSkipLink(name, path) { + continue + } + + links = append(links, symlink{ + name: name, + path: path, + }) + + // Remove the created symlink from the symlink table so that its + // dependent symlinks can be created in the next iteration + delete(index, path) + break + } + } + + // Check to see if there are any symlinks left in the map which would + // indicate a cyclical dependency + if len(index) > 0 { + return nil, fmt.Errorf("failed: max iterations reached: this symlink graph contains a cycle") + } + + return links, nil +} diff --git a/vacation/tar_archive.go b/vacation/tar_archive.go index e0a30ca2..4f3d1956 100644 --- a/vacation/tar_archive.go +++ b/vacation/tar_archive.go @@ -6,7 +6,6 @@ import ( "io" "os" "path/filepath" - "sort" "strings" ) @@ -31,15 +30,7 @@ func (ta TarArchive) Decompress(destination string) error { // metadata. directories := map[string]interface{}{} - // Struct and slice to collect symlinks and create them after all files have - // been created - type header struct { - name string - linkname string - path string - } - - var symlinkHeaders []header + var symlinks []symlink tarReader := tar.NewReader(ta.reader) for { @@ -118,48 +109,26 @@ func (ta TarArchive) Decompress(destination string) error { case tar.TypeSymlink: // Collect all of the headers for symlinks so that they can be verified // after all other files are written - symlinkHeaders = append(symlinkHeaders, header{ - name: hdr.Name, - linkname: hdr.Linkname, - path: path, + symlinks = append(symlinks, symlink{ + name: hdr.Linkname, + path: path, }) } } - // Sort the symlinks so that symlinks of symlinks have their base link - // created before they are created. - // - // For example: - // b-sym -> a-sym/x - // a-sym -> z - // c-sym -> d-sym - // d-sym -> z - // - // Will sort to: - // a-sym -> z - // b-sym -> a-sym/x - // d-sym -> z - // c-sym -> d-sym - sort.Slice(symlinkHeaders, func(i, j int) bool { - if filepath.Clean(symlinkHeaders[i].name) == linknameFullPath(symlinkHeaders[j].name, symlinkHeaders[j].linkname) { - return true - } - - if filepath.Clean(symlinkHeaders[j].name) == linknameFullPath(symlinkHeaders[i].name, symlinkHeaders[i].linkname) { - return false - } - - return filepath.Clean(symlinkHeaders[i].name) < linknameFullPath(symlinkHeaders[j].name, symlinkHeaders[j].linkname) - }) + symlinks, err := sortSymlinks(symlinks) + if err != nil { + return err + } - for _, h := range symlinkHeaders { + for _, link := range symlinks { // Check to see if the file that will be linked to is valid for symlinking - _, err := filepath.EvalSymlinks(linknameFullPath(h.path, h.linkname)) + _, err := filepath.EvalSymlinks(linknameFullPath(link.path, link.name)) if err != nil { - return fmt.Errorf("failed to evaluate symlink %s: %w", h.path, err) + return fmt.Errorf("failed to evaluate symlink %s: %w", link.path, err) } - err = os.Symlink(h.linkname, h.path) + err = os.Symlink(link.name, link.path) if err != nil { return fmt.Errorf("failed to extract symlink: %s", err) } diff --git a/vacation/tar_archive_test.go b/vacation/tar_archive_test.go index f30149bd..e8c3ab1f 100644 --- a/vacation/tar_archive_test.go +++ b/vacation/tar_archive_test.go @@ -289,5 +289,33 @@ func testTarArchive(t *testing.T, context spec.G, it spec.S) { }) }) }) + + context("when there is a symlink cycle", func() { + var cyclicalSymlinkTar vacation.TarArchive + + it.Before(func() { + var err error + + buffer := bytes.NewBuffer(nil) + tw := tar.NewWriter(buffer) + + Expect(tw.WriteHeader(&tar.Header{Name: "a-symlink", Mode: 0755, Size: int64(0), Typeflag: tar.TypeSymlink, Linkname: "b-symlink"})).To(Succeed()) + _, err = tw.Write([]byte{}) + Expect(err).NotTo(HaveOccurred()) + + Expect(tw.WriteHeader(&tar.Header{Name: "b-symlink", Mode: 0755, Size: int64(0), Typeflag: tar.TypeSymlink, Linkname: "a-symlink"})).To(Succeed()) + _, err = tw.Write([]byte{}) + Expect(err).NotTo(HaveOccurred()) + + Expect(tw.Close()).To(Succeed()) + + cyclicalSymlinkTar = vacation.NewTarArchive(bytes.NewReader(buffer.Bytes())) + }) + + it("returns an error", func() { + err := cyclicalSymlinkTar.Decompress(tempDir) + Expect(err).To(MatchError(ContainSubstring("failed: max iterations reached: this symlink graph contains a cycle"))) + }) + }) }) } diff --git a/vacation/zip_archive.go b/vacation/zip_archive.go index 6c11151a..906523ef 100644 --- a/vacation/zip_archive.go +++ b/vacation/zip_archive.go @@ -6,7 +6,6 @@ import ( "io" "os" "path/filepath" - "sort" "strings" ) @@ -24,15 +23,6 @@ func NewZipArchive(inputReader io.Reader) ZipArchive { // Decompress reads from ZipArchive and writes files into the destination // specified. func (z ZipArchive) Decompress(destination string) error { - // Struct and slice to collect symlinks and create them after all files have - // been created - type header struct { - name string - linkname string - path string - } - - var symlinkHeaders []header // Use an os.File to buffer the zip contents. This is needed because // zip.NewReader requires an io.ReaderAt so that it can jump around within @@ -53,6 +43,7 @@ func (z ZipArchive) Decompress(destination string) error { return fmt.Errorf("failed to create zip reader: %w", err) } + var symlinks []symlink for _, f := range zr.File { // Clean the name in the header to prevent './filename' being stripped to // 'filename' also to skip if the destination it the destination directory @@ -96,10 +87,9 @@ func (z ZipArchive) Decompress(destination string) error { // Collect all of the headers for symlinks so that they can be verified // after all other files are written - symlinkHeaders = append(symlinkHeaders, header{ - name: f.Name, - linkname: string(linkname), - path: path, + symlinks = append(symlinks, symlink{ + name: string(linkname), + path: path, }) default: @@ -133,42 +123,21 @@ func (z ZipArchive) Decompress(destination string) error { } } - // Sort the symlinks so that symlinks of symlinks have their base link - // created before they are created. - // - // For example: - // b-sym -> a-sym/x - // a-sym -> z - // c-sym -> d-sym - // d-sym -> z - // - // Will sort to: - // a-sym -> z - // b-sym -> a-sym/x - // d-sym -> z - // c-sym -> d-sym - sort.Slice(symlinkHeaders, func(i, j int) bool { - if filepath.Clean(symlinkHeaders[i].name) == linknameFullPath(symlinkHeaders[j].name, symlinkHeaders[j].linkname) { - return true - } - - if filepath.Clean(symlinkHeaders[j].name) == linknameFullPath(symlinkHeaders[i].name, symlinkHeaders[i].linkname) { - return false - } - - return filepath.Clean(symlinkHeaders[i].name) < linknameFullPath(symlinkHeaders[j].name, symlinkHeaders[j].linkname) - }) + symlinks, err = sortSymlinks(symlinks) + if err != nil { + return err + } - for _, h := range symlinkHeaders { + for _, link := range symlinks { // Check to see if the file that will be linked to is valid for symlinking - _, err := filepath.EvalSymlinks(linknameFullPath(h.path, h.linkname)) + _, err := filepath.EvalSymlinks(linknameFullPath(link.path, link.name)) if err != nil { - return fmt.Errorf("failed to evaluate symlink %s: %w", h.path, err) + return fmt.Errorf("failed to evaluate symlink %s: %w", link.path, err) } - err = os.Symlink(h.linkname, h.path) + err = os.Symlink(link.name, link.path) if err != nil { - return fmt.Errorf("failed to unzip symlink: %w", err) + return fmt.Errorf("failed to unzip symlink: %s", err) } } diff --git a/vacation/zip_archive_test.go b/vacation/zip_archive_test.go index c6244297..9a3ce570 100644 --- a/vacation/zip_archive_test.go +++ b/vacation/zip_archive_test.go @@ -244,6 +244,33 @@ func testZipArchive(t *testing.T, context spec.G, it spec.S) { }) }) + context("when it fails to unzip a file", func() { + var buffer *bytes.Buffer + it.Before(func() { + var err error + buffer = bytes.NewBuffer(nil) + zw := zip.NewWriter(buffer) + + _, err = zw.Create("some-file") + Expect(err).NotTo(HaveOccurred()) + + Expect(zw.Close()).To(Succeed()) + + Expect(os.Chmod(tempDir, 0000)).To(Succeed()) + }) + + it.After(func() { + Expect(os.Chmod(tempDir, os.ModePerm)).To(Succeed()) + }) + + it("returns an error", func() { + readyArchive := vacation.NewZipArchive(buffer) + + err := readyArchive.Decompress(tempDir) + Expect(err).To(MatchError(ContainSubstring("failed to unzip file"))) + }) + }) + context("when it tries to symlink to a file that does not exist", func() { var buffer *bytes.Buffer it.Before(func() { @@ -305,30 +332,39 @@ func testZipArchive(t *testing.T, context spec.G, it spec.S) { }) }) - context("when it fails to unzip a file", func() { + context("when there is a symlink cycle", func() { var buffer *bytes.Buffer it.Before(func() { var err error buffer = bytes.NewBuffer(nil) zw := zip.NewWriter(buffer) - _, err = zw.Create("some-file") + header := &zip.FileHeader{Name: "a-symlink"} + header.SetMode(0755 | os.ModeSymlink) + + aSymlink, err := zw.CreateHeader(header) Expect(err).NotTo(HaveOccurred()) - Expect(zw.Close()).To(Succeed()) + _, err = aSymlink.Write([]byte(filepath.Join("b-symlink"))) + Expect(err).NotTo(HaveOccurred()) - Expect(os.Chmod(tempDir, 0000)).To(Succeed()) - }) + header = &zip.FileHeader{Name: "b-symlink"} + header.SetMode(0755 | os.ModeSymlink) - it.After(func() { - Expect(os.Chmod(tempDir, os.ModePerm)).To(Succeed()) + bSymlink, err := zw.CreateHeader(header) + Expect(err).NotTo(HaveOccurred()) + + _, err = bSymlink.Write([]byte(filepath.Join("a-symlink"))) + Expect(err).NotTo(HaveOccurred()) + + Expect(zw.Close()).To(Succeed()) }) it("returns an error", func() { readyArchive := vacation.NewZipArchive(buffer) err := readyArchive.Decompress(tempDir) - Expect(err).To(MatchError(ContainSubstring("failed to unzip file"))) + Expect(err).To(MatchError("failed: max iterations reached: this symlink graph contains a cycle")) }) }) })