Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the symlink creation #225

Merged
merged 3 commits into from
Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions vacation/symlink_sorting.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package vacation

import (
"fmt"
"path/filepath"
"strings"
)

type symlink struct {
name string
path string
}

func sortSymlinks(symlinks []symlink) ([]symlink, error) {
// Create a map of all of the symlink names and where they are pointing to to
// act as a quasi-graph
index := map[string]string{}
for _, s := range symlinks {
index[filepath.Clean(s.path)] = s.name
}

// Check to see if the link name lies on the path of another symlink in
// the table or if it is another symlink in the table
//
// Example:
// path = dir/file
// a-symlink -> dir
// b-symlink -> a-symlink
// c-symlink -> a-symlink/file
shouldSkipLink := func(linkname, linkpath string) bool {
sln := strings.Split(linkname, "/")
for j := 0; j < len(sln); j++ {
if _, ok := index[linknameFullPath(linkpath, filepath.Join(sln[:j+1]...))]; ok {
return true
}
}
return false
}

// Iterate over the symlink map for every link that is found this ensures
// that all symlinks that can be created will be created and any that are
// left over are cyclically dependent
var links []symlink
maxIterations := len(index)
for i := 0; i < maxIterations; i++ {
for path, name := range index {
// If there is a match either of the symlink or it is on the path then
// skip the creation of this symlink for now
if shouldSkipLink(name, path) {
continue
}

links = append(links, symlink{
name: name,
path: path,
})

// Remove the created symlink from the symlink table so that its
// dependent symlinks can be created in the next iteration
delete(index, path)
break
}
}

// Check to see if there are any symlinks left in the map which would
// indicate a cyclical dependency
if len(index) > 0 {
return nil, fmt.Errorf("failed: max iterations reached: this symlink graph contains a cycle")
}

return links, nil
}
55 changes: 12 additions & 43 deletions vacation/tar_archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"io"
"os"
"path/filepath"
"sort"
"strings"
)

Expand All @@ -31,15 +30,7 @@ func (ta TarArchive) Decompress(destination string) error {
// metadata.
directories := map[string]interface{}{}

// Struct and slice to collect symlinks and create them after all files have
// been created
type header struct {
name string
linkname string
path string
}

var symlinkHeaders []header
var symlinks []symlink

tarReader := tar.NewReader(ta.reader)
for {
Expand Down Expand Up @@ -118,48 +109,26 @@ func (ta TarArchive) Decompress(destination string) error {
case tar.TypeSymlink:
// Collect all of the headers for symlinks so that they can be verified
// after all other files are written
symlinkHeaders = append(symlinkHeaders, header{
name: hdr.Name,
linkname: hdr.Linkname,
path: path,
symlinks = append(symlinks, symlink{
name: hdr.Linkname,
path: path,
})
}
}

// Sort the symlinks so that symlinks of symlinks have their base link
// created before they are created.
//
// For example:
// b-sym -> a-sym/x
// a-sym -> z
// c-sym -> d-sym
// d-sym -> z
//
// Will sort to:
// a-sym -> z
// b-sym -> a-sym/x
// d-sym -> z
// c-sym -> d-sym
sort.Slice(symlinkHeaders, func(i, j int) bool {
if filepath.Clean(symlinkHeaders[i].name) == linknameFullPath(symlinkHeaders[j].name, symlinkHeaders[j].linkname) {
return true
}

if filepath.Clean(symlinkHeaders[j].name) == linknameFullPath(symlinkHeaders[i].name, symlinkHeaders[i].linkname) {
return false
}

return filepath.Clean(symlinkHeaders[i].name) < linknameFullPath(symlinkHeaders[j].name, symlinkHeaders[j].linkname)
})
symlinks, err := sortSymlinks(symlinks)
if err != nil {
return err
}

for _, h := range symlinkHeaders {
for _, link := range symlinks {
// Check to see if the file that will be linked to is valid for symlinking
_, err := filepath.EvalSymlinks(linknameFullPath(h.path, h.linkname))
_, err := filepath.EvalSymlinks(linknameFullPath(link.path, link.name))
if err != nil {
return fmt.Errorf("failed to evaluate symlink %s: %w", h.path, err)
return fmt.Errorf("failed to evaluate symlink %s: %w", link.path, err)
}

err = os.Symlink(h.linkname, h.path)
err = os.Symlink(link.name, link.path)
if err != nil {
return fmt.Errorf("failed to extract symlink: %s", err)
}
Expand Down
28 changes: 28 additions & 0 deletions vacation/tar_archive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,5 +289,33 @@ func testTarArchive(t *testing.T, context spec.G, it spec.S) {
})
})
})

context("when there is a symlink cycle", func() {
var cyclicalSymlinkTar vacation.TarArchive

it.Before(func() {
var err error

buffer := bytes.NewBuffer(nil)
tw := tar.NewWriter(buffer)

Expect(tw.WriteHeader(&tar.Header{Name: "a-symlink", Mode: 0755, Size: int64(0), Typeflag: tar.TypeSymlink, Linkname: "b-symlink"})).To(Succeed())
_, err = tw.Write([]byte{})
Expect(err).NotTo(HaveOccurred())

Expect(tw.WriteHeader(&tar.Header{Name: "b-symlink", Mode: 0755, Size: int64(0), Typeflag: tar.TypeSymlink, Linkname: "a-symlink"})).To(Succeed())
_, err = tw.Write([]byte{})
Expect(err).NotTo(HaveOccurred())

Expect(tw.Close()).To(Succeed())

cyclicalSymlinkTar = vacation.NewTarArchive(bytes.NewReader(buffer.Bytes()))
})

it("returns an error", func() {
err := cyclicalSymlinkTar.Decompress(tempDir)
Expect(err).To(MatchError(ContainSubstring("failed: max iterations reached: this symlink graph contains a cycle")))
})
})
})
}
57 changes: 13 additions & 44 deletions vacation/zip_archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"io"
"os"
"path/filepath"
"sort"
"strings"
)

Expand All @@ -24,15 +23,6 @@ func NewZipArchive(inputReader io.Reader) ZipArchive {
// Decompress reads from ZipArchive and writes files into the destination
// specified.
func (z ZipArchive) Decompress(destination string) error {
// Struct and slice to collect symlinks and create them after all files have
// been created
type header struct {
name string
linkname string
path string
}

var symlinkHeaders []header

// Use an os.File to buffer the zip contents. This is needed because
// zip.NewReader requires an io.ReaderAt so that it can jump around within
Expand All @@ -53,6 +43,7 @@ func (z ZipArchive) Decompress(destination string) error {
return fmt.Errorf("failed to create zip reader: %w", err)
}

var symlinks []symlink
for _, f := range zr.File {
// Clean the name in the header to prevent './filename' being stripped to
// 'filename' also to skip if the destination it the destination directory
Expand Down Expand Up @@ -96,10 +87,9 @@ func (z ZipArchive) Decompress(destination string) error {

// Collect all of the headers for symlinks so that they can be verified
// after all other files are written
symlinkHeaders = append(symlinkHeaders, header{
name: f.Name,
linkname: string(linkname),
path: path,
symlinks = append(symlinks, symlink{
name: string(linkname),
path: path,
})

default:
Expand Down Expand Up @@ -133,42 +123,21 @@ func (z ZipArchive) Decompress(destination string) error {
}
}

// Sort the symlinks so that symlinks of symlinks have their base link
// created before they are created.
//
// For example:
// b-sym -> a-sym/x
// a-sym -> z
// c-sym -> d-sym
// d-sym -> z
//
// Will sort to:
// a-sym -> z
// b-sym -> a-sym/x
// d-sym -> z
// c-sym -> d-sym
sort.Slice(symlinkHeaders, func(i, j int) bool {
if filepath.Clean(symlinkHeaders[i].name) == linknameFullPath(symlinkHeaders[j].name, symlinkHeaders[j].linkname) {
return true
}

if filepath.Clean(symlinkHeaders[j].name) == linknameFullPath(symlinkHeaders[i].name, symlinkHeaders[i].linkname) {
return false
}

return filepath.Clean(symlinkHeaders[i].name) < linknameFullPath(symlinkHeaders[j].name, symlinkHeaders[j].linkname)
})
symlinks, err = sortSymlinks(symlinks)
if err != nil {
return err
}

for _, h := range symlinkHeaders {
for _, link := range symlinks {
// Check to see if the file that will be linked to is valid for symlinking
_, err := filepath.EvalSymlinks(linknameFullPath(h.path, h.linkname))
_, err := filepath.EvalSymlinks(linknameFullPath(link.path, link.name))
if err != nil {
return fmt.Errorf("failed to evaluate symlink %s: %w", h.path, err)
return fmt.Errorf("failed to evaluate symlink %s: %w", link.path, err)
}

err = os.Symlink(h.linkname, h.path)
err = os.Symlink(link.name, link.path)
if err != nil {
return fmt.Errorf("failed to unzip symlink: %w", err)
return fmt.Errorf("failed to unzip symlink: %s", err)
}
}

Expand Down
52 changes: 44 additions & 8 deletions vacation/zip_archive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,33 @@ func testZipArchive(t *testing.T, context spec.G, it spec.S) {
})
})

context("when it fails to unzip a file", func() {
var buffer *bytes.Buffer
it.Before(func() {
var err error
buffer = bytes.NewBuffer(nil)
zw := zip.NewWriter(buffer)

_, err = zw.Create("some-file")
Expect(err).NotTo(HaveOccurred())

Expect(zw.Close()).To(Succeed())

Expect(os.Chmod(tempDir, 0000)).To(Succeed())
})

it.After(func() {
Expect(os.Chmod(tempDir, os.ModePerm)).To(Succeed())
})

it("returns an error", func() {
readyArchive := vacation.NewZipArchive(buffer)

err := readyArchive.Decompress(tempDir)
Expect(err).To(MatchError(ContainSubstring("failed to unzip file")))
})
})

context("when it tries to symlink to a file that does not exist", func() {
var buffer *bytes.Buffer
it.Before(func() {
Expand Down Expand Up @@ -305,30 +332,39 @@ func testZipArchive(t *testing.T, context spec.G, it spec.S) {
})
})

context("when it fails to unzip a file", func() {
context("when there is a symlink cycle", func() {
var buffer *bytes.Buffer
it.Before(func() {
var err error
buffer = bytes.NewBuffer(nil)
zw := zip.NewWriter(buffer)

_, err = zw.Create("some-file")
header := &zip.FileHeader{Name: "a-symlink"}
header.SetMode(0755 | os.ModeSymlink)

aSymlink, err := zw.CreateHeader(header)
Expect(err).NotTo(HaveOccurred())

Expect(zw.Close()).To(Succeed())
_, err = aSymlink.Write([]byte(filepath.Join("b-symlink")))
Expect(err).NotTo(HaveOccurred())

Expect(os.Chmod(tempDir, 0000)).To(Succeed())
})
header = &zip.FileHeader{Name: "b-symlink"}
header.SetMode(0755 | os.ModeSymlink)

it.After(func() {
Expect(os.Chmod(tempDir, os.ModePerm)).To(Succeed())
bSymlink, err := zw.CreateHeader(header)
Expect(err).NotTo(HaveOccurred())

_, err = bSymlink.Write([]byte(filepath.Join("a-symlink")))
Expect(err).NotTo(HaveOccurred())

Expect(zw.Close()).To(Succeed())
})

it("returns an error", func() {
readyArchive := vacation.NewZipArchive(buffer)

err := readyArchive.Decompress(tempDir)
Expect(err).To(MatchError(ContainSubstring("failed to unzip file")))
Expect(err).To(MatchError("failed: max iterations reached: this symlink graph contains a cycle"))
})
})
})
Expand Down