Skip to content
Merged
96 changes: 74 additions & 22 deletions internal/pkg/agent/application/upgrade/step_unpack.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,43 @@ type UnpackResult struct {
VersionedHome string `json:"versioned-home" yaml:"versioned-home"`
}

type copyFunc func(dst io.Writer, src io.Reader) (written int64, err error)
type mkdirAllFunc func(name string, perm fs.FileMode) error
type openFileFunc func(name string, flag int, perm fs.FileMode) (*os.File, error)
type unarchiveFunc func(log *logger.Logger, archivePath, dataDir string, flavor string, copy copyFunc, mkdirAll mkdirAllFunc, openFile openFileFunc) (UnpackResult, error)

type unpacker struct {
log *logger.Logger
// Abstractsions for testability
unzip unarchiveFunc
untar unarchiveFunc
// stdlib abstractions for testability
copy copyFunc
mkdirAll mkdirAllFunc
openFile openFileFunc
}

func newUnpacker(log *logger.Logger) *unpacker {
return &unpacker{
log: log,
unzip: unzip,
untar: untar,
copy: io.Copy,
mkdirAll: os.MkdirAll,
openFile: os.OpenFile,
}
}

// unpack unpacks archive correctly, skips root (symlink, config...) unpacks data/*
func (u *Upgrader) unpack(version, archivePath, dataDir string, flavor string) (UnpackResult, error) {
func (u *unpacker) unpack(version, archivePath, dataDir string, flavor string) (UnpackResult, error) {
// unpack must occur in directory that holds the installation directory
// or the extraction will be double nested
var unpackRes UnpackResult
var err error
if runtime.GOOS == windows {
unpackRes, err = unzip(u.log, archivePath, dataDir, flavor)
unpackRes, err = u.unzip(u.log, archivePath, dataDir, flavor, u.copy, u.mkdirAll, u.openFile)
} else {
unpackRes, err = untar(u.log, archivePath, dataDir, flavor)
unpackRes, err = u.untar(u.log, archivePath, dataDir, flavor, u.copy, u.mkdirAll, u.openFile)
}

if err != nil {
Expand All @@ -61,7 +88,7 @@ type packageMetadata struct {
hash string
}

func (u *Upgrader) getPackageMetadata(archivePath string) (packageMetadata, error) {
func (u *unpacker) getPackageMetadata(archivePath string) (packageMetadata, error) {
ext := filepath.Ext(archivePath)
if ext == ".gz" {
// if we got gzip extension we need another extension before last
Expand All @@ -78,7 +105,8 @@ func (u *Upgrader) getPackageMetadata(archivePath string) (packageMetadata, erro
}
}

func unzip(log *logger.Logger, archivePath, dataDir string, flavor string) (UnpackResult, error) {
// injecting copy, mkdirAll and openFile for testability
func unzip(log *logger.Logger, archivePath, dataDir string, flavor string, copy copyFunc, mkdirAll mkdirAllFunc, openFile openFileFunc) (UnpackResult, error) {
var hash, rootDir string
r, err := zip.OpenReader(archivePath)
if err != nil {
Expand Down Expand Up @@ -148,8 +176,10 @@ func unzip(log *logger.Logger, archivePath, dataDir string, flavor string) (Unpa
// check if the directory already exists
_, err = os.Stat(dstPath)
if errors.Is(err, fs.ErrNotExist) {
// the directory does not exist, create it and any non-existing parent directory with the same permissions
if err := os.MkdirAll(dstPath, f.Mode().Perm()&0770); err != nil {
// the directory does not exist, create it and any non-existing
// parent directory with the same permissions.
// Using mkdirAll instead of os.MkdirAll so that we can mock it in tests.
if err := mkdirAll(dstPath, f.Mode().Perm()&0770); err != nil {
return fmt.Errorf("creating directory %q: %w", dstPath, err)
}
} else if err != nil {
Expand All @@ -162,13 +192,23 @@ func unzip(log *logger.Logger, archivePath, dataDir string, flavor string) (Unpa
}
}

_ = os.MkdirAll(dstPath, f.Mode()&0770)
// Using mkdirAll instead of os.MkdirAll so that we can mock it in tests.
err = mkdirAll(dstPath, f.Mode()&0770)
if err != nil {
return fmt.Errorf("creating directory %q: %w", dstPath, err)
}
} else {
log.Debugw("Unpacking file", "archive", "zip", "file.path", dstPath)
// create non-existing containing folders with 0770 permissions right now, we'll fix the permission of each
// directory as we come across them while processing the other package entries
_ = os.MkdirAll(filepath.Dir(dstPath), 0770)
f, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()&0770)
// directory as we come across them while processing the other
// package entries
// Using mkdirAll instead of os.MkdirAll so that we can mock it in tests.
err = mkdirAll(filepath.Dir(dstPath), 0770)
if err != nil {
return fmt.Errorf("creating directory %q: %w", dstPath, err)
}
// Using openFile instead of os.OpenFile so that we can mock it in tests.
f, err := openFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()&0770)
if err != nil {
return err
}
Expand All @@ -178,7 +218,9 @@ func unzip(log *logger.Logger, archivePath, dataDir string, flavor string) (Unpa
}
}()

if _, err = io.Copy(f, rc); err != nil { //nolint:gosec // legacy
// Using copy instead of io.Copy so that we can
// mock it in tests.
if _, err = copy(f, rc); err != nil {
return err
}
}
Expand Down Expand Up @@ -313,7 +355,8 @@ func getPackageMetadataFromZipReader(r *zip.ReadCloser, fileNamePrefix string) (
return ret, nil
}

func untar(log *logger.Logger, archivePath, dataDir string, flavor string) (UnpackResult, error) {
// injecting copy, mkdirAll and openFile for testability
func untar(log *logger.Logger, archivePath, dataDir string, flavor string, copy copyFunc, mkdirAll mkdirAllFunc, openFile openFileFunc) (UnpackResult, error) {
var versionedHome string
var rootDir string
var hash string
Expand Down Expand Up @@ -413,17 +456,23 @@ func untar(log *logger.Logger, archivePath, dataDir string, flavor string) (Unpa
log.Debugw("Unpacking file", "archive", "tar", "file.path", abs)
// create non-existing containing folders with 0750 permissions right now, we'll fix the permission of each
// directory as we come across them while processing the other package entries
if err = os.MkdirAll(filepath.Dir(abs), 0750); err != nil {
return UnpackResult{}, errors.New(err, "TarInstaller: creating directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))
// Using mkdirAll instead of os.MkdirAll so that we can
// mock it in tests.
if err = mkdirAll(filepath.Dir(abs), 0750); err != nil {
return UnpackResult{}, goerrors.Join(err, errors.New("TarInstaller: creating directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)))
}

// remove any world permissions from the file
wf, err := os.OpenFile(abs, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode.Perm()&0770)
// Using openFile instead of os.OpenFile so that we can
// mock it in tests.
wf, err := openFile(abs, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode.Perm()&0770)
if err != nil {
return UnpackResult{}, errors.New(err, "TarInstaller: creating file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))
return UnpackResult{}, goerrors.Join(err, errors.New("TarInstaller: creating file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)))
}

_, err = io.Copy(wf, tr) //nolint:gosec // legacy
// Using copy instead of io.Copy so that we can
// mock it in tests.
_, err = copy(wf, tr)
if closeErr := wf.Close(); closeErr != nil && err == nil {
err = closeErr
}
Expand All @@ -435,17 +484,20 @@ func untar(log *logger.Logger, archivePath, dataDir string, flavor string) (Unpa
// check if the directory already exists
_, err = os.Stat(abs)
if errors.Is(err, fs.ErrNotExist) {
// the directory does not exist, create it and any non-existing parent directory with the same permissions
if err := os.MkdirAll(abs, mode.Perm()&0770); err != nil {
return UnpackResult{}, errors.New(err, "TarInstaller: creating directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))
// the directory does not exist, create it and any non-existing
// parent directory with the same permissions.
// Using mkdirAll instead of os.MkdirAll so that we can
// mock it in tests.
if err := mkdirAll(abs, mode.Perm()&0770); err != nil {
return UnpackResult{}, goerrors.Join(err, errors.New("TarInstaller: creating directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)))
}
} else if err != nil {
return UnpackResult{}, errors.New(err, "TarInstaller: stat() directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))
} else {
// directory already exists, set the appropriate permissions
err = os.Chmod(abs, mode.Perm()&0770)
if err != nil {
return UnpackResult{}, errors.New(err, fmt.Sprintf("TarInstaller: setting permissions %O for directory %q", mode.Perm()&0770, abs), errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))
return UnpackResult{}, goerrors.Join(err, errors.New("TarInstaller: setting permissions %O for directory %q", mode.Perm()&0770, abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)))
}
}
default:
Expand Down
Loading
Loading