From 5b15548ca1b129eff6360c0393ff744be5719c49 Mon Sep 17 00:00:00 2001 From: Daniel J Walsh Date: Thu, 7 Sep 2017 12:10:02 +0000 Subject: [PATCH] Update packages to match latest code in moby/pkg Had to vendor in a new version of golang.org/x/net to build Also had to make some changes to drivers to handle archive.Reader -> io.Reader archive.Archive -> io.ReadCloser Also update .gitingore to ignore emacs files, containers-storage.* and generated man pages. Also no longer test travis against golang 1.7, cri-o, moby have also done this. Signed-off-by: Daniel J Walsh --- .gitignore | 10 +- .travis.yml | 1 - drivers/aufs/aufs.go | 7 +- drivers/driver.go | 7 +- drivers/fsdiff.go | 12 +- drivers/overlay/overlay.go | 5 +- drivers/proxy.go | 7 +- drivers/vfs/driver.go | 2 +- drivers/windows/windows.go | 4 +- layers.go | 8 +- pkg/archive/README.md | 1 + pkg/archive/archive.go | 485 ++++++++++-------- pkg/archive/archive_linux.go | 25 +- pkg/archive/archive_linux_test.go | 160 ++++++ pkg/archive/archive_test.go | 391 ++++++++++---- pkg/archive/archive_unix.go | 50 +- pkg/archive/archive_unix_test.go | 180 ++++--- pkg/archive/archive_windows.go | 23 +- pkg/archive/archive_windows_test.go | 6 +- pkg/archive/changes.go | 15 +- pkg/archive/changes_linux.go | 13 +- pkg/archive/changes_posix_test.go | 11 +- pkg/archive/changes_test.go | 273 ++++------ pkg/archive/changes_unix.go | 7 +- pkg/archive/changes_windows.go | 6 +- pkg/archive/copy.go | 15 +- pkg/archive/copy_unix_test.go | 114 ++-- pkg/archive/diff.go | 45 +- pkg/archive/diff_test.go | 6 +- pkg/archive/utils_test.go | 2 +- pkg/archive/wrap.go | 6 +- pkg/archive/wrap_test.go | 18 +- pkg/chrootarchive/archive.go | 47 +- pkg/chrootarchive/archive_test.go | 72 ++- pkg/chrootarchive/chroot_linux.go | 39 +- pkg/chrootarchive/chroot_unix.go | 6 +- pkg/chrootarchive/diff.go | 10 +- pkg/chrootarchive/diff_unix.go | 14 +- pkg/chrootarchive/diff_windows.go | 5 +- pkg/devicemapper/devmapper.go | 132 ++--- pkg/devicemapper/devmapper_log.go | 96 +++- pkg/devicemapper/devmapper_wrapper.go | 31 +- .../devmapper_wrapper_deferred_remove.go | 9 +- pkg/devicemapper/devmapper_wrapper_dynamic.go | 6 + .../devmapper_wrapper_no_deferred_remove.go | 4 +- pkg/devicemapper/devmapper_wrapper_static.go | 6 + pkg/devicemapper/ioctl.go | 9 +- pkg/fileutils/fileutils.go | 207 ++++---- pkg/fileutils/fileutils_test.go | 154 +++--- pkg/fsutils/fsutils_linux.go | 88 ++++ pkg/homedir/homedir_linux.go | 23 + pkg/homedir/homedir_others.go | 13 + pkg/homedir/{homedir.go => homedir_unix.go} | 11 +- pkg/homedir/homedir_windows.go | 24 + pkg/idtools/idtools.go | 146 ++++-- pkg/idtools/idtools_unix.go | 154 +++++- pkg/idtools/idtools_unix_test.go | 60 +-- pkg/idtools/idtools_windows.go | 9 +- pkg/idtools/usergroupadd_linux.go | 24 - pkg/idtools/utils_unix.go | 32 ++ pkg/ioutils/buffer_test.go | 78 +++ pkg/ioutils/bytespipe.go | 186 +++++++ pkg/ioutils/bytespipe_test.go | 159 ++++++ pkg/ioutils/fmt.go | 22 - pkg/ioutils/fmt_test.go | 17 - pkg/ioutils/fswriters.go | 80 +++ pkg/ioutils/fswriters_test.go | 99 +++- pkg/ioutils/multireader.go | 226 -------- pkg/ioutils/multireader_test.go | 149 ------ pkg/ioutils/readers.go | 71 +++ pkg/ioutils/readers_test.go | 23 +- pkg/locker/README.md | 65 +++ pkg/locker/locker.go | 112 ++++ pkg/locker/locker_test.go | 124 +++++ pkg/mount/flags_freebsd.go | 1 + pkg/mount/flags_linux.go | 48 +- pkg/mount/flags_unsupported.go | 1 + pkg/mount/mount.go | 38 +- pkg/mount/mount_unix_test.go | 2 +- pkg/mount/mounter_freebsd.go | 5 +- pkg/mount/mounter_linux.go | 52 +- pkg/mount/mounter_linux_test.go | 228 ++++++++ pkg/mount/mounter_solaris.go | 3 +- pkg/mount/mountinfo.go | 14 + pkg/mount/sharedsubtree_linux_test.go | 5 +- pkg/mount/sharedsubtree_solaris.go | 58 +++ pkg/parsers/kernel/kernel_unix.go | 17 +- pkg/parsers/kernel/kernel_windows.go | 21 +- pkg/parsers/kernel/uname_linux.go | 14 +- .../operatingsystem/operatingsystem_linux.go | 77 +++ .../operatingsystem_solaris.go | 37 ++ .../operatingsystem/operatingsystem_unix.go | 25 + .../operatingsystem_unix_test.go | 247 +++++++++ .../operatingsystem_windows.go | 50 ++ pkg/plugins/client.go | 39 +- pkg/plugins/client_test.go | 108 +++- pkg/plugins/discovery.go | 5 +- pkg/plugins/discovery_test.go | 67 ++- pkg/plugins/discovery_unix.go | 5 + pkg/plugins/discovery_unix_test.go | 43 +- pkg/plugins/discovery_windows.go | 8 + pkg/plugins/plugin_test.go | 156 ++++++ pkg/plugins/pluginrpc-gen/README.md | 58 +++ pkg/plugins/pluginrpc-gen/fixtures/foo.go | 6 - pkg/plugins/pluginrpc-gen/parser_test.go | 2 +- pkg/plugins/plugins.go | 120 +++-- pkg/plugins/plugins_unix.go | 9 + pkg/plugins/plugins_windows.go | 8 + pkg/plugins/transport/http_test.go | 20 + pkg/promise/promise_test.go | 25 + pkg/random/random.go | 71 --- pkg/random/random_test.go | 22 - pkg/reexec/README.md | 5 + pkg/reexec/command_linux.go | 6 +- pkg/reexec/command_unix.go | 4 +- pkg/reexec/command_unsupported.go | 4 +- pkg/reexec/command_windows.go | 2 +- pkg/reexec/reexec_test.go | 53 ++ pkg/stringid/README.md | 1 + pkg/stringid/stringid.go | 60 ++- pkg/stringid/stringid_test.go | 20 +- pkg/stringutils/README.md | 1 + pkg/stringutils/stringutils.go | 24 +- pkg/stringutils/stringutils_test.go | 32 +- pkg/system/chtimes.go | 17 - pkg/system/chtimes_test.go | 2 +- pkg/system/chtimes_unix_test.go | 2 +- pkg/system/chtimes_windows.go | 19 +- pkg/system/chtimes_windows_test.go | 2 +- pkg/system/events_windows.go | 83 --- pkg/system/exitcode.go | 33 ++ pkg/system/filesys.go | 50 +- pkg/system/filesys_windows.go | 226 +++++++- pkg/system/init.go | 22 + pkg/system/init_windows.go | 17 + pkg/system/lcow_unix.go | 8 + pkg/system/lcow_windows.go | 6 + pkg/system/{lstat.go => lstat_unix.go} | 0 pkg/system/lstat_windows.go | 15 +- pkg/system/meminfo_solaris.go | 3 +- pkg/system/meminfo_windows.go | 5 +- pkg/system/mknod.go | 4 +- pkg/system/path.go | 21 + pkg/system/path_unix.go | 5 - pkg/system/path_windows.go | 6 +- pkg/system/process_unix.go | 24 + pkg/system/rm.go | 80 +++ pkg/system/rm_test.go | 84 +++ .../{stat_unsupported.go => stat_darwin.go} | 8 +- pkg/system/stat_freebsd.go | 16 +- pkg/system/stat_linux.go | 20 +- pkg/system/stat_openbsd.go | 6 +- pkg/system/stat_solaris.go | 25 +- pkg/system/{stat.go => stat_unix.go} | 13 +- pkg/system/stat_windows.go | 46 +- pkg/system/syscall_unix.go | 4 +- pkg/system/syscall_windows.go | 35 +- pkg/system/umask.go | 4 +- pkg/system/utimes_darwin.go | 8 - pkg/system/utimes_freebsd.go | 8 +- pkg/system/utimes_linux.go | 13 +- pkg/system/utimes_unix_test.go | 4 +- pkg/system/utimes_unsupported.go | 4 +- pkg/system/xattrs_linux.go | 48 +- pkg/testutil/assert/assert.go | 70 --- store.go | 8 +- vendor.conf | 2 +- vendor/golang.org/x/net/README | 3 + vendor/golang.org/x/net/context/context.go | 156 ++++++ vendor/golang.org/x/net/context/go17.go | 72 +++ vendor/golang.org/x/net/context/pre_go17.go | 300 +++++++++++ vendor/golang.org/x/net/proxy/socks5.go | 57 +- 172 files changed, 6012 insertions(+), 2336 deletions(-) create mode 100644 pkg/archive/README.md create mode 100644 pkg/archive/archive_linux_test.go create mode 100644 pkg/devicemapper/devmapper_wrapper_dynamic.go create mode 100644 pkg/devicemapper/devmapper_wrapper_static.go create mode 100644 pkg/fsutils/fsutils_linux.go create mode 100644 pkg/homedir/homedir_linux.go create mode 100644 pkg/homedir/homedir_others.go rename pkg/homedir/{homedir.go => homedir_unix.go} (77%) create mode 100644 pkg/homedir/homedir_windows.go create mode 100644 pkg/idtools/utils_unix.go create mode 100644 pkg/ioutils/bytespipe.go create mode 100644 pkg/ioutils/bytespipe_test.go delete mode 100644 pkg/ioutils/fmt.go delete mode 100644 pkg/ioutils/fmt_test.go delete mode 100644 pkg/ioutils/multireader.go delete mode 100644 pkg/ioutils/multireader_test.go create mode 100644 pkg/locker/README.md create mode 100644 pkg/locker/locker.go create mode 100644 pkg/locker/locker_test.go create mode 100644 pkg/mount/mounter_linux_test.go create mode 100644 pkg/mount/sharedsubtree_solaris.go create mode 100644 pkg/parsers/operatingsystem/operatingsystem_linux.go create mode 100644 pkg/parsers/operatingsystem/operatingsystem_solaris.go create mode 100644 pkg/parsers/operatingsystem/operatingsystem_unix.go create mode 100644 pkg/parsers/operatingsystem/operatingsystem_unix_test.go create mode 100644 pkg/parsers/operatingsystem/operatingsystem_windows.go create mode 100644 pkg/plugins/discovery_unix.go create mode 100644 pkg/plugins/discovery_windows.go create mode 100644 pkg/plugins/plugin_test.go create mode 100644 pkg/plugins/pluginrpc-gen/README.md create mode 100644 pkg/plugins/plugins_unix.go create mode 100644 pkg/plugins/plugins_windows.go create mode 100644 pkg/plugins/transport/http_test.go create mode 100644 pkg/promise/promise_test.go delete mode 100644 pkg/random/random.go delete mode 100644 pkg/random/random_test.go create mode 100644 pkg/reexec/README.md create mode 100644 pkg/reexec/reexec_test.go create mode 100644 pkg/stringid/README.md create mode 100644 pkg/stringutils/README.md delete mode 100644 pkg/system/events_windows.go create mode 100644 pkg/system/exitcode.go create mode 100644 pkg/system/init.go create mode 100644 pkg/system/init_windows.go create mode 100644 pkg/system/lcow_unix.go create mode 100644 pkg/system/lcow_windows.go rename pkg/system/{lstat.go => lstat_unix.go} (100%) create mode 100644 pkg/system/path.go create mode 100644 pkg/system/process_unix.go create mode 100644 pkg/system/rm.go create mode 100644 pkg/system/rm_test.go rename pkg/system/{stat_unsupported.go => stat_darwin.go} (59%) rename pkg/system/{stat.go => stat_unix.go} (74%) delete mode 100644 pkg/system/utimes_darwin.go delete mode 100644 pkg/testutil/assert/assert.go create mode 100644 vendor/golang.org/x/net/README create mode 100644 vendor/golang.org/x/net/context/context.go create mode 100644 vendor/golang.org/x/net/context/go17.go create mode 100644 vendor/golang.org/x/net/context/pre_go17.go diff --git a/.gitignore b/.gitignore index 99899e9b5a..b424a96efa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,9 @@ -# Docker project generated files to ignore +# containers/storage project generated files to ignore # if you want to ignore files created by your editor/tools, # please consider a global .gitignore https://help.github.com/articles/ignoring-files +*.1 *.exe -*.exe~ +*~ *.orig *.test .*.swp @@ -24,3 +25,8 @@ man/man5 man/man8 vendor/pkg/ .vagrant +containers-storage +containers-storage.darwin.amd64 +containers-storage.linux.amd64 +containers-storage.linux.386 +containers-storage.linux.arm diff --git a/.travis.yml b/.travis.yml index 4be7ab78c4..892344e574 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,6 @@ language: go go: - tip - 1.8 - - 1.7 dist: trusty sudo: required before_install: diff --git a/drivers/aufs/aufs.go b/drivers/aufs/aufs.go index 2e7f1e6599..7301dbdcd6 100644 --- a/drivers/aufs/aufs.go +++ b/drivers/aufs/aufs.go @@ -25,6 +25,7 @@ package aufs import ( "bufio" "fmt" + "io" "io/ioutil" "os" "os/exec" @@ -363,7 +364,7 @@ func (a *Driver) Put(id string) error { // Diff produces an archive of the changes between the specified // layer and its parent layer which may be "". -func (a *Driver) Diff(id, parent string) (archive.Archive, error) { +func (a *Driver) Diff(id, parent string) (io.ReadCloser, error) { // AUFS doesn't need the parent layer to produce a diff. return archive.TarWithOptions(path.Join(a.rootPath(), "diff", id), &archive.TarOptions{ Compression: archive.Uncompressed, @@ -394,7 +395,7 @@ func (a *Driver) DiffGetter(id string) (graphdriver.FileGetCloser, error) { return fileGetNilCloser{storage.NewPathFileGetter(p)}, nil } -func (a *Driver) applyDiff(id string, diff archive.Reader) error { +func (a *Driver) applyDiff(id string, diff io.Reader) error { return chrootarchive.UntarUncompressed(diff, path.Join(a.rootPath(), "diff", id), &archive.TarOptions{ UIDMaps: a.uidMaps, GIDMaps: a.gidMaps, @@ -412,7 +413,7 @@ func (a *Driver) DiffSize(id, parent string) (size int64, err error) { // ApplyDiff extracts the changeset from the given diff into the // layer with the specified id and parent, returning the size of the // new layer in bytes. -func (a *Driver) ApplyDiff(id, parent string, diff archive.Reader) (size int64, err error) { +func (a *Driver) ApplyDiff(id, parent string, diff io.Reader) (size int64, err error) { // AUFS doesn't need the parent id to apply the diff. if err = a.applyDiff(id, diff); err != nil { return diff --git a/drivers/driver.go b/drivers/driver.go index c16fc33e11..cfa3320988 100644 --- a/drivers/driver.go +++ b/drivers/driver.go @@ -2,6 +2,7 @@ package graphdriver import ( "fmt" + "io" "os" "path/filepath" "strings" @@ -83,15 +84,15 @@ type Driver interface { ProtoDriver // Diff produces an archive of the changes between the specified // layer and its parent layer which may be "". - Diff(id, parent string) (archive.Archive, error) + Diff(id, parent string) (io.ReadCloser, error) // Changes produces a list of changes between the specified layer // and its parent layer. If parent is "", then all changes will be ADD changes. Changes(id, parent string) ([]archive.Change, error) // ApplyDiff extracts the changeset from the given diff into the // layer with the specified id and parent, returning the size of the // new layer in bytes. - // The archive.Reader must be an uncompressed stream. - ApplyDiff(id, parent string, diff archive.Reader) (size int64, err error) + // The io.Reader must be an uncompressed stream. + ApplyDiff(id, parent string, diff io.Reader) (size int64, err error) // DiffSize calculates the changes between the specified id // and its parent and returns the size in bytes of the changes // relative to its base filesystem directory. diff --git a/drivers/fsdiff.go b/drivers/fsdiff.go index 693107295e..74e4325707 100644 --- a/drivers/fsdiff.go +++ b/drivers/fsdiff.go @@ -1,14 +1,14 @@ package graphdriver import ( + "io" "time" - "github.com/sirupsen/logrus" - "github.com/containers/storage/pkg/archive" "github.com/containers/storage/pkg/chrootarchive" "github.com/containers/storage/pkg/idtools" "github.com/containers/storage/pkg/ioutils" + "github.com/sirupsen/logrus" ) var ( @@ -31,9 +31,9 @@ type NaiveDiffDriver struct { // NewNaiveDiffDriver returns a fully functional driver that wraps the // given ProtoDriver and adds the capability of the following methods which // it may or may not support on its own: -// Diff(id, parent string) (archive.Archive, error) +// Diff(id, parent string) (io.ReadCloser, error) // Changes(id, parent string) ([]archive.Change, error) -// ApplyDiff(id, parent string, diff archive.Reader) (size int64, err error) +// ApplyDiff(id, parent string, diff io.Reader) (size int64, err error) // DiffSize(id, parent string) (size int64, err error) func NewNaiveDiffDriver(driver ProtoDriver, uidMaps, gidMaps []idtools.IDMap) Driver { gdw := &NaiveDiffDriver{ @@ -46,7 +46,7 @@ func NewNaiveDiffDriver(driver ProtoDriver, uidMaps, gidMaps []idtools.IDMap) Dr // Diff produces an archive of the changes between the specified // layer and its parent layer which may be "". -func (gdw *NaiveDiffDriver) Diff(id, parent string) (arch archive.Archive, err error) { +func (gdw *NaiveDiffDriver) Diff(id, parent string) (arch io.ReadCloser, err error) { layerFs, err := gdw.Get(id, "") if err != nil { return nil, err @@ -118,7 +118,7 @@ func (gdw *NaiveDiffDriver) Changes(id, parent string) ([]archive.Change, error) // ApplyDiff extracts the changeset from the given diff into the // layer with the specified id and parent, returning the size of the // new layer in bytes. -func (gdw *NaiveDiffDriver) ApplyDiff(id, parent string, diff archive.Reader) (size int64, err error) { +func (gdw *NaiveDiffDriver) ApplyDiff(id, parent string, diff io.Reader) (size int64, err error) { // Mount the root filesystem so we can apply the diff/layer. layerFs, err := gdw.Get(id, "") if err != nil { diff --git a/drivers/overlay/overlay.go b/drivers/overlay/overlay.go index 8c58c59088..f5e302d758 100644 --- a/drivers/overlay/overlay.go +++ b/drivers/overlay/overlay.go @@ -5,6 +5,7 @@ package overlay import ( "bufio" "fmt" + "io" "io/ioutil" "os" "os/exec" @@ -528,7 +529,7 @@ func (d *Driver) Exists(id string) bool { } // ApplyDiff applies the new layer into a root -func (d *Driver) ApplyDiff(id string, parent string, diff archive.Reader) (size int64, err error) { +func (d *Driver) ApplyDiff(id string, parent string, diff io.Reader) (size int64, err error) { applyDir := d.getDiffPath(id) logrus.Debugf("Applying tar in %s", applyDir) @@ -559,7 +560,7 @@ func (d *Driver) DiffSize(id, parent string) (size int64, err error) { // Diff produces an archive of the changes between the specified // layer and its parent layer which may be "". -func (d *Driver) Diff(id, parent string) (archive.Archive, error) { +func (d *Driver) Diff(id, parent string) (io.ReadCloser, error) { diffPath := d.getDiffPath(id) logrus.Debugf("Tar with options on %s", diffPath) return archive.TarWithOptions(diffPath, &archive.TarOptions{ diff --git a/drivers/proxy.go b/drivers/proxy.go index d56b8731f1..e4f742fb25 100644 --- a/drivers/proxy.go +++ b/drivers/proxy.go @@ -4,6 +4,7 @@ package graphdriver import ( "fmt" + "io" "github.com/containers/storage/pkg/archive" "github.com/pkg/errors" @@ -170,7 +171,7 @@ func (d *graphDriverProxy) Cleanup() error { return nil } -func (d *graphDriverProxy) Diff(id, parent string) (archive.Archive, error) { +func (d *graphDriverProxy) Diff(id, parent string) (io.ReadCloser, error) { args := &graphDriverRequest{ ID: id, Parent: parent, @@ -179,7 +180,7 @@ func (d *graphDriverProxy) Diff(id, parent string) (archive.Archive, error) { if err != nil { return nil, err } - return archive.Archive(body), nil + return io.ReadClose(body), nil } func (d *graphDriverProxy) Changes(id, parent string) ([]archive.Change, error) { @@ -198,7 +199,7 @@ func (d *graphDriverProxy) Changes(id, parent string) ([]archive.Change, error) return ret.Changes, nil } -func (d *graphDriverProxy) ApplyDiff(id, parent string, diff archive.Reader) (int64, error) { +func (d *graphDriverProxy) ApplyDiff(id, parent string, diff io.Reader) (int64, error) { var ret graphDriverResponse if err := d.client.SendFile(fmt.Sprintf("GraphDriver.ApplyDiff?id=%s&parent=%s", id, parent), diff, &ret); err != nil { return -1, err diff --git a/drivers/vfs/driver.go b/drivers/vfs/driver.go index ff7a88f1a2..a5aefd9280 100644 --- a/drivers/vfs/driver.go +++ b/drivers/vfs/driver.go @@ -14,7 +14,7 @@ import ( var ( // CopyWithTar defines the copy method to use. - CopyWithTar = chrootarchive.CopyWithTar + CopyWithTar = chrootarchive.NewArchiver(nil).CopyWithTar ) func init() { diff --git a/drivers/windows/windows.go b/drivers/windows/windows.go index 7ab36513c1..a502e963de 100644 --- a/drivers/windows/windows.go +++ b/drivers/windows/windows.go @@ -300,7 +300,7 @@ func (d *Driver) Cleanup() error { // Diff produces an archive of the changes between the specified // layer and its parent layer which may be "". // The layer should be mounted when calling this function -func (d *Driver) Diff(id, parent string) (_ archive.Archive, err error) { +func (d *Driver) Diff(id, parent string) (_ io.ReadCloser, err error) { rID, err := d.resolveID(id) if err != nil { return @@ -483,7 +483,7 @@ func writeTarFromLayer(r hcsshim.LayerReader, w io.Writer) error { } // exportLayer generates an archive from a layer based on the given ID. -func (d *Driver) exportLayer(id string, parentLayerPaths []string) (archive.Archive, error) { +func (d *Driver) exportLayer(id string, parentLayerPaths []string) (io.ReadCloser, error) { archive, w := io.Pipe() go func() { err := winio.RunWithPrivilege(winio.SeBackupPrivilege, func() error { diff --git a/layers.go b/layers.go index 774aa65495..90481ec761 100644 --- a/layers.go +++ b/layers.go @@ -184,7 +184,7 @@ type LayerStore interface { CreateWithFlags(id, parent string, names []string, mountLabel string, options map[string]string, writeable bool, flags map[string]interface{}) (layer *Layer, err error) // Put combines the functions of CreateWithFlags and ApplyDiff. - Put(id, parent string, names []string, mountLabel string, options map[string]string, writeable bool, flags map[string]interface{}, diff archive.Reader) (*Layer, int64, error) + Put(id, parent string, names []string, mountLabel string, options map[string]string, writeable bool, flags map[string]interface{}, diff io.Reader) (*Layer, int64, error) // SetNames replaces the list of names associated with a layer with the // supplied values. @@ -206,7 +206,7 @@ type LayerStore interface { // ApplyDiff reads a tarstream which was created by a previous call to Diff and // applies its changes to a specified layer. - ApplyDiff(to string, diff archive.Reader) (int64, error) + ApplyDiff(to string, diff io.Reader) (int64, error) } type layerStore struct { @@ -463,7 +463,7 @@ func (r *layerStore) Status() ([][2]string, error) { return r.driver.Status(), nil } -func (r *layerStore) Put(id, parent string, names []string, mountLabel string, options map[string]string, writeable bool, flags map[string]interface{}, diff archive.Reader) (layer *Layer, size int64, err error) { +func (r *layerStore) Put(id, parent string, names []string, mountLabel string, options map[string]string, writeable bool, flags map[string]interface{}, diff io.Reader) (layer *Layer, size int64, err error) { if !r.IsReadWrite() { return nil, -1, errors.Wrapf(ErrStoreIsReadOnly, "not allowed to create new layers at %q", r.layerspath()) } @@ -900,7 +900,7 @@ func (r *layerStore) DiffSize(from, to string) (size int64, err error) { return r.driver.DiffSize(to, from) } -func (r *layerStore) ApplyDiff(to string, diff archive.Reader) (size int64, err error) { +func (r *layerStore) ApplyDiff(to string, diff io.Reader) (size int64, err error) { if !r.IsReadWrite() { return -1, errors.Wrapf(ErrStoreIsReadOnly, "not allowed to modify layer contents at %q", r.layerspath()) } diff --git a/pkg/archive/README.md b/pkg/archive/README.md new file mode 100644 index 0000000000..7307d9694f --- /dev/null +++ b/pkg/archive/README.md @@ -0,0 +1 @@ +This code provides helper functions for dealing with archive files. diff --git a/pkg/archive/archive.go b/pkg/archive/archive.go index a4071d71b5..c7cd298275 100644 --- a/pkg/archive/archive.go +++ b/pkg/archive/archive.go @@ -6,7 +6,6 @@ import ( "bytes" "compress/bzip2" "compress/gzip" - "errors" "fmt" "io" "io/ioutil" @@ -27,18 +26,11 @@ import ( ) type ( - // Archive is a type of io.ReadCloser which has two interfaces Read and Closer. - Archive io.ReadCloser - // Reader is a type of io.Reader. - Reader io.Reader // Compression is the state represents if compressed or not. Compression int // WhiteoutFormat is the format of whiteouts unpacked WhiteoutFormat int - // TarChownOptions wraps the chown options UID and GID. - TarChownOptions struct { - UID, GID int - } + // TarOptions wraps the tar options. TarOptions struct { IncludeFiles []string @@ -47,7 +39,7 @@ type ( NoLchown bool UIDMaps []idtools.IDMap GIDMaps []idtools.IDMap - ChownOpts *TarChownOptions + ChownOpts *idtools.IDPair IncludeSourceDir bool // WhiteoutFormat is the expected on disk format for whiteout files. // This format will be converted to the standard format on pack @@ -59,34 +51,28 @@ type ( // For each include when creating an archive, the included name will be // replaced with the matching name from this map. RebaseNames map[string]string + InUserNS bool } - - // Archiver allows the reuse of most utility functions of this package - // with a pluggable Untar function. Also, to facilitate the passing of - // specific id mappings for untar, an archiver can be created with maps - // which will then be passed to Untar operations - Archiver struct { - Untar func(io.Reader, string, *TarOptions) error - UIDMaps []idtools.IDMap - GIDMaps []idtools.IDMap - } - - // breakoutError is used to differentiate errors related to breaking out - // When testing archive breakout in the unit tests, this error is expected - // in order for the test to pass. - breakoutError error ) -var ( - // ErrNotImplemented is the error message of function not implemented. - ErrNotImplemented = errors.New("Function not implemented") - defaultArchiver = &Archiver{Untar: Untar, UIDMaps: nil, GIDMaps: nil} -) +// Archiver allows the reuse of most utility functions of this package +// with a pluggable Untar function. Also, to facilitate the passing of +// specific id mappings for untar, an archiver can be created with maps +// which will then be passed to Untar operations +type Archiver struct { + Untar func(io.Reader, string, *TarOptions) error + IDMappings *idtools.IDMappings +} -const ( - // HeaderSize is the size in bytes of a tar header - HeaderSize = 512 -) +// NewDefaultArchiver returns a new Archiver without any IDMappings +func NewDefaultArchiver() *Archiver { + return &Archiver{Untar: Untar, IDMappings: &idtools.IDMappings{}} +} + +// breakoutError is used to differentiate errors related to breaking out +// When testing archive breakout in the unit tests, this error is expected +// in order for the test to pass. +type breakoutError error const ( // Uncompressed represents the uncompressed. @@ -107,17 +93,15 @@ const ( OverlayWhiteoutFormat ) -// IsArchive checks for the magic bytes of a tar or any supported compression -// algorithm. -func IsArchive(header []byte) bool { - compression := DetectCompression(header) - if compression != Uncompressed { - return true - } - r := tar.NewReader(bytes.NewBuffer(header)) - _, err := r.Next() - return err == nil -} +const ( + modeISDIR = 040000 // Directory + modeISFIFO = 010000 // FIFO + modeISREG = 0100000 // Regular file + modeISLNK = 0120000 // Symbolic link + modeISBLK = 060000 // Block special file + modeISCHR = 020000 // Character special file + modeISSOCK = 0140000 // Socket +) // IsArchivePath checks if the (possibly compressed) file at the given path // starts with a tar file header. @@ -147,7 +131,7 @@ func DetectCompression(source []byte) Compression { logrus.Debug("Len too short") continue } - if bytes.Compare(m, source[:len(m)]) == 0 { + if bytes.Equal(m, source[:len(m)]) { return compression } } @@ -206,7 +190,7 @@ func DecompressStream(archive io.Reader) (io.ReadCloser, error) { } } -// CompressStream compresseses the dest with specified compression algorithm. +// CompressStream compresses the dest with specified compression algorithm. func CompressStream(dest io.Writer, compression Compression) (io.WriteCloser, error) { p := pools.BufioWriter32KPool buf := p.Get(dest) @@ -220,13 +204,100 @@ func CompressStream(dest io.Writer, compression Compression) (io.WriteCloser, er return writeBufWrapper, nil case Bzip2, Xz: // archive/bzip2 does not support writing, and there is no xz support at all - // However, this is not a problem as we only currently generates gzipped tars + // However, this is not a problem as docker only currently generates gzipped tars return nil, fmt.Errorf("Unsupported compression format %s", (&compression).Extension()) default: return nil, fmt.Errorf("Unsupported compression format %s", (&compression).Extension()) } } +// TarModifierFunc is a function that can be passed to ReplaceFileTarWrapper to +// modify the contents or header of an entry in the archive. If the file already +// exists in the archive the TarModifierFunc will be called with the Header and +// a reader which will return the files content. If the file does not exist both +// header and content will be nil. +type TarModifierFunc func(path string, header *tar.Header, content io.Reader) (*tar.Header, []byte, error) + +// ReplaceFileTarWrapper converts inputTarStream to a new tar stream. Files in the +// tar stream are modified if they match any of the keys in mods. +func ReplaceFileTarWrapper(inputTarStream io.ReadCloser, mods map[string]TarModifierFunc) io.ReadCloser { + pipeReader, pipeWriter := io.Pipe() + + go func() { + tarReader := tar.NewReader(inputTarStream) + tarWriter := tar.NewWriter(pipeWriter) + defer inputTarStream.Close() + defer tarWriter.Close() + + modify := func(name string, original *tar.Header, modifier TarModifierFunc, tarReader io.Reader) error { + header, data, err := modifier(name, original, tarReader) + switch { + case err != nil: + return err + case header == nil: + return nil + } + + header.Name = name + header.Size = int64(len(data)) + if err := tarWriter.WriteHeader(header); err != nil { + return err + } + if len(data) != 0 { + if _, err := tarWriter.Write(data); err != nil { + return err + } + } + return nil + } + + var err error + var originalHeader *tar.Header + for { + originalHeader, err = tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + pipeWriter.CloseWithError(err) + return + } + + modifier, ok := mods[originalHeader.Name] + if !ok { + // No modifiers for this file, copy the header and data + if err := tarWriter.WriteHeader(originalHeader); err != nil { + pipeWriter.CloseWithError(err) + return + } + if _, err := pools.Copy(tarWriter, tarReader); err != nil { + pipeWriter.CloseWithError(err) + return + } + continue + } + delete(mods, originalHeader.Name) + + if err := modify(originalHeader.Name, originalHeader, modifier, tarReader); err != nil { + pipeWriter.CloseWithError(err) + return + } + } + + // Apply the modifiers that haven't matched any files in the archive + for name, modifier := range mods { + if err := modify(name, nil, modifier, nil); err != nil { + pipeWriter.CloseWithError(err) + return + } + } + + pipeWriter.Close() + + }() + return pipeReader +} + // Extension returns the extension of a file that uses the specified compression algorithm. func (compression *Compression) Extension() string { switch *compression { @@ -242,8 +313,65 @@ func (compression *Compression) Extension() string { return "" } +// FileInfoHeader creates a populated Header from fi. +// Compared to archive pkg this function fills in more information. +// Also, regardless of Go version, this function fills file type bits (e.g. hdr.Mode |= modeISDIR), +// which have been deleted since Go 1.9 archive/tar. +func FileInfoHeader(name string, fi os.FileInfo, link string) (*tar.Header, error) { + hdr, err := tar.FileInfoHeader(fi, link) + if err != nil { + return nil, err + } + hdr.Mode = fillGo18FileTypeBits(int64(chmodTarEntry(os.FileMode(hdr.Mode))), fi) + name, err = canonicalTarName(name, fi.IsDir()) + if err != nil { + return nil, fmt.Errorf("tar: cannot canonicalize path: %v", err) + } + hdr.Name = name + if err := setHeaderForSpecialDevice(hdr, name, fi.Sys()); err != nil { + return nil, err + } + return hdr, nil +} + +// fillGo18FileTypeBits fills type bits which have been removed on Go 1.9 archive/tar +// https://github.com/golang/go/commit/66b5a2f +func fillGo18FileTypeBits(mode int64, fi os.FileInfo) int64 { + fm := fi.Mode() + switch { + case fm.IsRegular(): + mode |= modeISREG + case fi.IsDir(): + mode |= modeISDIR + case fm&os.ModeSymlink != 0: + mode |= modeISLNK + case fm&os.ModeDevice != 0: + if fm&os.ModeCharDevice != 0 { + mode |= modeISCHR + } else { + mode |= modeISBLK + } + case fm&os.ModeNamedPipe != 0: + mode |= modeISFIFO + case fm&os.ModeSocket != 0: + mode |= modeISSOCK + } + return mode +} + +// ReadSecurityXattrToTarHeader reads security.capability xattr from filesystem +// to a tar header +func ReadSecurityXattrToTarHeader(path string, hdr *tar.Header) error { + capability, _ := system.Lgetxattr(path, "security.capability") + if capability != nil { + hdr.Xattrs = make(map[string]string) + hdr.Xattrs["security.capability"] = string(capability) + } + return nil +} + type tarWhiteoutConverter interface { - ConvertWrite(*tar.Header, string, os.FileInfo) error + ConvertWrite(*tar.Header, string, os.FileInfo) (*tar.Header, error) ConvertRead(*tar.Header, string) (bool, error) } @@ -252,9 +380,9 @@ type tarAppender struct { Buffer *bufio.Writer // for hardlink mapping - SeenFiles map[uint64]string - UIDMaps []idtools.IDMap - GIDMaps []idtools.IDMap + SeenFiles map[uint64]string + IDMappings *idtools.IDMappings + ChownOpts *idtools.IDPair // For packing and unpacking whiteout files in the // non standard format. The whiteout files defined @@ -263,6 +391,16 @@ type tarAppender struct { WhiteoutConverter tarWhiteoutConverter } +func newTarAppender(idMapping *idtools.IDMappings, writer io.Writer, chownOpts *idtools.IDPair) *tarAppender { + return &tarAppender{ + SeenFiles: make(map[uint64]string), + TarWriter: tar.NewWriter(writer), + Buffer: pools.BufioWriter32KPool.Get(nil), + IDMappings: idMapping, + ChownOpts: chownOpts, + } +} + // canonicalTarName provides a platform-independent and consistent posix-style //path for files and directories to be archived regardless of the platform. func canonicalTarName(name string, isDir bool) (string, error) { @@ -285,33 +423,30 @@ func (ta *tarAppender) addTarFile(path, name string) error { return err } - link := "" + var link string if fi.Mode()&os.ModeSymlink != 0 { - if link, err = os.Readlink(path); err != nil { + var err error + link, err = os.Readlink(path) + if err != nil { return err } } - hdr, err := tar.FileInfoHeader(fi, link) + hdr, err := FileInfoHeader(name, fi, link) if err != nil { return err } - hdr.Mode = int64(chmodTarEntry(os.FileMode(hdr.Mode))) - - name, err = canonicalTarName(name, fi.IsDir()) - if err != nil { - return fmt.Errorf("tar: cannot canonicalize path: %v", err) - } - hdr.Name = name - - inode, err := setHeaderForSpecialDevice(hdr, ta, name, fi.Sys()) - if err != nil { + if err := ReadSecurityXattrToTarHeader(path, hdr); err != nil { return err } // if it's not a directory and has more than 1 link, - // it's hardlinked, so set the type flag accordingly + // it's hard linked, so set the type flag accordingly if !fi.IsDir() && hasHardlinks(fi) { + inode, err := getInodeFromStat(fi.Sys()) + if err != nil { + return err + } // a link should have a name that it links too // and that linked name should be first in the tar archive if oldpath, ok := ta.SeenFiles[inode]; ok { @@ -323,36 +458,46 @@ func (ta *tarAppender) addTarFile(path, name string) error { } } - capability, _ := system.Lgetxattr(path, "security.capability") - if capability != nil { - hdr.Xattrs = make(map[string]string) - hdr.Xattrs["security.capability"] = string(capability) - } - //handle re-mapping container ID mappings back to host ID mappings before //writing tar headers/files. We skip whiteout files because they were written //by the kernel and already have proper ownership relative to the host - if !strings.HasPrefix(filepath.Base(hdr.Name), WhiteoutPrefix) && (ta.UIDMaps != nil || ta.GIDMaps != nil) { - uid, gid, err := getFileUIDGID(fi.Sys()) - if err != nil { - return err - } - xUID, err := idtools.ToContainer(uid, ta.UIDMaps) + if !strings.HasPrefix(filepath.Base(hdr.Name), WhiteoutPrefix) && !ta.IDMappings.Empty() { + fileIDPair, err := getFileUIDGID(fi.Sys()) if err != nil { return err } - xGID, err := idtools.ToContainer(gid, ta.GIDMaps) + hdr.Uid, hdr.Gid, err = ta.IDMappings.ToContainer(fileIDPair) if err != nil { return err } - hdr.Uid = xUID - hdr.Gid = xGID + } + + // explicitly override with ChownOpts + if ta.ChownOpts != nil { + hdr.Uid = ta.ChownOpts.UID + hdr.Gid = ta.ChownOpts.GID } if ta.WhiteoutConverter != nil { - if err := ta.WhiteoutConverter.ConvertWrite(hdr, path, fi); err != nil { + wo, err := ta.WhiteoutConverter.ConvertWrite(hdr, path, fi) + if err != nil { return err } + + // If a new whiteout file exists, write original hdr, then + // replace hdr with wo to be written after. Whiteouts should + // always be written after the original. Note the original + // hdr may have been updated to be a whiteout with returning + // a whiteout header + if wo != nil { + if err := ta.TarWriter.WriteHeader(hdr); err != nil { + return err + } + if hdr.Typeflag == tar.TypeReg && hdr.Size > 0 { + return fmt.Errorf("tar: cannot use whiteout for non-empty file") + } + hdr = wo + } } if err := ta.TarWriter.WriteHeader(hdr); err != nil { @@ -360,7 +505,10 @@ func (ta *tarAppender) addTarFile(path, name string) error { } if hdr.Typeflag == tar.TypeReg && hdr.Size > 0 { - file, err := os.Open(path) + // We use system.OpenSequential to ensure we use sequential file + // access on Windows to avoid depleting the standby list. + // On Linux, this equates to a regular os.Open. + file, err := system.OpenSequential(path) if err != nil { return err } @@ -381,7 +529,7 @@ func (ta *tarAppender) addTarFile(path, name string) error { return nil } -func createTarFile(path, extractDir string, hdr *tar.Header, reader io.Reader, Lchown bool, chownOpts *TarChownOptions) error { +func createTarFile(path, extractDir string, hdr *tar.Header, reader io.Reader, Lchown bool, chownOpts *idtools.IDPair, inUserns bool) error { // hdr.Mode is in linux format, which we can use for sycalls, // but for os.Foo() calls we need the mode converted to os.FileMode, // so use hdrInfo.Mode() (they differ for e.g. setuid bits) @@ -398,8 +546,10 @@ func createTarFile(path, extractDir string, hdr *tar.Header, reader io.Reader, L } case tar.TypeReg, tar.TypeRegA: - // Source is regular file - file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, hdrInfo.Mode()) + // Source is regular file. We use system.OpenFileSequential to use sequential + // file access to avoid depleting the standby list on Windows. + // On Linux, this equates to a regular os.OpenFile + file, err := system.OpenFileSequential(path, os.O_CREATE|os.O_WRONLY, hdrInfo.Mode()) if err != nil { return err } @@ -409,7 +559,16 @@ func createTarFile(path, extractDir string, hdr *tar.Header, reader io.Reader, L } file.Close() - case tar.TypeBlock, tar.TypeChar, tar.TypeFifo: + case tar.TypeBlock, tar.TypeChar: + if inUserns { // cannot create devices in a userns + return nil + } + // Handle this is an OS-specific way + if err := handleTarTypeBlockCharFifo(hdr, path); err != nil { + return err + } + + case tar.TypeFifo: // Handle this is an OS-specific way if err := handleTarTypeBlockCharFifo(hdr, path); err != nil { return err @@ -444,13 +603,13 @@ func createTarFile(path, extractDir string, hdr *tar.Header, reader io.Reader, L return nil default: - return fmt.Errorf("Unhandled tar header type %d\n", hdr.Typeflag) + return fmt.Errorf("unhandled tar header type %d", hdr.Typeflag) } // Lchown is not supported on Windows. if Lchown && runtime.GOOS != "windows" { if chownOpts == nil { - chownOpts = &TarChownOptions{UID: hdr.Uid, GID: hdr.Gid} + chownOpts = &idtools.IDPair{UID: hdr.Uid, GID: hdr.Gid} } if err := os.Lchown(path, chownOpts.UID, chownOpts.GID); err != nil { return err @@ -525,8 +684,7 @@ func TarWithOptions(srcPath string, options *TarOptions) (io.ReadCloser, error) // on platforms other than Windows. srcPath = fixVolumePathPrefix(srcPath) - patterns, patDirs, exceptions, err := fileutils.CleanPatterns(options.ExcludePatterns) - + pm, err := fileutils.NewPatternMatcher(options.ExcludePatterns) if err != nil { return nil, err } @@ -539,14 +697,12 @@ func TarWithOptions(srcPath string, options *TarOptions) (io.ReadCloser, error) } go func() { - ta := &tarAppender{ - TarWriter: tar.NewWriter(compressWriter), - Buffer: pools.BufioWriter32KPool.Get(nil), - SeenFiles: make(map[uint64]string), - UIDMaps: options.UIDMaps, - GIDMaps: options.GIDMaps, - WhiteoutConverter: getWhiteoutConverter(options.WhiteoutFormat), - } + ta := newTarAppender( + idtools.NewIDMappingsFromMaps(options.UIDMaps, options.GIDMaps), + compressWriter, + options.ChownOpts, + ) + ta.WhiteoutConverter = getWhiteoutConverter(options.WhiteoutFormat) defer func() { // Make sure to check the error on Close. @@ -623,7 +779,7 @@ func TarWithOptions(srcPath string, options *TarOptions) (io.ReadCloser, error) // is asking for that file no matter what - which is true // for some files, like .dockerignore and Dockerfile (sometimes) if include != relFilePath { - skip, err = fileutils.OptimizedMatches(relFilePath, patterns, patDirs) + skip, err = pm.Matches(relFilePath) if err != nil { logrus.Errorf("Error matching %s: %v", relFilePath, err) return err @@ -633,7 +789,7 @@ func TarWithOptions(srcPath string, options *TarOptions) (io.ReadCloser, error) if skip { // If we want to skip this file and its a directory // then we should first check to see if there's an - // excludes pattern (eg !dir/file) that starts with this + // excludes pattern (e.g. !dir/file) that starts with this // dir. If so then we can't skip this dir. // Its not a dir then so we can just return/skip. @@ -642,18 +798,17 @@ func TarWithOptions(srcPath string, options *TarOptions) (io.ReadCloser, error) } // No exceptions (!...) in patterns so just skip dir - if !exceptions { + if !pm.Exclusions() { return filepath.SkipDir } dirSlash := relFilePath + string(filepath.Separator) - for _, pat := range patterns { - if pat[0] != '!' { + for _, pat := range pm.Patterns() { + if !pat.Exclusion() { continue } - pat = pat[1:] + string(filepath.Separator) - if strings.HasPrefix(pat, dirSlash) { + if strings.HasPrefix(pat.String()+string(filepath.Separator), dirSlash) { // found a match - so can't skip this dir return nil } @@ -703,10 +858,8 @@ func Unpack(decompressedArchive io.Reader, dest string, options *TarOptions) err defer pools.BufioReader32KPool.Put(trBuf) var dirs []*tar.Header - remappedRootUID, remappedRootGID, err := idtools.GetRootUIDGID(options.UIDMaps, options.GIDMaps) - if err != nil { - return err - } + idMappings := idtools.NewIDMappingsFromMaps(options.UIDMaps, options.GIDMaps) + rootIDs := idMappings.RootPair() whiteoutConverter := getWhiteoutConverter(options.WhiteoutFormat) // Iterate through the files in the archive. @@ -740,7 +893,7 @@ loop: parent := filepath.Dir(hdr.Name) parentPath := filepath.Join(dest, parent) if _, err := os.Lstat(parentPath); err != nil && os.IsNotExist(err) { - err = idtools.MkdirAllNewAs(parentPath, 0777, remappedRootUID, remappedRootGID) + err = idtools.MkdirAllAndChownNew(parentPath, 0777, rootIDs) if err != nil { return err } @@ -785,26 +938,8 @@ loop: } trBuf.Reset(tr) - // if the options contain a uid & gid maps, convert header uid/gid - // entries using the maps such that lchown sets the proper mapped - // uid/gid after writing the file. We only perform this mapping if - // the file isn't already owned by the remapped root UID or GID, as - // that specific uid/gid has no mapping from container -> host, and - // those files already have the proper ownership for inside the - // container. - if hdr.Uid != remappedRootUID { - xUID, err := idtools.ToHost(hdr.Uid, options.UIDMaps) - if err != nil { - return err - } - hdr.Uid = xUID - } - if hdr.Gid != remappedRootGID { - xGID, err := idtools.ToHost(hdr.Gid, options.GIDMaps) - if err != nil { - return err - } - hdr.Gid = xGID + if err := remapIDs(idMappings, hdr); err != nil { + return err } if whiteoutConverter != nil { @@ -817,7 +952,7 @@ loop: } } - if err := createTarFile(path, dest, hdr, trBuf, !options.NoLchown, options.ChownOpts); err != nil { + if err := createTarFile(path, dest, hdr, trBuf, !options.NoLchown, options.ChownOpts, options.InUserNS); err != nil { return err } @@ -889,23 +1024,13 @@ func (archiver *Archiver) TarUntar(src, dst string) error { return err } defer archive.Close() - - var options *TarOptions - if archiver.UIDMaps != nil || archiver.GIDMaps != nil { - options = &TarOptions{ - UIDMaps: archiver.UIDMaps, - GIDMaps: archiver.GIDMaps, - } + options := &TarOptions{ + UIDMaps: archiver.IDMappings.UIDs(), + GIDMaps: archiver.IDMappings.GIDs(), } return archiver.Untar(archive, dst, options) } -// TarUntar is a convenience function which calls Tar and Untar, with the output of one piped into the other. -// If either Tar or Untar fails, TarUntar aborts and returns the error. -func TarUntar(src, dst string) error { - return defaultArchiver.TarUntar(src, dst) -} - // UntarPath untar a file from path to a destination, src is the source tar file path. func (archiver *Archiver) UntarPath(src, dst string) error { archive, err := os.Open(src) @@ -913,22 +1038,13 @@ func (archiver *Archiver) UntarPath(src, dst string) error { return err } defer archive.Close() - var options *TarOptions - if archiver.UIDMaps != nil || archiver.GIDMaps != nil { - options = &TarOptions{ - UIDMaps: archiver.UIDMaps, - GIDMaps: archiver.GIDMaps, - } + options := &TarOptions{ + UIDMaps: archiver.IDMappings.UIDs(), + GIDMaps: archiver.IDMappings.GIDs(), } return archiver.Untar(archive, dst, options) } -// UntarPath is a convenience function which looks for an archive -// at filesystem path `src`, and unpacks it at `dst`. -func UntarPath(src, dst string) error { - return defaultArchiver.UntarPath(src, dst) -} - // CopyWithTar creates a tar archive of filesystem path `src`, and // unpacks it at filesystem path `dst`. // The archive is streamed directly with fixed buffering and no @@ -945,27 +1061,16 @@ func (archiver *Archiver) CopyWithTar(src, dst string) error { // if this archiver is set up with ID mapping we need to create // the new destination directory with the remapped root UID/GID pair // as owner - rootUID, rootGID, err := idtools.GetRootUIDGID(archiver.UIDMaps, archiver.GIDMaps) - if err != nil { - return err - } + rootIDs := archiver.IDMappings.RootPair() // Create dst, copy src's content into it logrus.Debugf("Creating dest directory: %s", dst) - if err := idtools.MkdirAllNewAs(dst, 0755, rootUID, rootGID); err != nil { + if err := idtools.MkdirAllAndChownNew(dst, 0755, rootIDs); err != nil { return err } logrus.Debugf("Calling TarUntar(%s, %s)", src, dst) return archiver.TarUntar(src, dst) } -// CopyWithTar creates a tar archive of filesystem path `src`, and -// unpacks it at filesystem path `dst`. -// The archive is streamed directly with fixed buffering and no -// intermediary disk IO. -func CopyWithTar(src, dst string) error { - return defaultArchiver.CopyWithTar(src, dst) -} - // CopyFileWithTar emulates the behavior of the 'cp' command-line // for a single file. It copies a regular file from path `src` to // path `dst`, and preserves all its metadata. @@ -986,7 +1091,7 @@ func (archiver *Archiver) CopyFileWithTar(src, dst string) (err error) { dst = filepath.Join(dst, filepath.Base(src)) } // Create the holding directory if necessary - if err := system.MkdirAll(filepath.Dir(dst), 0700); err != nil { + if err := system.MkdirAll(filepath.Dir(dst), 0700, ""); err != nil { return err } @@ -1007,28 +1112,10 @@ func (archiver *Archiver) CopyFileWithTar(src, dst string) (err error) { hdr.Name = filepath.Base(dst) hdr.Mode = int64(chmodTarEntry(os.FileMode(hdr.Mode))) - remappedRootUID, remappedRootGID, err := idtools.GetRootUIDGID(archiver.UIDMaps, archiver.GIDMaps) - if err != nil { + if err := remapIDs(archiver.IDMappings, hdr); err != nil { return err } - // only perform mapping if the file being copied isn't already owned by the - // uid or gid of the remapped root in the container - if remappedRootUID != hdr.Uid { - xUID, err := idtools.ToHost(hdr.Uid, archiver.UIDMaps) - if err != nil { - return err - } - hdr.Uid = xUID - } - if remappedRootGID != hdr.Gid { - xGID, err := idtools.ToHost(hdr.Gid, archiver.GIDMaps) - if err != nil { - return err - } - hdr.Gid = xGID - } - tw := tar.NewWriter(w) defer tw.Close() if err := tw.WriteHeader(hdr); err != nil { @@ -1040,7 +1127,7 @@ func (archiver *Archiver) CopyFileWithTar(src, dst string) (err error) { return nil }) defer func() { - if er := <-errC; err != nil { + if er := <-errC; err == nil && er != nil { err = er } }() @@ -1052,16 +1139,10 @@ func (archiver *Archiver) CopyFileWithTar(src, dst string) (err error) { return err } -// CopyFileWithTar emulates the behavior of the 'cp' command-line -// for a single file. It copies a regular file from path `src` to -// path `dst`, and preserves all its metadata. -// -// Destination handling is in an operating specific manner depending -// where the daemon is running. If `dst` ends with a trailing slash -// the final destination path will be `dst/base(src)` (Linux) or -// `dst\base(src)` (Windows). -func CopyFileWithTar(src, dst string) (err error) { - return defaultArchiver.CopyFileWithTar(src, dst) +func remapIDs(idMappings *idtools.IDMappings, hdr *tar.Header) error { + ids, err := idMappings.ToHost(idtools.IDPair{UID: hdr.Uid, GID: hdr.Gid}) + hdr.Uid, hdr.Gid = ids.UID, ids.GID + return err } // cmdStream executes a command, and returns its stdout as a stream. @@ -1096,7 +1177,7 @@ func cmdStream(cmd *exec.Cmd, input io.Reader) (io.ReadCloser, <-chan struct{}, // NewTempArchive reads the content of src into a temporary file, and returns the contents // of that file as an archive. The archive can only be read once - as soon as reading completes, // the file will be deleted. -func NewTempArchive(src Archive, dir string) (*TempArchive, error) { +func NewTempArchive(src io.Reader, dir string) (*TempArchive, error) { f, err := ioutil.TempFile(dir, "") if err != nil { return nil, err diff --git a/pkg/archive/archive_linux.go b/pkg/archive/archive_linux.go index e944ca2a75..5a14eb91a9 100644 --- a/pkg/archive/archive_linux.go +++ b/pkg/archive/archive_linux.go @@ -5,9 +5,9 @@ import ( "os" "path/filepath" "strings" - "syscall" "github.com/containers/storage/pkg/system" + "golang.org/x/sys/unix" ) func getWhiteoutConverter(format WhiteoutFormat) tarWhiteoutConverter { @@ -19,7 +19,7 @@ func getWhiteoutConverter(format WhiteoutFormat) tarWhiteoutConverter { type overlayWhiteoutConverter struct{} -func (overlayWhiteoutConverter) ConvertWrite(hdr *tar.Header, path string, fi os.FileInfo) error { +func (overlayWhiteoutConverter) ConvertWrite(hdr *tar.Header, path string, fi os.FileInfo) (wo *tar.Header, err error) { // convert whiteouts to AUFS format if fi.Mode()&os.ModeCharDevice != 0 && hdr.Devmajor == 0 && hdr.Devminor == 0 { // we just rename the file and make it normal @@ -34,12 +34,16 @@ func (overlayWhiteoutConverter) ConvertWrite(hdr *tar.Header, path string, fi os // convert opaque dirs to AUFS format by writing an empty file with the prefix opaque, err := system.Lgetxattr(path, "trusted.overlay.opaque") if err != nil { - return err + return nil, err } - if opaque != nil && len(opaque) == 1 && opaque[0] == 'y' { + if len(opaque) == 1 && opaque[0] == 'y' { + if hdr.Xattrs != nil { + delete(hdr.Xattrs, "trusted.overlay.opaque") + } + // create a header for the whiteout file // it should inherit some properties from the parent, but be a regular file - *hdr = tar.Header{ + wo = &tar.Header{ Typeflag: tar.TypeReg, Mode: hdr.Mode & int64(os.ModePerm), Name: filepath.Join(hdr.Name, WhiteoutOpaqueDir), @@ -54,7 +58,7 @@ func (overlayWhiteoutConverter) ConvertWrite(hdr *tar.Header, path string, fi os } } - return nil + return } func (overlayWhiteoutConverter) ConvertRead(hdr *tar.Header, path string) (bool, error) { @@ -63,12 +67,9 @@ func (overlayWhiteoutConverter) ConvertRead(hdr *tar.Header, path string) (bool, // if a directory is marked as opaque by the AUFS special file, we need to translate that to overlay if base == WhiteoutOpaqueDir { - if err := syscall.Setxattr(dir, "trusted.overlay.opaque", []byte{'y'}, 0); err != nil { - return false, err - } - + err := unix.Setxattr(dir, "trusted.overlay.opaque", []byte{'y'}, 0) // don't write the file itself - return false, nil + return false, err } // if a file was deleted and we are using overlay, we need to create a character device @@ -76,7 +77,7 @@ func (overlayWhiteoutConverter) ConvertRead(hdr *tar.Header, path string) (bool, originalBase := base[len(WhiteoutPrefix):] originalPath := filepath.Join(dir, originalBase) - if err := syscall.Mknod(originalPath, syscall.S_IFCHR, 0); err != nil { + if err := unix.Mknod(originalPath, unix.S_IFCHR, 0); err != nil { return false, err } if err := os.Chown(originalPath, hdr.Uid, hdr.Gid); err != nil { diff --git a/pkg/archive/archive_linux_test.go b/pkg/archive/archive_linux_test.go new file mode 100644 index 0000000000..4d5441beea --- /dev/null +++ b/pkg/archive/archive_linux_test.go @@ -0,0 +1,160 @@ +package archive + +import ( + "io/ioutil" + "os" + "path/filepath" + "syscall" + "testing" + + "github.com/containers/storage/pkg/system" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" +) + +// setupOverlayTestDir creates files in a directory with overlay whiteouts +// Tree layout +// . +// ├── d1 # opaque, 0700 +// │   └── f1 # empty file, 0600 +// ├── d2 # opaque, 0750 +// │   └── f1 # empty file, 0660 +// └── d3 # 0700 +// └── f1 # whiteout, 0644 +func setupOverlayTestDir(t *testing.T, src string) { + // Create opaque directory containing single file and permission 0700 + err := os.Mkdir(filepath.Join(src, "d1"), 0700) + require.NoError(t, err) + + err = system.Lsetxattr(filepath.Join(src, "d1"), "trusted.overlay.opaque", []byte("y"), 0) + require.NoError(t, err) + + err = ioutil.WriteFile(filepath.Join(src, "d1", "f1"), []byte{}, 0600) + require.NoError(t, err) + + // Create another opaque directory containing single file but with permission 0750 + err = os.Mkdir(filepath.Join(src, "d2"), 0750) + require.NoError(t, err) + + err = system.Lsetxattr(filepath.Join(src, "d2"), "trusted.overlay.opaque", []byte("y"), 0) + require.NoError(t, err) + + err = ioutil.WriteFile(filepath.Join(src, "d2", "f1"), []byte{}, 0660) + require.NoError(t, err) + + // Create regular directory with deleted file + err = os.Mkdir(filepath.Join(src, "d3"), 0700) + require.NoError(t, err) + + err = system.Mknod(filepath.Join(src, "d3", "f1"), unix.S_IFCHR, 0) + require.NoError(t, err) +} + +func checkOpaqueness(t *testing.T, path string, opaque string) { + xattrOpaque, err := system.Lgetxattr(path, "trusted.overlay.opaque") + require.NoError(t, err) + + if string(xattrOpaque) != opaque { + t.Fatalf("Unexpected opaque value: %q, expected %q", string(xattrOpaque), opaque) + } + +} + +func checkOverlayWhiteout(t *testing.T, path string) { + stat, err := os.Stat(path) + require.NoError(t, err) + + statT, ok := stat.Sys().(*syscall.Stat_t) + if !ok { + t.Fatalf("Unexpected type: %t, expected *syscall.Stat_t", stat.Sys()) + } + if statT.Rdev != 0 { + t.Fatalf("Non-zero device number for whiteout") + } +} + +func checkFileMode(t *testing.T, path string, perm os.FileMode) { + stat, err := os.Stat(path) + require.NoError(t, err) + + if stat.Mode() != perm { + t.Fatalf("Unexpected file mode for %s: %o, expected %o", path, stat.Mode(), perm) + } +} + +func TestOverlayTarUntar(t *testing.T) { + oldmask, err := system.Umask(0) + require.NoError(t, err) + defer system.Umask(oldmask) + + src, err := ioutil.TempDir("", "storage-test-overlay-tar-src") + require.NoError(t, err) + defer os.RemoveAll(src) + + setupOverlayTestDir(t, src) + + dst, err := ioutil.TempDir("", "storage-test-overlay-tar-dst") + require.NoError(t, err) + defer os.RemoveAll(dst) + + options := &TarOptions{ + Compression: Uncompressed, + WhiteoutFormat: OverlayWhiteoutFormat, + } + archive, err := TarWithOptions(src, options) + require.NoError(t, err) + defer archive.Close() + + err = Untar(archive, dst, options) + require.NoError(t, err) + + checkFileMode(t, filepath.Join(dst, "d1"), 0700|os.ModeDir) + checkFileMode(t, filepath.Join(dst, "d2"), 0750|os.ModeDir) + checkFileMode(t, filepath.Join(dst, "d3"), 0700|os.ModeDir) + checkFileMode(t, filepath.Join(dst, "d1", "f1"), 0600) + checkFileMode(t, filepath.Join(dst, "d2", "f1"), 0660) + checkFileMode(t, filepath.Join(dst, "d3", "f1"), os.ModeCharDevice|os.ModeDevice) + + checkOpaqueness(t, filepath.Join(dst, "d1"), "y") + checkOpaqueness(t, filepath.Join(dst, "d2"), "y") + checkOpaqueness(t, filepath.Join(dst, "d3"), "") + checkOverlayWhiteout(t, filepath.Join(dst, "d3", "f1")) +} + +func TestOverlayTarAUFSUntar(t *testing.T) { + oldmask, err := system.Umask(0) + require.NoError(t, err) + defer system.Umask(oldmask) + + src, err := ioutil.TempDir("", "storage-test-overlay-tar-src") + require.NoError(t, err) + defer os.RemoveAll(src) + + setupOverlayTestDir(t, src) + + dst, err := ioutil.TempDir("", "storage-test-overlay-tar-dst") + require.NoError(t, err) + defer os.RemoveAll(dst) + + archive, err := TarWithOptions(src, &TarOptions{ + Compression: Uncompressed, + WhiteoutFormat: OverlayWhiteoutFormat, + }) + require.NoError(t, err) + defer archive.Close() + + err = Untar(archive, dst, &TarOptions{ + Compression: Uncompressed, + WhiteoutFormat: AUFSWhiteoutFormat, + }) + require.NoError(t, err) + + checkFileMode(t, filepath.Join(dst, "d1"), 0700|os.ModeDir) + checkFileMode(t, filepath.Join(dst, "d1", WhiteoutOpaqueDir), 0700) + checkFileMode(t, filepath.Join(dst, "d2"), 0750|os.ModeDir) + checkFileMode(t, filepath.Join(dst, "d2", WhiteoutOpaqueDir), 0750) + checkFileMode(t, filepath.Join(dst, "d3"), 0700|os.ModeDir) + checkFileMode(t, filepath.Join(dst, "d1", "f1"), 0600) + checkFileMode(t, filepath.Join(dst, "d2", "f1"), 0660) + checkFileMode(t, filepath.Join(dst, "d3", WhiteoutPrefix+"f1"), 0600) +} diff --git a/pkg/archive/archive_test.go b/pkg/archive/archive_test.go index 85e41227c0..74508c914b 100644 --- a/pkg/archive/archive_test.go +++ b/pkg/archive/archive_test.go @@ -13,6 +13,10 @@ import ( "strings" "testing" "time" + + "github.com/containers/storage/pkg/idtools" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var tmp string @@ -24,35 +28,22 @@ func init() { } } -func TestIsArchiveNilHeader(t *testing.T) { - out := IsArchive(nil) - if out { - t.Fatalf("isArchive should return false as nil is not a valid archive header") - } +var defaultArchiver = NewDefaultArchiver() + +func defaultTarUntar(src, dst string) error { + return defaultArchiver.TarUntar(src, dst) } -func TestIsArchiveInvalidHeader(t *testing.T) { - header := []byte{0x00, 0x01, 0x02} - out := IsArchive(header) - if out { - t.Fatalf("isArchive should return false as %s is not a valid archive header", header) - } +func defaultUntarPath(src, dst string) error { + return defaultArchiver.UntarPath(src, dst) } -func TestIsArchiveBzip2(t *testing.T) { - header := []byte{0x42, 0x5A, 0x68} - out := IsArchive(header) - if !out { - t.Fatalf("isArchive should return true as %s is a bz2 header", header) - } +func defaultCopyFileWithTar(src, dst string) (err error) { + return defaultArchiver.CopyFileWithTar(src, dst) } -func TestIsArchive7zip(t *testing.T) { - header := []byte{0x50, 0x4b, 0x03, 0x04} - out := IsArchive(header) - if out { - t.Fatalf("isArchive should return false as %s is a 7z header and it is not supported", header) - } +func defaultCopyWithTar(src, dst string) error { + return defaultArchiver.CopyWithTar(src, dst) } func TestIsArchivePathDir(t *testing.T) { @@ -67,7 +58,7 @@ func TestIsArchivePathDir(t *testing.T) { } func TestIsArchivePathInvalidFile(t *testing.T) { - cmd := exec.Command("sh", "-c", "dd if=/dev/zero bs=1K count=1 of=/tmp/archive && gzip --stdout /tmp/archive > /tmp/archive.gz") + cmd := exec.Command("sh", "-c", "dd if=/dev/zero bs=1024 count=1 of=/tmp/archive && gzip --stdout /tmp/archive > /tmp/archive.gz") output, err := cmd.CombinedOutput() if err != nil { t.Fatalf("Fail to create an archive file for test : %s.", output) @@ -81,7 +72,14 @@ func TestIsArchivePathInvalidFile(t *testing.T) { } func TestIsArchivePathTar(t *testing.T) { - cmd := exec.Command("sh", "-c", "touch /tmp/archivedata && tar -cf /tmp/archive /tmp/archivedata && gzip --stdout /tmp/archive > /tmp/archive.gz") + var whichTar string + if runtime.GOOS == "solaris" { + whichTar = "gtar" + } else { + whichTar = "tar" + } + cmdStr := fmt.Sprintf("touch /tmp/archivedata && %s -cf /tmp/archive /tmp/archivedata && gzip --stdout /tmp/archive > /tmp/archive.gz", whichTar) + cmd := exec.Command("sh", "-c", cmdStr) output, err := cmd.CombinedOutput() if err != nil { t.Fatalf("Fail to create an archive file for test : %s.", output) @@ -94,53 +92,54 @@ func TestIsArchivePathTar(t *testing.T) { } } -func TestDecompressStreamGzip(t *testing.T) { - cmd := exec.Command("sh", "-c", "touch /tmp/archive && gzip -f /tmp/archive") +func testDecompressStream(t *testing.T, ext, compressCommand string) { + cmd := exec.Command("sh", "-c", + fmt.Sprintf("touch /tmp/archive && %s /tmp/archive", compressCommand)) output, err := cmd.CombinedOutput() if err != nil { - t.Fatalf("Fail to create an archive file for test : %s.", output) + t.Fatalf("Failed to create an archive file for test : %s.", output) } - archive, err := os.Open(tmp + "archive.gz") - _, err = DecompressStream(archive) + filename := "archive." + ext + archive, err := os.Open(tmp + filename) if err != nil { - t.Fatalf("Failed to decompress a gzip file.") + t.Fatalf("Failed to open file %s: %v", filename, err) } -} + defer archive.Close() -func TestDecompressStreamBzip2(t *testing.T) { - cmd := exec.Command("sh", "-c", "touch /tmp/archive && bzip2 -f /tmp/archive") - output, err := cmd.CombinedOutput() + r, err := DecompressStream(archive) if err != nil { - t.Fatalf("Fail to create an archive file for test : %s.", output) + t.Fatalf("Failed to decompress %s: %v", filename, err) } - archive, err := os.Open(tmp + "archive.bz2") - _, err = DecompressStream(archive) - if err != nil { - t.Fatalf("Failed to decompress a bzip2 file.") + if _, err = ioutil.ReadAll(r); err != nil { + t.Fatalf("Failed to read the decompressed stream: %v ", err) + } + if err = r.Close(); err != nil { + t.Fatalf("Failed to close the decompressed stream: %v ", err) } } +func TestDecompressStreamGzip(t *testing.T) { + testDecompressStream(t, "gz", "gzip -f") +} + +func TestDecompressStreamBzip2(t *testing.T) { + testDecompressStream(t, "bz2", "bzip2 -f") +} + func TestDecompressStreamXz(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Xz not present in msys2") } - cmd := exec.Command("sh", "-c", "touch /tmp/archive && xz -f /tmp/archive") - output, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("Fail to create an archive file for test : %s.", output) - } - archive, err := os.Open(tmp + "archive.xz") - _, err = DecompressStream(archive) - if err != nil { - t.Fatalf("Failed to decompress an xz file.") - } + testDecompressStream(t, "xz", "xz -f") } -func TestCompressStreamXzUnsuported(t *testing.T) { +func TestCompressStreamXzUnsupported(t *testing.T) { dest, err := os.Create(tmp + "dest") if err != nil { t.Fatalf("Fail to create the destination file") } + defer dest.Close() + _, err = CompressStream(dest, Xz) if err == nil { t.Fatalf("Should fail as xz is unsupported for compression format.") @@ -152,6 +151,8 @@ func TestCompressStreamBzip2Unsupported(t *testing.T) { if err != nil { t.Fatalf("Fail to create the destination file") } + defer dest.Close() + _, err = CompressStream(dest, Xz) if err == nil { t.Fatalf("Should fail as xz is unsupported for compression format.") @@ -163,6 +164,8 @@ func TestCompressStreamInvalid(t *testing.T) { if err != nil { t.Fatalf("Fail to create the destination file") } + defer dest.Close() + _, err = CompressStream(dest, -1) if err == nil { t.Fatalf("Should fail as xz is unsupported for compression format.") @@ -260,10 +263,8 @@ func TestCmdStreamGood(t *testing.T) { } func TestUntarPathWithInvalidDest(t *testing.T) { - tempFolder, err := ioutil.TempDir("", "docker-archive-test") - if err != nil { - t.Fatal(err) - } + tempFolder, err := ioutil.TempDir("", "storage-archive-test") + require.NoError(t, err) defer os.RemoveAll(tempFolder) invalidDestFolder := filepath.Join(tempFolder, "invalidDest") // Create a src file @@ -282,33 +283,29 @@ func TestUntarPathWithInvalidDest(t *testing.T) { cmd := exec.Command("sh", "-c", "tar cf "+tarFileU+" "+srcFileU) _, err = cmd.CombinedOutput() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - err = UntarPath(tarFile, invalidDestFolder) + err = defaultUntarPath(tarFile, invalidDestFolder) if err == nil { t.Fatalf("UntarPath with invalid destination path should throw an error.") } } func TestUntarPathWithInvalidSrc(t *testing.T) { - dest, err := ioutil.TempDir("", "docker-archive-test") + dest, err := ioutil.TempDir("", "storage-archive-test") if err != nil { t.Fatalf("Fail to create the destination file") } defer os.RemoveAll(dest) - err = UntarPath("/invalid/path", dest) + err = defaultUntarPath("/invalid/path", dest) if err == nil { t.Fatalf("UntarPath with invalid src path should throw an error.") } } func TestUntarPath(t *testing.T) { - tmpFolder, err := ioutil.TempDir("", "docker-archive-test") - if err != nil { - t.Fatal(err) - } + tmpFolder, err := ioutil.TempDir("", "storage-archive-test") + require.NoError(t, err) defer os.RemoveAll(tmpFolder) srcFile := filepath.Join(tmpFolder, "src") tarFile := filepath.Join(tmpFolder, "src.tar") @@ -329,11 +326,9 @@ func TestUntarPath(t *testing.T) { } cmd := exec.Command("sh", "-c", "tar cf "+tarFileU+" "+srcFileU) _, err = cmd.CombinedOutput() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - err = UntarPath(tarFile, destFolder) + err = defaultUntarPath(tarFile, destFolder) if err != nil { t.Fatalf("UntarPath shouldn't throw an error, %s.", err) } @@ -346,7 +341,7 @@ func TestUntarPath(t *testing.T) { // Do the same test as above but with the destination as file, it should fail func TestUntarPathWithDestinationFile(t *testing.T) { - tmpFolder, err := ioutil.TempDir("", "docker-archive-test") + tmpFolder, err := ioutil.TempDir("", "storage-archive-test") if err != nil { t.Fatal(err) } @@ -372,7 +367,7 @@ func TestUntarPathWithDestinationFile(t *testing.T) { if err != nil { t.Fatalf("Fail to create the destination file") } - err = UntarPath(tarFile, destFile) + err = defaultUntarPath(tarFile, destFile) if err == nil { t.Fatalf("UntarPath should throw an error if the destination if a file") } @@ -382,7 +377,7 @@ func TestUntarPathWithDestinationFile(t *testing.T) { // and the destination file is a directory // It's working, see https://github.com/docker/docker/issues/10040 func TestUntarPathWithDestinationSrcFileAsFolder(t *testing.T) { - tmpFolder, err := ioutil.TempDir("", "docker-archive-test") + tmpFolder, err := ioutil.TempDir("", "storage-archive-test") if err != nil { t.Fatal(err) } @@ -415,14 +410,14 @@ func TestUntarPathWithDestinationSrcFileAsFolder(t *testing.T) { if err != nil { t.Fatal(err) } - err = UntarPath(tarFile, destFolder) + err = defaultUntarPath(tarFile, destFolder) if err != nil { t.Fatalf("UntarPath should throw not throw an error if the extracted file already exists and is a folder") } } func TestCopyWithTarInvalidSrc(t *testing.T) { - tempFolder, err := ioutil.TempDir("", "docker-archive-test") + tempFolder, err := ioutil.TempDir("", "storage-archive-test") if err != nil { t.Fatal(nil) } @@ -432,14 +427,14 @@ func TestCopyWithTarInvalidSrc(t *testing.T) { if err != nil { t.Fatal(err) } - err = CopyWithTar(invalidSrc, destFolder) + err = defaultCopyWithTar(invalidSrc, destFolder) if err == nil { t.Fatalf("archiver.CopyWithTar with invalid src path should throw an error.") } } func TestCopyWithTarInexistentDestWillCreateIt(t *testing.T) { - tempFolder, err := ioutil.TempDir("", "docker-archive-test") + tempFolder, err := ioutil.TempDir("", "storage-archive-test") if err != nil { t.Fatal(nil) } @@ -449,7 +444,7 @@ func TestCopyWithTarInexistentDestWillCreateIt(t *testing.T) { if err != nil { t.Fatal(err) } - err = CopyWithTar(srcFolder, inexistentDestFolder) + err = defaultCopyWithTar(srcFolder, inexistentDestFolder) if err != nil { t.Fatalf("CopyWithTar with an inexistent folder shouldn't fail.") } @@ -461,7 +456,7 @@ func TestCopyWithTarInexistentDestWillCreateIt(t *testing.T) { // Test CopyWithTar with a file as src func TestCopyWithTarSrcFile(t *testing.T) { - folder, err := ioutil.TempDir("", "docker-archive-test") + folder, err := ioutil.TempDir("", "storage-archive-test") if err != nil { t.Fatal(err) } @@ -478,7 +473,7 @@ func TestCopyWithTarSrcFile(t *testing.T) { t.Fatal(err) } ioutil.WriteFile(src, []byte("content"), 0777) - err = CopyWithTar(src, dest) + err = defaultCopyWithTar(src, dest) if err != nil { t.Fatalf("archiver.CopyWithTar shouldn't throw an error, %s.", err) } @@ -491,7 +486,7 @@ func TestCopyWithTarSrcFile(t *testing.T) { // Test CopyWithTar with a folder as src func TestCopyWithTarSrcFolder(t *testing.T) { - folder, err := ioutil.TempDir("", "docker-archive-test") + folder, err := ioutil.TempDir("", "storage-archive-test") if err != nil { t.Fatal(err) } @@ -507,7 +502,7 @@ func TestCopyWithTarSrcFolder(t *testing.T) { t.Fatal(err) } ioutil.WriteFile(filepath.Join(src, "file"), []byte("content"), 0777) - err = CopyWithTar(src, dest) + err = defaultCopyWithTar(src, dest) if err != nil { t.Fatalf("archiver.CopyWithTar shouldn't throw an error, %s.", err) } @@ -519,7 +514,7 @@ func TestCopyWithTarSrcFolder(t *testing.T) { } func TestCopyFileWithTarInvalidSrc(t *testing.T) { - tempFolder, err := ioutil.TempDir("", "docker-archive-test") + tempFolder, err := ioutil.TempDir("", "storage-archive-test") if err != nil { t.Fatal(err) } @@ -530,14 +525,14 @@ func TestCopyFileWithTarInvalidSrc(t *testing.T) { t.Fatal(err) } invalidFile := filepath.Join(tempFolder, "doesnotexists") - err = CopyFileWithTar(invalidFile, destFolder) + err = defaultCopyFileWithTar(invalidFile, destFolder) if err == nil { t.Fatalf("archiver.CopyWithTar with invalid src path should throw an error.") } } func TestCopyFileWithTarInexistentDestWillCreateIt(t *testing.T) { - tempFolder, err := ioutil.TempDir("", "docker-archive-test") + tempFolder, err := ioutil.TempDir("", "storage-archive-test") if err != nil { t.Fatal(nil) } @@ -548,7 +543,7 @@ func TestCopyFileWithTarInexistentDestWillCreateIt(t *testing.T) { if err != nil { t.Fatal(err) } - err = CopyFileWithTar(srcFile, inexistentDestFolder) + err = defaultCopyFileWithTar(srcFile, inexistentDestFolder) if err != nil { t.Fatalf("CopyWithTar with an inexistent folder shouldn't fail.") } @@ -560,7 +555,7 @@ func TestCopyFileWithTarInexistentDestWillCreateIt(t *testing.T) { } func TestCopyFileWithTarSrcFolder(t *testing.T) { - folder, err := ioutil.TempDir("", "docker-archive-copyfilewithtar-test") + folder, err := ioutil.TempDir("", "storage-archive-copyfilewithtar-test") if err != nil { t.Fatal(err) } @@ -575,14 +570,14 @@ func TestCopyFileWithTarSrcFolder(t *testing.T) { if err != nil { t.Fatal(err) } - err = CopyFileWithTar(src, dest) + err = defaultCopyFileWithTar(src, dest) if err == nil { t.Fatalf("CopyFileWithTar should throw an error with a folder.") } } func TestCopyFileWithTarSrcFile(t *testing.T) { - folder, err := ioutil.TempDir("", "docker-archive-test") + folder, err := ioutil.TempDir("", "storage-archive-test") if err != nil { t.Fatal(err) } @@ -599,7 +594,7 @@ func TestCopyFileWithTarSrcFile(t *testing.T) { t.Fatal(err) } ioutil.WriteFile(src, []byte("content"), 0777) - err = CopyWithTar(src, dest+"/") + err = defaultCopyWithTar(src, dest+"/") if err != nil { t.Fatalf("archiver.CopyFileWithTar shouldn't throw an error, %s.", err) } @@ -625,13 +620,13 @@ func TestTarFiles(t *testing.T) { } func checkNoChanges(fileNum int, hardlinks bool) error { - srcDir, err := ioutil.TempDir("", "docker-test-srcDir") + srcDir, err := ioutil.TempDir("", "storage-test-srcDir") if err != nil { return err } defer os.RemoveAll(srcDir) - destDir, err := ioutil.TempDir("", "docker-test-destDir") + destDir, err := ioutil.TempDir("", "storage-test-destDir") if err != nil { return err } @@ -642,7 +637,7 @@ func checkNoChanges(fileNum int, hardlinks bool) error { return err } - err = TarUntar(srcDir, destDir) + err = defaultTarUntar(srcDir, destDir) if err != nil { return err } @@ -676,7 +671,7 @@ func tarUntar(t *testing.T, origin string, options *TarOptions) ([]Change, error return nil, fmt.Errorf("Wrong compression detected. Actual compression: %s, found %s", compression.Extension(), detectedCompression.Extension()) } - tmp, err := ioutil.TempDir("", "docker-test-untar") + tmp, err := ioutil.TempDir("", "storage-test-untar") if err != nil { return nil, err } @@ -696,7 +691,7 @@ func TestTarUntar(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Failing on Windows") } - origin, err := ioutil.TempDir("", "docker-test-untar-origin") + origin, err := ioutil.TempDir("", "storage-test-untar-origin") if err != nil { t.Fatal(err) } @@ -730,12 +725,63 @@ func TestTarUntar(t *testing.T) { } } +func TestTarWithOptionsChownOptsAlwaysOverridesIdPair(t *testing.T) { + origin, err := ioutil.TempDir("", "storage-test-tar-chown-opt") + require.NoError(t, err) + + defer os.RemoveAll(origin) + filePath := filepath.Join(origin, "1") + err = ioutil.WriteFile(filePath, []byte("hello world"), 0700) + require.NoError(t, err) + + idMaps := []idtools.IDMap{ + 0: { + ContainerID: 0, + HostID: 0, + Size: 65536, + }, + 1: { + ContainerID: 0, + HostID: 100000, + Size: 65536, + }, + } + + cases := []struct { + opts *TarOptions + expectedUID int + expectedGID int + }{ + {&TarOptions{ChownOpts: &idtools.IDPair{UID: 1337, GID: 42}}, 1337, 42}, + {&TarOptions{ChownOpts: &idtools.IDPair{UID: 100001, GID: 100001}, UIDMaps: idMaps, GIDMaps: idMaps}, 100001, 100001}, + {&TarOptions{ChownOpts: &idtools.IDPair{UID: 0, GID: 0}, NoLchown: false}, 0, 0}, + {&TarOptions{ChownOpts: &idtools.IDPair{UID: 1, GID: 1}, NoLchown: true}, 1, 1}, + {&TarOptions{ChownOpts: &idtools.IDPair{UID: 1000, GID: 1000}, NoLchown: true}, 1000, 1000}, + } + for _, testCase := range cases { + reader, err := TarWithOptions(filePath, testCase.opts) + require.NoError(t, err) + tr := tar.NewReader(reader) + defer reader.Close() + for { + hdr, err := tr.Next() + if err == io.EOF { + // end of tar archive + break + } + require.NoError(t, err) + assert.Equal(t, hdr.Uid, testCase.expectedUID, "Uid equals expected value") + assert.Equal(t, hdr.Gid, testCase.expectedGID, "Gid equals expected value") + } + } +} + func TestTarWithOptions(t *testing.T) { // TODO Windows: Figure out how to fix this test. if runtime.GOOS == "windows" { t.Skip("Failing on Windows") } - origin, err := ioutil.TempDir("", "docker-test-untar-origin") + origin, err := ioutil.TempDir("", "storage-test-untar-origin") if err != nil { t.Fatal(err) } @@ -777,12 +823,12 @@ func TestTarWithOptions(t *testing.T) { // Failing prevents the archives from being uncompressed during ADD func TestTypeXGlobalHeaderDoesNotFail(t *testing.T) { hdr := tar.Header{Typeflag: tar.TypeXGlobalHeader} - tmpDir, err := ioutil.TempDir("", "docker-test-archive-pax-test") + tmpDir, err := ioutil.TempDir("", "storage-test-archive-pax-test") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpDir) - err = createTarFile(filepath.Join(tmpDir, "pax_global_header"), tmpDir, &hdr, nil, true, nil) + err = createTarFile(filepath.Join(tmpDir, "pax_global_header"), tmpDir, &hdr, nil, true, nil, false) if err != nil { t.Fatal(err) } @@ -795,6 +841,8 @@ func TestUntarUstarGnuConflict(t *testing.T) { if err != nil { t.Fatal(err) } + defer f.Close() + found := false tr := tar.NewReader(f) // Iterate through the files in the archive. @@ -835,11 +883,11 @@ func prepareUntarSourceDirectory(numberOfFiles int, targetPath string, makeLinks } func BenchmarkTarUntar(b *testing.B) { - origin, err := ioutil.TempDir("", "docker-test-untar-origin") + origin, err := ioutil.TempDir("", "storage-test-untar-origin") if err != nil { b.Fatal(err) } - tempDir, err := ioutil.TempDir("", "docker-test-untar-destination") + tempDir, err := ioutil.TempDir("", "storage-test-untar-destination") if err != nil { b.Fatal(err) } @@ -854,7 +902,7 @@ func BenchmarkTarUntar(b *testing.B) { b.ResetTimer() b.SetBytes(int64(n)) for n := 0; n < b.N; n++ { - err := TarUntar(origin, target) + err := defaultTarUntar(origin, target) if err != nil { b.Fatal(err) } @@ -863,11 +911,11 @@ func BenchmarkTarUntar(b *testing.B) { } func BenchmarkTarUntarWithLinks(b *testing.B) { - origin, err := ioutil.TempDir("", "docker-test-untar-origin") + origin, err := ioutil.TempDir("", "storage-test-untar-origin") if err != nil { b.Fatal(err) } - tempDir, err := ioutil.TempDir("", "docker-test-untar-destination") + tempDir, err := ioutil.TempDir("", "storage-test-untar-destination") if err != nil { b.Fatal(err) } @@ -882,7 +930,7 @@ func BenchmarkTarUntarWithLinks(b *testing.B) { b.ResetTimer() b.SetBytes(int64(n)) for n := 0; n < b.N; n++ { - err := TarUntar(origin, target) + err := defaultTarUntar(origin, target) if err != nil { b.Fatal(err) } @@ -912,7 +960,7 @@ func TestUntarInvalidFilenames(t *testing.T) { }, }, } { - if err := testBreakout("untar", "docker-TestUntarInvalidFilenames", headers); err != nil { + if err := testBreakout("untar", "storage-TestUntarInvalidFilenames", headers); err != nil { t.Fatalf("i=%d. %v", i, err) } } @@ -944,7 +992,7 @@ func TestUntarHardlinkToSymlink(t *testing.T) { }, }, } { - if err := testBreakout("untar", "docker-TestUntarHardlinkToSymlink", headers); err != nil { + if err := testBreakout("untar", "storage-TestUntarHardlinkToSymlink", headers); err != nil { t.Fatalf("i=%d. %v", i, err) } } @@ -1028,7 +1076,7 @@ func TestUntarInvalidHardlink(t *testing.T) { }, }, } { - if err := testBreakout("untar", "docker-TestUntarInvalidHardlink", headers); err != nil { + if err := testBreakout("untar", "storage-TestUntarInvalidHardlink", headers); err != nil { t.Fatalf("i=%d. %v", i, err) } } @@ -1126,7 +1174,7 @@ func TestUntarInvalidSymlink(t *testing.T) { }, }, } { - if err := testBreakout("untar", "docker-TestUntarInvalidSymlink", headers); err != nil { + if err := testBreakout("untar", "storage-TestUntarInvalidSymlink", headers); err != nil { t.Fatalf("i=%d. %v", i, err) } } @@ -1146,3 +1194,132 @@ func TestTempArchiveCloseMultipleTimes(t *testing.T) { } } } + +func TestReplaceFileTarWrapper(t *testing.T) { + filesInArchive := 20 + testcases := []struct { + doc string + filename string + modifier TarModifierFunc + expected string + fileCount int + }{ + { + doc: "Modifier creates a new file", + filename: "newfile", + modifier: createModifier(t), + expected: "the new content", + fileCount: filesInArchive + 1, + }, + { + doc: "Modifier replaces a file", + filename: "file-2", + modifier: createOrReplaceModifier, + expected: "the new content", + fileCount: filesInArchive, + }, + { + doc: "Modifier replaces the last file", + filename: fmt.Sprintf("file-%d", filesInArchive-1), + modifier: createOrReplaceModifier, + expected: "the new content", + fileCount: filesInArchive, + }, + { + doc: "Modifier appends to a file", + filename: "file-3", + modifier: appendModifier, + expected: "fooo\nnext line", + fileCount: filesInArchive, + }, + } + + for _, testcase := range testcases { + sourceArchive, cleanup := buildSourceArchive(t, filesInArchive) + defer cleanup() + + resultArchive := ReplaceFileTarWrapper( + sourceArchive, + map[string]TarModifierFunc{testcase.filename: testcase.modifier}) + + actual := readFileFromArchive(t, resultArchive, testcase.filename, testcase.fileCount, testcase.doc) + assert.Equal(t, testcase.expected, actual, testcase.doc) + } +} + +/* +// TestPrefixHeaderReadable tests that files that could be created with the +// version of this package that was built with <=go17 are still readable. +func TestPrefixHeaderReadable(t *testing.T) { + // https://gist.github.com/stevvooe/e2a790ad4e97425896206c0816e1a882#file-out-go + var testFile = []byte("\x1f\x8b\x08\x08\x44\x21\x68\x59\x00\x03\x74\x2e\x74\x61\x72\x00\x4b\xcb\xcf\x67\xa0\x35\x30\x80\x00\x86\x06\x10\x47\x01\xc1\x37\x40\x00\x54\xb6\xb1\xa1\xa9\x99\x09\x48\x25\x1d\x40\x69\x71\x49\x62\x91\x02\xe5\x76\xa1\x79\x84\x21\x91\xd6\x80\x72\xaf\x8f\x82\x51\x30\x0a\x46\x36\x00\x00\xf0\x1c\x1e\x95\x00\x06\x00\x00") + + tmpDir, err := ioutil.TempDir("", "prefix-test") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + err = Untar(bytes.NewReader(testFile), tmpDir, nil) + require.NoError(t, err) + + baseName := "foo" + pth := strings.Repeat("a", 100-len(baseName)) + "/" + baseName + + _, err = os.Lstat(filepath.Join(tmpDir, pth)) + require.NoError(t, err) +} +*/ + +func buildSourceArchive(t *testing.T, numberOfFiles int) (io.ReadCloser, func()) { + srcDir, err := ioutil.TempDir("", "storage-test-srcDir") + require.NoError(t, err) + + _, err = prepareUntarSourceDirectory(numberOfFiles, srcDir, false) + require.NoError(t, err) + + sourceArchive, err := TarWithOptions(srcDir, &TarOptions{}) + require.NoError(t, err) + return sourceArchive, func() { + os.RemoveAll(srcDir) + sourceArchive.Close() + } +} + +func createOrReplaceModifier(path string, header *tar.Header, content io.Reader) (*tar.Header, []byte, error) { + return &tar.Header{ + Mode: 0600, + Typeflag: tar.TypeReg, + }, []byte("the new content"), nil +} + +func createModifier(t *testing.T) TarModifierFunc { + return func(path string, header *tar.Header, content io.Reader) (*tar.Header, []byte, error) { + assert.Nil(t, content) + return createOrReplaceModifier(path, header, content) + } +} + +func appendModifier(path string, header *tar.Header, content io.Reader) (*tar.Header, []byte, error) { + buffer := bytes.Buffer{} + if content != nil { + if _, err := buffer.ReadFrom(content); err != nil { + return nil, nil, err + } + } + buffer.WriteString("\nnext line") + return &tar.Header{Mode: 0600, Typeflag: tar.TypeReg}, buffer.Bytes(), nil +} + +func readFileFromArchive(t *testing.T, archive io.ReadCloser, name string, expectedCount int, doc string) string { + destDir, err := ioutil.TempDir("", "storage-test-destDir") + require.NoError(t, err) + defer os.RemoveAll(destDir) + + err = Untar(archive, destDir, nil) + require.NoError(t, err) + + files, _ := ioutil.ReadDir(destDir) + assert.Len(t, files, expectedCount, doc) + + content, err := ioutil.ReadFile(filepath.Join(destDir, name)) + assert.NoError(t, err) + return string(content) +} diff --git a/pkg/archive/archive_unix.go b/pkg/archive/archive_unix.go index 19d731fd23..bdc1a3d794 100644 --- a/pkg/archive/archive_unix.go +++ b/pkg/archive/archive_unix.go @@ -9,7 +9,10 @@ import ( "path/filepath" "syscall" + "github.com/containers/storage/pkg/idtools" "github.com/containers/storage/pkg/system" + rsystem "github.com/opencontainers/runc/libcontainer/system" + "golang.org/x/sys/unix" ) // fixVolumePathPrefix does platform specific processing to ensure that if @@ -40,33 +43,38 @@ func chmodTarEntry(perm os.FileMode) os.FileMode { return perm // noop for unix as golang APIs provide perm bits correctly } -func setHeaderForSpecialDevice(hdr *tar.Header, ta *tarAppender, name string, stat interface{}) (inode uint64, err error) { +func setHeaderForSpecialDevice(hdr *tar.Header, name string, stat interface{}) (err error) { s, ok := stat.(*syscall.Stat_t) - if !ok { - err = errors.New("cannot convert stat value to syscall.Stat_t") - return + if ok { + // Currently go does not fill in the major/minors + if s.Mode&unix.S_IFBLK != 0 || + s.Mode&unix.S_IFCHR != 0 { + hdr.Devmajor = int64(major(uint64(s.Rdev))) // nolint: unconvert + hdr.Devminor = int64(minor(uint64(s.Rdev))) // nolint: unconvert + } } - inode = uint64(s.Ino) + return +} + +func getInodeFromStat(stat interface{}) (inode uint64, err error) { + s, ok := stat.(*syscall.Stat_t) - // Currently go does not fill in the major/minors - if s.Mode&syscall.S_IFBLK != 0 || - s.Mode&syscall.S_IFCHR != 0 { - hdr.Devmajor = int64(major(uint64(s.Rdev))) - hdr.Devminor = int64(minor(uint64(s.Rdev))) + if ok { + inode = s.Ino } return } -func getFileUIDGID(stat interface{}) (int, int, error) { +func getFileUIDGID(stat interface{}) (idtools.IDPair, error) { s, ok := stat.(*syscall.Stat_t) if !ok { - return -1, -1, errors.New("cannot convert stat value to syscall.Stat_t") + return idtools.IDPair{}, errors.New("cannot convert stat value to syscall.Stat_t") } - return int(s.Uid), int(s.Gid), nil + return idtools.IDPair{UID: int(s.Uid), GID: int(s.Gid)}, nil } func major(device uint64) uint64 { @@ -80,20 +88,22 @@ func minor(device uint64) uint64 { // handleTarTypeBlockCharFifo is an OS-specific helper function used by // createTarFile to handle the following types of header: Block; Char; Fifo func handleTarTypeBlockCharFifo(hdr *tar.Header, path string) error { + if rsystem.RunningInUserNS() { + // cannot create a device if running in user namespace + return nil + } + mode := uint32(hdr.Mode & 07777) switch hdr.Typeflag { case tar.TypeBlock: - mode |= syscall.S_IFBLK + mode |= unix.S_IFBLK case tar.TypeChar: - mode |= syscall.S_IFCHR + mode |= unix.S_IFCHR case tar.TypeFifo: - mode |= syscall.S_IFIFO + mode |= unix.S_IFIFO } - if err := system.Mknod(path, mode, int(system.Mkdev(hdr.Devmajor, hdr.Devminor))); err != nil { - return err - } - return nil + return system.Mknod(path, mode, int(system.Mkdev(hdr.Devmajor, hdr.Devminor))) } func handleLChmod(hdr *tar.Header, path string, hdrInfo os.FileInfo) error { diff --git a/pkg/archive/archive_unix_test.go b/pkg/archive/archive_unix_test.go index 4bf0ae2dfa..095a909317 100644 --- a/pkg/archive/archive_unix_test.go +++ b/pkg/archive/archive_unix_test.go @@ -8,10 +8,14 @@ import ( "io/ioutil" "os" "path/filepath" + "runtime" "syscall" "testing" "github.com/containers/storage/pkg/system" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" ) func TestCanonicalTarNameForPath(t *testing.T) { @@ -67,61 +71,90 @@ func TestChmodTarEntry(t *testing.T) { } func TestTarWithHardLink(t *testing.T) { - origin, err := ioutil.TempDir("", "docker-test-tar-hardlink") - if err != nil { - t.Fatal(err) - } + origin, err := ioutil.TempDir("", "storage-test-tar-hardlink") + require.NoError(t, err) defer os.RemoveAll(origin) - if err := ioutil.WriteFile(filepath.Join(origin, "1"), []byte("hello world"), 0700); err != nil { - t.Fatal(err) - } - if err := os.Link(filepath.Join(origin, "1"), filepath.Join(origin, "2")); err != nil { - t.Fatal(err) - } + + err = ioutil.WriteFile(filepath.Join(origin, "1"), []byte("hello world"), 0700) + require.NoError(t, err) + + err = os.Link(filepath.Join(origin, "1"), filepath.Join(origin, "2")) + require.NoError(t, err) var i1, i2 uint64 - if i1, err = getNlink(filepath.Join(origin, "1")); err != nil { - t.Fatal(err) - } + i1, err = getNlink(filepath.Join(origin, "1")) + require.NoError(t, err) + // sanity check that we can hardlink if i1 != 2 { t.Skipf("skipping since hardlinks don't work here; expected 2 links, got %d", i1) } - dest, err := ioutil.TempDir("", "docker-test-tar-hardlink-dest") - if err != nil { - t.Fatal(err) - } + dest, err := ioutil.TempDir("", "storage-test-tar-hardlink-dest") + require.NoError(t, err) defer os.RemoveAll(dest) // we'll do this in two steps to separate failure fh, err := Tar(origin, Uncompressed) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // ensure we can read the whole thing with no error, before writing back out buf, err := ioutil.ReadAll(fh) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) bRdr := bytes.NewReader(buf) err = Untar(bRdr, dest, &TarOptions{Compression: Uncompressed}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - if i1, err = getInode(filepath.Join(dest, "1")); err != nil { - t.Fatal(err) - } - if i2, err = getInode(filepath.Join(dest, "2")); err != nil { - t.Fatal(err) - } + i1, err = getInode(filepath.Join(dest, "1")) + require.NoError(t, err) + + i2, err = getInode(filepath.Join(dest, "2")) + require.NoError(t, err) + + assert.Equal(t, i1, i2) +} + +func TestTarWithHardLinkAndRebase(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "storage-test-tar-hardlink-rebase") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + origin := filepath.Join(tmpDir, "origin") + err = os.Mkdir(origin, 0700) + require.NoError(t, err) - if i1 != i2 { - t.Errorf("expected matching inodes, but got %d and %d", i1, i2) + err = ioutil.WriteFile(filepath.Join(origin, "1"), []byte("hello world"), 0700) + require.NoError(t, err) + + err = os.Link(filepath.Join(origin, "1"), filepath.Join(origin, "2")) + require.NoError(t, err) + + var i1, i2 uint64 + i1, err = getNlink(filepath.Join(origin, "1")) + require.NoError(t, err) + + // sanity check that we can hardlink + if i1 != 2 { + t.Skipf("skipping since hardlinks don't work here; expected 2 links, got %d", i1) } + + dest := filepath.Join(tmpDir, "dest") + bRdr, err := TarResourceRebase(origin, "origin") + require.NoError(t, err) + + dstDir, srcBase := SplitPathDirEntry(origin) + _, dstBase := SplitPathDirEntry(dest) + content := RebaseArchiveEntries(bRdr, srcBase, dstBase) + err = Untar(content, dstDir, &TarOptions{Compression: Uncompressed, NoLchown: true, NoOverwriteDirNonDir: true}) + require.NoError(t, err) + + i1, err = getInode(filepath.Join(dest, "1")) + require.NoError(t, err) + i2, err = getInode(filepath.Join(dest, "2")) + require.NoError(t, err) + + assert.Equal(t, i1, i2) } func getNlink(path string) (uint64, error) { @@ -150,52 +183,39 @@ func getInode(path string) (uint64, error) { } func TestTarWithBlockCharFifo(t *testing.T) { - origin, err := ioutil.TempDir("", "docker-test-tar-hardlink") - if err != nil { - t.Fatal(err) - } + origin, err := ioutil.TempDir("", "storage-test-tar-hardlink") + require.NoError(t, err) + defer os.RemoveAll(origin) - if err := ioutil.WriteFile(filepath.Join(origin, "1"), []byte("hello world"), 0700); err != nil { - t.Fatal(err) - } - if err := system.Mknod(filepath.Join(origin, "2"), syscall.S_IFBLK, int(system.Mkdev(int64(12), int64(5)))); err != nil { - t.Fatal(err) - } - if err := system.Mknod(filepath.Join(origin, "3"), syscall.S_IFCHR, int(system.Mkdev(int64(12), int64(5)))); err != nil { - t.Fatal(err) - } - if err := system.Mknod(filepath.Join(origin, "4"), syscall.S_IFIFO, int(system.Mkdev(int64(12), int64(5)))); err != nil { - t.Fatal(err) - } + err = ioutil.WriteFile(filepath.Join(origin, "1"), []byte("hello world"), 0700) + require.NoError(t, err) - dest, err := ioutil.TempDir("", "docker-test-tar-hardlink-dest") - if err != nil { - t.Fatal(err) - } + err = system.Mknod(filepath.Join(origin, "2"), unix.S_IFBLK, int(system.Mkdev(int64(12), int64(5)))) + require.NoError(t, err) + err = system.Mknod(filepath.Join(origin, "3"), unix.S_IFCHR, int(system.Mkdev(int64(12), int64(5)))) + require.NoError(t, err) + err = system.Mknod(filepath.Join(origin, "4"), unix.S_IFIFO, int(system.Mkdev(int64(12), int64(5)))) + require.NoError(t, err) + + dest, err := ioutil.TempDir("", "storage-test-tar-hardlink-dest") + require.NoError(t, err) defer os.RemoveAll(dest) // we'll do this in two steps to separate failure fh, err := Tar(origin, Uncompressed) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // ensure we can read the whole thing with no error, before writing back out buf, err := ioutil.ReadAll(fh) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) bRdr := bytes.NewReader(buf) err = Untar(bRdr, dest, &TarOptions{Compression: Uncompressed}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) changes, err := ChangesDirs(origin, dest) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + if len(changes) > 0 { t.Fatalf("Tar with special device (block, char, fifo) should keep them (recreate them when untar) : %v", changes) } @@ -203,23 +223,21 @@ func TestTarWithBlockCharFifo(t *testing.T) { // TestTarUntarWithXattr is Unix as Lsetxattr is not supported on Windows func TestTarUntarWithXattr(t *testing.T) { - origin, err := ioutil.TempDir("", "docker-test-untar-origin") - if err != nil { - t.Fatal(err) + if runtime.GOOS == "solaris" { + t.Skip() } + origin, err := ioutil.TempDir("", "storage-test-untar-origin") + require.NoError(t, err) defer os.RemoveAll(origin) - if err := ioutil.WriteFile(filepath.Join(origin, "1"), []byte("hello world"), 0700); err != nil { - t.Fatal(err) - } - if err := ioutil.WriteFile(filepath.Join(origin, "2"), []byte("welcome!"), 0700); err != nil { - t.Fatal(err) - } - if err := ioutil.WriteFile(filepath.Join(origin, "3"), []byte("will be ignored"), 0700); err != nil { - t.Fatal(err) - } - if err := system.Lsetxattr(filepath.Join(origin, "2"), "security.capability", []byte{0x00}, 0); err != nil { - t.Fatal(err) - } + err = ioutil.WriteFile(filepath.Join(origin, "1"), []byte("hello world"), 0700) + require.NoError(t, err) + + err = ioutil.WriteFile(filepath.Join(origin, "2"), []byte("welcome!"), 0700) + require.NoError(t, err) + err = ioutil.WriteFile(filepath.Join(origin, "3"), []byte("will be ignored"), 0700) + require.NoError(t, err) + err = system.Lsetxattr(filepath.Join(origin, "2"), "security.capability", []byte{0x00}, 0) + require.NoError(t, err) for _, c := range []Compression{ Uncompressed, diff --git a/pkg/archive/archive_windows.go b/pkg/archive/archive_windows.go index 828d3b9d0b..0bcbb925d2 100644 --- a/pkg/archive/archive_windows.go +++ b/pkg/archive/archive_windows.go @@ -9,6 +9,7 @@ import ( "path/filepath" "strings" + "github.com/containers/storage/pkg/idtools" "github.com/containers/storage/pkg/longpath" ) @@ -42,15 +43,23 @@ func CanonicalTarNameForPath(p string) (string, error) { // chmodTarEntry is used to adjust the file permissions used in tar header based // on the platform the archival is done. func chmodTarEntry(perm os.FileMode) os.FileMode { - perm &= 0755 + //perm &= 0755 // this 0-ed out tar flags (like link, regular file, directory marker etc.) + permPart := perm & os.ModePerm + noPermPart := perm &^ os.ModePerm // Add the x bit: make everything +x from windows - perm |= 0111 + permPart |= 0111 + permPart &= 0755 - return perm + return noPermPart | permPart } -func setHeaderForSpecialDevice(hdr *tar.Header, ta *tarAppender, name string, stat interface{}) (inode uint64, err error) { - // do nothing. no notion of Rdev, Inode, Nlink in stat on Windows +func setHeaderForSpecialDevice(hdr *tar.Header, name string, stat interface{}) (err error) { + // do nothing. no notion of Rdev, Nlink in stat on Windows + return +} + +func getInodeFromStat(stat interface{}) (inode uint64, err error) { + // do nothing. no notion of Inode in stat on Windows return } @@ -64,7 +73,7 @@ func handleLChmod(hdr *tar.Header, path string, hdrInfo os.FileInfo) error { return nil } -func getFileUIDGID(stat interface{}) (int, int, error) { +func getFileUIDGID(stat interface{}) (idtools.IDPair, error) { // no notion of file ownership mapping yet on Windows - return 0, 0, nil + return idtools.IDPair{0, 0}, nil } diff --git a/pkg/archive/archive_windows_test.go b/pkg/archive/archive_windows_test.go index 0c6733d6bd..849d99c2d3 100644 --- a/pkg/archive/archive_windows_test.go +++ b/pkg/archive/archive_windows_test.go @@ -14,7 +14,7 @@ func TestCopyFileWithInvalidDest(t *testing.T) { // recently changed in CopyWithTar as used to pass. Further investigation // is required. t.Skip("Currently fails") - folder, err := ioutil.TempDir("", "docker-archive-test") + folder, err := ioutil.TempDir("", "storage-archive-test") if err != nil { t.Fatal(err) } @@ -27,7 +27,7 @@ func TestCopyFileWithInvalidDest(t *testing.T) { t.Fatal(err) } ioutil.WriteFile(src, []byte("content"), 0777) - err = CopyWithTar(src, dest) + err = defaultCopyWithTar(src, dest) if err == nil { t.Fatalf("archiver.CopyWithTar should throw an error on invalid dest.") } @@ -82,6 +82,8 @@ func TestChmodTarEntry(t *testing.T) { {0644, 0755}, {0755, 0755}, {0444, 0555}, + {0755 | os.ModeDir, 0755 | os.ModeDir}, + {0755 | os.ModeSymlink, 0755 | os.ModeSymlink}, } for _, v := range cases { if out := chmodTarEntry(v.in); out != v.expected { diff --git a/pkg/archive/changes.go b/pkg/archive/changes.go index 488e12989a..6ba4b8ec6f 100644 --- a/pkg/archive/changes.go +++ b/pkg/archive/changes.go @@ -267,7 +267,7 @@ func (info *FileInfo) addChanges(oldInfo *FileInfo, changes *[]Change) { } for name, newChild := range info.children { - oldChild, _ := oldChildren[name] + oldChild := oldChildren[name] if oldChild != nil { // change? oldStat := oldChild.stat @@ -279,7 +279,7 @@ func (info *FileInfo) addChanges(oldInfo *FileInfo, changes *[]Change) { // breaks down is if some code intentionally hides a change by setting // back mtime if statDifferent(oldStat, newStat) || - bytes.Compare(oldChild.capability, newChild.capability) != 0 { + !bytes.Equal(oldChild.capability, newChild.capability) { change := Change{ Path: newChild.path(), Kind: ChangeModify, @@ -391,16 +391,11 @@ func ChangesSize(newDir string, changes []Change) int64 { } // ExportChanges produces an Archive from the provided changes, relative to dir. -func ExportChanges(dir string, changes []Change, uidMaps, gidMaps []idtools.IDMap) (Archive, error) { +func ExportChanges(dir string, changes []Change, uidMaps, gidMaps []idtools.IDMap) (io.ReadCloser, error) { reader, writer := io.Pipe() go func() { - ta := &tarAppender{ - TarWriter: tar.NewWriter(writer), - Buffer: pools.BufioWriter32KPool.Get(nil), - SeenFiles: make(map[uint64]string), - UIDMaps: uidMaps, - GIDMaps: gidMaps, - } + ta := newTarAppender(idtools.NewIDMappingsFromMaps(uidMaps, gidMaps), writer, nil) + // this buffer is needed for the duration of this piped stream defer pools.BufioWriter32KPool.Put(ta.Buffer) diff --git a/pkg/archive/changes_linux.go b/pkg/archive/changes_linux.go index 798d7bfccb..90c9a627e5 100644 --- a/pkg/archive/changes_linux.go +++ b/pkg/archive/changes_linux.go @@ -10,6 +10,7 @@ import ( "unsafe" "github.com/containers/storage/pkg/system" + "golang.org/x/sys/unix" ) // walker is used to implement collectFileInfoForChanges on linux. Where this @@ -65,7 +66,7 @@ func walkchunk(path string, fi os.FileInfo, dir string, root *FileInfo) error { } parent := root.LookUp(filepath.Dir(path)) if parent == nil { - return fmt.Errorf("collectFileInfoForChanges: Unexpectedly no parent for %s", path) + return fmt.Errorf("walkchunk: Unexpectedly no parent for %s", path) } info := &FileInfo{ name: filepath.Base(path), @@ -233,7 +234,7 @@ func readdirnames(dirname string) (names []nameIno, err error) { // Refill the buffer if necessary if bufp >= nbuf { bufp = 0 - nbuf, err = syscall.ReadDirent(int(f.Fd()), buf) // getdents on linux + nbuf, err = unix.ReadDirent(int(f.Fd()), buf) // getdents on linux if nbuf < 0 { nbuf = 0 } @@ -255,12 +256,12 @@ func readdirnames(dirname string) (names []nameIno, err error) { return sl, nil } -// parseDirent is a minor modification of syscall.ParseDirent (linux version) +// parseDirent is a minor modification of unix.ParseDirent (linux version) // which returns {name,inode} pairs instead of just names. func parseDirent(buf []byte, names []nameIno) (consumed int, newnames []nameIno) { origlen := len(buf) for len(buf) > 0 { - dirent := (*syscall.Dirent)(unsafe.Pointer(&buf[0])) + dirent := (*unix.Dirent)(unsafe.Pointer(&buf[0])) buf = buf[dirent.Reclen:] if dirent.Ino == 0 { // File absent in directory. continue @@ -293,7 +294,7 @@ func OverlayChanges(layers []string, rw string) ([]Change, error) { func overlayDeletedFile(root, path string, fi os.FileInfo) (string, error) { if fi.Mode()&os.ModeCharDevice != 0 { s := fi.Sys().(*syscall.Stat_t) - if major(uint64(s.Rdev)) == 0 && minor(uint64(s.Rdev)) == 0 { + if major(s.Rdev) == 0 && minor(s.Rdev) == 0 { return path, nil } } @@ -302,7 +303,7 @@ func overlayDeletedFile(root, path string, fi os.FileInfo) (string, error) { if err != nil { return "", err } - if opaque != nil && len(opaque) == 1 && opaque[0] == 'y' { + if len(opaque) == 1 && opaque[0] == 'y' { return path, nil } } diff --git a/pkg/archive/changes_posix_test.go b/pkg/archive/changes_posix_test.go index 5a3282b5a8..9d6defc565 100644 --- a/pkg/archive/changes_posix_test.go +++ b/pkg/archive/changes_posix_test.go @@ -7,20 +7,25 @@ import ( "io/ioutil" "os" "path" + "runtime" "sort" "testing" ) func TestHardLinkOrder(t *testing.T) { + //TODO Should run for Solaris + if runtime.GOOS == "solaris" { + t.Skip("gcp failures on Solaris") + } names := []string{"file1.txt", "file2.txt", "file3.txt"} msg := []byte("Hey y'all") // Create dir - src, err := ioutil.TempDir("", "docker-hardlink-test-src-") + src, err := ioutil.TempDir("", "storage-hardlink-test-src-") if err != nil { t.Fatal(err) } - //defer os.RemoveAll(src) + defer os.RemoveAll(src) for _, name := range names { func() { fh, err := os.Create(path.Join(src, name)) @@ -34,7 +39,7 @@ func TestHardLinkOrder(t *testing.T) { }() } // Create dest, with changes that includes hardlinks - dest, err := ioutil.TempDir("", "docker-hardlink-test-dest-") + dest, err := ioutil.TempDir("", "storage-hardlink-test-dest-") if err != nil { t.Fatal(err) } diff --git a/pkg/archive/changes_test.go b/pkg/archive/changes_test.go index f19c9fd148..c92b6ee38d 100644 --- a/pkg/archive/changes_test.go +++ b/pkg/archive/changes_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/containers/storage/pkg/system" + "github.com/stretchr/testify/require" ) func max(x, y int) int { @@ -22,6 +23,10 @@ func max(x, y int) int { func copyDir(src, dst string) error { cmd := exec.Command("cp", "-a", src, dst) + if runtime.GOOS == "solaris" { + cmd = exec.Command("gcp", "-a", src, dst) + } + if err := cmd.Run(); err != nil { return err } @@ -75,33 +80,29 @@ func createSampleDir(t *testing.T, root string) { for _, info := range files { p := path.Join(root, info.path) if info.filetype == Dir { - if err := os.MkdirAll(p, info.permissions); err != nil { - t.Fatal(err) - } + err := os.MkdirAll(p, info.permissions) + require.NoError(t, err) } else if info.filetype == Regular { - if err := ioutil.WriteFile(p, []byte(info.contents), info.permissions); err != nil { - t.Fatal(err) - } + err := ioutil.WriteFile(p, []byte(info.contents), info.permissions) + require.NoError(t, err) } else if info.filetype == Symlink { - if err := os.Symlink(info.contents, p); err != nil { - t.Fatal(err) - } + err := os.Symlink(info.contents, p) + require.NoError(t, err) } if info.filetype != Symlink { // Set a consistent ctime, atime for all files and dirs - if err := system.Chtimes(p, now, now); err != nil { - t.Fatal(err) - } + err := system.Chtimes(p, now, now) + require.NoError(t, err) } } } func TestChangeString(t *testing.T) { - modifiyChange := Change{"change", ChangeModify} - toString := modifiyChange.String() + modifyChange := Change{"change", ChangeModify} + toString := modifyChange.String() if toString != "C change" { - t.Fatalf("String() of a change with ChangeModifiy Kind should have been %s but was %s", "C change", toString) + t.Fatalf("String() of a change with ChangeModify Kind should have been %s but was %s", "C change", toString) } addChange := Change{"change", ChangeAdd} toString = addChange.String() @@ -121,21 +122,15 @@ func TestChangesWithNoChanges(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("symlinks on Windows") } - rwLayer, err := ioutil.TempDir("", "docker-changes-test") - if err != nil { - t.Fatal(err) - } + rwLayer, err := ioutil.TempDir("", "storage-changes-test") + require.NoError(t, err) defer os.RemoveAll(rwLayer) - layer, err := ioutil.TempDir("", "docker-changes-test-layer") - if err != nil { - t.Fatal(err) - } + layer, err := ioutil.TempDir("", "storage-changes-test-layer") + require.NoError(t, err) defer os.RemoveAll(layer) createSampleDir(t, layer) changes, err := Changes([]string{layer}, rwLayer) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if len(changes) != 0 { t.Fatalf("Changes with no difference should have detect no changes, but detected %d", len(changes)) } @@ -148,19 +143,15 @@ func TestChangesWithChanges(t *testing.T) { t.Skip("symlinks on Windows") } // Mock the readonly layer - layer, err := ioutil.TempDir("", "docker-changes-test-layer") - if err != nil { - t.Fatal(err) - } + layer, err := ioutil.TempDir("", "storage-changes-test-layer") + require.NoError(t, err) defer os.RemoveAll(layer) createSampleDir(t, layer) os.MkdirAll(path.Join(layer, "dir1/subfolder"), 0740) // Mock the RW layer - rwLayer, err := ioutil.TempDir("", "docker-changes-test") - if err != nil { - t.Fatal(err) - } + rwLayer, err := ioutil.TempDir("", "storage-changes-test") + require.NoError(t, err) defer os.RemoveAll(rwLayer) // Create a folder in RW layer @@ -177,9 +168,7 @@ func TestChangesWithChanges(t *testing.T) { ioutil.WriteFile(newFile, []byte{}, 0740) changes, err := Changes([]string{layer}, rwLayer) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) expectedChanges := []Change{ {"/dir1", ChangeModify}, @@ -198,7 +187,7 @@ func TestChangesWithChangesGH13590(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("symlinks on Windows") } - baseLayer, err := ioutil.TempDir("", "docker-changes-test.") + baseLayer, err := ioutil.TempDir("", "storage-changes-test.") defer os.RemoveAll(baseLayer) dir3 := path.Join(baseLayer, "dir1/dir2/dir3") @@ -207,7 +196,7 @@ func TestChangesWithChangesGH13590(t *testing.T) { file := path.Join(dir3, "file.txt") ioutil.WriteFile(file, []byte("hello"), 0666) - layer, err := ioutil.TempDir("", "docker-changes-test2.") + layer, err := ioutil.TempDir("", "storage-changes-test2.") defer os.RemoveAll(layer) // Test creating a new file @@ -220,9 +209,7 @@ func TestChangesWithChangesGH13590(t *testing.T) { ioutil.WriteFile(file, []byte("bye"), 0666) changes, err := Changes([]string{baseLayer}, layer) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) expectedChanges := []Change{ {"/dir1/dir2/dir3", ChangeModify}, @@ -231,7 +218,7 @@ func TestChangesWithChangesGH13590(t *testing.T) { checkChanges(expectedChanges, changes, t) // Now test changing a file - layer, err = ioutil.TempDir("", "docker-changes-test3.") + layer, err = ioutil.TempDir("", "storage-changes-test3.") defer os.RemoveAll(layer) if err := copyDir(baseLayer+"/dir1", layer+"/"); err != nil { @@ -242,9 +229,7 @@ func TestChangesWithChangesGH13590(t *testing.T) { ioutil.WriteFile(file, []byte("bye"), 0666) changes, err = Changes([]string{baseLayer}, layer) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) expectedChanges = []Change{ {"/dir1/dir2/dir3/file.txt", ChangeModify}, @@ -256,24 +241,20 @@ func TestChangesWithChangesGH13590(t *testing.T) { func TestChangesDirsEmpty(t *testing.T) { // TODO Windows. There may be a way of running this, but turning off for now // as createSampleDir uses symlinks. - if runtime.GOOS == "windows" { - t.Skip("symlinks on Windows") - } - src, err := ioutil.TempDir("", "docker-changes-test") - if err != nil { - t.Fatal(err) + // TODO Should work for Solaris + if runtime.GOOS == "windows" || runtime.GOOS == "solaris" { + t.Skip("symlinks on Windows; gcp failure on Solaris") } + src, err := ioutil.TempDir("", "storage-changes-test") + require.NoError(t, err) defer os.RemoveAll(src) createSampleDir(t, src) dst := src + "-copy" - if err := copyDir(src, dst); err != nil { - t.Fatal(err) - } + err = copyDir(src, dst) + require.NoError(t, err) defer os.RemoveAll(dst) changes, err := ChangesDirs(dst, src) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if len(changes) != 0 { t.Fatalf("Reported changes for identical dirs: %v", changes) @@ -284,107 +265,87 @@ func TestChangesDirsEmpty(t *testing.T) { func mutateSampleDir(t *testing.T, root string) { // Remove a regular file - if err := os.RemoveAll(path.Join(root, "file1")); err != nil { - t.Fatal(err) - } + err := os.RemoveAll(path.Join(root, "file1")) + require.NoError(t, err) // Remove a directory - if err := os.RemoveAll(path.Join(root, "dir1")); err != nil { - t.Fatal(err) - } + err = os.RemoveAll(path.Join(root, "dir1")) + require.NoError(t, err) // Remove a symlink - if err := os.RemoveAll(path.Join(root, "symlink1")); err != nil { - t.Fatal(err) - } + err = os.RemoveAll(path.Join(root, "symlink1")) + require.NoError(t, err) // Rewrite a file - if err := ioutil.WriteFile(path.Join(root, "file2"), []byte("fileNN\n"), 0777); err != nil { - t.Fatal(err) - } + err = ioutil.WriteFile(path.Join(root, "file2"), []byte("fileNN\n"), 0777) + require.NoError(t, err) // Replace a file - if err := os.RemoveAll(path.Join(root, "file3")); err != nil { - t.Fatal(err) - } - if err := ioutil.WriteFile(path.Join(root, "file3"), []byte("fileMM\n"), 0404); err != nil { - t.Fatal(err) - } + err = os.RemoveAll(path.Join(root, "file3")) + require.NoError(t, err) + err = ioutil.WriteFile(path.Join(root, "file3"), []byte("fileMM\n"), 0404) + require.NoError(t, err) // Touch file - if err := system.Chtimes(path.Join(root, "file4"), time.Now().Add(time.Second), time.Now().Add(time.Second)); err != nil { - t.Fatal(err) - } + err = system.Chtimes(path.Join(root, "file4"), time.Now().Add(time.Second), time.Now().Add(time.Second)) + require.NoError(t, err) // Replace file with dir - if err := os.RemoveAll(path.Join(root, "file5")); err != nil { - t.Fatal(err) - } - if err := os.MkdirAll(path.Join(root, "file5"), 0666); err != nil { - t.Fatal(err) - } + err = os.RemoveAll(path.Join(root, "file5")) + require.NoError(t, err) + err = os.MkdirAll(path.Join(root, "file5"), 0666) + require.NoError(t, err) // Create new file - if err := ioutil.WriteFile(path.Join(root, "filenew"), []byte("filenew\n"), 0777); err != nil { - t.Fatal(err) - } + err = ioutil.WriteFile(path.Join(root, "filenew"), []byte("filenew\n"), 0777) + require.NoError(t, err) // Create new dir - if err := os.MkdirAll(path.Join(root, "dirnew"), 0766); err != nil { - t.Fatal(err) - } + err = os.MkdirAll(path.Join(root, "dirnew"), 0766) + require.NoError(t, err) // Create a new symlink - if err := os.Symlink("targetnew", path.Join(root, "symlinknew")); err != nil { - t.Fatal(err) - } + err = os.Symlink("targetnew", path.Join(root, "symlinknew")) + require.NoError(t, err) // Change a symlink - if err := os.RemoveAll(path.Join(root, "symlink2")); err != nil { - t.Fatal(err) - } - if err := os.Symlink("target2change", path.Join(root, "symlink2")); err != nil { - t.Fatal(err) - } + err = os.RemoveAll(path.Join(root, "symlink2")) + require.NoError(t, err) + + err = os.Symlink("target2change", path.Join(root, "symlink2")) + require.NoError(t, err) // Replace dir with file - if err := os.RemoveAll(path.Join(root, "dir2")); err != nil { - t.Fatal(err) - } - if err := ioutil.WriteFile(path.Join(root, "dir2"), []byte("dir2\n"), 0777); err != nil { - t.Fatal(err) - } + err = os.RemoveAll(path.Join(root, "dir2")) + require.NoError(t, err) + err = ioutil.WriteFile(path.Join(root, "dir2"), []byte("dir2\n"), 0777) + require.NoError(t, err) // Touch dir - if err := system.Chtimes(path.Join(root, "dir3"), time.Now().Add(time.Second), time.Now().Add(time.Second)); err != nil { - t.Fatal(err) - } + err = system.Chtimes(path.Join(root, "dir3"), time.Now().Add(time.Second), time.Now().Add(time.Second)) + require.NoError(t, err) } func TestChangesDirsMutated(t *testing.T) { // TODO Windows. There may be a way of running this, but turning off for now // as createSampleDir uses symlinks. - if runtime.GOOS == "windows" { - t.Skip("symlinks on Windows") - } - src, err := ioutil.TempDir("", "docker-changes-test") - if err != nil { - t.Fatal(err) + // TODO Should work for Solaris + if runtime.GOOS == "windows" || runtime.GOOS == "solaris" { + t.Skip("symlinks on Windows; gcp failures on Solaris") } + src, err := ioutil.TempDir("", "storage-changes-test") + require.NoError(t, err) createSampleDir(t, src) dst := src + "-copy" - if err := copyDir(src, dst); err != nil { - t.Fatal(err) - } + err = copyDir(src, dst) + require.NoError(t, err) defer os.RemoveAll(src) defer os.RemoveAll(dst) mutateSampleDir(t, dst) changes, err := ChangesDirs(dst, src) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) sort.Sort(changesByPath(changes)) @@ -425,45 +386,34 @@ func TestChangesDirsMutated(t *testing.T) { func TestApplyLayer(t *testing.T) { // TODO Windows. There may be a way of running this, but turning off for now // as createSampleDir uses symlinks. - if runtime.GOOS == "windows" { - t.Skip("symlinks on Windows") - } - src, err := ioutil.TempDir("", "docker-changes-test") - if err != nil { - t.Fatal(err) + // TODO Should work for Solaris + if runtime.GOOS == "windows" || runtime.GOOS == "solaris" { + t.Skip("symlinks on Windows; gcp failures on Solaris") } + src, err := ioutil.TempDir("", "storage-changes-test") + require.NoError(t, err) createSampleDir(t, src) defer os.RemoveAll(src) dst := src + "-copy" - if err := copyDir(src, dst); err != nil { - t.Fatal(err) - } + err = copyDir(src, dst) + require.NoError(t, err) mutateSampleDir(t, dst) defer os.RemoveAll(dst) changes, err := ChangesDirs(dst, src) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) layer, err := ExportChanges(dst, changes, nil, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) layerCopy, err := NewTempArchive(layer, "") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - if _, err := ApplyLayer(src, layerCopy); err != nil { - t.Fatal(err) - } + _, err = ApplyLayer(src, layerCopy) + require.NoError(t, err) changes2, err := ChangesDirs(src, dst) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if len(changes2) != 0 { t.Fatalf("Unexpected differences after reapplying mutation: %v", changes2) @@ -476,27 +426,19 @@ func TestChangesSizeWithHardlinks(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("hardlinks on Windows") } - srcDir, err := ioutil.TempDir("", "docker-test-srcDir") - if err != nil { - t.Fatal(err) - } + srcDir, err := ioutil.TempDir("", "storage-test-srcDir") + require.NoError(t, err) defer os.RemoveAll(srcDir) - destDir, err := ioutil.TempDir("", "docker-test-destDir") - if err != nil { - t.Fatal(err) - } + destDir, err := ioutil.TempDir("", "storage-test-destDir") + require.NoError(t, err) defer os.RemoveAll(destDir) creationSize, err := prepareUntarSourceDirectory(100, destDir, true) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) changes, err := ChangesDirs(destDir, srcDir) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) got := ChangesSize(destDir, changes) if got != int64(creationSize) { @@ -522,16 +464,15 @@ func TestChangesSizeWithOnlyDeleteChanges(t *testing.T) { } func TestChangesSize(t *testing.T) { - parentPath, err := ioutil.TempDir("", "docker-changes-test") + parentPath, err := ioutil.TempDir("", "storage-changes-test") defer os.RemoveAll(parentPath) addition := path.Join(parentPath, "addition") - if err := ioutil.WriteFile(addition, []byte{0x01, 0x01, 0x01}, 0744); err != nil { - t.Fatal(err) - } + err = ioutil.WriteFile(addition, []byte{0x01, 0x01, 0x01}, 0744) + require.NoError(t, err) modification := path.Join(parentPath, "modification") - if err = ioutil.WriteFile(modification, []byte{0x01, 0x01, 0x01}, 0744); err != nil { - t.Fatal(err) - } + err = ioutil.WriteFile(modification, []byte{0x01, 0x01, 0x01}, 0744) + require.NoError(t, err) + changes := []Change{ {Path: "addition", Kind: ChangeAdd}, {Path: "modification", Kind: ChangeModify}, diff --git a/pkg/archive/changes_unix.go b/pkg/archive/changes_unix.go index 43dd94e2d7..d669c01b46 100644 --- a/pkg/archive/changes_unix.go +++ b/pkg/archive/changes_unix.go @@ -7,6 +7,7 @@ import ( "syscall" "github.com/containers/storage/pkg/system" + "golang.org/x/sys/unix" ) func statDifferent(oldStat *system.StatT, newStat *system.StatT) bool { @@ -16,7 +17,7 @@ func statDifferent(oldStat *system.StatT, newStat *system.StatT) bool { oldStat.GID() != newStat.GID() || oldStat.Rdev() != newStat.Rdev() || // Don't look at size for dirs, its not a good measure of change - (oldStat.Mode()&syscall.S_IFDIR != syscall.S_IFDIR && + (oldStat.Mode()&unix.S_IFDIR != unix.S_IFDIR && (!sameFsTimeSpec(oldStat.Mtim(), newStat.Mtim()) || (oldStat.Size() != newStat.Size()))) { return true } @@ -24,11 +25,11 @@ func statDifferent(oldStat *system.StatT, newStat *system.StatT) bool { } func (info *FileInfo) isDir() bool { - return info.parent == nil || info.stat.Mode()&syscall.S_IFDIR != 0 + return info.parent == nil || info.stat.Mode()&unix.S_IFDIR != 0 } func getIno(fi os.FileInfo) uint64 { - return uint64(fi.Sys().(*syscall.Stat_t).Ino) + return fi.Sys().(*syscall.Stat_t).Ino } func hasHardlinks(fi os.FileInfo) bool { diff --git a/pkg/archive/changes_windows.go b/pkg/archive/changes_windows.go index 06eadd662b..5ad3d7e38d 100644 --- a/pkg/archive/changes_windows.go +++ b/pkg/archive/changes_windows.go @@ -9,16 +9,16 @@ import ( func statDifferent(oldStat *system.StatT, newStat *system.StatT) bool { // Don't look at size for dirs, its not a good measure of change - if oldStat.ModTime() != newStat.ModTime() || + if oldStat.Mtim() != newStat.Mtim() || oldStat.Mode() != newStat.Mode() || - oldStat.Size() != newStat.Size() && !oldStat.IsDir() { + oldStat.Size() != newStat.Size() && !oldStat.Mode().IsDir() { return true } return false } func (info *FileInfo) isDir() bool { - return info.parent == nil || info.stat.IsDir() + return info.parent == nil || info.stat.Mode().IsDir() } func getIno(fi os.FileInfo) (inode uint64) { diff --git a/pkg/archive/copy.go b/pkg/archive/copy.go index c970f422aa..ea012b2d99 100644 --- a/pkg/archive/copy.go +++ b/pkg/archive/copy.go @@ -88,13 +88,13 @@ func SplitPathDirEntry(path string) (dir, base string) { // This function acts as a convenient wrapper around TarWithOptions, which // requires a directory as the source path. TarResource accepts either a // directory or a file path and correctly sets the Tar options. -func TarResource(sourceInfo CopyInfo) (content Archive, err error) { +func TarResource(sourceInfo CopyInfo) (content io.ReadCloser, err error) { return TarResourceRebase(sourceInfo.Path, sourceInfo.RebaseName) } // TarResourceRebase is like TarResource but renames the first path element of // items in the resulting tar archive to match the given rebaseName if not "". -func TarResourceRebase(sourcePath, rebaseName string) (content Archive, err error) { +func TarResourceRebase(sourcePath, rebaseName string) (content io.ReadCloser, err error) { sourcePath = normalizePath(sourcePath) if _, err = os.Lstat(sourcePath); err != nil { // Catches the case where the source does not exist or is not a @@ -103,7 +103,7 @@ func TarResourceRebase(sourcePath, rebaseName string) (content Archive, err erro return } - // Separate the source path between it's directory and + // Separate the source path between its directory and // the entry in that directory which we are archiving. sourceDir, sourceBase := SplitPathDirEntry(sourcePath) @@ -241,7 +241,7 @@ func CopyInfoDestinationPath(path string) (info CopyInfo, err error) { // contain the archived resource described by srcInfo, to the destination // described by dstInfo. Returns the possibly modified content archive along // with the path to the destination directory which it should be extracted to. -func PrepareArchiveCopy(srcContent Reader, srcInfo, dstInfo CopyInfo) (dstDir string, content Archive, err error) { +func PrepareArchiveCopy(srcContent io.Reader, srcInfo, dstInfo CopyInfo) (dstDir string, content io.ReadCloser, err error) { // Ensure in platform semantics srcInfo.Path = normalizePath(srcInfo.Path) dstInfo.Path = normalizePath(dstInfo.Path) @@ -304,7 +304,7 @@ func PrepareArchiveCopy(srcContent Reader, srcInfo, dstInfo CopyInfo) (dstDir st // RebaseArchiveEntries rewrites the given srcContent archive replacing // an occurrence of oldBase with newBase at the beginning of entry names. -func RebaseArchiveEntries(srcContent Reader, oldBase, newBase string) Archive { +func RebaseArchiveEntries(srcContent io.Reader, oldBase, newBase string) io.ReadCloser { if oldBase == string(os.PathSeparator) { // If oldBase specifies the root directory, use an empty string as // oldBase instead so that newBase doesn't replace the path separator @@ -332,6 +332,9 @@ func RebaseArchiveEntries(srcContent Reader, oldBase, newBase string) Archive { } hdr.Name = strings.Replace(hdr.Name, oldBase, newBase, 1) + if hdr.Typeflag == tar.TypeLink { + hdr.Linkname = strings.Replace(hdr.Linkname, oldBase, newBase, 1) + } if err = rebasedTar.WriteHeader(hdr); err != nil { w.CloseWithError(err) @@ -380,7 +383,7 @@ func CopyResource(srcPath, dstPath string, followLink bool) error { // CopyTo handles extracting the given content whose // entries should be sourced from srcInfo to dstPath. -func CopyTo(content Reader, srcInfo CopyInfo, dstPath string) error { +func CopyTo(content io.Reader, srcInfo CopyInfo, dstPath string) error { // The destination path need not exist, but CopyInfoDestinationPath will // ensure that at least the parent directory exists. dstInfo, err := CopyInfoDestinationPath(normalizePath(dstPath)) diff --git a/pkg/archive/copy_unix_test.go b/pkg/archive/copy_unix_test.go index ecbfc172b0..e08bcb4916 100644 --- a/pkg/archive/copy_unix_test.go +++ b/pkg/archive/copy_unix_test.go @@ -1,6 +1,6 @@ // +build !windows -// TODO Windows: Some of these tests may be salvagable and portable to Windows. +// TODO Windows: Some of these tests may be salvageable and portable to Windows. package archive @@ -15,6 +15,8 @@ import ( "path/filepath" "strings" "testing" + + "github.com/stretchr/testify/require" ) func removeAllPaths(paths ...string) { @@ -26,13 +28,11 @@ func removeAllPaths(paths ...string) { func getTestTempDirs(t *testing.T) (tmpDirA, tmpDirB string) { var err error - if tmpDirA, err = ioutil.TempDir("", "archive-copy-test"); err != nil { - t.Fatal(err) - } + tmpDirA, err = ioutil.TempDir("", "archive-copy-test") + require.NoError(t, err) - if tmpDirB, err = ioutil.TempDir("", "archive-copy-test"); err != nil { - t.Fatal(err) - } + tmpDirB, err = ioutil.TempDir("", "archive-copy-test") + require.NoError(t, err) return } @@ -118,9 +118,8 @@ func logDirContents(t *testing.T, dirPath string) { t.Logf("logging directory contents: %q", dirPath) - if err := filepath.Walk(dirPath, logWalkedPaths); err != nil { - t.Fatal(err) - } + err := filepath.Walk(dirPath, logWalkedPaths) + require.NoError(t, err) } func testCopyHelper(t *testing.T, srcPath, dstPath string) (err error) { @@ -293,9 +292,8 @@ func TestCopyCaseA(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = fileContentsEqual(t, srcPath, dstPath); err != nil { - t.Fatal(err) - } + err = fileContentsEqual(t, srcPath, dstPath) + require.NoError(t, err) os.Remove(dstPath) symlinkPath := filepath.Join(tmpDirA, "symlink3") @@ -306,17 +304,15 @@ func TestCopyCaseA(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = fileContentsEqual(t, linkTarget, dstPath); err != nil { - t.Fatal(err) - } + err = fileContentsEqual(t, linkTarget, dstPath) + require.NoError(t, err) os.Remove(dstPath) if err = testCopyHelperFSym(t, symlinkPath1, dstPath); err != nil { t.Fatalf("unexpected error %T: %s", err, err) } - if err = fileContentsEqual(t, linkTarget, dstPath); err != nil { - t.Fatal(err) - } + err = fileContentsEqual(t, linkTarget, dstPath) + require.NoError(t, err) } // B. SRC specifies a file and DST (with trailing path separator) doesn't @@ -377,9 +373,8 @@ func TestCopyCaseC(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = fileContentsEqual(t, srcPath, dstPath); err != nil { - t.Fatal(err) - } + err = fileContentsEqual(t, srcPath, dstPath) + require.NoError(t, err) } // C. Symbol link following version: @@ -415,9 +410,8 @@ func TestCopyCaseCFSym(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = fileContentsEqual(t, linkTarget, dstPath); err != nil { - t.Fatal(err) - } + err = fileContentsEqual(t, linkTarget, dstPath) + require.NoError(t, err) } // D. SRC specifies a file and DST exists as a directory. This should place @@ -446,9 +440,8 @@ func TestCopyCaseD(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = fileContentsEqual(t, srcPath, dstPath); err != nil { - t.Fatal(err) - } + err = fileContentsEqual(t, srcPath, dstPath) + require.NoError(t, err) // Now try again but using a trailing path separator for dstDir. @@ -466,9 +459,8 @@ func TestCopyCaseD(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = fileContentsEqual(t, srcPath, dstPath); err != nil { - t.Fatal(err) - } + err = fileContentsEqual(t, srcPath, dstPath) + require.NoError(t, err) } // D. Symbol link following version: @@ -499,9 +491,8 @@ func TestCopyCaseDFSym(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = fileContentsEqual(t, linkTarget, dstPath); err != nil { - t.Fatal(err) - } + err = fileContentsEqual(t, linkTarget, dstPath) + require.NoError(t, err) // Now try again but using a trailing path separator for dstDir. @@ -519,9 +510,8 @@ func TestCopyCaseDFSym(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = fileContentsEqual(t, linkTarget, dstPath); err != nil { - t.Fatal(err) - } + err = fileContentsEqual(t, linkTarget, dstPath) + require.NoError(t, err) } // E. SRC specifies a directory and DST does not exist. This should create a @@ -563,9 +553,8 @@ func TestCopyCaseE(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = dirContentsEqual(t, dstDir, srcDir); err != nil { - t.Fatal(err) - } + err = dirContentsEqual(t, dstDir, srcDir) + require.NoError(t, err) } // E. Symbol link following version: @@ -609,9 +598,8 @@ func TestCopyCaseEFSym(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = dirContentsEqual(t, dstDir, linkTarget); err != nil { - t.Fatal(err) - } + err = dirContentsEqual(t, dstDir, linkTarget) + require.NoError(t, err) } // F. SRC specifies a directory and DST exists as a file. This should cause an @@ -669,9 +657,8 @@ func TestCopyCaseG(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = dirContentsEqual(t, resultDir, srcDir); err != nil { - t.Fatal(err) - } + err = dirContentsEqual(t, resultDir, srcDir) + require.NoError(t, err) // Now try again but using a trailing path separator for dstDir. @@ -689,9 +676,8 @@ func TestCopyCaseG(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = dirContentsEqual(t, resultDir, srcDir); err != nil { - t.Fatal(err) - } + err = dirContentsEqual(t, resultDir, srcDir) + require.NoError(t, err) } // G. Symbol link version: @@ -717,9 +703,8 @@ func TestCopyCaseGFSym(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = dirContentsEqual(t, resultDir, linkTarget); err != nil { - t.Fatal(err) - } + err = dirContentsEqual(t, resultDir, linkTarget) + require.NoError(t, err) // Now try again but using a trailing path separator for dstDir. @@ -737,9 +722,8 @@ func TestCopyCaseGFSym(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = dirContentsEqual(t, resultDir, linkTarget); err != nil { - t.Fatal(err) - } + err = dirContentsEqual(t, resultDir, linkTarget) + require.NoError(t, err) } // H. SRC specifies a directory's contents only and DST does not exist. This @@ -899,9 +883,8 @@ func TestCopyCaseJ(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = dirContentsEqual(t, dstDir, srcDir); err != nil { - t.Fatal(err) - } + err = dirContentsEqual(t, dstDir, srcDir) + require.NoError(t, err) // Now try again but using a trailing path separator for dstDir. @@ -919,9 +902,8 @@ func TestCopyCaseJ(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = dirContentsEqual(t, dstDir, srcDir); err != nil { - t.Fatal(err) - } + err = dirContentsEqual(t, dstDir, srcDir) + require.NoError(t, err) } // J. Symbol link following version: @@ -952,9 +934,8 @@ func TestCopyCaseJFSym(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = dirContentsEqual(t, dstDir, linkTarget); err != nil { - t.Fatal(err) - } + err = dirContentsEqual(t, dstDir, linkTarget) + require.NoError(t, err) // Now try again but using a trailing path separator for dstDir. @@ -972,7 +953,6 @@ func TestCopyCaseJFSym(t *testing.T) { t.Fatalf("unexpected error %T: %s", err, err) } - if err = dirContentsEqual(t, dstDir, linkTarget); err != nil { - t.Fatal(err) - } + err = dirContentsEqual(t, dstDir, linkTarget) + require.NoError(t, err) } diff --git a/pkg/archive/diff.go b/pkg/archive/diff.go index c7ad4d9407..f93f4cb175 100644 --- a/pkg/archive/diff.go +++ b/pkg/archive/diff.go @@ -19,7 +19,7 @@ import ( // UnpackLayer unpack `layer` to a `dest`. The stream `layer` can be // compressed or uncompressed. // Returns the size in bytes of the contents of the layer. -func UnpackLayer(dest string, layer Reader, options *TarOptions) (size int64, err error) { +func UnpackLayer(dest string, layer io.Reader, options *TarOptions) (size int64, err error) { tr := tar.NewReader(layer) trBuf := pools.BufioReader32KPool.Get(tr) defer pools.BufioReader32KPool.Put(trBuf) @@ -33,17 +33,11 @@ func UnpackLayer(dest string, layer Reader, options *TarOptions) (size int64, er if options.ExcludePatterns == nil { options.ExcludePatterns = []string{} } - remappedRootUID, remappedRootGID, err := idtools.GetRootUIDGID(options.UIDMaps, options.GIDMaps) - if err != nil { - return 0, err - } + idMappings := idtools.NewIDMappingsFromMaps(options.UIDMaps, options.GIDMaps) aufsTempdir := "" aufsHardlinks := make(map[string]*tar.Header) - if options == nil { - options = &TarOptions{} - } // Iterate through the files in the archive. for { hdr, err := tr.Next() @@ -90,7 +84,7 @@ func UnpackLayer(dest string, layer Reader, options *TarOptions) (size int64, er parentPath := filepath.Join(dest, parent) if _, err := os.Lstat(parentPath); err != nil && os.IsNotExist(err) { - err = system.MkdirAll(parentPath, 0600) + err = system.MkdirAll(parentPath, 0600, "") if err != nil { return 0, err } @@ -111,7 +105,7 @@ func UnpackLayer(dest string, layer Reader, options *TarOptions) (size int64, er } defer os.RemoveAll(aufsTempdir) } - if err := createTarFile(filepath.Join(aufsTempdir, basename), dest, hdr, tr, true, nil); err != nil { + if err := createTarFile(filepath.Join(aufsTempdir, basename), dest, hdr, tr, true, nil, options.InUserNS); err != nil { return 0, err } } @@ -198,28 +192,11 @@ func UnpackLayer(dest string, layer Reader, options *TarOptions) (size int64, er srcData = tmpFile } - // if the options contain a uid & gid maps, convert header uid/gid - // entries using the maps such that lchown sets the proper mapped - // uid/gid after writing the file. We only perform this mapping if - // the file isn't already owned by the remapped root UID or GID, as - // that specific uid/gid has no mapping from container -> host, and - // those files already have the proper ownership for inside the - // container. - if srcHdr.Uid != remappedRootUID { - xUID, err := idtools.ToHost(srcHdr.Uid, options.UIDMaps) - if err != nil { - return 0, err - } - srcHdr.Uid = xUID - } - if srcHdr.Gid != remappedRootGID { - xGID, err := idtools.ToHost(srcHdr.Gid, options.GIDMaps) - if err != nil { - return 0, err - } - srcHdr.Gid = xGID + if err := remapIDs(idMappings, srcHdr); err != nil { + return 0, err } - if err := createTarFile(path, dest, srcHdr, srcData, true, nil); err != nil { + + if err := createTarFile(path, dest, srcHdr, srcData, true, nil, options.InUserNS); err != nil { return 0, err } @@ -246,7 +223,7 @@ func UnpackLayer(dest string, layer Reader, options *TarOptions) (size int64, er // and applies it to the directory `dest`. The stream `layer` can be // compressed or uncompressed. // Returns the size in bytes of the contents of the layer. -func ApplyLayer(dest string, layer Reader) (int64, error) { +func ApplyLayer(dest string, layer io.Reader) (int64, error) { return applyLayerHandler(dest, layer, &TarOptions{}, true) } @@ -254,12 +231,12 @@ func ApplyLayer(dest string, layer Reader) (int64, error) { // `layer`, and applies it to the directory `dest`. The stream `layer` // can only be uncompressed. // Returns the size in bytes of the contents of the layer. -func ApplyUncompressedLayer(dest string, layer Reader, options *TarOptions) (int64, error) { +func ApplyUncompressedLayer(dest string, layer io.Reader, options *TarOptions) (int64, error) { return applyLayerHandler(dest, layer, options, false) } // do the bulk load of ApplyLayer, but allow for not calling DecompressStream -func applyLayerHandler(dest string, layer Reader, options *TarOptions, decompress bool) (int64, error) { +func applyLayerHandler(dest string, layer io.Reader, options *TarOptions, decompress bool) (int64, error) { dest = filepath.Clean(dest) // We need to be able to set any perms diff --git a/pkg/archive/diff_test.go b/pkg/archive/diff_test.go index b6f654b76a..d3fec31ed9 100644 --- a/pkg/archive/diff_test.go +++ b/pkg/archive/diff_test.go @@ -35,7 +35,7 @@ func TestApplyLayerInvalidFilenames(t *testing.T) { }, }, } { - if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidFilenames", headers); err != nil { + if err := testBreakout("applylayer", "storage-TestApplyLayerInvalidFilenames", headers); err != nil { t.Fatalf("i=%d. %v", i, err) } } @@ -118,7 +118,7 @@ func TestApplyLayerInvalidHardlink(t *testing.T) { }, }, } { - if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidHardlink", headers); err != nil { + if err := testBreakout("applylayer", "storage-TestApplyLayerInvalidHardlink", headers); err != nil { t.Fatalf("i=%d. %v", i, err) } } @@ -201,7 +201,7 @@ func TestApplyLayerInvalidSymlink(t *testing.T) { }, }, } { - if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidSymlink", headers); err != nil { + if err := testBreakout("applylayer", "storage-TestApplyLayerInvalidSymlink", headers); err != nil { t.Fatalf("i=%d. %v", i, err) } } diff --git a/pkg/archive/utils_test.go b/pkg/archive/utils_test.go index 98719032f3..01b9e92d1c 100644 --- a/pkg/archive/utils_test.go +++ b/pkg/archive/utils_test.go @@ -16,7 +16,7 @@ var testUntarFns = map[string]func(string, io.Reader) error{ return Untar(r, dest, nil) }, "applylayer": func(dest string, r io.Reader) error { - _, err := ApplyLayer(dest, Reader(r)) + _, err := ApplyLayer(dest, r) return err }, } diff --git a/pkg/archive/wrap.go b/pkg/archive/wrap.go index dfb335c0b6..b39d12c878 100644 --- a/pkg/archive/wrap.go +++ b/pkg/archive/wrap.go @@ -3,7 +3,7 @@ package archive import ( "archive/tar" "bytes" - "io/ioutil" + "io" ) // Generate generates a new archive from the content provided @@ -22,7 +22,7 @@ import ( // // FIXME: stream content instead of buffering // FIXME: specify permissions and other archive metadata -func Generate(input ...string) (Archive, error) { +func Generate(input ...string) (io.Reader, error) { files := parseStringPairs(input...) buf := new(bytes.Buffer) tw := tar.NewWriter(buf) @@ -42,7 +42,7 @@ func Generate(input ...string) (Archive, error) { if err := tw.Close(); err != nil { return nil, err } - return ioutil.NopCloser(buf), nil + return buf, nil } func parseStringPairs(input ...string) (output [][2]string) { diff --git a/pkg/archive/wrap_test.go b/pkg/archive/wrap_test.go index 46ab36697a..bd26bda3a2 100644 --- a/pkg/archive/wrap_test.go +++ b/pkg/archive/wrap_test.go @@ -5,13 +5,13 @@ import ( "bytes" "io" "testing" + + "github.com/stretchr/testify/require" ) func TestGenerateEmptyFile(t *testing.T) { archive, err := Generate("emptyFile") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if archive == nil { t.Fatal("The generated archive should not be nil.") } @@ -28,9 +28,7 @@ func TestGenerateEmptyFile(t *testing.T) { if err == io.EOF { break } - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) buf := new(bytes.Buffer) buf.ReadFrom(tr) content := buf.String() @@ -54,9 +52,7 @@ func TestGenerateEmptyFile(t *testing.T) { func TestGenerateWithContent(t *testing.T) { archive, err := Generate("file", "content") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if archive == nil { t.Fatal("The generated archive should not be nil.") } @@ -73,9 +69,7 @@ func TestGenerateWithContent(t *testing.T) { if err == io.EOF { break } - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) buf := new(bytes.Buffer) buf.ReadFrom(tr) content := buf.String() diff --git a/pkg/chrootarchive/archive.go b/pkg/chrootarchive/archive.go index 649575c006..2735f14001 100644 --- a/pkg/chrootarchive/archive.go +++ b/pkg/chrootarchive/archive.go @@ -11,7 +11,13 @@ import ( "github.com/containers/storage/pkg/idtools" ) -var chrootArchiver = &archive.Archiver{Untar: Untar} +// NewArchiver returns a new Archiver which uses chrootarchive.Untar +func NewArchiver(idMappings *idtools.IDMappings) *archive.Archiver { + if idMappings == nil { + idMappings = &idtools.IDMappings{} + } + return &archive.Archiver{Untar: Untar, IDMappings: idMappings} +} // Untar reads a stream of bytes from `archive`, parses it as a tar archive, // and unpacks it into the directory at `dest`. @@ -30,7 +36,6 @@ func UntarUncompressed(tarArchive io.Reader, dest string, options *archive.TarOp // Handler for teasing out the automatic decompression func untarHandler(tarArchive io.Reader, dest string, options *archive.TarOptions, decompress bool) error { - if tarArchive == nil { return fmt.Errorf("Empty archive") } @@ -41,14 +46,12 @@ func untarHandler(tarArchive io.Reader, dest string, options *archive.TarOptions options.ExcludePatterns = []string{} } - rootUID, rootGID, err := idtools.GetRootUIDGID(options.UIDMaps, options.GIDMaps) - if err != nil { - return err - } + idMappings := idtools.NewIDMappingsFromMaps(options.UIDMaps, options.GIDMaps) + rootIDs := idMappings.RootPair() dest = filepath.Clean(dest) if _, err := os.Stat(dest); os.IsNotExist(err) { - if err := idtools.MkdirAllNewAs(dest, 0755, rootUID, rootGID); err != nil { + if err := idtools.MkdirAllAndChownNew(dest, 0755, rootIDs); err != nil { return err } } @@ -65,33 +68,3 @@ func untarHandler(tarArchive io.Reader, dest string, options *archive.TarOptions return invokeUnpack(r, dest, options) } - -// TarUntar is a convenience function which calls Tar and Untar, with the output of one piped into the other. -// If either Tar or Untar fails, TarUntar aborts and returns the error. -func TarUntar(src, dst string) error { - return chrootArchiver.TarUntar(src, dst) -} - -// CopyWithTar creates a tar archive of filesystem path `src`, and -// unpacks it at filesystem path `dst`. -// The archive is streamed directly with fixed buffering and no -// intermediary disk IO. -func CopyWithTar(src, dst string) error { - return chrootArchiver.CopyWithTar(src, dst) -} - -// CopyFileWithTar emulates the behavior of the 'cp' command-line -// for a single file. It copies a regular file from path `src` to -// path `dst`, and preserves all its metadata. -// -// If `dst` ends with a trailing slash '/' ('\' on Windows), the final -// destination path will be `dst/base(src)` or `dst\base(src)` -func CopyFileWithTar(src, dst string) (err error) { - return chrootArchiver.CopyFileWithTar(src, dst) -} - -// UntarPath is a convenience function which looks for an archive -// at filesystem path `src`, and unpacks it at `dst`. -func UntarPath(src, dst string) error { - return chrootArchiver.UntarPath(src, dst) -} diff --git a/pkg/chrootarchive/archive_test.go b/pkg/chrootarchive/archive_test.go index 4a9b0115db..cb3253d9f9 100644 --- a/pkg/chrootarchive/archive_test.go +++ b/pkg/chrootarchive/archive_test.go @@ -22,14 +22,32 @@ func init() { reexec.Init() } +var chrootArchiver = NewArchiver(nil) + +func TarUntar(src, dst string) error { + return chrootArchiver.TarUntar(src, dst) +} + +func CopyFileWithTar(src, dst string) (err error) { + return chrootArchiver.CopyFileWithTar(src, dst) +} + +func UntarPath(src, dst string) error { + return chrootArchiver.UntarPath(src, dst) +} + +func CopyWithTar(src, dst string) error { + return chrootArchiver.CopyWithTar(src, dst) +} + func TestChrootTarUntar(t *testing.T) { - tmpdir, err := ioutil.TempDir("", "docker-TestChrootTarUntar") + tmpdir, err := ioutil.TempDir("", "storage-TestChrootTarUntar") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpdir) src := filepath.Join(tmpdir, "src") - if err := system.MkdirAll(src, 0700); err != nil { + if err := system.MkdirAll(src, 0700, ""); err != nil { t.Fatal(err) } if err := ioutil.WriteFile(filepath.Join(src, "toto"), []byte("hello toto"), 0644); err != nil { @@ -43,7 +61,7 @@ func TestChrootTarUntar(t *testing.T) { t.Fatal(err) } dest := filepath.Join(tmpdir, "src") - if err := system.MkdirAll(dest, 0700); err != nil { + if err := system.MkdirAll(dest, 0700, ""); err != nil { t.Fatal(err) } if err := Untar(stream, dest, &archive.TarOptions{ExcludePatterns: []string{"lolo"}}); err != nil { @@ -54,13 +72,13 @@ func TestChrootTarUntar(t *testing.T) { // gh#10426: Verify the fix for having a huge excludes list (like on `docker load` with large # of // local images) func TestChrootUntarWithHugeExcludesList(t *testing.T) { - tmpdir, err := ioutil.TempDir("", "docker-TestChrootUntarHugeExcludes") + tmpdir, err := ioutil.TempDir("", "storage-TestChrootUntarHugeExcludes") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpdir) src := filepath.Join(tmpdir, "src") - if err := system.MkdirAll(src, 0700); err != nil { + if err := system.MkdirAll(src, 0700, ""); err != nil { t.Fatal(err) } if err := ioutil.WriteFile(filepath.Join(src, "toto"), []byte("hello toto"), 0644); err != nil { @@ -71,13 +89,13 @@ func TestChrootUntarWithHugeExcludesList(t *testing.T) { t.Fatal(err) } dest := filepath.Join(tmpdir, "dest") - if err := system.MkdirAll(dest, 0700); err != nil { + if err := system.MkdirAll(dest, 0700, ""); err != nil { t.Fatal(err) } options := &archive.TarOptions{} //65534 entries of 64-byte strings ~= 4MB of environment space which should overflow //on most systems when passed via environment or command line arguments - excludes := make([]string, 65534, 65534) + excludes := make([]string, 65534) for i := 0; i < 65534; i++ { excludes[i] = strings.Repeat(string(i), 64) } @@ -88,7 +106,7 @@ func TestChrootUntarWithHugeExcludesList(t *testing.T) { } func TestChrootUntarEmptyArchive(t *testing.T) { - tmpdir, err := ioutil.TempDir("", "docker-TestChrootUntarEmptyArchive") + tmpdir, err := ioutil.TempDir("", "storage-TestChrootUntarEmptyArchive") if err != nil { t.Fatal(err) } @@ -156,16 +174,16 @@ func TestChrootTarUntarWithSymlink(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Failing on Windows") } - tmpdir, err := ioutil.TempDir("", "docker-TestChrootTarUntarWithSymlink") + tmpdir, err := ioutil.TempDir("", "storage-TestChrootTarUntarWithSymlink") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpdir) src := filepath.Join(tmpdir, "src") - if err := system.MkdirAll(src, 0700); err != nil { + if err := system.MkdirAll(src, 0700, ""); err != nil { t.Fatal(err) } - if _, err := prepareSourceDirectory(10, src, true); err != nil { + if _, err := prepareSourceDirectory(10, src, false); err != nil { t.Fatal(err) } dest := filepath.Join(tmpdir, "dest") @@ -179,16 +197,16 @@ func TestChrootTarUntarWithSymlink(t *testing.T) { func TestChrootCopyWithTar(t *testing.T) { // TODO Windows: Figure out why this is failing - if runtime.GOOS == "windows" { - t.Skip("Failing on Windows") + if runtime.GOOS == "windows" || runtime.GOOS == "solaris" { + t.Skip("Failing on Windows and Solaris") } - tmpdir, err := ioutil.TempDir("", "docker-TestChrootCopyWithTar") + tmpdir, err := ioutil.TempDir("", "storage-TestChrootCopyWithTar") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpdir) src := filepath.Join(tmpdir, "src") - if err := system.MkdirAll(src, 0700); err != nil { + if err := system.MkdirAll(src, 0700, ""); err != nil { t.Fatal(err) } if _, err := prepareSourceDirectory(10, src, true); err != nil { @@ -228,13 +246,13 @@ func TestChrootCopyWithTar(t *testing.T) { } func TestChrootCopyFileWithTar(t *testing.T) { - tmpdir, err := ioutil.TempDir("", "docker-TestChrootCopyFileWithTar") + tmpdir, err := ioutil.TempDir("", "storage-TestChrootCopyFileWithTar") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpdir) src := filepath.Join(tmpdir, "src") - if err := system.MkdirAll(src, 0700); err != nil { + if err := system.MkdirAll(src, 0700, ""); err != nil { t.Fatal(err) } if _, err := prepareSourceDirectory(10, src, true); err != nil { @@ -275,16 +293,16 @@ func TestChrootUntarPath(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Failing on Windows") } - tmpdir, err := ioutil.TempDir("", "docker-TestChrootUntarPath") + tmpdir, err := ioutil.TempDir("", "storage-TestChrootUntarPath") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpdir) src := filepath.Join(tmpdir, "src") - if err := system.MkdirAll(src, 0700); err != nil { + if err := system.MkdirAll(src, 0700, ""); err != nil { t.Fatal(err) } - if _, err := prepareSourceDirectory(10, src, true); err != nil { + if _, err := prepareSourceDirectory(10, src, false); err != nil { t.Fatal(err) } dest := filepath.Join(tmpdir, "dest") @@ -336,13 +354,13 @@ func (s *slowEmptyTarReader) Read(p []byte) (int, error) { } func TestChrootUntarEmptyArchiveFromSlowReader(t *testing.T) { - tmpdir, err := ioutil.TempDir("", "docker-TestChrootUntarEmptyArchiveFromSlowReader") + tmpdir, err := ioutil.TempDir("", "storage-TestChrootUntarEmptyArchiveFromSlowReader") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpdir) dest := filepath.Join(tmpdir, "dest") - if err := system.MkdirAll(dest, 0700); err != nil { + if err := system.MkdirAll(dest, 0700, ""); err != nil { t.Fatal(err) } stream := &slowEmptyTarReader{size: 10240, chunkSize: 1024} @@ -352,13 +370,13 @@ func TestChrootUntarEmptyArchiveFromSlowReader(t *testing.T) { } func TestChrootApplyEmptyArchiveFromSlowReader(t *testing.T) { - tmpdir, err := ioutil.TempDir("", "docker-TestChrootApplyEmptyArchiveFromSlowReader") + tmpdir, err := ioutil.TempDir("", "storage-TestChrootApplyEmptyArchiveFromSlowReader") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpdir) dest := filepath.Join(tmpdir, "dest") - if err := system.MkdirAll(dest, 0700); err != nil { + if err := system.MkdirAll(dest, 0700, ""); err != nil { t.Fatal(err) } stream := &slowEmptyTarReader{size: 10240, chunkSize: 1024} @@ -368,13 +386,13 @@ func TestChrootApplyEmptyArchiveFromSlowReader(t *testing.T) { } func TestChrootApplyDotDotFile(t *testing.T) { - tmpdir, err := ioutil.TempDir("", "docker-TestChrootApplyDotDotFile") + tmpdir, err := ioutil.TempDir("", "storage-TestChrootApplyDotDotFile") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpdir) src := filepath.Join(tmpdir, "src") - if err := system.MkdirAll(src, 0700); err != nil { + if err := system.MkdirAll(src, 0700, ""); err != nil { t.Fatal(err) } if err := ioutil.WriteFile(filepath.Join(src, "..gitme"), []byte(""), 0644); err != nil { @@ -385,7 +403,7 @@ func TestChrootApplyDotDotFile(t *testing.T) { t.Fatal(err) } dest := filepath.Join(tmpdir, "dest") - if err := system.MkdirAll(dest, 0700); err != nil { + if err := system.MkdirAll(dest, 0700, ""); err != nil { t.Fatal(err) } if _, err := ApplyLayer(dest, stream); err != nil { diff --git a/pkg/chrootarchive/chroot_linux.go b/pkg/chrootarchive/chroot_linux.go index 54b5ff4899..e8bd22e36b 100644 --- a/pkg/chrootarchive/chroot_linux.go +++ b/pkg/chrootarchive/chroot_linux.go @@ -5,9 +5,10 @@ import ( "io/ioutil" "os" "path/filepath" - "syscall" "github.com/containers/storage/pkg/mount" + rsystem "github.com/opencontainers/runc/libcontainer/system" + "golang.org/x/sys/unix" ) // chroot on linux uses pivot_root instead of chroot @@ -17,14 +18,25 @@ import ( // Old root is removed after the call to pivot_root so it is no longer available under the new root. // This is similar to how libcontainer sets up a container's rootfs func chroot(path string) (err error) { - if err := syscall.Unshare(syscall.CLONE_NEWNS); err != nil { + // if the engine is running in a user namespace we need to use actual chroot + if rsystem.RunningInUserNS() { + return realChroot(path) + } + if err := unix.Unshare(unix.CLONE_NEWNS); err != nil { return fmt.Errorf("Error creating mount namespace before pivot: %v", err) } - if err := mount.MakeRPrivate(path); err != nil { + // make everything in new ns private + if err := mount.MakeRPrivate("/"); err != nil { return err } + if mounted, _ := mount.Mounted(path); !mounted { + if err := mount.Mount(path, path, "bind", "rbind,rw"); err != nil { + return realChroot(path) + } + } + // setup oldRoot for pivot_root pivotDir, err := ioutil.TempDir(path, ".pivot_root") if err != nil { @@ -35,7 +47,7 @@ func chroot(path string) (err error) { defer func() { if mounted { // make sure pivotDir is not mounted before we try to remove it - if errCleanup := syscall.Unmount(pivotDir, syscall.MNT_DETACH); errCleanup != nil { + if errCleanup := unix.Unmount(pivotDir, unix.MNT_DETACH); errCleanup != nil { if err == nil { err = errCleanup } @@ -52,16 +64,9 @@ func chroot(path string) (err error) { err = errCleanup } } - - if errCleanup := syscall.Unmount("/", syscall.MNT_DETACH); errCleanup != nil { - if err == nil { - err = fmt.Errorf("error unmounting root: %v", errCleanup) - } - return - } }() - if err := syscall.PivotRoot(path, pivotDir); err != nil { + if err := unix.PivotRoot(path, pivotDir); err != nil { // If pivot fails, fall back to the normal chroot after cleaning up temp dir if err := os.Remove(pivotDir); err != nil { return fmt.Errorf("Error cleaning up after failed pivot: %v", err) @@ -74,17 +79,17 @@ func chroot(path string) (err error) { // This dir contains the rootfs of the caller, which we need to remove so it is not visible during extraction pivotDir = filepath.Join("/", filepath.Base(pivotDir)) - if err := syscall.Chdir("/"); err != nil { + if err := unix.Chdir("/"); err != nil { return fmt.Errorf("Error changing to new root: %v", err) } // Make the pivotDir (where the old root lives) private so it can be unmounted without propagating to the host - if err := syscall.Mount("", pivotDir, "", syscall.MS_PRIVATE|syscall.MS_REC, ""); err != nil { + if err := unix.Mount("", pivotDir, "", unix.MS_PRIVATE|unix.MS_REC, ""); err != nil { return fmt.Errorf("Error making old root private after pivot: %v", err) } // Now unmount the old root so it's no longer visible from the new root - if err := syscall.Unmount(pivotDir, syscall.MNT_DETACH); err != nil { + if err := unix.Unmount(pivotDir, unix.MNT_DETACH); err != nil { return fmt.Errorf("Error while unmounting old root after pivot: %v", err) } mounted = false @@ -93,10 +98,10 @@ func chroot(path string) (err error) { } func realChroot(path string) error { - if err := syscall.Chroot(path); err != nil { + if err := unix.Chroot(path); err != nil { return fmt.Errorf("Error after fallback to chroot: %v", err) } - if err := syscall.Chdir("/"); err != nil { + if err := unix.Chdir("/"); err != nil { return fmt.Errorf("Error changing to new root after chroot: %v", err) } return nil diff --git a/pkg/chrootarchive/chroot_unix.go b/pkg/chrootarchive/chroot_unix.go index 16354bf648..f9b5dece8c 100644 --- a/pkg/chrootarchive/chroot_unix.go +++ b/pkg/chrootarchive/chroot_unix.go @@ -2,11 +2,11 @@ package chrootarchive -import "syscall" +import "golang.org/x/sys/unix" func chroot(path string) error { - if err := syscall.Chroot(path); err != nil { + if err := unix.Chroot(path); err != nil { return err } - return syscall.Chdir("/") + return unix.Chdir("/") } diff --git a/pkg/chrootarchive/diff.go b/pkg/chrootarchive/diff.go index 377aeb9f0a..68b8f74f77 100644 --- a/pkg/chrootarchive/diff.go +++ b/pkg/chrootarchive/diff.go @@ -1,12 +1,16 @@ package chrootarchive -import "github.com/containers/storage/pkg/archive" +import ( + "io" + + "github.com/containers/storage/pkg/archive" +) // ApplyLayer parses a diff in the standard layer format from `layer`, // and applies it to the directory `dest`. The stream `layer` can only be // uncompressed. // Returns the size in bytes of the contents of the layer. -func ApplyLayer(dest string, layer archive.Reader) (size int64, err error) { +func ApplyLayer(dest string, layer io.Reader) (size int64, err error) { return applyLayerHandler(dest, layer, &archive.TarOptions{}, true) } @@ -14,6 +18,6 @@ func ApplyLayer(dest string, layer archive.Reader) (size int64, err error) { // `layer`, and applies it to the directory `dest`. The stream `layer` // can only be uncompressed. // Returns the size in bytes of the contents of the layer. -func ApplyUncompressedLayer(dest string, layer archive.Reader, options *archive.TarOptions) (int64, error) { +func ApplyUncompressedLayer(dest string, layer io.Reader, options *archive.TarOptions) (int64, error) { return applyLayerHandler(dest, layer, options, false) } diff --git a/pkg/chrootarchive/diff_unix.go b/pkg/chrootarchive/diff_unix.go index 3a9f9a822b..4369f30c99 100644 --- a/pkg/chrootarchive/diff_unix.go +++ b/pkg/chrootarchive/diff_unix.go @@ -7,6 +7,7 @@ import ( "encoding/json" "flag" "fmt" + "io" "io/ioutil" "os" "path/filepath" @@ -15,6 +16,7 @@ import ( "github.com/containers/storage/pkg/archive" "github.com/containers/storage/pkg/reexec" "github.com/containers/storage/pkg/system" + rsystem "github.com/opencontainers/runc/libcontainer/system" ) type applyLayerResponse struct { @@ -27,13 +29,14 @@ type applyLayerResponse struct { func applyLayer() { var ( - tmpDir = "" + tmpDir string err error options *archive.TarOptions ) runtime.LockOSThread() flag.Parse() + inUserns := rsystem.RunningInUserNS() if err := chroot(flag.Arg(0)); err != nil { fatal(err) } @@ -49,6 +52,10 @@ func applyLayer() { fatal(err) } + if inUserns { + options.InUserNS = true + } + if tmpDir, err = ioutil.TempDir("/", "temp-storage-extract"); err != nil { fatal(err) } @@ -75,7 +82,7 @@ func applyLayer() { // applyLayerHandler parses a diff in the standard layer format from `layer`, and // applies it to the directory `dest`. Returns the size in bytes of the // contents of the layer. -func applyLayerHandler(dest string, layer archive.Reader, options *archive.TarOptions, decompress bool) (size int64, err error) { +func applyLayerHandler(dest string, layer io.Reader, options *archive.TarOptions, decompress bool) (size int64, err error) { dest = filepath.Clean(dest) if decompress { decompressed, err := archive.DecompressStream(layer) @@ -88,6 +95,9 @@ func applyLayerHandler(dest string, layer archive.Reader, options *archive.TarOp } if options == nil { options = &archive.TarOptions{} + if rsystem.RunningInUserNS() { + options.InUserNS = true + } } if options.ExcludePatterns == nil { options.ExcludePatterns = []string{} diff --git a/pkg/chrootarchive/diff_windows.go b/pkg/chrootarchive/diff_windows.go index 534d2708aa..8f8e88bfbe 100644 --- a/pkg/chrootarchive/diff_windows.go +++ b/pkg/chrootarchive/diff_windows.go @@ -2,6 +2,7 @@ package chrootarchive import ( "fmt" + "io" "io/ioutil" "os" "path/filepath" @@ -13,7 +14,7 @@ import ( // applyLayerHandler parses a diff in the standard layer format from `layer`, and // applies it to the directory `dest`. Returns the size in bytes of the // contents of the layer. -func applyLayerHandler(dest string, layer archive.Reader, options *archive.TarOptions, decompress bool) (size int64, err error) { +func applyLayerHandler(dest string, layer io.Reader, options *archive.TarOptions, decompress bool) (size int64, err error) { dest = filepath.Clean(dest) // Ensure it is a Windows-style volume path @@ -37,7 +38,7 @@ func applyLayerHandler(dest string, layer archive.Reader, options *archive.TarOp s, err := archive.UnpackLayer(dest, layer, nil) os.RemoveAll(tmpDir) if err != nil { - return 0, fmt.Errorf("ApplyLayer %s failed UnpackLayer to %s", err, dest) + return 0, fmt.Errorf("ApplyLayer %s failed UnpackLayer to %s: %s", layer, dest, err) } return s, nil diff --git a/pkg/devicemapper/devmapper.go b/pkg/devicemapper/devmapper.go index 1ed0e861fc..6a0ac24647 100644 --- a/pkg/devicemapper/devmapper.go +++ b/pkg/devicemapper/devmapper.go @@ -1,4 +1,4 @@ -// +build linux +// +build linux,cgo package devicemapper @@ -7,17 +7,14 @@ import ( "fmt" "os" "runtime" - "syscall" "unsafe" "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" ) -// DevmapperLogger defines methods for logging with devicemapper. -type DevmapperLogger interface { - DMLog(level int, file string, line int, dmError int, message string) -} - +// Same as DM_DEVICE_* enum values from libdevmapper.h +// nolint: deadcode const ( deviceCreate TaskType = iota deviceReload @@ -155,6 +152,7 @@ func (t *Task) run() error { if res := DmTaskRun(t.unmanaged); res != 1 { return ErrTaskRun } + runtime.KeepAlive(t) return nil } @@ -257,25 +255,12 @@ func (t *Task) getNextTarget(next unsafe.Pointer) (nextPtr unsafe.Pointer, start // UdevWait waits for any processes that are waiting for udev to complete the specified cookie. func UdevWait(cookie *uint) error { if res := DmUdevWait(*cookie); res != 1 { - logrus.Debugf("devicemapper: Failed to wait on udev cookie %d", *cookie) + logrus.Debugf("devicemapper: Failed to wait on udev cookie %d, %d", *cookie, res) return ErrUdevWait } return nil } -// LogInitVerbose is an interface to initialize the verbose logger for the device mapper library. -func LogInitVerbose(level int) { - DmLogInitVerbose(level) -} - -var dmLogger DevmapperLogger - -// LogInit initializes the logger for the device mapper library. -func LogInit(logger DevmapperLogger) { - dmLogger = logger - LogWithErrnoInit() -} - // SetDevDir sets the dev folder for the device mapper library (usually /dev). func SetDevDir(dir string) error { if res := DmSetDevDir(dir); res != 1 { @@ -328,17 +313,21 @@ func RemoveDevice(name string) error { return err } - var cookie uint - if err := task.setCookie(&cookie, 0); err != nil { + cookie := new(uint) + if err := task.setCookie(cookie, 0); err != nil { return fmt.Errorf("devicemapper: Can not set cookie: %s", err) } - defer UdevWait(&cookie) + defer UdevWait(cookie) dmSawBusy = false // reset before the task is run + dmSawEnxio = false if err = task.run(); err != nil { if dmSawBusy { return ErrBusy } + if dmSawEnxio { + return ErrEnxio + } return fmt.Errorf("devicemapper: Error running RemoveDevice %s", err) } @@ -358,7 +347,32 @@ func RemoveDeviceDeferred(name string) error { return ErrTaskDeferredRemove } + // set a task cookie and disable library fallback, or else libdevmapper will + // disable udev dm rules and delete the symlink under /dev/mapper by itself, + // even if the removal is deferred by the kernel. + cookie := new(uint) + var flags uint16 + flags = DmUdevDisableLibraryFallback + if err := task.setCookie(cookie, flags); err != nil { + return fmt.Errorf("devicemapper: Can not set cookie: %s", err) + } + + // libdevmapper and udev relies on System V semaphore for synchronization, + // semaphores created in `task.setCookie` will be cleaned up in `UdevWait`. + // So these two function call must come in pairs, otherwise semaphores will + // be leaked, and the limit of number of semaphores defined in `/proc/sys/kernel/sem` + // will be reached, which will eventually make all following calls to 'task.SetCookie' + // fail. + // this call will not wait for the deferred removal's final executing, since no + // udev event will be generated, and the semaphore's value will not be incremented + // by udev, what UdevWait is just cleaning up the semaphore. + defer UdevWait(cookie) + + dmSawEnxio = false if err = task.run(); err != nil { + if dmSawEnxio { + return ErrEnxio + } return fmt.Errorf("devicemapper: Error running RemoveDeviceDeferred %s", err) } @@ -427,7 +441,7 @@ func BlockDeviceDiscard(path string) error { // Without this sometimes the remove of the device that happens after // discard fails with EBUSY. - syscall.Sync() + unix.Sync() return nil } @@ -450,13 +464,13 @@ func CreatePool(poolName string, dataFile, metadataFile *os.File, poolBlockSize return fmt.Errorf("devicemapper: Can't add target %s", err) } - var cookie uint + cookie := new(uint) var flags uint16 flags = DmUdevDisableSubsystemRulesFlag | DmUdevDisableDiskRulesFlag | DmUdevDisableOtherRulesFlag - if err := task.setCookie(&cookie, flags); err != nil { + if err := task.setCookie(cookie, flags); err != nil { return fmt.Errorf("devicemapper: Can't set cookie %s", err) } - defer UdevWait(&cookie) + defer UdevWait(cookie) if err := task.run(); err != nil { return fmt.Errorf("devicemapper: Error running deviceCreate (CreatePool) %s", err) @@ -484,7 +498,7 @@ func ReloadPool(poolName string, dataFile, metadataFile *os.File, poolBlockSize } if err := task.run(); err != nil { - return fmt.Errorf("devicemapper: Error running deviceCreate %s", err) + return fmt.Errorf("devicemapper: Error running ReloadPool %s", err) } return nil @@ -638,11 +652,11 @@ func ResumeDevice(name string) error { return err } - var cookie uint - if err := task.setCookie(&cookie, 0); err != nil { + cookie := new(uint) + if err := task.setCookie(cookie, 0); err != nil { return fmt.Errorf("devicemapper: Can't set cookie %s", err) } - defer UdevWait(&cookie) + defer UdevWait(cookie) if err := task.run(); err != nil { return fmt.Errorf("devicemapper: Error running deviceResume %s", err) @@ -736,12 +750,12 @@ func activateDevice(poolName string, name string, deviceID int, size uint64, ext return fmt.Errorf("devicemapper: Can't add node %s", err) } - var cookie uint - if err := task.setCookie(&cookie, 0); err != nil { + cookie := new(uint) + if err := task.setCookie(cookie, 0); err != nil { return fmt.Errorf("devicemapper: Can't set cookie %s", err) } - defer UdevWait(&cookie) + defer UdevWait(cookie) if err := task.run(); err != nil { return fmt.Errorf("devicemapper: Error running deviceCreate (ActivateDevice) %s", err) @@ -750,51 +764,51 @@ func activateDevice(poolName string, name string, deviceID int, size uint64, ext return nil } -// CreateSnapDevice creates a snapshot based on the device identified by the baseName and baseDeviceId, -func CreateSnapDevice(poolName string, deviceID int, baseName string, baseDeviceID int) error { - devinfo, _ := GetInfo(baseName) - doSuspend := devinfo != nil && devinfo.Exists != 0 - - if doSuspend { - if err := SuspendDevice(baseName); err != nil { - return err - } - } - +// CreateSnapDeviceRaw creates a snapshot device. Caller needs to suspend and resume the origin device if it is active. +func CreateSnapDeviceRaw(poolName string, deviceID int, baseDeviceID int) error { task, err := TaskCreateNamed(deviceTargetMsg, poolName) if task == nil { - if doSuspend { - ResumeDevice(baseName) - } return err } if err := task.setSector(0); err != nil { - if doSuspend { - ResumeDevice(baseName) - } return fmt.Errorf("devicemapper: Can't set sector %s", err) } if err := task.setMessage(fmt.Sprintf("create_snap %d %d", deviceID, baseDeviceID)); err != nil { - if doSuspend { - ResumeDevice(baseName) - } return fmt.Errorf("devicemapper: Can't set message %s", err) } dmSawExist = false // reset before the task is run if err := task.run(); err != nil { - if doSuspend { - ResumeDevice(baseName) - } // Caller wants to know about ErrDeviceIDExists so that it can try with a different device id. if dmSawExist { return ErrDeviceIDExists } + return fmt.Errorf("devicemapper: Error running deviceCreate (CreateSnapDeviceRaw) %s", err) + } - return fmt.Errorf("devicemapper: Error running deviceCreate (createSnapDevice) %s", err) + return nil +} +// CreateSnapDevice creates a snapshot based on the device identified by the baseName and baseDeviceId, +func CreateSnapDevice(poolName string, deviceID int, baseName string, baseDeviceID int) error { + devinfo, _ := GetInfo(baseName) + doSuspend := devinfo != nil && devinfo.Exists != 0 + + if doSuspend { + if err := SuspendDevice(baseName); err != nil { + return err + } + } + + if err := CreateSnapDeviceRaw(poolName, deviceID, baseDeviceID); err != nil { + if doSuspend { + if err2 := ResumeDevice(baseName); err2 != nil { + return fmt.Errorf("CreateSnapDeviceRaw Error: (%v): ResumeDevice Error: (%v)", err, err2) + } + } + return err } if doSuspend { diff --git a/pkg/devicemapper/devmapper_log.go b/pkg/devicemapper/devmapper_log.go index 76c9756601..65a202ad2a 100644 --- a/pkg/devicemapper/devmapper_log.go +++ b/pkg/devicemapper/devmapper_log.go @@ -1,21 +1,49 @@ -// +build linux +// +build linux,cgo package devicemapper import "C" import ( + "fmt" "strings" + + "github.com/sirupsen/logrus" ) +// DevmapperLogger defines methods required to register as a callback for +// logging events recieved from devicemapper. Note that devicemapper will send +// *all* logs regardless to callbacks (including debug logs) so it's +// recommended to not spam the console with the outputs. +type DevmapperLogger interface { + // DMLog is the logging callback containing all of the information from + // devicemapper. The interface is identical to the C libdm counterpart. + DMLog(level int, file string, line int, dmError int, message string) +} + +// dmLogger is the current logger in use that is being forwarded our messages. +var dmLogger DevmapperLogger + +// LogInit changes the logging callback called after processing libdm logs for +// error message information. The default logger simply forwards all logs to +// logrus. Calling LogInit(nil) disables the calling of callbacks. +func LogInit(logger DevmapperLogger) { + dmLogger = logger +} + // Due to the way cgo works this has to be in a separate file, as devmapper.go has // definitions in the cgo block, which is incompatible with using "//export" -// StorageDevmapperLogCallback exports the devmapper log callback for cgo. -//export StorageDevmapperLogCallback -func StorageDevmapperLogCallback(level C.int, file *C.char, line C.int, dmErrnoOrClass C.int, message *C.char) { +// DevmapperLogCallback exports the devmapper log callback for cgo. Note that +// because we are using callbacks, this function will be called for *every* log +// in libdm (even debug ones because there's no way of setting the verbosity +// level for an external logging callback). +//export DevmapperLogCallback +func DevmapperLogCallback(level C.int, file *C.char, line, dmErrnoOrClass C.int, message *C.char) { msg := C.GoString(message) - if level < 7 { + + // Track what errno libdm saw, because the library only gives us 0 or 1. + if level < LogLevelDebug { if strings.Contains(msg, "busy") { dmSawBusy = true } @@ -33,3 +61,61 @@ func StorageDevmapperLogCallback(level C.int, file *C.char, line C.int, dmErrnoO dmLogger.DMLog(int(level), C.GoString(file), int(line), int(dmErrnoOrClass), msg) } } + +// DefaultLogger is the default logger used by pkg/devicemapper. It forwards +// all logs that are of higher or equal priority to the given level to the +// corresponding logrus level. +type DefaultLogger struct { + // Level corresponds to the highest libdm level that will be forwarded to + // logrus. In order to change this, register a new DefaultLogger. + Level int +} + +// DMLog is the logging callback containing all of the information from +// devicemapper. The interface is identical to the C libdm counterpart. +func (l DefaultLogger) DMLog(level int, file string, line, dmError int, message string) { + if level <= l.Level { + // Forward the log to the correct logrus level, if allowed by dmLogLevel. + logMsg := fmt.Sprintf("libdevmapper(%d): %s:%d (%d) %s", level, file, line, dmError, message) + switch level { + case LogLevelFatal, LogLevelErr: + logrus.Error(logMsg) + case LogLevelWarn: + logrus.Warn(logMsg) + case LogLevelNotice, LogLevelInfo: + logrus.Info(logMsg) + case LogLevelDebug: + logrus.Debug(logMsg) + default: + // Don't drop any "unknown" levels. + logrus.Info(logMsg) + } + } +} + +// registerLogCallback registers our own logging callback function for libdm +// (which is DevmapperLogCallback). +// +// Because libdm only gives us {0,1} error codes we need to parse the logs +// produced by libdm (to set dmSawBusy and so on). Note that by registering a +// callback using DevmapperLogCallback, libdm will no longer output logs to +// stderr so we have to log everything ourselves. None of this handling is +// optional because we depend on log callbacks to parse the logs, and if we +// don't forward the log information we'll be in a lot of trouble when +// debugging things. +func registerLogCallback() { + LogWithErrnoInit() +} + +func init() { + // Use the default logger by default. We only allow LogLevelFatal by + // default, because internally we mask a lot of libdm errors by retrying + // and similar tricks. Also, libdm is very chatty and we don't want to + // worry users for no reason. + dmLogger = DefaultLogger{ + Level: LogLevelFatal, + } + + // Register as early as possible so we don't miss anything. + registerLogCallback() +} diff --git a/pkg/devicemapper/devmapper_wrapper.go b/pkg/devicemapper/devmapper_wrapper.go index e37e02059c..190d83d499 100644 --- a/pkg/devicemapper/devmapper_wrapper.go +++ b/pkg/devicemapper/devmapper_wrapper.go @@ -1,9 +1,9 @@ -// +build linux +// +build linux,cgo package devicemapper /* -#cgo LDFLAGS: -L. -ldevmapper +#define _GNU_SOURCE #include #include // FIXME: present only for BLKGETSIZE64, maybe we can remove it? @@ -12,19 +12,25 @@ extern void StorageDevmapperLogCallback(int level, char *file, int line, int dm_ static void log_cb(int level, const char *file, int line, int dm_errno_or_class, const char *f, ...) { - char buffer[256]; - va_list ap; - - va_start(ap, f); - vsnprintf(buffer, 256, f, ap); - va_end(ap); + char *buffer = NULL; + va_list ap; + int ret; + + va_start(ap, f); + ret = vasprintf(&buffer, f, ap); + va_end(ap); + if (ret < 0) { + // memory allocation failed -- should never happen? + return; + } - StorageDevmapperLogCallback(level, (char *)file, line, dm_errno_or_class, buffer); + StorageDevmapperLogCallback(level, (char *)file, line, dm_errno_or_class, buffer); + free(buffer); } static void log_with_errno_init() { - dm_log_with_errno_init(log_cb); + dm_log_with_errno_init(log_cb); } */ import "C" @@ -56,7 +62,6 @@ const ( var ( DmGetLibraryVersion = dmGetLibraryVersionFct DmGetNextTarget = dmGetNextTargetFct - DmLogInitVerbose = dmLogInitVerboseFct DmSetDevDir = dmSetDevDirFct DmTaskAddTarget = dmTaskAddTargetFct DmTaskCreate = dmTaskCreateFct @@ -226,10 +231,6 @@ func dmCookieSupportedFct() int { return int(C.dm_cookie_supported()) } -func dmLogInitVerboseFct(level int) { - C.dm_log_init_verbose(C.int(level)) -} - func logWithErrnoInitFct() { C.log_with_errno_init() } diff --git a/pkg/devicemapper/devmapper_wrapper_deferred_remove.go b/pkg/devicemapper/devmapper_wrapper_deferred_remove.go index dc361eab76..7f793c2708 100644 --- a/pkg/devicemapper/devmapper_wrapper_deferred_remove.go +++ b/pkg/devicemapper/devmapper_wrapper_deferred_remove.go @@ -1,14 +1,11 @@ -// +build linux,!libdm_no_deferred_remove +// +build linux,cgo,!libdm_no_deferred_remove package devicemapper -/* -#cgo LDFLAGS: -L. -ldevmapper -#include -*/ +// #include import "C" -// LibraryDeferredRemovalSupport is supported when statically linked. +// LibraryDeferredRemovalSupport tells if the feature is enabled in the build const LibraryDeferredRemovalSupport = true func dmTaskDeferredRemoveFct(task *cdmTask) int { diff --git a/pkg/devicemapper/devmapper_wrapper_dynamic.go b/pkg/devicemapper/devmapper_wrapper_dynamic.go new file mode 100644 index 0000000000..7d84508982 --- /dev/null +++ b/pkg/devicemapper/devmapper_wrapper_dynamic.go @@ -0,0 +1,6 @@ +// +build linux,cgo,!static_build + +package devicemapper + +// #cgo pkg-config: devmapper +import "C" diff --git a/pkg/devicemapper/devmapper_wrapper_no_deferred_remove.go b/pkg/devicemapper/devmapper_wrapper_no_deferred_remove.go index 4a6665de86..a880fec8c4 100644 --- a/pkg/devicemapper/devmapper_wrapper_no_deferred_remove.go +++ b/pkg/devicemapper/devmapper_wrapper_no_deferred_remove.go @@ -1,8 +1,8 @@ -// +build linux,libdm_no_deferred_remove +// +build linux,cgo,libdm_no_deferred_remove package devicemapper -// LibraryDeferredRemovalsupport is not supported when statically linked. +// LibraryDeferredRemovalSupport tells if the feature is enabled in the build const LibraryDeferredRemovalSupport = false func dmTaskDeferredRemoveFct(task *cdmTask) int { diff --git a/pkg/devicemapper/devmapper_wrapper_static.go b/pkg/devicemapper/devmapper_wrapper_static.go new file mode 100644 index 0000000000..cf7f26a4c6 --- /dev/null +++ b/pkg/devicemapper/devmapper_wrapper_static.go @@ -0,0 +1,6 @@ +// +build linux,cgo,static_build + +package devicemapper + +// #cgo pkg-config: --static devmapper +import "C" diff --git a/pkg/devicemapper/ioctl.go b/pkg/devicemapper/ioctl.go index 581b57eb86..50ea7c4823 100644 --- a/pkg/devicemapper/ioctl.go +++ b/pkg/devicemapper/ioctl.go @@ -1,15 +1,16 @@ -// +build linux +// +build linux,cgo package devicemapper import ( - "syscall" "unsafe" + + "golang.org/x/sys/unix" ) func ioctlBlkGetSize64(fd uintptr) (int64, error) { var size int64 - if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, BlkGetSize64, uintptr(unsafe.Pointer(&size))); err != 0 { + if _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, BlkGetSize64, uintptr(unsafe.Pointer(&size))); err != 0 { return 0, err } return size, nil @@ -20,7 +21,7 @@ func ioctlBlkDiscard(fd uintptr, offset, length uint64) error { r[0] = offset r[1] = length - if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, BlkDiscard, uintptr(unsafe.Pointer(&r[0]))); err != 0 { + if _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, BlkDiscard, uintptr(unsafe.Pointer(&r[0]))); err != 0 { return err } return nil diff --git a/pkg/fileutils/fileutils.go b/pkg/fileutils/fileutils.go index 6f4a6e6138..a129e654ea 100644 --- a/pkg/fileutils/fileutils.go +++ b/pkg/fileutils/fileutils.go @@ -13,98 +13,74 @@ import ( "github.com/sirupsen/logrus" ) -// exclusion returns true if the specified pattern is an exclusion -func exclusion(pattern string) bool { - return pattern[0] == '!' +// PatternMatcher allows checking paths agaist a list of patterns +type PatternMatcher struct { + patterns []*Pattern + exclusions bool } -// empty returns true if the specified pattern is empty -func empty(pattern string) bool { - return pattern == "" -} - -// CleanPatterns takes a slice of patterns returns a new -// slice of patterns cleaned with filepath.Clean, stripped -// of any empty patterns and lets the caller know whether the -// slice contains any exception patterns (prefixed with !). -func CleanPatterns(patterns []string) ([]string, [][]string, bool, error) { - // Loop over exclusion patterns and: - // 1. Clean them up. - // 2. Indicate whether we are dealing with any exception rules. - // 3. Error if we see a single exclusion marker on it's own (!). - cleanedPatterns := []string{} - patternDirs := [][]string{} - exceptions := false - for _, pattern := range patterns { +// NewPatternMatcher creates a new matcher object for specific patterns that can +// be used later to match against patterns against paths +func NewPatternMatcher(patterns []string) (*PatternMatcher, error) { + pm := &PatternMatcher{ + patterns: make([]*Pattern, 0, len(patterns)), + } + for _, p := range patterns { // Eliminate leading and trailing whitespace. - pattern = strings.TrimSpace(pattern) - if empty(pattern) { + p = strings.TrimSpace(p) + if p == "" { continue } - if exclusion(pattern) { - if len(pattern) == 1 { - return nil, nil, false, errors.New("Illegal exclusion pattern: !") + p = filepath.Clean(p) + newp := &Pattern{} + if p[0] == '!' { + if len(p) == 1 { + return nil, errors.New("illegal exclusion pattern: \"!\"") } - exceptions = true + newp.exclusion = true + p = p[1:] + pm.exclusions = true } - pattern = filepath.Clean(pattern) - cleanedPatterns = append(cleanedPatterns, pattern) - if exclusion(pattern) { - pattern = pattern[1:] + // Do some syntax checking on the pattern. + // filepath's Match() has some really weird rules that are inconsistent + // so instead of trying to dup their logic, just call Match() for its + // error state and if there is an error in the pattern return it. + // If this becomes an issue we can remove this since its really only + // needed in the error (syntax) case - which isn't really critical. + if _, err := filepath.Match(p, "."); err != nil { + return nil, err } - patternDirs = append(patternDirs, strings.Split(pattern, string(os.PathSeparator))) + newp.cleanedPattern = p + newp.dirs = strings.Split(p, string(os.PathSeparator)) + pm.patterns = append(pm.patterns, newp) } - - return cleanedPatterns, patternDirs, exceptions, nil + return pm, nil } -// Matches returns true if file matches any of the patterns -// and isn't excluded by any of the subsequent patterns. -func Matches(file string, patterns []string) (bool, error) { - file = filepath.Clean(file) - - if file == "." { - // Don't let them exclude everything, kind of silly. - return false, nil - } - - patterns, patDirs, _, err := CleanPatterns(patterns) - if err != nil { - return false, err - } - - return OptimizedMatches(file, patterns, patDirs) -} - -// OptimizedMatches is basically the same as fileutils.Matches() but optimized for archive.go. -// It will assume that the inputs have been preprocessed and therefore the function -// doesn't need to do as much error checking and clean-up. This was done to avoid -// repeating these steps on each file being checked during the archive process. -// The more generic fileutils.Matches() can't make these assumptions. -func OptimizedMatches(file string, patterns []string, patDirs [][]string) (bool, error) { +// Matches matches path against all the patterns. Matches is not safe to be +// called concurrently +func (pm *PatternMatcher) Matches(file string) (bool, error) { matched := false file = filepath.FromSlash(file) parentPath := filepath.Dir(file) parentPathDirs := strings.Split(parentPath, string(os.PathSeparator)) - for i, pattern := range patterns { + for _, pattern := range pm.patterns { negative := false - if exclusion(pattern) { + if pattern.exclusion { negative = true - pattern = pattern[1:] } - match, err := regexpMatch(pattern, file) + match, err := pattern.match(file) if err != nil { - return false, fmt.Errorf("Error in pattern (%s): %s", pattern, err) + return false, err } if !match && parentPath != "." { // Check to see if the pattern matches one of our parent dirs. - if len(patDirs[i]) <= len(parentPathDirs) { - match, _ = regexpMatch(strings.Join(patDirs[i], string(os.PathSeparator)), - strings.Join(parentPathDirs[:len(patDirs[i])], string(os.PathSeparator))) + if len(pattern.dirs) <= len(parentPathDirs) { + match, _ = pattern.match(strings.Join(parentPathDirs[:len(pattern.dirs)], string(os.PathSeparator))) } } @@ -120,28 +96,49 @@ func OptimizedMatches(file string, patterns []string, patDirs [][]string) (bool, return matched, nil } -// regexpMatch tries to match the logic of filepath.Match but -// does so using regexp logic. We do this so that we can expand the -// wildcard set to include other things, like "**" to mean any number -// of directories. This means that we should be backwards compatible -// with filepath.Match(). We'll end up supporting more stuff, due to -// the fact that we're using regexp, but that's ok - it does no harm. -// -// As per the comment in golangs filepath.Match, on Windows, escaping -// is disabled. Instead, '\\' is treated as path separator. -func regexpMatch(pattern, path string) (bool, error) { - regStr := "^" +// Exclusions returns true if any of the patterns define exclusions +func (pm *PatternMatcher) Exclusions() bool { + return pm.exclusions +} - // Do some syntax checking on the pattern. - // filepath's Match() has some really weird rules that are inconsistent - // so instead of trying to dup their logic, just call Match() for its - // error state and if there is an error in the pattern return it. - // If this becomes an issue we can remove this since its really only - // needed in the error (syntax) case - which isn't really critical. - if _, err := filepath.Match(pattern, path); err != nil { - return false, err +// Patterns returns array of active patterns +func (pm *PatternMatcher) Patterns() []*Pattern { + return pm.patterns +} + +// Pattern defines a single regexp used used to filter file paths. +type Pattern struct { + cleanedPattern string + dirs []string + regexp *regexp.Regexp + exclusion bool +} + +func (p *Pattern) String() string { + return p.cleanedPattern +} + +// Exclusion returns true if this pattern defines exclusion +func (p *Pattern) Exclusion() bool { + return p.exclusion +} + +func (p *Pattern) match(path string) (bool, error) { + + if p.regexp == nil { + if err := p.compile(); err != nil { + return false, filepath.ErrBadPattern + } } + b := p.regexp.MatchString(path) + + return b, nil +} + +func (p *Pattern) compile() error { + regStr := "^" + pattern := p.cleanedPattern // Go through the pattern and convert it to a regexp. // We use a scanner so we can support utf-8 chars. var scan scanner.Scanner @@ -161,17 +158,19 @@ func regexpMatch(pattern, path string) (bool, error) { // is some flavor of "**" scan.Next() + // Treat **/ as ** so eat the "/" + if string(scan.Peek()) == sl { + scan.Next() + } + if scan.Peek() == scanner.EOF { // is "**EOF" - to align with .gitignore just accept all regStr += ".*" } else { // is "**" - regStr += "((.*" + escSL + ")|([^" + escSL + "]*))" - } - - // Treat **/ as ** so eat the "/" - if string(scan.Peek()) == sl { - scan.Next() + // Note that this allows for any # of /'s (even 0) because + // the .* will eat everything, even /'s + regStr += "(.*" + escSL + ")?" } } else { // is "*" so map it to anything but "/" @@ -180,7 +179,7 @@ func regexpMatch(pattern, path string) (bool, error) { } else if ch == '?' { // "?" is any char except "/" regStr += "[^" + escSL + "]" - } else if strings.Index(".$", string(ch)) != -1 { + } else if ch == '.' || ch == '$' { // Escape some regexp special chars that have no meaning // in golang's filepath.Match regStr += `\` + string(ch) @@ -206,14 +205,30 @@ func regexpMatch(pattern, path string) (bool, error) { regStr += "$" - res, err := regexp.MatchString(regStr, path) + re, err := regexp.Compile(regStr) + if err != nil { + return err + } + + p.regexp = re + return nil +} - // Map regexp's error to filepath's so no one knows we're not using filepath +// Matches returns true if file matches any of the patterns +// and isn't excluded by any of the subsequent patterns. +func Matches(file string, patterns []string) (bool, error) { + pm, err := NewPatternMatcher(patterns) if err != nil { - err = filepath.ErrBadPattern + return false, err + } + file = filepath.Clean(file) + + if file == "." { + // Don't let them exclude everything, kind of silly. + return false, nil } - return res, err + return pm.Matches(file) } // CopyFile copies from src to dst until either EOF is reached diff --git a/pkg/fileutils/fileutils_test.go b/pkg/fileutils/fileutils_test.go index 6df1be89bb..a70d8911fa 100644 --- a/pkg/fileutils/fileutils_test.go +++ b/pkg/fileutils/fileutils_test.go @@ -1,6 +1,7 @@ package fileutils import ( + "fmt" "io/ioutil" "os" "path" @@ -8,11 +9,14 @@ import ( "runtime" "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // CopyFile with invalid src func TestCopyFileWithInvalidSrc(t *testing.T) { - tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + tempFolder, err := ioutil.TempDir("", "storage-fileutils-test") defer os.RemoveAll(tempFolder) if err != nil { t.Fatal(err) @@ -29,7 +33,7 @@ func TestCopyFileWithInvalidSrc(t *testing.T) { // CopyFile with invalid dest func TestCopyFileWithInvalidDest(t *testing.T) { - tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + tempFolder, err := ioutil.TempDir("", "storage-fileutils-test") defer os.RemoveAll(tempFolder) if err != nil { t.Fatal(err) @@ -51,7 +55,7 @@ func TestCopyFileWithInvalidDest(t *testing.T) { // CopyFile with same src and dest func TestCopyFileWithSameSrcAndDest(t *testing.T) { - tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + tempFolder, err := ioutil.TempDir("", "storage-fileutils-test") defer os.RemoveAll(tempFolder) if err != nil { t.Fatal(err) @@ -72,7 +76,7 @@ func TestCopyFileWithSameSrcAndDest(t *testing.T) { // CopyFile with same src and dest but path is different and not clean func TestCopyFileWithSameSrcAndDestWithPathNameDifferent(t *testing.T) { - tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + tempFolder, err := ioutil.TempDir("", "storage-fileutils-test") defer os.RemoveAll(tempFolder) if err != nil { t.Fatal(err) @@ -98,7 +102,7 @@ func TestCopyFileWithSameSrcAndDestWithPathNameDifferent(t *testing.T) { } func TestCopyFile(t *testing.T) { - tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + tempFolder, err := ioutil.TempDir("", "storage-fileutils-test") defer os.RemoveAll(tempFolder) if err != nil { t.Fatal(err) @@ -208,7 +212,7 @@ func TestReadSymlinkedDirectoryToFile(t *testing.T) { func TestWildcardMatches(t *testing.T) { match, _ := Matches("fileutils.go", []string{"*"}) - if match != true { + if !match { t.Errorf("failed to get a wildcard match, got %v", match) } } @@ -216,7 +220,7 @@ func TestWildcardMatches(t *testing.T) { // A simple pattern match should return true. func TestPatternMatches(t *testing.T) { match, _ := Matches("fileutils.go", []string{"*.go"}) - if match != true { + if !match { t.Errorf("failed to get a match, got %v", match) } } @@ -224,7 +228,7 @@ func TestPatternMatches(t *testing.T) { // An exclusion followed by an inclusion should return true. func TestExclusionPatternMatchesPatternBefore(t *testing.T) { match, _ := Matches("fileutils.go", []string{"!fileutils.go", "*.go"}) - if match != true { + if !match { t.Errorf("failed to get true match on exclusion pattern, got %v", match) } } @@ -232,7 +236,7 @@ func TestExclusionPatternMatchesPatternBefore(t *testing.T) { // A folder pattern followed by an exception should return false. func TestPatternMatchesFolderExclusions(t *testing.T) { match, _ := Matches("docs/README.md", []string{"docs", "!docs/README.md"}) - if match != false { + if match { t.Errorf("failed to get a false match on exclusion pattern, got %v", match) } } @@ -240,7 +244,7 @@ func TestPatternMatchesFolderExclusions(t *testing.T) { // A folder pattern followed by an exception should return false. func TestPatternMatchesFolderWithSlashExclusions(t *testing.T) { match, _ := Matches("docs/README.md", []string{"docs/", "!docs/README.md"}) - if match != false { + if match { t.Errorf("failed to get a false match on exclusion pattern, got %v", match) } } @@ -248,7 +252,7 @@ func TestPatternMatchesFolderWithSlashExclusions(t *testing.T) { // A folder pattern followed by an exception should return false. func TestPatternMatchesFolderWildcardExclusions(t *testing.T) { match, _ := Matches("docs/README.md", []string{"docs/*", "!docs/README.md"}) - if match != false { + if match { t.Errorf("failed to get a false match on exclusion pattern, got %v", match) } } @@ -256,7 +260,7 @@ func TestPatternMatchesFolderWildcardExclusions(t *testing.T) { // A pattern followed by an exclusion should return false. func TestExclusionPatternMatchesPatternAfter(t *testing.T) { match, _ := Matches("fileutils.go", []string{"*.go", "!fileutils.go"}) - if match != false { + if match { t.Errorf("failed to get false match on exclusion pattern, got %v", match) } } @@ -264,7 +268,7 @@ func TestExclusionPatternMatchesPatternAfter(t *testing.T) { // A filename evaluating to . should return false. func TestExclusionPatternMatchesWholeDirectory(t *testing.T) { match, _ := Matches(".", []string{"*.go"}) - if match != false { + if match { t.Errorf("failed to get false match on ., got %v", match) } } @@ -277,14 +281,6 @@ func TestSingleExclamationError(t *testing.T) { } } -// A string preceded with a ! should return true from Exclusion. -func TestExclusion(t *testing.T) { - exclusion := exclusion("!") - if !exclusion { - t.Errorf("failed to get true for a single !, got %v", exclusion) - } -} - // Matches with no patterns func TestMatchesWithNoPatterns(t *testing.T) { matches, err := Matches("/any/path/there", []string{}) @@ -307,17 +303,14 @@ func TestMatchesWithMalformedPatterns(t *testing.T) { } } -// Test lots of variants of patterns & strings +type matchesTestCase struct { + pattern string + text string + pass bool +} + func TestMatches(t *testing.T) { - // TODO Windows: Port this test - if runtime.GOOS == "windows" { - t.Skip("Needs porting to Windows") - } - tests := []struct { - pattern string - text string - pass bool - }{ + tests := []matchesTestCase{ {"**", "file", true}, {"**", "file/", true}, {"**/", "file", true}, // weird one @@ -325,7 +318,7 @@ func TestMatches(t *testing.T) { {"**", "/", true}, {"**/", "/", true}, {"**", "dir/file", true}, - {"**/", "dir/file", false}, + {"**/", "dir/file", true}, {"**", "dir/file/", true}, {"**/", "dir/file/", true}, {"**/**", "dir/file", true}, @@ -335,7 +328,7 @@ func TestMatches(t *testing.T) { {"dir/**", "dir/dir2/file", true}, {"dir/**", "dir/dir2/file/", true}, {"**/dir2/*", "dir/dir2/file", true}, - {"**/dir2/*", "dir/dir2/file/", false}, + {"**/dir2/*", "dir/dir2/file/", true}, {"**/dir2/**", "dir/dir2/dir3/file", true}, {"**/dir2/**", "dir/dir2/dir3/file/", true}, {"**file", "file", true}, @@ -369,9 +362,6 @@ func TestMatches(t *testing.T) { {"abc.def", "abcZdef", false}, {"abc?def", "abcZdef", true}, {"abc?def", "abcdef", false}, - {"a\\*b", "a*b", true}, - {"a\\", "a", false}, - {"a\\", "a\\", false}, {"a\\\\", "a\\", true}, {"**/foo/bar", "foo/bar", true}, {"**/foo/bar", "dir/foo/bar", true}, @@ -379,78 +369,94 @@ func TestMatches(t *testing.T) { {"abc/**", "abc", false}, {"abc/**", "abc/def", true}, {"abc/**", "abc/def/ghi", true}, + {"**/.foo", ".foo", true}, + {"**/.foo", "bar.foo", false}, } - for _, test := range tests { - res, _ := regexpMatch(test.pattern, test.text) - if res != test.pass { - t.Fatalf("Failed: %v - res:%v", test, res) - } + if runtime.GOOS != "windows" { + tests = append(tests, []matchesTestCase{ + {"a\\*b", "a*b", true}, + {"a\\", "a", false}, + {"a\\", "a\\", false}, + }...) } -} -// An empty string should return true from Empty. -func TestEmpty(t *testing.T) { - empty := empty("") - if !empty { - t.Errorf("failed to get true for an empty string, got %v", empty) + for _, test := range tests { + desc := fmt.Sprintf("pattern=%q text=%q", test.pattern, test.text) + pm, err := NewPatternMatcher([]string{test.pattern}) + require.NoError(t, err, desc) + res, _ := pm.Matches(test.text) + assert.Equal(t, test.pass, res, desc) } } func TestCleanPatterns(t *testing.T) { - cleaned, _, _, _ := CleanPatterns([]string{"docs", "config"}) + patterns := []string{"docs", "config"} + pm, err := NewPatternMatcher(patterns) + if err != nil { + t.Fatalf("invalid pattern %v", patterns) + } + cleaned := pm.Patterns() if len(cleaned) != 2 { t.Errorf("expected 2 element slice, got %v", len(cleaned)) } } func TestCleanPatternsStripEmptyPatterns(t *testing.T) { - cleaned, _, _, _ := CleanPatterns([]string{"docs", "config", ""}) + patterns := []string{"docs", "config", ""} + pm, err := NewPatternMatcher(patterns) + if err != nil { + t.Fatalf("invalid pattern %v", patterns) + } + cleaned := pm.Patterns() if len(cleaned) != 2 { t.Errorf("expected 2 element slice, got %v", len(cleaned)) } } func TestCleanPatternsExceptionFlag(t *testing.T) { - _, _, exceptions, _ := CleanPatterns([]string{"docs", "!docs/README.md"}) - if !exceptions { - t.Errorf("expected exceptions to be true, got %v", exceptions) + patterns := []string{"docs", "!docs/README.md"} + pm, err := NewPatternMatcher(patterns) + if err != nil { + t.Fatalf("invalid pattern %v", patterns) + } + if !pm.Exclusions() { + t.Errorf("expected exceptions to be true, got %v", pm.Exclusions()) } } func TestCleanPatternsLeadingSpaceTrimmed(t *testing.T) { - _, _, exceptions, _ := CleanPatterns([]string{"docs", " !docs/README.md"}) - if !exceptions { - t.Errorf("expected exceptions to be true, got %v", exceptions) + patterns := []string{"docs", " !docs/README.md"} + pm, err := NewPatternMatcher(patterns) + if err != nil { + t.Fatalf("invalid pattern %v", patterns) + } + if !pm.Exclusions() { + t.Errorf("expected exceptions to be true, got %v", pm.Exclusions()) } } func TestCleanPatternsTrailingSpaceTrimmed(t *testing.T) { - _, _, exceptions, _ := CleanPatterns([]string{"docs", "!docs/README.md "}) - if !exceptions { - t.Errorf("expected exceptions to be true, got %v", exceptions) + patterns := []string{"docs", "!docs/README.md "} + pm, err := NewPatternMatcher(patterns) + if err != nil { + t.Fatalf("invalid pattern %v", patterns) + } + if !pm.Exclusions() { + t.Errorf("expected exceptions to be true, got %v", pm.Exclusions()) } } func TestCleanPatternsErrorSingleException(t *testing.T) { - _, _, _, err := CleanPatterns([]string{"!"}) + patterns := []string{"!"} + _, err := NewPatternMatcher(patterns) if err == nil { t.Errorf("expected error on single exclamation point, got %v", err) } } -func TestCleanPatternsFolderSplit(t *testing.T) { - _, dirs, _, _ := CleanPatterns([]string{"docs/config/CONFIG.md"}) - if dirs[0][0] != "docs" { - t.Errorf("expected first element in dirs slice to be docs, got %v", dirs[0][1]) - } - if dirs[0][1] != "config" { - t.Errorf("expected first element in dirs slice to be config, got %v", dirs[0][1]) - } -} - func TestCreateIfNotExistsDir(t *testing.T) { - tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + tempFolder, err := ioutil.TempDir("", "storage-fileutils-test") if err != nil { t.Fatal(err) } @@ -472,7 +478,7 @@ func TestCreateIfNotExistsDir(t *testing.T) { } func TestCreateIfNotExistsFile(t *testing.T) { - tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + tempFolder, err := ioutil.TempDir("", "storage-fileutils-test") if err != nil { t.Fatal(err) } @@ -506,7 +512,7 @@ var matchTests = []matchTest{ {"*c", "abc", true, nil}, {"a*", "a", true, nil}, {"a*", "abc", true, nil}, - {"a*", "ab/c", false, nil}, + {"a*", "ab/c", true, nil}, {"a*/b", "abc/b", true, nil}, {"a*/b", "a/c/b", false, nil}, {"a*b*c*d*e*/f", "axbxcxdxe/f", true, nil}, @@ -570,14 +576,14 @@ func TestMatch(t *testing.T) { pattern := tt.pattern s := tt.s if runtime.GOOS == "windows" { - if strings.Index(pattern, "\\") >= 0 { + if strings.Contains(pattern, "\\") { // no escape allowed on windows. continue } pattern = filepath.Clean(pattern) s = filepath.Clean(s) } - ok, err := regexpMatch(pattern, s) + ok, err := Matches(s, []string{pattern}) if ok != tt.match || err != tt.err { t.Fatalf("Match(%#q, %#q) = %v, %q want %v, %q", pattern, s, ok, errp(err), tt.match, errp(tt.err)) } diff --git a/pkg/fsutils/fsutils_linux.go b/pkg/fsutils/fsutils_linux.go new file mode 100644 index 0000000000..e6094b55b7 --- /dev/null +++ b/pkg/fsutils/fsutils_linux.go @@ -0,0 +1,88 @@ +// +build linux + +package fsutils + +import ( + "fmt" + "io/ioutil" + "os" + "unsafe" + + "golang.org/x/sys/unix" +) + +func locateDummyIfEmpty(path string) (string, error) { + children, err := ioutil.ReadDir(path) + if err != nil { + return "", err + } + if len(children) != 0 { + return "", nil + } + dummyFile, err := ioutil.TempFile(path, "fsutils-dummy") + if err != nil { + return "", err + } + name := dummyFile.Name() + err = dummyFile.Close() + return name, err +} + +// SupportsDType returns whether the filesystem mounted on path supports d_type +func SupportsDType(path string) (bool, error) { + // locate dummy so that we have at least one dirent + dummy, err := locateDummyIfEmpty(path) + if err != nil { + return false, err + } + if dummy != "" { + defer os.Remove(dummy) + } + + visited := 0 + supportsDType := true + fn := func(ent *unix.Dirent) bool { + visited++ + if ent.Type == unix.DT_UNKNOWN { + supportsDType = false + // stop iteration + return true + } + // continue iteration + return false + } + if err = iterateReadDir(path, fn); err != nil { + return false, err + } + if visited == 0 { + return false, fmt.Errorf("did not hit any dirent during iteration %s", path) + } + return supportsDType, nil +} + +func iterateReadDir(path string, fn func(*unix.Dirent) bool) error { + d, err := os.Open(path) + if err != nil { + return err + } + defer d.Close() + fd := int(d.Fd()) + buf := make([]byte, 4096) + for { + nbytes, err := unix.ReadDirent(fd, buf) + if err != nil { + return err + } + if nbytes == 0 { + break + } + for off := 0; off < nbytes; { + ent := (*unix.Dirent)(unsafe.Pointer(&buf[off])) + if stop := fn(ent); stop { + return nil + } + off += int(ent.Reclen) + } + } + return nil +} diff --git a/pkg/homedir/homedir_linux.go b/pkg/homedir/homedir_linux.go new file mode 100644 index 0000000000..c001fbecbf --- /dev/null +++ b/pkg/homedir/homedir_linux.go @@ -0,0 +1,23 @@ +// +build linux + +package homedir + +import ( + "os" + + "github.com/containers/storage/pkg/idtools" +) + +// GetStatic returns the home directory for the current user without calling +// os/user.Current(). This is useful for static-linked binary on glibc-based +// system, because a call to os/user.Current() in a static binary leads to +// segfault due to a glibc issue that won't be fixed in a short term. +// (#29344, golang/go#13470, https://sourceware.org/bugzilla/show_bug.cgi?id=19341) +func GetStatic() (string, error) { + uid := os.Getuid() + usr, err := idtools.LookupUID(uid) + if err != nil { + return "", err + } + return usr.Home, nil +} diff --git a/pkg/homedir/homedir_others.go b/pkg/homedir/homedir_others.go new file mode 100644 index 0000000000..6b96b856f6 --- /dev/null +++ b/pkg/homedir/homedir_others.go @@ -0,0 +1,13 @@ +// +build !linux + +package homedir + +import ( + "errors" +) + +// GetStatic is not needed for non-linux systems. +// (Precisely, it is needed only for glibc-based linux systems.) +func GetStatic() (string, error) { + return "", errors.New("homedir.GetStatic() is not supported on this system") +} diff --git a/pkg/homedir/homedir.go b/pkg/homedir/homedir_unix.go similarity index 77% rename from pkg/homedir/homedir.go rename to pkg/homedir/homedir_unix.go index 8154e83f0c..f2a20ea8f8 100644 --- a/pkg/homedir/homedir.go +++ b/pkg/homedir/homedir_unix.go @@ -1,8 +1,9 @@ +// +build !windows + package homedir import ( "os" - "runtime" "github.com/opencontainers/runc/libcontainer/user" ) @@ -10,9 +11,6 @@ import ( // Key returns the env var name for the user's home dir based on // the platform being run on func Key() string { - if runtime.GOOS == "windows" { - return "USERPROFILE" - } return "HOME" } @@ -21,7 +19,7 @@ func Key() string { // Returned path should be used with "path/filepath" to form new paths. func Get() string { home := os.Getenv(Key()) - if home == "" && runtime.GOOS != "windows" { + if home == "" { if u, err := user.CurrentUser(); err == nil { return u.Home } @@ -32,8 +30,5 @@ func Get() string { // GetShortcutString returns the string that is shortcut to user's home directory // in the native shell of the platform running on. func GetShortcutString() string { - if runtime.GOOS == "windows" { - return "%USERPROFILE%" // be careful while using in format functions - } return "~" } diff --git a/pkg/homedir/homedir_windows.go b/pkg/homedir/homedir_windows.go new file mode 100644 index 0000000000..fafdb2bbf9 --- /dev/null +++ b/pkg/homedir/homedir_windows.go @@ -0,0 +1,24 @@ +package homedir + +import ( + "os" +) + +// Key returns the env var name for the user's home dir based on +// the platform being run on +func Key() string { + return "USERPROFILE" +} + +// Get returns the home directory of the current user with the help of +// environment variables depending on the target operating system. +// Returned path should be used with "path/filepath" to form new paths. +func Get() string { + return os.Getenv(Key()) +} + +// GetShortcutString returns the string that is shortcut to user's home directory +// in the native shell of the platform running on. +func GetShortcutString() string { + return "%USERPROFILE%" // be careful while using in format functions +} diff --git a/pkg/idtools/idtools.go b/pkg/idtools/idtools.go index 6bca466286..68a072db22 100644 --- a/pkg/idtools/idtools.go +++ b/pkg/idtools/idtools.go @@ -37,49 +37,56 @@ const ( // MkdirAllAs creates a directory (include any along the path) and then modifies // ownership to the requested uid/gid. If the directory already exists, this // function will still change ownership to the requested uid/gid pair. +// Deprecated: Use MkdirAllAndChown func MkdirAllAs(path string, mode os.FileMode, ownerUID, ownerGID int) error { return mkdirAs(path, mode, ownerUID, ownerGID, true, true) } -// MkdirAllNewAs creates a directory (include any along the path) and then modifies -// ownership ONLY of newly created directories to the requested uid/gid. If the -// directories along the path exist, no change of ownership will be performed -func MkdirAllNewAs(path string, mode os.FileMode, ownerUID, ownerGID int) error { - return mkdirAs(path, mode, ownerUID, ownerGID, true, false) -} - // MkdirAs creates a directory and then modifies ownership to the requested uid/gid. // If the directory already exists, this function still changes ownership +// Deprecated: Use MkdirAndChown with a IDPair func MkdirAs(path string, mode os.FileMode, ownerUID, ownerGID int) error { return mkdirAs(path, mode, ownerUID, ownerGID, false, true) } +// MkdirAllAndChown creates a directory (include any along the path) and then modifies +// ownership to the requested uid/gid. If the directory already exists, this +// function will still change ownership to the requested uid/gid pair. +func MkdirAllAndChown(path string, mode os.FileMode, ids IDPair) error { + return mkdirAs(path, mode, ids.UID, ids.GID, true, true) +} + +// MkdirAndChown creates a directory and then modifies ownership to the requested uid/gid. +// If the directory already exists, this function still changes ownership +func MkdirAndChown(path string, mode os.FileMode, ids IDPair) error { + return mkdirAs(path, mode, ids.UID, ids.GID, false, true) +} + +// MkdirAllAndChownNew creates a directory (include any along the path) and then modifies +// ownership ONLY of newly created directories to the requested uid/gid. If the +// directories along the path exist, no change of ownership will be performed +func MkdirAllAndChownNew(path string, mode os.FileMode, ids IDPair) error { + return mkdirAs(path, mode, ids.UID, ids.GID, true, false) +} + // GetRootUIDGID retrieves the remapped root uid/gid pair from the set of maps. // If the maps are empty, then the root uid/gid will default to "real" 0/0 func GetRootUIDGID(uidMap, gidMap []IDMap) (int, int, error) { - var uid, gid int - - if uidMap != nil { - xUID, err := ToHost(0, uidMap) - if err != nil { - return -1, -1, err - } - uid = xUID + uid, err := toHost(0, uidMap) + if err != nil { + return -1, -1, err } - if gidMap != nil { - xGID, err := ToHost(0, gidMap) - if err != nil { - return -1, -1, err - } - gid = xGID + gid, err := toHost(0, gidMap) + if err != nil { + return -1, -1, err } return uid, gid, nil } -// ToContainer takes an id mapping, and uses it to translate a +// toContainer takes an id mapping, and uses it to translate a // host ID to the remapped ID. If no map is provided, then the translation // assumes a 1-to-1 mapping and returns the passed in id -func ToContainer(hostID int, idMap []IDMap) (int, error) { +func toContainer(hostID int, idMap []IDMap) (int, error) { if idMap == nil { return hostID, nil } @@ -92,10 +99,10 @@ func ToContainer(hostID int, idMap []IDMap) (int, error) { return -1, fmt.Errorf("Host ID %d cannot be mapped to a container ID", hostID) } -// ToHost takes an id mapping and a remapped ID, and translates the +// toHost takes an id mapping and a remapped ID, and translates the // ID to the mapped host ID. If no map is provided, then the translation // assumes a 1-to-1 mapping and returns the passed in id # -func ToHost(contID int, idMap []IDMap) (int, error) { +func toHost(contID int, idMap []IDMap) (int, error) { if idMap == nil { return contID, nil } @@ -108,26 +115,101 @@ func ToHost(contID int, idMap []IDMap) (int, error) { return -1, fmt.Errorf("Container ID %d cannot be mapped to a host ID", contID) } -// CreateIDMappings takes a requested user and group name and +// IDPair is a UID and GID pair +type IDPair struct { + UID int + GID int +} + +// IDMappings contains a mappings of UIDs and GIDs +type IDMappings struct { + uids []IDMap + gids []IDMap +} + +// NewIDMappings takes a requested user and group name and // using the data from /etc/sub{uid,gid} ranges, creates the // proper uid and gid remapping ranges for that user/group pair -func CreateIDMappings(username, groupname string) ([]IDMap, []IDMap, error) { +func NewIDMappings(username, groupname string) (*IDMappings, error) { subuidRanges, err := parseSubuid(username) if err != nil { - return nil, nil, err + return nil, err } subgidRanges, err := parseSubgid(groupname) if err != nil { - return nil, nil, err + return nil, err } if len(subuidRanges) == 0 { - return nil, nil, fmt.Errorf("No subuid ranges found for user %q", username) + return nil, fmt.Errorf("No subuid ranges found for user %q", username) } if len(subgidRanges) == 0 { - return nil, nil, fmt.Errorf("No subgid ranges found for group %q", groupname) + return nil, fmt.Errorf("No subgid ranges found for group %q", groupname) + } + + return &IDMappings{ + uids: createIDMap(subuidRanges), + gids: createIDMap(subgidRanges), + }, nil +} + +// NewIDMappingsFromMaps creates a new mapping from two slices +// Deprecated: this is a temporary shim while transitioning to IDMapping +func NewIDMappingsFromMaps(uids []IDMap, gids []IDMap) *IDMappings { + return &IDMappings{uids: uids, gids: gids} +} + +// RootPair returns a uid and gid pair for the root user. The error is ignored +// because a root user always exists, and the defaults are correct when the uid +// and gid maps are empty. +func (i *IDMappings) RootPair() IDPair { + uid, gid, _ := GetRootUIDGID(i.uids, i.gids) + return IDPair{UID: uid, GID: gid} +} + +// ToHost returns the host UID and GID for the container uid, gid. +// Remapping is only performed if the ids aren't already the remapped root ids +func (i *IDMappings) ToHost(pair IDPair) (IDPair, error) { + var err error + target := i.RootPair() + + if pair.UID != target.UID { + target.UID, err = toHost(pair.UID, i.uids) + if err != nil { + return target, err + } + } + + if pair.GID != target.GID { + target.GID, err = toHost(pair.GID, i.gids) + } + return target, err +} + +// ToContainer returns the container UID and GID for the host uid and gid +func (i *IDMappings) ToContainer(pair IDPair) (int, int, error) { + uid, err := toContainer(pair.UID, i.uids) + if err != nil { + return -1, -1, err } + gid, err := toContainer(pair.GID, i.gids) + return uid, gid, err +} + +// Empty returns true if there are no id mappings +func (i *IDMappings) Empty() bool { + return len(i.uids) == 0 && len(i.gids) == 0 +} + +// UIDs return the UID mapping +// TODO: remove this once everything has been refactored to use pairs +func (i *IDMappings) UIDs() []IDMap { + return i.uids +} - return createIDMap(subuidRanges), createIDMap(subgidRanges), nil +// GIDs return the UID mapping +// TODO: remove this once everything has been refactored to use pairs +func (i *IDMappings) GIDs() []IDMap { + return i.gids } func createIDMap(subidRanges ranges) []IDMap { diff --git a/pkg/idtools/idtools_unix.go b/pkg/idtools/idtools_unix.go index b2cfb05e4f..b5870506a0 100644 --- a/pkg/idtools/idtools_unix.go +++ b/pkg/idtools/idtools_unix.go @@ -3,10 +3,21 @@ package idtools import ( + "bytes" + "fmt" + "io" "os" "path/filepath" + "strings" + "sync" "github.com/containers/storage/pkg/system" + "github.com/opencontainers/runc/libcontainer/user" +) + +var ( + entOnce sync.Once + getentCmd string ) func mkdirAs(path string, mode os.FileMode, ownerUID, ownerGID int, mkAll, chownExisting bool) error { @@ -18,11 +29,8 @@ func mkdirAs(path string, mode os.FileMode, ownerUID, ownerGID int, mkAll, chown if _, err := os.Stat(path); err != nil && os.IsNotExist(err) { paths = []string{path} } else if err == nil && chownExisting { - if err := os.Chown(path, ownerUID, ownerGID); err != nil { - return err - } // short-circuit--we were called with an existing directory and chown was requested - return nil + return os.Chown(path, ownerUID, ownerGID) } else if err == nil { // nothing to do; directory path fully exists already and chown was NOT requested return nil @@ -41,7 +49,7 @@ func mkdirAs(path string, mode os.FileMode, ownerUID, ownerGID int, mkAll, chown paths = append(paths, dirPath) } } - if err := system.MkdirAll(path, mode); err != nil && !os.IsExist(err) { + if err := system.MkdirAll(path, mode, ""); err != nil && !os.IsExist(err) { return err } } else { @@ -58,3 +66,139 @@ func mkdirAs(path string, mode os.FileMode, ownerUID, ownerGID int, mkAll, chown } return nil } + +// CanAccess takes a valid (existing) directory and a uid, gid pair and determines +// if that uid, gid pair has access (execute bit) to the directory +func CanAccess(path string, pair IDPair) bool { + statInfo, err := system.Stat(path) + if err != nil { + return false + } + fileMode := os.FileMode(statInfo.Mode()) + permBits := fileMode.Perm() + return accessible(statInfo.UID() == uint32(pair.UID), + statInfo.GID() == uint32(pair.GID), permBits) +} + +func accessible(isOwner, isGroup bool, perms os.FileMode) bool { + if isOwner && (perms&0100 == 0100) { + return true + } + if isGroup && (perms&0010 == 0010) { + return true + } + if perms&0001 == 0001 { + return true + } + return false +} + +// LookupUser uses traditional local system files lookup (from libcontainer/user) on a username, +// followed by a call to `getent` for supporting host configured non-files passwd and group dbs +func LookupUser(username string) (user.User, error) { + // first try a local system files lookup using existing capabilities + usr, err := user.LookupUser(username) + if err == nil { + return usr, nil + } + // local files lookup failed; attempt to call `getent` to query configured passwd dbs + usr, err = getentUser(fmt.Sprintf("%s %s", "passwd", username)) + if err != nil { + return user.User{}, err + } + return usr, nil +} + +// LookupUID uses traditional local system files lookup (from libcontainer/user) on a uid, +// followed by a call to `getent` for supporting host configured non-files passwd and group dbs +func LookupUID(uid int) (user.User, error) { + // first try a local system files lookup using existing capabilities + usr, err := user.LookupUid(uid) + if err == nil { + return usr, nil + } + // local files lookup failed; attempt to call `getent` to query configured passwd dbs + return getentUser(fmt.Sprintf("%s %d", "passwd", uid)) +} + +func getentUser(args string) (user.User, error) { + reader, err := callGetent(args) + if err != nil { + return user.User{}, err + } + users, err := user.ParsePasswd(reader) + if err != nil { + return user.User{}, err + } + if len(users) == 0 { + return user.User{}, fmt.Errorf("getent failed to find passwd entry for %q", strings.Split(args, " ")[1]) + } + return users[0], nil +} + +// LookupGroup uses traditional local system files lookup (from libcontainer/user) on a group name, +// followed by a call to `getent` for supporting host configured non-files passwd and group dbs +func LookupGroup(groupname string) (user.Group, error) { + // first try a local system files lookup using existing capabilities + group, err := user.LookupGroup(groupname) + if err == nil { + return group, nil + } + // local files lookup failed; attempt to call `getent` to query configured group dbs + return getentGroup(fmt.Sprintf("%s %s", "group", groupname)) +} + +// LookupGID uses traditional local system files lookup (from libcontainer/user) on a group ID, +// followed by a call to `getent` for supporting host configured non-files passwd and group dbs +func LookupGID(gid int) (user.Group, error) { + // first try a local system files lookup using existing capabilities + group, err := user.LookupGid(gid) + if err == nil { + return group, nil + } + // local files lookup failed; attempt to call `getent` to query configured group dbs + return getentGroup(fmt.Sprintf("%s %d", "group", gid)) +} + +func getentGroup(args string) (user.Group, error) { + reader, err := callGetent(args) + if err != nil { + return user.Group{}, err + } + groups, err := user.ParseGroup(reader) + if err != nil { + return user.Group{}, err + } + if len(groups) == 0 { + return user.Group{}, fmt.Errorf("getent failed to find groups entry for %q", strings.Split(args, " ")[1]) + } + return groups[0], nil +} + +func callGetent(args string) (io.Reader, error) { + entOnce.Do(func() { getentCmd, _ = resolveBinary("getent") }) + // if no `getent` command on host, can't do anything else + if getentCmd == "" { + return nil, fmt.Errorf("") + } + out, err := execCmd(getentCmd, args) + if err != nil { + exitCode, errC := system.GetExitCode(err) + if errC != nil { + return nil, err + } + switch exitCode { + case 1: + return nil, fmt.Errorf("getent reported invalid parameters/database unknown") + case 2: + terms := strings.Split(args, " ") + return nil, fmt.Errorf("getent unable to find entry %q in %s database", terms[1], terms[0]) + case 3: + return nil, fmt.Errorf("getent database doesn't support enumeration") + default: + return nil, err + } + + } + return bytes.NewReader(out), nil +} diff --git a/pkg/idtools/idtools_unix_test.go b/pkg/idtools/idtools_unix_test.go index 540d3079ee..2463342a65 100644 --- a/pkg/idtools/idtools_unix_test.go +++ b/pkg/idtools/idtools_unix_test.go @@ -7,8 +7,10 @@ import ( "io/ioutil" "os" "path/filepath" - "syscall" "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" ) type node struct { @@ -76,12 +78,9 @@ func TestMkdirAllAs(t *testing.T) { } } -func TestMkdirAllNewAs(t *testing.T) { - +func TestMkdirAllAndChownNew(t *testing.T) { dirName, err := ioutil.TempDir("", "mkdirnew") - if err != nil { - t.Fatalf("Couldn't create temp dir: %v", err) - } + require.NoError(t, err) defer os.RemoveAll(dirName) testTree := map[string]node{ @@ -91,49 +90,32 @@ func TestMkdirAllNewAs(t *testing.T) { "lib/x86_64": {45, 45}, "lib/x86_64/share": {1, 1}, } - - if err := buildTree(dirName, testTree); err != nil { - t.Fatal(err) - } + require.NoError(t, buildTree(dirName, testTree)) // test adding a directory to a pre-existing dir; only the new dir is owned by the uid/gid - if err := MkdirAllNewAs(filepath.Join(dirName, "usr", "share"), 0755, 99, 99); err != nil { - t.Fatal(err) - } + err = MkdirAllAndChownNew(filepath.Join(dirName, "usr", "share"), 0755, IDPair{99, 99}) + require.NoError(t, err) + testTree["usr/share"] = node{99, 99} verifyTree, err := readTree(dirName, "") - if err != nil { - t.Fatal(err) - } - if err := compareTrees(testTree, verifyTree); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + require.NoError(t, compareTrees(testTree, verifyTree)) // test 2-deep new directories--both should be owned by the uid/gid pair - if err := MkdirAllNewAs(filepath.Join(dirName, "lib", "some", "other"), 0755, 101, 101); err != nil { - t.Fatal(err) - } + err = MkdirAllAndChownNew(filepath.Join(dirName, "lib", "some", "other"), 0755, IDPair{101, 101}) + require.NoError(t, err) testTree["lib/some"] = node{101, 101} testTree["lib/some/other"] = node{101, 101} verifyTree, err = readTree(dirName, "") - if err != nil { - t.Fatal(err) - } - if err := compareTrees(testTree, verifyTree); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + require.NoError(t, compareTrees(testTree, verifyTree)) // test a directory that already exists; should NOT be chowned - if err := MkdirAllNewAs(filepath.Join(dirName, "usr"), 0755, 102, 102); err != nil { - t.Fatal(err) - } + err = MkdirAllAndChownNew(filepath.Join(dirName, "usr"), 0755, IDPair{102, 102}) + require.NoError(t, err) verifyTree, err = readTree(dirName, "") - if err != nil { - t.Fatal(err) - } - if err := compareTrees(testTree, verifyTree); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + require.NoError(t, compareTrees(testTree, verifyTree)) } func TestMkdirAs(t *testing.T) { @@ -205,8 +187,8 @@ func readTree(base, root string) (map[string]node, error) { } for _, info := range dirInfos { - s := &syscall.Stat_t{} - if err := syscall.Stat(filepath.Join(base, info.Name()), s); err != nil { + s := &unix.Stat_t{} + if err := unix.Stat(filepath.Join(base, info.Name()), s); err != nil { return nil, fmt.Errorf("Can't stat file %q: %v", filepath.Join(base, info.Name()), err) } tree[filepath.Join(root, info.Name())] = node{int(s.Uid), int(s.Gid)} diff --git a/pkg/idtools/idtools_windows.go b/pkg/idtools/idtools_windows.go index 0cad173669..dbf6bc4c94 100644 --- a/pkg/idtools/idtools_windows.go +++ b/pkg/idtools/idtools_windows.go @@ -11,8 +11,15 @@ import ( // Platforms such as Windows do not support the UID/GID concept. So make this // just a wrapper around system.MkdirAll. func mkdirAs(path string, mode os.FileMode, ownerUID, ownerGID int, mkAll, chownExisting bool) error { - if err := system.MkdirAll(path, mode); err != nil && !os.IsExist(err) { + if err := system.MkdirAll(path, mode, ""); err != nil && !os.IsExist(err) { return err } return nil } + +// CanAccess takes a valid (existing) directory and a uid, gid pair and determines +// if that uid, gid pair has access (execute bit) to the directory +// Windows does not require/support this function, so always return true +func CanAccess(path string, pair IDPair) bool { + return true +} diff --git a/pkg/idtools/usergroupadd_linux.go b/pkg/idtools/usergroupadd_linux.go index 4a4aaed04d..9da7975e2c 100644 --- a/pkg/idtools/usergroupadd_linux.go +++ b/pkg/idtools/usergroupadd_linux.go @@ -2,8 +2,6 @@ package idtools import ( "fmt" - "os/exec" - "path/filepath" "regexp" "sort" "strconv" @@ -33,23 +31,6 @@ var ( userMod = "usermod" ) -func resolveBinary(binname string) (string, error) { - binaryPath, err := exec.LookPath(binname) - if err != nil { - return "", err - } - resolvedPath, err := filepath.EvalSymlinks(binaryPath) - if err != nil { - return "", err - } - //only return no error if the final resolved binary basename - //matches what was searched for - if filepath.Base(resolvedPath) == binname { - return resolvedPath, nil - } - return "", fmt.Errorf("Binary %q does not resolve to a binary of that name in $PATH (%q)", binname, resolvedPath) -} - // AddNamespaceRangesUser takes a username and uses the standard system // utility to create a system user/group pair used to hold the // /etc/sub{uid,gid} ranges which will be used for user namespace @@ -181,8 +162,3 @@ func wouldOverlap(arange subIDRange, ID int) bool { } return false } - -func execCmd(cmd, args string) ([]byte, error) { - execCmd := exec.Command(cmd, strings.Split(args, " ")...) - return execCmd.CombinedOutput() -} diff --git a/pkg/idtools/utils_unix.go b/pkg/idtools/utils_unix.go new file mode 100644 index 0000000000..9703ecbd9d --- /dev/null +++ b/pkg/idtools/utils_unix.go @@ -0,0 +1,32 @@ +// +build !windows + +package idtools + +import ( + "fmt" + "os/exec" + "path/filepath" + "strings" +) + +func resolveBinary(binname string) (string, error) { + binaryPath, err := exec.LookPath(binname) + if err != nil { + return "", err + } + resolvedPath, err := filepath.EvalSymlinks(binaryPath) + if err != nil { + return "", err + } + //only return no error if the final resolved binary basename + //matches what was searched for + if filepath.Base(resolvedPath) == binname { + return resolvedPath, nil + } + return "", fmt.Errorf("Binary %q does not resolve to a binary of that name in $PATH (%q)", binname, resolvedPath) +} + +func execCmd(cmd, args string) ([]byte, error) { + execCmd := exec.Command(cmd, strings.Split(args, " ")...) + return execCmd.CombinedOutput() +} diff --git a/pkg/ioutils/buffer_test.go b/pkg/ioutils/buffer_test.go index 41098fa6e7..f687124386 100644 --- a/pkg/ioutils/buffer_test.go +++ b/pkg/ioutils/buffer_test.go @@ -5,6 +5,81 @@ import ( "testing" ) +func TestFixedBufferCap(t *testing.T) { + buf := &fixedBuffer{buf: make([]byte, 0, 5)} + + n := buf.Cap() + if n != 5 { + t.Fatalf("expected buffer capacity to be 5 bytes, got %d", n) + } +} + +func TestFixedBufferLen(t *testing.T) { + buf := &fixedBuffer{buf: make([]byte, 0, 10)} + + buf.Write([]byte("hello")) + l := buf.Len() + if l != 5 { + t.Fatalf("expected buffer length to be 5 bytes, got %d", l) + } + + buf.Write([]byte("world")) + l = buf.Len() + if l != 10 { + t.Fatalf("expected buffer length to be 10 bytes, got %d", l) + } + + // read 5 bytes + b := make([]byte, 5) + buf.Read(b) + + l = buf.Len() + if l != 5 { + t.Fatalf("expected buffer length to be 5 bytes, got %d", l) + } + + n, err := buf.Write([]byte("i-wont-fit")) + if n != 0 { + t.Fatalf("expected no bytes to be written to buffer, got %d", n) + } + if err != errBufferFull { + t.Fatalf("expected errBufferFull, got %v", err) + } + + l = buf.Len() + if l != 5 { + t.Fatalf("expected buffer length to still be 5 bytes, got %d", l) + } + + buf.Reset() + l = buf.Len() + if l != 0 { + t.Fatalf("expected buffer length to still be 0 bytes, got %d", l) + } +} + +func TestFixedBufferString(t *testing.T) { + buf := &fixedBuffer{buf: make([]byte, 0, 10)} + + buf.Write([]byte("hello")) + buf.Write([]byte("world")) + + out := buf.String() + if out != "helloworld" { + t.Fatalf("expected output to be \"helloworld\", got %q", out) + } + + // read 5 bytes + b := make([]byte, 5) + buf.Read(b) + + // test that fixedBuffer.String() only returns the part that hasn't been read + out = buf.String() + if out != "world" { + t.Fatalf("expected output to be \"world\", got %q", out) + } +} + func TestFixedBufferWrite(t *testing.T) { buf := &fixedBuffer{buf: make([]byte, 0, 64)} n, err := buf.Write([]byte("hello")) @@ -21,6 +96,9 @@ func TestFixedBufferWrite(t *testing.T) { } n, err = buf.Write(bytes.Repeat([]byte{1}, 64)) + if n != 59 { + t.Fatalf("expected 59 bytes written before buffer is full, got %d", n) + } if err != errBufferFull { t.Fatalf("expected errBufferFull, got %v - %v", err, buf.buf[:64]) } diff --git a/pkg/ioutils/bytespipe.go b/pkg/ioutils/bytespipe.go new file mode 100644 index 0000000000..72a04f3491 --- /dev/null +++ b/pkg/ioutils/bytespipe.go @@ -0,0 +1,186 @@ +package ioutils + +import ( + "errors" + "io" + "sync" +) + +// maxCap is the highest capacity to use in byte slices that buffer data. +const maxCap = 1e6 + +// minCap is the lowest capacity to use in byte slices that buffer data +const minCap = 64 + +// blockThreshold is the minimum number of bytes in the buffer which will cause +// a write to BytesPipe to block when allocating a new slice. +const blockThreshold = 1e6 + +var ( + // ErrClosed is returned when Write is called on a closed BytesPipe. + ErrClosed = errors.New("write to closed BytesPipe") + + bufPools = make(map[int]*sync.Pool) + bufPoolsLock sync.Mutex +) + +// BytesPipe is io.ReadWriteCloser which works similarly to pipe(queue). +// All written data may be read at most once. Also, BytesPipe allocates +// and releases new byte slices to adjust to current needs, so the buffer +// won't be overgrown after peak loads. +type BytesPipe struct { + mu sync.Mutex + wait *sync.Cond + buf []*fixedBuffer + bufLen int + closeErr error // error to return from next Read. set to nil if not closed. +} + +// NewBytesPipe creates new BytesPipe, initialized by specified slice. +// If buf is nil, then it will be initialized with slice which cap is 64. +// buf will be adjusted in a way that len(buf) == 0, cap(buf) == cap(buf). +func NewBytesPipe() *BytesPipe { + bp := &BytesPipe{} + bp.buf = append(bp.buf, getBuffer(minCap)) + bp.wait = sync.NewCond(&bp.mu) + return bp +} + +// Write writes p to BytesPipe. +// It can allocate new []byte slices in a process of writing. +func (bp *BytesPipe) Write(p []byte) (int, error) { + bp.mu.Lock() + + written := 0 +loop0: + for { + if bp.closeErr != nil { + bp.mu.Unlock() + return written, ErrClosed + } + + if len(bp.buf) == 0 { + bp.buf = append(bp.buf, getBuffer(64)) + } + // get the last buffer + b := bp.buf[len(bp.buf)-1] + + n, err := b.Write(p) + written += n + bp.bufLen += n + + // errBufferFull is an error we expect to get if the buffer is full + if err != nil && err != errBufferFull { + bp.wait.Broadcast() + bp.mu.Unlock() + return written, err + } + + // if there was enough room to write all then break + if len(p) == n { + break + } + + // more data: write to the next slice + p = p[n:] + + // make sure the buffer doesn't grow too big from this write + for bp.bufLen >= blockThreshold { + bp.wait.Wait() + if bp.closeErr != nil { + continue loop0 + } + } + + // add new byte slice to the buffers slice and continue writing + nextCap := b.Cap() * 2 + if nextCap > maxCap { + nextCap = maxCap + } + bp.buf = append(bp.buf, getBuffer(nextCap)) + } + bp.wait.Broadcast() + bp.mu.Unlock() + return written, nil +} + +// CloseWithError causes further reads from a BytesPipe to return immediately. +func (bp *BytesPipe) CloseWithError(err error) error { + bp.mu.Lock() + if err != nil { + bp.closeErr = err + } else { + bp.closeErr = io.EOF + } + bp.wait.Broadcast() + bp.mu.Unlock() + return nil +} + +// Close causes further reads from a BytesPipe to return immediately. +func (bp *BytesPipe) Close() error { + return bp.CloseWithError(nil) +} + +// Read reads bytes from BytesPipe. +// Data could be read only once. +func (bp *BytesPipe) Read(p []byte) (n int, err error) { + bp.mu.Lock() + if bp.bufLen == 0 { + if bp.closeErr != nil { + bp.mu.Unlock() + return 0, bp.closeErr + } + bp.wait.Wait() + if bp.bufLen == 0 && bp.closeErr != nil { + err := bp.closeErr + bp.mu.Unlock() + return 0, err + } + } + + for bp.bufLen > 0 { + b := bp.buf[0] + read, _ := b.Read(p) // ignore error since fixedBuffer doesn't really return an error + n += read + bp.bufLen -= read + + if b.Len() == 0 { + // it's empty so return it to the pool and move to the next one + returnBuffer(b) + bp.buf[0] = nil + bp.buf = bp.buf[1:] + } + + if len(p) == read { + break + } + + p = p[read:] + } + + bp.wait.Broadcast() + bp.mu.Unlock() + return +} + +func returnBuffer(b *fixedBuffer) { + b.Reset() + bufPoolsLock.Lock() + pool := bufPools[b.Cap()] + bufPoolsLock.Unlock() + if pool != nil { + pool.Put(b) + } +} + +func getBuffer(size int) *fixedBuffer { + bufPoolsLock.Lock() + pool, ok := bufPools[size] + if !ok { + pool = &sync.Pool{New: func() interface{} { return &fixedBuffer{buf: make([]byte, 0, size)} }} + bufPools[size] = pool + } + bufPoolsLock.Unlock() + return pool.Get().(*fixedBuffer) +} diff --git a/pkg/ioutils/bytespipe_test.go b/pkg/ioutils/bytespipe_test.go new file mode 100644 index 0000000000..300fb5f6d5 --- /dev/null +++ b/pkg/ioutils/bytespipe_test.go @@ -0,0 +1,159 @@ +package ioutils + +import ( + "crypto/sha1" + "encoding/hex" + "math/rand" + "testing" + "time" +) + +func TestBytesPipeRead(t *testing.T) { + buf := NewBytesPipe() + buf.Write([]byte("12")) + buf.Write([]byte("34")) + buf.Write([]byte("56")) + buf.Write([]byte("78")) + buf.Write([]byte("90")) + rd := make([]byte, 4) + n, err := buf.Read(rd) + if err != nil { + t.Fatal(err) + } + if n != 4 { + t.Fatalf("Wrong number of bytes read: %d, should be %d", n, 4) + } + if string(rd) != "1234" { + t.Fatalf("Read %s, but must be %s", rd, "1234") + } + n, err = buf.Read(rd) + if err != nil { + t.Fatal(err) + } + if n != 4 { + t.Fatalf("Wrong number of bytes read: %d, should be %d", n, 4) + } + if string(rd) != "5678" { + t.Fatalf("Read %s, but must be %s", rd, "5679") + } + n, err = buf.Read(rd) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("Wrong number of bytes read: %d, should be %d", n, 2) + } + if string(rd[:n]) != "90" { + t.Fatalf("Read %s, but must be %s", rd, "90") + } +} + +func TestBytesPipeWrite(t *testing.T) { + buf := NewBytesPipe() + buf.Write([]byte("12")) + buf.Write([]byte("34")) + buf.Write([]byte("56")) + buf.Write([]byte("78")) + buf.Write([]byte("90")) + if buf.buf[0].String() != "1234567890" { + t.Fatalf("Buffer %q, must be %q", buf.buf[0].String(), "1234567890") + } +} + +// Write and read in different speeds/chunk sizes and check valid data is read. +func TestBytesPipeWriteRandomChunks(t *testing.T) { + cases := []struct{ iterations, writesPerLoop, readsPerLoop int }{ + {100, 10, 1}, + {1000, 10, 5}, + {1000, 100, 0}, + {1000, 5, 6}, + {10000, 50, 25}, + } + + testMessage := []byte("this is a random string for testing") + // random slice sizes to read and write + writeChunks := []int{25, 35, 15, 20} + readChunks := []int{5, 45, 20, 25} + + for _, c := range cases { + // first pass: write directly to hash + hash := sha1.New() + for i := 0; i < c.iterations*c.writesPerLoop; i++ { + if _, err := hash.Write(testMessage[:writeChunks[i%len(writeChunks)]]); err != nil { + t.Fatal(err) + } + } + expected := hex.EncodeToString(hash.Sum(nil)) + + // write/read through buffer + buf := NewBytesPipe() + hash.Reset() + + done := make(chan struct{}) + + go func() { + // random delay before read starts + <-time.After(time.Duration(rand.Intn(10)) * time.Millisecond) + for i := 0; ; i++ { + p := make([]byte, readChunks[(c.iterations*c.readsPerLoop+i)%len(readChunks)]) + n, _ := buf.Read(p) + if n == 0 { + break + } + hash.Write(p[:n]) + } + + close(done) + }() + + for i := 0; i < c.iterations; i++ { + for w := 0; w < c.writesPerLoop; w++ { + buf.Write(testMessage[:writeChunks[(i*c.writesPerLoop+w)%len(writeChunks)]]) + } + } + buf.Close() + <-done + + actual := hex.EncodeToString(hash.Sum(nil)) + + if expected != actual { + t.Fatalf("BytesPipe returned invalid data. Expected checksum %v, got %v", expected, actual) + } + + } +} + +func BenchmarkBytesPipeWrite(b *testing.B) { + testData := []byte("pretty short line, because why not?") + for i := 0; i < b.N; i++ { + readBuf := make([]byte, 1024) + buf := NewBytesPipe() + go func() { + var err error + for err == nil { + _, err = buf.Read(readBuf) + } + }() + for j := 0; j < 1000; j++ { + buf.Write(testData) + } + buf.Close() + } +} + +func BenchmarkBytesPipeRead(b *testing.B) { + rd := make([]byte, 512) + for i := 0; i < b.N; i++ { + b.StopTimer() + buf := NewBytesPipe() + for j := 0; j < 500; j++ { + buf.Write(make([]byte, 1024)) + } + b.StartTimer() + for j := 0; j < 1000; j++ { + if n, _ := buf.Read(rd); n != 512 { + b.Fatalf("Wrong number of bytes: %d", n) + } + } + } +} diff --git a/pkg/ioutils/fmt.go b/pkg/ioutils/fmt.go deleted file mode 100644 index 0b04b0ba3e..0000000000 --- a/pkg/ioutils/fmt.go +++ /dev/null @@ -1,22 +0,0 @@ -package ioutils - -import ( - "fmt" - "io" -) - -// FprintfIfNotEmpty prints the string value if it's not empty -func FprintfIfNotEmpty(w io.Writer, format, value string) (int, error) { - if value != "" { - return fmt.Fprintf(w, format, value) - } - return 0, nil -} - -// FprintfIfTrue prints the boolean value if it's true -func FprintfIfTrue(w io.Writer, format string, ok bool) (int, error) { - if ok { - return fmt.Fprintf(w, format, ok) - } - return 0, nil -} diff --git a/pkg/ioutils/fmt_test.go b/pkg/ioutils/fmt_test.go deleted file mode 100644 index 8968863296..0000000000 --- a/pkg/ioutils/fmt_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package ioutils - -import "testing" - -func TestFprintfIfNotEmpty(t *testing.T) { - wc := NewWriteCounter(&NopWriter{}) - n, _ := FprintfIfNotEmpty(wc, "foo%s", "") - - if wc.Count != 0 || n != 0 { - t.Errorf("Wrong count: %v vs. %v vs. 0", wc.Count, n) - } - - n, _ = FprintfIfNotEmpty(wc, "foo%s", "bar") - if wc.Count != 6 || n != 6 { - t.Errorf("Wrong count: %v vs. %v vs. 6", wc.Count, n) - } -} diff --git a/pkg/ioutils/fswriters.go b/pkg/ioutils/fswriters.go index 6dc50a03dc..a56c462651 100644 --- a/pkg/ioutils/fswriters.go +++ b/pkg/ioutils/fswriters.go @@ -80,3 +80,83 @@ func (w *atomicFileWriter) Close() (retErr error) { } return nil } + +// AtomicWriteSet is used to atomically write a set +// of files and ensure they are visible at the same time. +// Must be committed to a new directory. +type AtomicWriteSet struct { + root string +} + +// NewAtomicWriteSet creates a new atomic write set to +// atomically create a set of files. The given directory +// is used as the base directory for storing files before +// commit. If no temporary directory is given the system +// default is used. +func NewAtomicWriteSet(tmpDir string) (*AtomicWriteSet, error) { + td, err := ioutil.TempDir(tmpDir, "write-set-") + if err != nil { + return nil, err + } + + return &AtomicWriteSet{ + root: td, + }, nil +} + +// WriteFile writes a file to the set, guaranteeing the file +// has been synced. +func (ws *AtomicWriteSet) WriteFile(filename string, data []byte, perm os.FileMode) error { + f, err := ws.FileWriter(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) + if err != nil { + return err + } + n, err := f.Write(data) + if err == nil && n < len(data) { + err = io.ErrShortWrite + } + if err1 := f.Close(); err == nil { + err = err1 + } + return err +} + +type syncFileCloser struct { + *os.File +} + +func (w syncFileCloser) Close() error { + err := w.File.Sync() + if err1 := w.File.Close(); err == nil { + err = err1 + } + return err +} + +// FileWriter opens a file writer inside the set. The file +// should be synced and closed before calling commit. +func (ws *AtomicWriteSet) FileWriter(name string, flag int, perm os.FileMode) (io.WriteCloser, error) { + f, err := os.OpenFile(filepath.Join(ws.root, name), flag, perm) + if err != nil { + return nil, err + } + return syncFileCloser{f}, nil +} + +// Cancel cancels the set and removes all temporary data +// created in the set. +func (ws *AtomicWriteSet) Cancel() error { + return os.RemoveAll(ws.root) +} + +// Commit moves all created files to the target directory. The +// target directory must not exist and the parent of the target +// directory must exist. +func (ws *AtomicWriteSet) Commit(target string) error { + return os.Rename(ws.root, target) +} + +// String returns the location the set is writing to. +func (ws *AtomicWriteSet) String() string { + return ws.root +} diff --git a/pkg/ioutils/fswriters_test.go b/pkg/ioutils/fswriters_test.go index 470ca1a6f4..5d286005d2 100644 --- a/pkg/ioutils/fswriters_test.go +++ b/pkg/ioutils/fswriters_test.go @@ -5,9 +5,21 @@ import ( "io/ioutil" "os" "path/filepath" + "runtime" "testing" ) +var ( + testMode os.FileMode = 0640 +) + +func init() { + // Windows does not support full Linux file mode + if runtime.GOOS == "windows" { + testMode = 0666 + } +} + func TestAtomicWriteToFile(t *testing.T) { tmpDir, err := ioutil.TempDir("", "atomic-writers-test") if err != nil { @@ -16,7 +28,7 @@ func TestAtomicWriteToFile(t *testing.T) { defer os.RemoveAll(tmpDir) expected := []byte("barbaz") - if err := AtomicWriteFile(filepath.Join(tmpDir, "foo"), expected, 0666); err != nil { + if err := AtomicWriteFile(filepath.Join(tmpDir, "foo"), expected, testMode); err != nil { t.Fatalf("Error writing to file: %v", err) } @@ -25,7 +37,7 @@ func TestAtomicWriteToFile(t *testing.T) { t.Fatalf("Error reading from file: %v", err) } - if bytes.Compare(actual, expected) != 0 { + if !bytes.Equal(actual, expected) { t.Fatalf("Data mismatch, expected %q, got %q", expected, actual) } @@ -33,7 +45,88 @@ func TestAtomicWriteToFile(t *testing.T) { if err != nil { t.Fatalf("Error statting file: %v", err) } - if expected := os.FileMode(0666); st.Mode() != expected { + if expected := os.FileMode(testMode); st.Mode() != expected { t.Fatalf("Mode mismatched, expected %o, got %o", expected, st.Mode()) } } + +func TestAtomicWriteSetCommit(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "atomic-writerset-test") + if err != nil { + t.Fatalf("Error when creating temporary directory: %s", err) + } + defer os.RemoveAll(tmpDir) + + if err := os.Mkdir(filepath.Join(tmpDir, "tmp"), 0700); err != nil { + t.Fatalf("Error creating tmp directory: %s", err) + } + + targetDir := filepath.Join(tmpDir, "target") + ws, err := NewAtomicWriteSet(filepath.Join(tmpDir, "tmp")) + if err != nil { + t.Fatalf("Error creating atomic write set: %s", err) + } + + expected := []byte("barbaz") + if err := ws.WriteFile("foo", expected, testMode); err != nil { + t.Fatalf("Error writing to file: %v", err) + } + + if _, err := ioutil.ReadFile(filepath.Join(targetDir, "foo")); err == nil { + t.Fatalf("Expected error reading file where should not exist") + } + + if err := ws.Commit(targetDir); err != nil { + t.Fatalf("Error committing file: %s", err) + } + + actual, err := ioutil.ReadFile(filepath.Join(targetDir, "foo")) + if err != nil { + t.Fatalf("Error reading from file: %v", err) + } + + if !bytes.Equal(actual, expected) { + t.Fatalf("Data mismatch, expected %q, got %q", expected, actual) + } + + st, err := os.Stat(filepath.Join(targetDir, "foo")) + if err != nil { + t.Fatalf("Error statting file: %v", err) + } + if expected := os.FileMode(testMode); st.Mode() != expected { + t.Fatalf("Mode mismatched, expected %o, got %o", expected, st.Mode()) + } + +} + +func TestAtomicWriteSetCancel(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "atomic-writerset-test") + if err != nil { + t.Fatalf("Error when creating temporary directory: %s", err) + } + defer os.RemoveAll(tmpDir) + + if err := os.Mkdir(filepath.Join(tmpDir, "tmp"), 0700); err != nil { + t.Fatalf("Error creating tmp directory: %s", err) + } + + ws, err := NewAtomicWriteSet(filepath.Join(tmpDir, "tmp")) + if err != nil { + t.Fatalf("Error creating atomic write set: %s", err) + } + + expected := []byte("barbaz") + if err := ws.WriteFile("foo", expected, testMode); err != nil { + t.Fatalf("Error writing to file: %v", err) + } + + if err := ws.Cancel(); err != nil { + t.Fatalf("Error committing file: %s", err) + } + + if _, err := ioutil.ReadFile(filepath.Join(tmpDir, "target", "foo")); err == nil { + t.Fatalf("Expected error reading file where should not exist") + } else if !os.IsNotExist(err) { + t.Fatalf("Unexpected error reading file: %s", err) + } +} diff --git a/pkg/ioutils/multireader.go b/pkg/ioutils/multireader.go deleted file mode 100644 index 0d2d76b479..0000000000 --- a/pkg/ioutils/multireader.go +++ /dev/null @@ -1,226 +0,0 @@ -package ioutils - -import ( - "bytes" - "fmt" - "io" - "os" -) - -type pos struct { - idx int - offset int64 -} - -type multiReadSeeker struct { - readers []io.ReadSeeker - pos *pos - posIdx map[io.ReadSeeker]int -} - -func (r *multiReadSeeker) Seek(offset int64, whence int) (int64, error) { - var tmpOffset int64 - switch whence { - case os.SEEK_SET: - for i, rdr := range r.readers { - // get size of the current reader - s, err := rdr.Seek(0, os.SEEK_END) - if err != nil { - return -1, err - } - - if offset > tmpOffset+s { - if i == len(r.readers)-1 { - rdrOffset := s + (offset - tmpOffset) - if _, err := rdr.Seek(rdrOffset, os.SEEK_SET); err != nil { - return -1, err - } - r.pos = &pos{i, rdrOffset} - return offset, nil - } - - tmpOffset += s - continue - } - - rdrOffset := offset - tmpOffset - idx := i - - rdr.Seek(rdrOffset, os.SEEK_SET) - // make sure all following readers are at 0 - for _, rdr := range r.readers[i+1:] { - rdr.Seek(0, os.SEEK_SET) - } - - if rdrOffset == s && i != len(r.readers)-1 { - idx++ - rdrOffset = 0 - } - r.pos = &pos{idx, rdrOffset} - return offset, nil - } - case os.SEEK_END: - for _, rdr := range r.readers { - s, err := rdr.Seek(0, os.SEEK_END) - if err != nil { - return -1, err - } - tmpOffset += s - } - r.Seek(tmpOffset+offset, os.SEEK_SET) - return tmpOffset + offset, nil - case os.SEEK_CUR: - if r.pos == nil { - return r.Seek(offset, os.SEEK_SET) - } - // Just return the current offset - if offset == 0 { - return r.getCurOffset() - } - - curOffset, err := r.getCurOffset() - if err != nil { - return -1, err - } - rdr, rdrOffset, err := r.getReaderForOffset(curOffset + offset) - if err != nil { - return -1, err - } - - r.pos = &pos{r.posIdx[rdr], rdrOffset} - return curOffset + offset, nil - default: - return -1, fmt.Errorf("Invalid whence: %d", whence) - } - - return -1, fmt.Errorf("Error seeking for whence: %d, offset: %d", whence, offset) -} - -func (r *multiReadSeeker) getReaderForOffset(offset int64) (io.ReadSeeker, int64, error) { - var rdr io.ReadSeeker - var rdrOffset int64 - - for i, rdr := range r.readers { - offsetTo, err := r.getOffsetToReader(rdr) - if err != nil { - return nil, -1, err - } - if offsetTo > offset { - rdr = r.readers[i-1] - rdrOffset = offsetTo - offset - break - } - - if rdr == r.readers[len(r.readers)-1] { - rdrOffset = offsetTo + offset - break - } - } - - return rdr, rdrOffset, nil -} - -func (r *multiReadSeeker) getCurOffset() (int64, error) { - var totalSize int64 - for _, rdr := range r.readers[:r.pos.idx+1] { - if r.posIdx[rdr] == r.pos.idx { - totalSize += r.pos.offset - break - } - - size, err := getReadSeekerSize(rdr) - if err != nil { - return -1, fmt.Errorf("error getting seeker size: %v", err) - } - totalSize += size - } - return totalSize, nil -} - -func (r *multiReadSeeker) getOffsetToReader(rdr io.ReadSeeker) (int64, error) { - var offset int64 - for _, r := range r.readers { - if r == rdr { - break - } - - size, err := getReadSeekerSize(rdr) - if err != nil { - return -1, err - } - offset += size - } - return offset, nil -} - -func (r *multiReadSeeker) Read(b []byte) (int, error) { - if r.pos == nil { - r.pos = &pos{0, 0} - } - - bCap := int64(cap(b)) - buf := bytes.NewBuffer(nil) - var rdr io.ReadSeeker - - for _, rdr = range r.readers[r.pos.idx:] { - readBytes, err := io.CopyN(buf, rdr, bCap) - if err != nil && err != io.EOF { - return -1, err - } - bCap -= readBytes - - if bCap == 0 { - break - } - } - - rdrPos, err := rdr.Seek(0, os.SEEK_CUR) - if err != nil { - return -1, err - } - r.pos = &pos{r.posIdx[rdr], rdrPos} - return buf.Read(b) -} - -func getReadSeekerSize(rdr io.ReadSeeker) (int64, error) { - // save the current position - pos, err := rdr.Seek(0, os.SEEK_CUR) - if err != nil { - return -1, err - } - - // get the size - size, err := rdr.Seek(0, os.SEEK_END) - if err != nil { - return -1, err - } - - // reset the position - if _, err := rdr.Seek(pos, os.SEEK_SET); err != nil { - return -1, err - } - return size, nil -} - -// MultiReadSeeker returns a ReadSeeker that's the logical concatenation of the provided -// input readseekers. After calling this method the initial position is set to the -// beginning of the first ReadSeeker. At the end of a ReadSeeker, Read always advances -// to the beginning of the next ReadSeeker and returns EOF at the end of the last ReadSeeker. -// Seek can be used over the sum of lengths of all readseekers. -// -// When a MultiReadSeeker is used, no Read and Seek operations should be made on -// its ReadSeeker components. Also, users should make no assumption on the state -// of individual readseekers while the MultiReadSeeker is used. -func MultiReadSeeker(readers ...io.ReadSeeker) io.ReadSeeker { - if len(readers) == 1 { - return readers[0] - } - idx := make(map[io.ReadSeeker]int) - for i, rdr := range readers { - idx[rdr] = i - } - return &multiReadSeeker{ - readers: readers, - posIdx: idx, - } -} diff --git a/pkg/ioutils/multireader_test.go b/pkg/ioutils/multireader_test.go deleted file mode 100644 index de495b56da..0000000000 --- a/pkg/ioutils/multireader_test.go +++ /dev/null @@ -1,149 +0,0 @@ -package ioutils - -import ( - "bytes" - "fmt" - "io" - "io/ioutil" - "os" - "strings" - "testing" -) - -func TestMultiReadSeekerReadAll(t *testing.T) { - str := "hello world" - s1 := strings.NewReader(str + " 1") - s2 := strings.NewReader(str + " 2") - s3 := strings.NewReader(str + " 3") - mr := MultiReadSeeker(s1, s2, s3) - - expectedSize := int64(s1.Len() + s2.Len() + s3.Len()) - - b, err := ioutil.ReadAll(mr) - if err != nil { - t.Fatal(err) - } - - expected := "hello world 1hello world 2hello world 3" - if string(b) != expected { - t.Fatalf("ReadAll failed, got: %q, expected %q", string(b), expected) - } - - size, err := mr.Seek(0, os.SEEK_END) - if err != nil { - t.Fatal(err) - } - if size != expectedSize { - t.Fatalf("reader size does not match, got %d, expected %d", size, expectedSize) - } - - // Reset the position and read again - pos, err := mr.Seek(0, os.SEEK_SET) - if err != nil { - t.Fatal(err) - } - if pos != 0 { - t.Fatalf("expected position to be set to 0, got %d", pos) - } - - b, err = ioutil.ReadAll(mr) - if err != nil { - t.Fatal(err) - } - - if string(b) != expected { - t.Fatalf("ReadAll failed, got: %q, expected %q", string(b), expected) - } -} - -func TestMultiReadSeekerReadEach(t *testing.T) { - str := "hello world" - s1 := strings.NewReader(str + " 1") - s2 := strings.NewReader(str + " 2") - s3 := strings.NewReader(str + " 3") - mr := MultiReadSeeker(s1, s2, s3) - - var totalBytes int64 - for i, s := range []*strings.Reader{s1, s2, s3} { - sLen := int64(s.Len()) - buf := make([]byte, s.Len()) - expected := []byte(fmt.Sprintf("%s %d", str, i+1)) - - if _, err := mr.Read(buf); err != nil && err != io.EOF { - t.Fatal(err) - } - - if !bytes.Equal(buf, expected) { - t.Fatalf("expected %q to be %q", string(buf), string(expected)) - } - - pos, err := mr.Seek(0, os.SEEK_CUR) - if err != nil { - t.Fatalf("iteration: %d, error: %v", i+1, err) - } - - // check that the total bytes read is the current position of the seeker - totalBytes += sLen - if pos != totalBytes { - t.Fatalf("expected current position to be: %d, got: %d, iteration: %d", totalBytes, pos, i+1) - } - - // This tests not only that SEEK_SET and SEEK_CUR give the same values, but that the next iteration is in the expected position as well - newPos, err := mr.Seek(pos, os.SEEK_SET) - if err != nil { - t.Fatal(err) - } - if newPos != pos { - t.Fatalf("expected to get same position when calling SEEK_SET with value from SEEK_CUR, cur: %d, set: %d", pos, newPos) - } - } -} - -func TestMultiReadSeekerReadSpanningChunks(t *testing.T) { - str := "hello world" - s1 := strings.NewReader(str + " 1") - s2 := strings.NewReader(str + " 2") - s3 := strings.NewReader(str + " 3") - mr := MultiReadSeeker(s1, s2, s3) - - buf := make([]byte, s1.Len()+3) - _, err := mr.Read(buf) - if err != nil { - t.Fatal(err) - } - - // expected is the contents of s1 + 3 bytes from s2, ie, the `hel` at the end of this string - expected := "hello world 1hel" - if string(buf) != expected { - t.Fatalf("expected %s to be %s", string(buf), expected) - } -} - -func TestMultiReadSeekerNegativeSeek(t *testing.T) { - str := "hello world" - s1 := strings.NewReader(str + " 1") - s2 := strings.NewReader(str + " 2") - s3 := strings.NewReader(str + " 3") - mr := MultiReadSeeker(s1, s2, s3) - - s1Len := s1.Len() - s2Len := s2.Len() - s3Len := s3.Len() - - s, err := mr.Seek(int64(-1*s3.Len()), os.SEEK_END) - if err != nil { - t.Fatal(err) - } - if s != int64(s1Len+s2Len) { - t.Fatalf("expected %d to be %d", s, s1.Len()+s2.Len()) - } - - buf := make([]byte, s3Len) - if _, err := mr.Read(buf); err != nil && err != io.EOF { - t.Fatal(err) - } - expected := fmt.Sprintf("%s %d", str, 3) - if string(buf) != fmt.Sprintf("%s %d", str, 3) { - t.Fatalf("expected %q to be %q", string(buf), expected) - } -} diff --git a/pkg/ioutils/readers.go b/pkg/ioutils/readers.go index 5a61e6bd4d..63f3c07f46 100644 --- a/pkg/ioutils/readers.go +++ b/pkg/ioutils/readers.go @@ -4,6 +4,8 @@ import ( "crypto/sha256" "encoding/hex" "io" + + "golang.org/x/net/context" ) type readCloserWrapper struct { @@ -81,3 +83,72 @@ func (r *OnEOFReader) runFunc() { r.Fn = nil } } + +// cancelReadCloser wraps an io.ReadCloser with a context for cancelling read +// operations. +type cancelReadCloser struct { + cancel func() + pR *io.PipeReader // Stream to read from + pW *io.PipeWriter +} + +// NewCancelReadCloser creates a wrapper that closes the ReadCloser when the +// context is cancelled. The returned io.ReadCloser must be closed when it is +// no longer needed. +func NewCancelReadCloser(ctx context.Context, in io.ReadCloser) io.ReadCloser { + pR, pW := io.Pipe() + + // Create a context used to signal when the pipe is closed + doneCtx, cancel := context.WithCancel(context.Background()) + + p := &cancelReadCloser{ + cancel: cancel, + pR: pR, + pW: pW, + } + + go func() { + _, err := io.Copy(pW, in) + select { + case <-ctx.Done(): + // If the context was closed, p.closeWithError + // was already called. Calling it again would + // change the error that Read returns. + default: + p.closeWithError(err) + } + in.Close() + }() + go func() { + for { + select { + case <-ctx.Done(): + p.closeWithError(ctx.Err()) + case <-doneCtx.Done(): + return + } + } + }() + + return p +} + +// Read wraps the Read method of the pipe that provides data from the wrapped +// ReadCloser. +func (p *cancelReadCloser) Read(buf []byte) (n int, err error) { + return p.pR.Read(buf) +} + +// closeWithError closes the wrapper and its underlying reader. It will +// cause future calls to Read to return err. +func (p *cancelReadCloser) closeWithError(err error) { + p.pW.CloseWithError(err) + p.cancel() +} + +// Close closes the wrapper its underlying reader. It will cause +// future calls to Read to return io.EOF. +func (p *cancelReadCloser) Close() error { + p.closeWithError(io.EOF) + return nil +} diff --git a/pkg/ioutils/readers_test.go b/pkg/ioutils/readers_test.go index 3ccfdf9303..86e50d38f9 100644 --- a/pkg/ioutils/readers_test.go +++ b/pkg/ioutils/readers_test.go @@ -2,8 +2,13 @@ package ioutils import ( "fmt" + "io/ioutil" "strings" "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/net/context" ) // Implement io.Reader @@ -31,9 +36,7 @@ func TestReaderErrWrapperReadOnError(t *testing.T) { called = true }) _, err := wrapper.Read([]byte{}) - if err == nil || !strings.Contains(err.Error(), "error reader always fail") { - t.Fatalf("readErrWrapper should returned an error") - } + assert.EqualError(t, err, "error reader always fail") if !called { t.Fatalf("readErrWrapper should have call the anonymous function on failure") } @@ -74,3 +77,17 @@ func (p *perpetualReader) Read(buf []byte) (n int, err error) { } return len(buf), nil } + +func TestCancelReadCloser(t *testing.T) { + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + cancelReadCloser := NewCancelReadCloser(ctx, ioutil.NopCloser(&perpetualReader{})) + for { + var buf [128]byte + _, err := cancelReadCloser.Read(buf[:]) + if err == context.DeadlineExceeded { + break + } else if err != nil { + t.Fatalf("got unexpected error: %v", err) + } + } +} diff --git a/pkg/locker/README.md b/pkg/locker/README.md new file mode 100644 index 0000000000..ad15e89af1 --- /dev/null +++ b/pkg/locker/README.md @@ -0,0 +1,65 @@ +Locker +===== + +locker provides a mechanism for creating finer-grained locking to help +free up more global locks to handle other tasks. + +The implementation looks close to a sync.Mutex, however, the user must provide a +reference to use to refer to the underlying lock when locking and unlocking, +and unlock may generate an error. + +If a lock with a given name does not exist when `Lock` is called, one is +created. +Lock references are automatically cleaned up on `Unlock` if nothing else is +waiting for the lock. + + +## Usage + +```go +package important + +import ( + "sync" + "time" + + "github.com/containers/storage/pkg/locker" +) + +type important struct { + locks *locker.Locker + data map[string]interface{} + mu sync.Mutex +} + +func (i *important) Get(name string) interface{} { + i.locks.Lock(name) + defer i.locks.Unlock(name) + return data[name] +} + +func (i *important) Create(name string, data interface{}) { + i.locks.Lock(name) + defer i.locks.Unlock(name) + + i.createImportant(data) + + s.mu.Lock() + i.data[name] = data + s.mu.Unlock() +} + +func (i *important) createImportant(data interface{}) { + time.Sleep(10 * time.Second) +} +``` + +For functions dealing with a given name, always lock at the beginning of the +function (or before doing anything with the underlying state), this ensures any +other function that is dealing with the same name will block. + +When needing to modify the underlying data, use the global lock to ensure nothing +else is modifying it at the same time. +Since name lock is already in place, no reads will occur while the modification +is being performed. + diff --git a/pkg/locker/locker.go b/pkg/locker/locker.go new file mode 100644 index 0000000000..0b22ddfab8 --- /dev/null +++ b/pkg/locker/locker.go @@ -0,0 +1,112 @@ +/* +Package locker provides a mechanism for creating finer-grained locking to help +free up more global locks to handle other tasks. + +The implementation looks close to a sync.Mutex, however the user must provide a +reference to use to refer to the underlying lock when locking and unlocking, +and unlock may generate an error. + +If a lock with a given name does not exist when `Lock` is called, one is +created. +Lock references are automatically cleaned up on `Unlock` if nothing else is +waiting for the lock. +*/ +package locker + +import ( + "errors" + "sync" + "sync/atomic" +) + +// ErrNoSuchLock is returned when the requested lock does not exist +var ErrNoSuchLock = errors.New("no such lock") + +// Locker provides a locking mechanism based on the passed in reference name +type Locker struct { + mu sync.Mutex + locks map[string]*lockCtr +} + +// lockCtr is used by Locker to represent a lock with a given name. +type lockCtr struct { + mu sync.Mutex + // waiters is the number of waiters waiting to acquire the lock + // this is int32 instead of uint32 so we can add `-1` in `dec()` + waiters int32 +} + +// inc increments the number of waiters waiting for the lock +func (l *lockCtr) inc() { + atomic.AddInt32(&l.waiters, 1) +} + +// dec decrements the number of waiters waiting on the lock +func (l *lockCtr) dec() { + atomic.AddInt32(&l.waiters, -1) +} + +// count gets the current number of waiters +func (l *lockCtr) count() int32 { + return atomic.LoadInt32(&l.waiters) +} + +// Lock locks the mutex +func (l *lockCtr) Lock() { + l.mu.Lock() +} + +// Unlock unlocks the mutex +func (l *lockCtr) Unlock() { + l.mu.Unlock() +} + +// New creates a new Locker +func New() *Locker { + return &Locker{ + locks: make(map[string]*lockCtr), + } +} + +// Lock locks a mutex with the given name. If it doesn't exist, one is created +func (l *Locker) Lock(name string) { + l.mu.Lock() + if l.locks == nil { + l.locks = make(map[string]*lockCtr) + } + + nameLock, exists := l.locks[name] + if !exists { + nameLock = &lockCtr{} + l.locks[name] = nameLock + } + + // increment the nameLock waiters while inside the main mutex + // this makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently + nameLock.inc() + l.mu.Unlock() + + // Lock the nameLock outside the main mutex so we don't block other operations + // once locked then we can decrement the number of waiters for this lock + nameLock.Lock() + nameLock.dec() +} + +// Unlock unlocks the mutex with the given name +// If the given lock is not being waited on by any other callers, it is deleted +func (l *Locker) Unlock(name string) error { + l.mu.Lock() + nameLock, exists := l.locks[name] + if !exists { + l.mu.Unlock() + return ErrNoSuchLock + } + + if nameLock.count() == 0 { + delete(l.locks, name) + } + nameLock.Unlock() + + l.mu.Unlock() + return nil +} diff --git a/pkg/locker/locker_test.go b/pkg/locker/locker_test.go new file mode 100644 index 0000000000..5a297dd47b --- /dev/null +++ b/pkg/locker/locker_test.go @@ -0,0 +1,124 @@ +package locker + +import ( + "sync" + "testing" + "time" +) + +func TestLockCounter(t *testing.T) { + l := &lockCtr{} + l.inc() + + if l.waiters != 1 { + t.Fatal("counter inc failed") + } + + l.dec() + if l.waiters != 0 { + t.Fatal("counter dec failed") + } +} + +func TestLockerLock(t *testing.T) { + l := New() + l.Lock("test") + ctr := l.locks["test"] + + if ctr.count() != 0 { + t.Fatalf("expected waiters to be 0, got :%d", ctr.waiters) + } + + chDone := make(chan struct{}) + go func() { + l.Lock("test") + close(chDone) + }() + + chWaiting := make(chan struct{}) + go func() { + for range time.Tick(1 * time.Millisecond) { + if ctr.count() == 1 { + close(chWaiting) + break + } + } + }() + + select { + case <-chWaiting: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for lock waiters to be incremented") + } + + select { + case <-chDone: + t.Fatal("lock should not have returned while it was still held") + default: + } + + if err := l.Unlock("test"); err != nil { + t.Fatal(err) + } + + select { + case <-chDone: + case <-time.After(3 * time.Second): + t.Fatalf("lock should have completed") + } + + if ctr.count() != 0 { + t.Fatalf("expected waiters to be 0, got: %d", ctr.count()) + } +} + +func TestLockerUnlock(t *testing.T) { + l := New() + + l.Lock("test") + l.Unlock("test") + + chDone := make(chan struct{}) + go func() { + l.Lock("test") + close(chDone) + }() + + select { + case <-chDone: + case <-time.After(3 * time.Second): + t.Fatalf("lock should not be blocked") + } +} + +func TestLockerConcurrency(t *testing.T) { + l := New() + + var wg sync.WaitGroup + for i := 0; i <= 10000; i++ { + wg.Add(1) + go func() { + l.Lock("test") + // if there is a concurrency issue, will very likely panic here + l.Unlock("test") + wg.Done() + }() + } + + chDone := make(chan struct{}) + go func() { + wg.Wait() + close(chDone) + }() + + select { + case <-chDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for locks to complete") + } + + // Since everything has unlocked this should not exist anymore + if ctr, exists := l.locks["test"]; exists { + t.Fatalf("lock should not exist: %v", ctr) + } +} diff --git a/pkg/mount/flags_freebsd.go b/pkg/mount/flags_freebsd.go index f166cb2f77..5f76f331b6 100644 --- a/pkg/mount/flags_freebsd.go +++ b/pkg/mount/flags_freebsd.go @@ -45,4 +45,5 @@ const ( RELATIME = 0 REMOUNT = 0 STRICTATIME = 0 + mntDetach = 0 ) diff --git a/pkg/mount/flags_linux.go b/pkg/mount/flags_linux.go index dc696dce90..0425d0dd63 100644 --- a/pkg/mount/flags_linux.go +++ b/pkg/mount/flags_linux.go @@ -1,85 +1,87 @@ package mount import ( - "syscall" + "golang.org/x/sys/unix" ) const ( // RDONLY will mount the file system read-only. - RDONLY = syscall.MS_RDONLY + RDONLY = unix.MS_RDONLY // NOSUID will not allow set-user-identifier or set-group-identifier bits to // take effect. - NOSUID = syscall.MS_NOSUID + NOSUID = unix.MS_NOSUID // NODEV will not interpret character or block special devices on the file // system. - NODEV = syscall.MS_NODEV + NODEV = unix.MS_NODEV // NOEXEC will not allow execution of any binaries on the mounted file system. - NOEXEC = syscall.MS_NOEXEC + NOEXEC = unix.MS_NOEXEC // SYNCHRONOUS will allow I/O to the file system to be done synchronously. - SYNCHRONOUS = syscall.MS_SYNCHRONOUS + SYNCHRONOUS = unix.MS_SYNCHRONOUS // DIRSYNC will force all directory updates within the file system to be done // synchronously. This affects the following system calls: create, link, // unlink, symlink, mkdir, rmdir, mknod and rename. - DIRSYNC = syscall.MS_DIRSYNC + DIRSYNC = unix.MS_DIRSYNC // REMOUNT will attempt to remount an already-mounted file system. This is // commonly used to change the mount flags for a file system, especially to // make a readonly file system writeable. It does not change device or mount // point. - REMOUNT = syscall.MS_REMOUNT + REMOUNT = unix.MS_REMOUNT // MANDLOCK will force mandatory locks on a filesystem. - MANDLOCK = syscall.MS_MANDLOCK + MANDLOCK = unix.MS_MANDLOCK // NOATIME will not update the file access time when reading from a file. - NOATIME = syscall.MS_NOATIME + NOATIME = unix.MS_NOATIME // NODIRATIME will not update the directory access time. - NODIRATIME = syscall.MS_NODIRATIME + NODIRATIME = unix.MS_NODIRATIME // BIND remounts a subtree somewhere else. - BIND = syscall.MS_BIND + BIND = unix.MS_BIND // RBIND remounts a subtree and all possible submounts somewhere else. - RBIND = syscall.MS_BIND | syscall.MS_REC + RBIND = unix.MS_BIND | unix.MS_REC // UNBINDABLE creates a mount which cannot be cloned through a bind operation. - UNBINDABLE = syscall.MS_UNBINDABLE + UNBINDABLE = unix.MS_UNBINDABLE // RUNBINDABLE marks the entire mount tree as UNBINDABLE. - RUNBINDABLE = syscall.MS_UNBINDABLE | syscall.MS_REC + RUNBINDABLE = unix.MS_UNBINDABLE | unix.MS_REC // PRIVATE creates a mount which carries no propagation abilities. - PRIVATE = syscall.MS_PRIVATE + PRIVATE = unix.MS_PRIVATE // RPRIVATE marks the entire mount tree as PRIVATE. - RPRIVATE = syscall.MS_PRIVATE | syscall.MS_REC + RPRIVATE = unix.MS_PRIVATE | unix.MS_REC // SLAVE creates a mount which receives propagation from its master, but not // vice versa. - SLAVE = syscall.MS_SLAVE + SLAVE = unix.MS_SLAVE // RSLAVE marks the entire mount tree as SLAVE. - RSLAVE = syscall.MS_SLAVE | syscall.MS_REC + RSLAVE = unix.MS_SLAVE | unix.MS_REC // SHARED creates a mount which provides the ability to create mirrors of // that mount such that mounts and unmounts within any of the mirrors // propagate to the other mirrors. - SHARED = syscall.MS_SHARED + SHARED = unix.MS_SHARED // RSHARED marks the entire mount tree as SHARED. - RSHARED = syscall.MS_SHARED | syscall.MS_REC + RSHARED = unix.MS_SHARED | unix.MS_REC // RELATIME updates inode access times relative to modify or change time. - RELATIME = syscall.MS_RELATIME + RELATIME = unix.MS_RELATIME // STRICTATIME allows to explicitly request full atime updates. This makes // it possible for the kernel to default to relatime or noatime but still // allow userspace to override it. - STRICTATIME = syscall.MS_STRICTATIME + STRICTATIME = unix.MS_STRICTATIME + + mntDetach = unix.MNT_DETACH ) diff --git a/pkg/mount/flags_unsupported.go b/pkg/mount/flags_unsupported.go index 5564f7b3cd..9ed741e3ff 100644 --- a/pkg/mount/flags_unsupported.go +++ b/pkg/mount/flags_unsupported.go @@ -27,4 +27,5 @@ const ( STRICTATIME = 0 SYNCHRONOUS = 0 RDONLY = 0 + mntDetach = 0 ) diff --git a/pkg/mount/mount.go b/pkg/mount/mount.go index e39aeddfb1..d3caa16bda 100644 --- a/pkg/mount/mount.go +++ b/pkg/mount/mount.go @@ -1,6 +1,8 @@ package mount import ( + "sort" + "strings" "time" "github.com/containers/storage/pkg/fileutils" @@ -52,13 +54,11 @@ func Mount(device, target, mType, options string) error { // flags.go for supported option flags. func ForceMount(device, target, mType, options string) error { flag, data := parseOptions(options) - if err := mount(device, target, mType, uintptr(flag), data); err != nil { - return err - } - return nil + return mount(device, target, mType, uintptr(flag), data) } -// Unmount will unmount the target filesystem, so long as it is mounted. +// Unmount lazily unmounts a filesystem on supported platforms, otherwise +// does a normal unmount. func Unmount(target string) error { if mounted, err := Mounted(target); err != nil || !mounted { return err @@ -66,6 +66,32 @@ func Unmount(target string) error { return ForceUnmount(target) } +// RecursiveUnmount unmounts the target and all mounts underneath, starting with +// the deepsest mount first. +func RecursiveUnmount(target string) error { + mounts, err := GetMounts() + if err != nil { + return err + } + + // Make the deepest mount be first + sort.Sort(sort.Reverse(byMountpoint(mounts))) + + for i, m := range mounts { + if !strings.HasPrefix(m.Mountpoint, target) { + continue + } + if err := Unmount(m.Mountpoint); err != nil && i == len(mounts)-1 { + if mounted, err := Mounted(m.Mountpoint); err != nil || mounted { + return err + } + // Ignore errors for submounts and continue trying to unmount others + // The final unmount should fail if there ane any submounts remaining + } + } + return nil +} + // ForceUnmount will force an unmount of the target filesystem, regardless if // it is mounted or not. func ForceUnmount(target string) (err error) { @@ -76,5 +102,5 @@ func ForceUnmount(target string) (err error) { } time.Sleep(100 * time.Millisecond) } - return + return nil } diff --git a/pkg/mount/mount_unix_test.go b/pkg/mount/mount_unix_test.go index 90fa348b22..253aff3b8e 100644 --- a/pkg/mount/mount_unix_test.go +++ b/pkg/mount/mount_unix_test.go @@ -1,4 +1,4 @@ -// +build !windows +// +build !windows,!solaris package mount diff --git a/pkg/mount/mounter_freebsd.go b/pkg/mount/mounter_freebsd.go index bb870e6f59..814896cc9e 100644 --- a/pkg/mount/mounter_freebsd.go +++ b/pkg/mount/mounter_freebsd.go @@ -13,8 +13,9 @@ import "C" import ( "fmt" "strings" - "syscall" "unsafe" + + "golang.org/x/sys/unix" ) func allocateIOVecs(options []string) []C.struct_iovec { @@ -55,5 +56,5 @@ func mount(device, target, mType string, flag uintptr, data string) error { } func unmount(target string, flag int) error { - return syscall.Unmount(target, flag) + return unix.Unmount(target, flag) } diff --git a/pkg/mount/mounter_linux.go b/pkg/mount/mounter_linux.go index dd4280c777..39c36d472a 100644 --- a/pkg/mount/mounter_linux.go +++ b/pkg/mount/mounter_linux.go @@ -1,21 +1,57 @@ package mount import ( - "syscall" + "golang.org/x/sys/unix" ) -func mount(device, target, mType string, flag uintptr, data string) error { - if err := syscall.Mount(device, target, mType, flag, data); err != nil { - return err +const ( + // ptypes is the set propagation types. + ptypes = unix.MS_SHARED | unix.MS_PRIVATE | unix.MS_SLAVE | unix.MS_UNBINDABLE + + // pflags is the full set valid flags for a change propagation call. + pflags = ptypes | unix.MS_REC | unix.MS_SILENT + + // broflags is the combination of bind and read only + broflags = unix.MS_BIND | unix.MS_RDONLY +) + +// isremount returns true if either device name or flags identify a remount request, false otherwise. +func isremount(device string, flags uintptr) bool { + switch { + // We treat device "" and "none" as a remount request to provide compatibility with + // requests that don't explicitly set MS_REMOUNT such as those manipulating bind mounts. + case flags&unix.MS_REMOUNT != 0, device == "", device == "none": + return true + default: + return false + } +} + +func mount(device, target, mType string, flags uintptr, data string) error { + oflags := flags &^ ptypes + if !isremount(device, flags) || data != "" { + // Initial call applying all non-propagation flags for mount + // or remount with changed data + if err := unix.Mount(device, target, mType, oflags, data); err != nil { + return err + } } - // If we have a bind mount or remount, remount... - if flag&syscall.MS_BIND == syscall.MS_BIND && flag&syscall.MS_RDONLY == syscall.MS_RDONLY { - return syscall.Mount(device, target, mType, flag|syscall.MS_REMOUNT, data) + if flags&ptypes != 0 { + // Change the propagation type. + if err := unix.Mount("", target, "", flags&pflags, ""); err != nil { + return err + } } + + if oflags&broflags == broflags { + // Remount the bind to apply read only. + return unix.Mount("", target, "", oflags|unix.MS_REMOUNT, "") + } + return nil } func unmount(target string, flag int) error { - return syscall.Unmount(target, flag) + return unix.Unmount(target, flag) } diff --git a/pkg/mount/mounter_linux_test.go b/pkg/mount/mounter_linux_test.go new file mode 100644 index 0000000000..47c03b3631 --- /dev/null +++ b/pkg/mount/mounter_linux_test.go @@ -0,0 +1,228 @@ +// +build linux + +package mount + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func TestMount(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("not root tests would fail") + } + + source, err := ioutil.TempDir("", "mount-test-source-") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(source) + + // Ensure we have a known start point by mounting tmpfs with given options + if err := Mount("tmpfs", source, "tmpfs", "private"); err != nil { + t.Fatal(err) + } + defer ensureUnmount(t, source) + validateMount(t, source, "", "", "") + if t.Failed() { + t.FailNow() + } + + target, err := ioutil.TempDir("", "mount-test-target-") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(target) + + tests := []struct { + source string + ftype string + options string + expectedOpts string + expectedOptional string + expectedVFS string + }{ + // No options + {"tmpfs", "tmpfs", "", "", "", ""}, + // Default rw / ro test + {source, "", "bind", "", "", ""}, + {source, "", "bind,private", "", "", ""}, + {source, "", "bind,shared", "", "shared", ""}, + {source, "", "bind,slave", "", "master", ""}, + {source, "", "bind,unbindable", "", "unbindable", ""}, + // Read Write tests + {source, "", "bind,rw", "rw", "", ""}, + {source, "", "bind,rw,private", "rw", "", ""}, + {source, "", "bind,rw,shared", "rw", "shared", ""}, + {source, "", "bind,rw,slave", "rw", "master", ""}, + {source, "", "bind,rw,unbindable", "rw", "unbindable", ""}, + // Read Only tests + {source, "", "bind,ro", "ro", "", ""}, + {source, "", "bind,ro,private", "ro", "", ""}, + {source, "", "bind,ro,shared", "ro", "shared", ""}, + {source, "", "bind,ro,slave", "ro", "master", ""}, + {source, "", "bind,ro,unbindable", "ro", "unbindable", ""}, + // Remount tests to change per filesystem options + {"", "", "remount,size=128k", "rw", "", "rw,size=128k"}, + {"", "", "remount,ro,size=128k", "ro", "", "ro,size=128k"}, + } + + for _, tc := range tests { + ftype, options := tc.ftype, tc.options + if tc.ftype == "" { + ftype = "none" + } + if tc.options == "" { + options = "none" + } + + t.Run(fmt.Sprintf("%v-%v", ftype, options), func(t *testing.T) { + if strings.Contains(tc.options, "slave") { + // Slave requires a shared source + if err := MakeShared(source); err != nil { + t.Fatal(err) + } + defer func() { + if err := MakePrivate(source); err != nil { + t.Fatal(err) + } + }() + } + if strings.Contains(tc.options, "remount") { + // create a new mount to remount first + if err := Mount("tmpfs", target, "tmpfs", ""); err != nil { + t.Fatal(err) + } + } + if err := Mount(tc.source, target, tc.ftype, tc.options); err != nil { + t.Fatal(err) + } + defer ensureUnmount(t, target) + validateMount(t, target, tc.expectedOpts, tc.expectedOptional, tc.expectedVFS) + }) + } +} + +// ensureUnmount umounts mnt checking for errors +func ensureUnmount(t *testing.T, mnt string) { + if err := Unmount(mnt); err != nil { + t.Error(err) + } +} + +// validateMount checks that mnt has the given options +func validateMount(t *testing.T, mnt string, opts, optional, vfs string) { + info, err := GetMounts() + if err != nil { + t.Fatal(err) + } + + wantedOpts := make(map[string]struct{}) + if opts != "" { + for _, opt := range strings.Split(opts, ",") { + wantedOpts[opt] = struct{}{} + } + } + + wantedOptional := make(map[string]struct{}) + if optional != "" { + for _, opt := range strings.Split(optional, ",") { + wantedOptional[opt] = struct{}{} + } + } + + wantedVFS := make(map[string]struct{}) + if vfs != "" { + for _, opt := range strings.Split(vfs, ",") { + wantedVFS[opt] = struct{}{} + } + } + + mnts := make(map[int]*Info, len(info)) + for _, mi := range info { + mnts[mi.ID] = mi + } + + for _, mi := range info { + if mi.Mountpoint != mnt { + continue + } + + // Use parent info as the defaults + p := mnts[mi.Parent] + pOpts := make(map[string]struct{}) + if p.Opts != "" { + for _, opt := range strings.Split(p.Opts, ",") { + pOpts[clean(opt)] = struct{}{} + } + } + pOptional := make(map[string]struct{}) + if p.Optional != "" { + for _, field := range strings.Split(p.Optional, ",") { + pOptional[clean(field)] = struct{}{} + } + } + + // Validate Opts + if mi.Opts != "" { + for _, opt := range strings.Split(mi.Opts, ",") { + opt = clean(opt) + if !has(wantedOpts, opt) && !has(pOpts, opt) { + t.Errorf("unexpected mount option %q expected %q", opt, opts) + } + delete(wantedOpts, opt) + } + } + for opt := range wantedOpts { + t.Errorf("missing mount option %q found %q", opt, mi.Opts) + } + + // Validate Optional + if mi.Optional != "" { + for _, field := range strings.Split(mi.Optional, ",") { + field = clean(field) + if !has(wantedOptional, field) && !has(pOptional, field) { + t.Errorf("unexpected optional failed %q expected %q", field, optional) + } + delete(wantedOptional, field) + } + } + for field := range wantedOptional { + t.Errorf("missing optional field %q found %q", field, mi.Optional) + } + + // Validate VFS if set + if vfs != "" { + if mi.VfsOpts != "" { + for _, opt := range strings.Split(mi.VfsOpts, ",") { + opt = clean(opt) + if !has(wantedVFS, opt) { + t.Errorf("unexpected mount option %q expected %q", opt, vfs) + } + delete(wantedVFS, opt) + } + } + for opt := range wantedVFS { + t.Errorf("missing mount option %q found %q", opt, mi.VfsOpts) + } + } + + return + } + + t.Errorf("failed to find mount %q", mnt) +} + +// clean strips off any value param after the colon +func clean(v string) string { + return strings.SplitN(v, ":", 2)[0] +} + +// has returns true if key is a member of m +func has(m map[string]struct{}, key string) bool { + _, ok := m[key] + return ok +} diff --git a/pkg/mount/mounter_solaris.go b/pkg/mount/mounter_solaris.go index c684aa81fc..48b86771e7 100644 --- a/pkg/mount/mounter_solaris.go +++ b/pkg/mount/mounter_solaris.go @@ -3,8 +3,9 @@ package mount import ( - "golang.org/x/sys/unix" "unsafe" + + "golang.org/x/sys/unix" ) // #include diff --git a/pkg/mount/mountinfo.go b/pkg/mount/mountinfo.go index e3fc3535e9..ff4cc1d86b 100644 --- a/pkg/mount/mountinfo.go +++ b/pkg/mount/mountinfo.go @@ -38,3 +38,17 @@ type Info struct { // VfsOpts represents per super block options. VfsOpts string } + +type byMountpoint []*Info + +func (by byMountpoint) Len() int { + return len(by) +} + +func (by byMountpoint) Less(i, j int) bool { + return by[i].Mountpoint < by[j].Mountpoint +} + +func (by byMountpoint) Swap(i, j int) { + by[i], by[j] = by[j], by[i] +} diff --git a/pkg/mount/sharedsubtree_linux_test.go b/pkg/mount/sharedsubtree_linux_test.go index c1837942e3..f25ab19fed 100644 --- a/pkg/mount/sharedsubtree_linux_test.go +++ b/pkg/mount/sharedsubtree_linux_test.go @@ -5,8 +5,9 @@ package mount import ( "os" "path" - "syscall" "testing" + + "golang.org/x/sys/unix" ) // nothing is propagated in or out @@ -309,7 +310,7 @@ func TestSubtreeUnbindable(t *testing.T) { }() // then attempt to mount it to target. It should fail - if err := Mount(sourceDir, targetDir, "none", "bind,rw"); err != nil && err != syscall.EINVAL { + if err := Mount(sourceDir, targetDir, "none", "bind,rw"); err != nil && err != unix.EINVAL { t.Fatal(err) } else if err == nil { t.Fatalf("%q should not have been bindable", sourceDir) diff --git a/pkg/mount/sharedsubtree_solaris.go b/pkg/mount/sharedsubtree_solaris.go new file mode 100644 index 0000000000..09f6b03cbc --- /dev/null +++ b/pkg/mount/sharedsubtree_solaris.go @@ -0,0 +1,58 @@ +// +build solaris + +package mount + +// MakeShared ensures a mounted filesystem has the SHARED mount option enabled. +// See the supported options in flags.go for further reference. +func MakeShared(mountPoint string) error { + return ensureMountedAs(mountPoint, "shared") +} + +// MakeRShared ensures a mounted filesystem has the RSHARED mount option enabled. +// See the supported options in flags.go for further reference. +func MakeRShared(mountPoint string) error { + return ensureMountedAs(mountPoint, "rshared") +} + +// MakePrivate ensures a mounted filesystem has the PRIVATE mount option enabled. +// See the supported options in flags.go for further reference. +func MakePrivate(mountPoint string) error { + return ensureMountedAs(mountPoint, "private") +} + +// MakeRPrivate ensures a mounted filesystem has the RPRIVATE mount option +// enabled. See the supported options in flags.go for further reference. +func MakeRPrivate(mountPoint string) error { + return ensureMountedAs(mountPoint, "rprivate") +} + +// MakeSlave ensures a mounted filesystem has the SLAVE mount option enabled. +// See the supported options in flags.go for further reference. +func MakeSlave(mountPoint string) error { + return ensureMountedAs(mountPoint, "slave") +} + +// MakeRSlave ensures a mounted filesystem has the RSLAVE mount option enabled. +// See the supported options in flags.go for further reference. +func MakeRSlave(mountPoint string) error { + return ensureMountedAs(mountPoint, "rslave") +} + +// MakeUnbindable ensures a mounted filesystem has the UNBINDABLE mount option +// enabled. See the supported options in flags.go for further reference. +func MakeUnbindable(mountPoint string) error { + return ensureMountedAs(mountPoint, "unbindable") +} + +// MakeRUnbindable ensures a mounted filesystem has the RUNBINDABLE mount +// option enabled. See the supported options in flags.go for further reference. +func MakeRUnbindable(mountPoint string) error { + return ensureMountedAs(mountPoint, "runbindable") +} + +func ensureMountedAs(mountPoint, options string) error { + // TODO: Solaris does not support bind mounts. + // Evaluate lofs and also look at the relevant + // mount flags to be supported. + return nil +} diff --git a/pkg/parsers/kernel/kernel_unix.go b/pkg/parsers/kernel/kernel_unix.go index 54a89d28c6..76e1e499f3 100644 --- a/pkg/parsers/kernel/kernel_unix.go +++ b/pkg/parsers/kernel/kernel_unix.go @@ -1,4 +1,4 @@ -// +build linux freebsd solaris +// +build linux freebsd solaris openbsd // Package kernel provides helper function to get, parse and compare kernel // versions for different platforms. @@ -6,6 +6,8 @@ package kernel import ( "bytes" + + "github.com/sirupsen/logrus" ) // GetKernelVersion gets the current kernel version. @@ -28,3 +30,16 @@ func GetKernelVersion() (*VersionInfo, error) { return ParseRelease(string(release)) } + +// CheckKernelVersion checks if current kernel is newer than (or equal to) +// the given version. +func CheckKernelVersion(k, major, minor int) bool { + if v, err := GetKernelVersion(); err != nil { + logrus.Warnf("error getting kernel version: %s", err) + } else { + if CompareKernelVersion(*v, VersionInfo{Kernel: k, Major: major, Minor: minor}) < 0 { + return false + } + } + return true +} diff --git a/pkg/parsers/kernel/kernel_windows.go b/pkg/parsers/kernel/kernel_windows.go index 80fab8ff64..e598672776 100644 --- a/pkg/parsers/kernel/kernel_windows.go +++ b/pkg/parsers/kernel/kernel_windows.go @@ -4,8 +4,9 @@ package kernel import ( "fmt" - "syscall" "unsafe" + + "golang.org/x/sys/windows" ) // VersionInfo holds information about the kernel. @@ -24,28 +25,28 @@ func (k *VersionInfo) String() string { func GetKernelVersion() (*VersionInfo, error) { var ( - h syscall.Handle + h windows.Handle dwVersion uint32 err error ) KVI := &VersionInfo{"Unknown", 0, 0, 0} - if err = syscall.RegOpenKeyEx(syscall.HKEY_LOCAL_MACHINE, - syscall.StringToUTF16Ptr(`SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\`), + if err = windows.RegOpenKeyEx(windows.HKEY_LOCAL_MACHINE, + windows.StringToUTF16Ptr(`SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\`), 0, - syscall.KEY_READ, + windows.KEY_READ, &h); err != nil { return KVI, err } - defer syscall.RegCloseKey(h) + defer windows.RegCloseKey(h) var buf [1 << 10]uint16 var typ uint32 n := uint32(len(buf) * 2) // api expects array of bytes, not uint16 - if err = syscall.RegQueryValueEx(h, - syscall.StringToUTF16Ptr("BuildLabEx"), + if err = windows.RegQueryValueEx(h, + windows.StringToUTF16Ptr("BuildLabEx"), nil, &typ, (*byte)(unsafe.Pointer(&buf[0])), @@ -53,11 +54,11 @@ func GetKernelVersion() (*VersionInfo, error) { return KVI, err } - KVI.kvi = syscall.UTF16ToString(buf[:]) + KVI.kvi = windows.UTF16ToString(buf[:]) // Important - docker.exe MUST be manifested for this API to return // the correct information. - if dwVersion, err = syscall.GetVersion(); err != nil { + if dwVersion, err = windows.GetVersion(); err != nil { return KVI, err } diff --git a/pkg/parsers/kernel/uname_linux.go b/pkg/parsers/kernel/uname_linux.go index bb9b32641e..e913fad001 100644 --- a/pkg/parsers/kernel/uname_linux.go +++ b/pkg/parsers/kernel/uname_linux.go @@ -1,18 +1,16 @@ package kernel -import ( - "syscall" -) +import "golang.org/x/sys/unix" // Utsname represents the system name structure. -// It is passthrough for syscall.Utsname in order to make it portable with +// It is passthrough for unix.Utsname in order to make it portable with // other platforms where it is not available. -type Utsname syscall.Utsname +type Utsname unix.Utsname -func uname() (*syscall.Utsname, error) { - uts := &syscall.Utsname{} +func uname() (*unix.Utsname, error) { + uts := &unix.Utsname{} - if err := syscall.Uname(uts); err != nil { + if err := unix.Uname(uts); err != nil { return nil, err } return uts, nil diff --git a/pkg/parsers/operatingsystem/operatingsystem_linux.go b/pkg/parsers/operatingsystem/operatingsystem_linux.go new file mode 100644 index 0000000000..e04a3499af --- /dev/null +++ b/pkg/parsers/operatingsystem/operatingsystem_linux.go @@ -0,0 +1,77 @@ +// Package operatingsystem provides helper function to get the operating system +// name for different platforms. +package operatingsystem + +import ( + "bufio" + "bytes" + "fmt" + "io/ioutil" + "os" + "strings" + + "github.com/mattn/go-shellwords" +) + +var ( + // file to use to detect if the daemon is running in a container + proc1Cgroup = "/proc/1/cgroup" + + // file to check to determine Operating System + etcOsRelease = "/etc/os-release" + + // used by stateless systems like Clear Linux + altOsRelease = "/usr/lib/os-release" +) + +// GetOperatingSystem gets the name of the current operating system. +func GetOperatingSystem() (string, error) { + osReleaseFile, err := os.Open(etcOsRelease) + if err != nil { + if !os.IsNotExist(err) { + return "", fmt.Errorf("Error opening %s: %v", etcOsRelease, err) + } + osReleaseFile, err = os.Open(altOsRelease) + if err != nil { + return "", fmt.Errorf("Error opening %s: %v", altOsRelease, err) + } + } + defer osReleaseFile.Close() + + var prettyName string + scanner := bufio.NewScanner(osReleaseFile) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "PRETTY_NAME=") { + data := strings.SplitN(line, "=", 2) + prettyNames, err := shellwords.Parse(data[1]) + if err != nil { + return "", fmt.Errorf("PRETTY_NAME is invalid: %s", err.Error()) + } + if len(prettyNames) != 1 { + return "", fmt.Errorf("PRETTY_NAME needs to be enclosed by quotes if they have spaces: %s", data[1]) + } + prettyName = prettyNames[0] + } + } + if prettyName != "" { + return prettyName, nil + } + // If not set, defaults to PRETTY_NAME="Linux" + // c.f. http://www.freedesktop.org/software/systemd/man/os-release.html + return "Linux", nil +} + +// IsContainerized returns true if we are running inside a container. +func IsContainerized() (bool, error) { + b, err := ioutil.ReadFile(proc1Cgroup) + if err != nil { + return false, err + } + for _, line := range bytes.Split(b, []byte{'\n'}) { + if len(line) > 0 && !bytes.HasSuffix(line, []byte{'/'}) && !bytes.HasSuffix(line, []byte("init.scope")) { + return true, nil + } + } + return false, nil +} diff --git a/pkg/parsers/operatingsystem/operatingsystem_solaris.go b/pkg/parsers/operatingsystem/operatingsystem_solaris.go new file mode 100644 index 0000000000..d08ad14860 --- /dev/null +++ b/pkg/parsers/operatingsystem/operatingsystem_solaris.go @@ -0,0 +1,37 @@ +// +build solaris,cgo + +package operatingsystem + +/* +#include +*/ +import "C" + +import ( + "bytes" + "errors" + "io/ioutil" +) + +var etcOsRelease = "/etc/release" + +// GetOperatingSystem gets the name of the current operating system. +func GetOperatingSystem() (string, error) { + b, err := ioutil.ReadFile(etcOsRelease) + if err != nil { + return "", err + } + if i := bytes.Index(b, []byte("\n")); i >= 0 { + b = bytes.Trim(b[:i], " ") + return string(b), nil + } + return "", errors.New("release not found") +} + +// IsContainerized returns true if we are running inside a container. +func IsContainerized() (bool, error) { + if C.getzoneid() != 0 { + return true, nil + } + return false, nil +} diff --git a/pkg/parsers/operatingsystem/operatingsystem_unix.go b/pkg/parsers/operatingsystem/operatingsystem_unix.go new file mode 100644 index 0000000000..bc91c3c533 --- /dev/null +++ b/pkg/parsers/operatingsystem/operatingsystem_unix.go @@ -0,0 +1,25 @@ +// +build freebsd darwin + +package operatingsystem + +import ( + "errors" + "os/exec" +) + +// GetOperatingSystem gets the name of the current operating system. +func GetOperatingSystem() (string, error) { + cmd := exec.Command("uname", "-s") + osName, err := cmd.Output() + if err != nil { + return "", err + } + return string(osName), nil +} + +// IsContainerized returns true if we are running inside a container. +// No-op on FreeBSD and Darwin, always returns false. +func IsContainerized() (bool, error) { + // TODO: Implement jail detection for freeBSD + return false, errors.New("Cannot detect if we are in container") +} diff --git a/pkg/parsers/operatingsystem/operatingsystem_unix_test.go b/pkg/parsers/operatingsystem/operatingsystem_unix_test.go new file mode 100644 index 0000000000..e7120c65c4 --- /dev/null +++ b/pkg/parsers/operatingsystem/operatingsystem_unix_test.go @@ -0,0 +1,247 @@ +// +build linux freebsd + +package operatingsystem + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" +) + +func TestGetOperatingSystem(t *testing.T) { + var backup = etcOsRelease + + invalids := []struct { + content string + errorExpected string + }{ + { + `PRETTY_NAME=Source Mage GNU/Linux +PRETTY_NAME=Ubuntu 14.04.LTS`, + "PRETTY_NAME needs to be enclosed by quotes if they have spaces: Source Mage GNU/Linux", + }, + { + `PRETTY_NAME="Ubuntu Linux +PRETTY_NAME=Ubuntu 14.04.LTS`, + "PRETTY_NAME is invalid: invalid command line string", + }, + { + `PRETTY_NAME=Ubuntu' +PRETTY_NAME=Ubuntu 14.04.LTS`, + "PRETTY_NAME is invalid: invalid command line string", + }, + { + `PRETTY_NAME' +PRETTY_NAME=Ubuntu 14.04.LTS`, + "PRETTY_NAME needs to be enclosed by quotes if they have spaces: Ubuntu 14.04.LTS", + }, + } + + valids := []struct { + content string + expected string + }{ + { + `NAME="Ubuntu" +PRETTY_NAME_AGAIN="Ubuntu 14.04.LTS" +VERSION="14.04, Trusty Tahr" +ID=ubuntu +ID_LIKE=debian +VERSION_ID="14.04" +HOME_URL="http://www.ubuntu.com/" +SUPPORT_URL="http://help.ubuntu.com/" +BUG_REPORT_URL="http://bugs.launchpad.net/ubuntu/"`, + "Linux", + }, + { + `NAME="Ubuntu" +VERSION="14.04, Trusty Tahr" +ID=ubuntu +ID_LIKE=debian +VERSION_ID="14.04" +HOME_URL="http://www.ubuntu.com/" +SUPPORT_URL="http://help.ubuntu.com/" +BUG_REPORT_URL="http://bugs.launchpad.net/ubuntu/"`, + "Linux", + }, + { + `NAME=Gentoo +ID=gentoo +PRETTY_NAME="Gentoo/Linux" +ANSI_COLOR="1;32" +HOME_URL="http://www.gentoo.org/" +SUPPORT_URL="http://www.gentoo.org/main/en/support.xml" +BUG_REPORT_URL="https://bugs.gentoo.org/" +`, + "Gentoo/Linux", + }, + { + `NAME="Ubuntu" +VERSION="14.04, Trusty Tahr" +ID=ubuntu +ID_LIKE=debian +PRETTY_NAME="Ubuntu 14.04 LTS" +VERSION_ID="14.04" +HOME_URL="http://www.ubuntu.com/" +SUPPORT_URL="http://help.ubuntu.com/" +BUG_REPORT_URL="http://bugs.launchpad.net/ubuntu/"`, + "Ubuntu 14.04 LTS", + }, + { + `NAME="Ubuntu" +VERSION="14.04, Trusty Tahr" +ID=ubuntu +ID_LIKE=debian +PRETTY_NAME='Ubuntu 14.04 LTS'`, + "Ubuntu 14.04 LTS", + }, + { + `PRETTY_NAME=Source +NAME="Source Mage"`, + "Source", + }, + { + `PRETTY_NAME=Source +PRETTY_NAME="Source Mage"`, + "Source Mage", + }, + } + + dir := os.TempDir() + etcOsRelease = filepath.Join(dir, "etcOsRelease") + + defer func() { + os.Remove(etcOsRelease) + etcOsRelease = backup + }() + + for _, elt := range invalids { + if err := ioutil.WriteFile(etcOsRelease, []byte(elt.content), 0600); err != nil { + t.Fatalf("failed to write to %s: %v", etcOsRelease, err) + } + s, err := GetOperatingSystem() + if err == nil || err.Error() != elt.errorExpected { + t.Fatalf("Expected an error %q, got %q (err: %v)", elt.errorExpected, s, err) + } + } + + for _, elt := range valids { + if err := ioutil.WriteFile(etcOsRelease, []byte(elt.content), 0600); err != nil { + t.Fatalf("failed to write to %s: %v", etcOsRelease, err) + } + s, err := GetOperatingSystem() + if err != nil || s != elt.expected { + t.Fatalf("Expected %q, got %q (err: %v)", elt.expected, s, err) + } + } +} + +func TestIsContainerized(t *testing.T) { + var ( + backup = proc1Cgroup + nonContainerizedProc1Cgroupsystemd226 = []byte(`9:memory:/init.scope +8:net_cls,net_prio:/ +7:cpuset:/ +6:freezer:/ +5:devices:/init.scope +4:blkio:/init.scope +3:cpu,cpuacct:/init.scope +2:perf_event:/ +1:name=systemd:/init.scope +`) + nonContainerizedProc1Cgroup = []byte(`14:name=systemd:/ +13:hugetlb:/ +12:net_prio:/ +11:perf_event:/ +10:bfqio:/ +9:blkio:/ +8:net_cls:/ +7:freezer:/ +6:devices:/ +5:memory:/ +4:cpuacct:/ +3:cpu:/ +2:cpuset:/ +`) + containerizedProc1Cgroup = []byte(`9:perf_event:/docker/3cef1b53c50b0fa357d994f8a1a8cd783c76bbf4f5dd08b226e38a8bd331338d +8:blkio:/docker/3cef1b53c50b0fa357d994f8a1a8cd783c76bbf4f5dd08b226e38a8bd331338d +7:net_cls:/ +6:freezer:/docker/3cef1b53c50b0fa357d994f8a1a8cd783c76bbf4f5dd08b226e38a8bd331338d +5:devices:/docker/3cef1b53c50b0fa357d994f8a1a8cd783c76bbf4f5dd08b226e38a8bd331338d +4:memory:/docker/3cef1b53c50b0fa357d994f8a1a8cd783c76bbf4f5dd08b226e38a8bd331338d +3:cpuacct:/docker/3cef1b53c50b0fa357d994f8a1a8cd783c76bbf4f5dd08b226e38a8bd331338d +2:cpu:/docker/3cef1b53c50b0fa357d994f8a1a8cd783c76bbf4f5dd08b226e38a8bd331338d +1:cpuset:/`) + ) + + dir := os.TempDir() + proc1Cgroup = filepath.Join(dir, "proc1Cgroup") + + defer func() { + os.Remove(proc1Cgroup) + proc1Cgroup = backup + }() + + if err := ioutil.WriteFile(proc1Cgroup, nonContainerizedProc1Cgroup, 0600); err != nil { + t.Fatalf("failed to write to %s: %v", proc1Cgroup, err) + } + inContainer, err := IsContainerized() + if err != nil { + t.Fatal(err) + } + if inContainer { + t.Fatal("Wrongly assuming containerized") + } + + if err := ioutil.WriteFile(proc1Cgroup, nonContainerizedProc1Cgroupsystemd226, 0600); err != nil { + t.Fatalf("failed to write to %s: %v", proc1Cgroup, err) + } + inContainer, err = IsContainerized() + if err != nil { + t.Fatal(err) + } + if inContainer { + t.Fatal("Wrongly assuming containerized for systemd /init.scope cgroup layout") + } + + if err := ioutil.WriteFile(proc1Cgroup, containerizedProc1Cgroup, 0600); err != nil { + t.Fatalf("failed to write to %s: %v", proc1Cgroup, err) + } + inContainer, err = IsContainerized() + if err != nil { + t.Fatal(err) + } + if !inContainer { + t.Fatal("Wrongly assuming non-containerized") + } +} + +func TestOsReleaseFallback(t *testing.T) { + var backup = etcOsRelease + var altBackup = altOsRelease + dir := os.TempDir() + etcOsRelease = filepath.Join(dir, "etcOsRelease") + altOsRelease = filepath.Join(dir, "altOsRelease") + + defer func() { + os.Remove(dir) + etcOsRelease = backup + altOsRelease = altBackup + }() + content := `NAME=Gentoo +ID=gentoo +PRETTY_NAME="Gentoo/Linux" +ANSI_COLOR="1;32" +HOME_URL="http://www.gentoo.org/" +SUPPORT_URL="http://www.gentoo.org/main/en/support.xml" +BUG_REPORT_URL="https://bugs.gentoo.org/" +` + if err := ioutil.WriteFile(altOsRelease, []byte(content), 0600); err != nil { + t.Fatalf("failed to write to %s: %v", etcOsRelease, err) + } + s, err := GetOperatingSystem() + if err != nil || s != "Gentoo/Linux" { + t.Fatalf("Expected %q, got %q (err: %v)", "Gentoo/Linux", s, err) + } +} diff --git a/pkg/parsers/operatingsystem/operatingsystem_windows.go b/pkg/parsers/operatingsystem/operatingsystem_windows.go new file mode 100644 index 0000000000..5d8b42cc36 --- /dev/null +++ b/pkg/parsers/operatingsystem/operatingsystem_windows.go @@ -0,0 +1,50 @@ +package operatingsystem + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +// See https://code.google.com/p/go/source/browse/src/pkg/mime/type_windows.go?r=d14520ac25bf6940785aabb71f5be453a286f58c +// for a similar sample + +// GetOperatingSystem gets the name of the current operating system. +func GetOperatingSystem() (string, error) { + + var h windows.Handle + + // Default return value + ret := "Unknown Operating System" + + if err := windows.RegOpenKeyEx(windows.HKEY_LOCAL_MACHINE, + windows.StringToUTF16Ptr(`SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\`), + 0, + windows.KEY_READ, + &h); err != nil { + return ret, err + } + defer windows.RegCloseKey(h) + + var buf [1 << 10]uint16 + var typ uint32 + n := uint32(len(buf) * 2) // api expects array of bytes, not uint16 + + if err := windows.RegQueryValueEx(h, + windows.StringToUTF16Ptr("ProductName"), + nil, + &typ, + (*byte)(unsafe.Pointer(&buf[0])), + &n); err != nil { + return ret, err + } + ret = windows.UTF16ToString(buf[:]) + + return ret, nil +} + +// IsContainerized returns true if we are running inside a container. +// No-op on Windows, always returns false. +func IsContainerized() (bool, error) { + return false, nil +} diff --git a/pkg/plugins/client.go b/pkg/plugins/client.go index b4c31c0569..2c0a91d13c 100644 --- a/pkg/plugins/client.go +++ b/pkg/plugins/client.go @@ -19,8 +19,7 @@ const ( defaultTimeOut = 30 ) -// NewClient creates a new plugin client (http). -func NewClient(addr string, tlsConfig *tlsconfig.Options) (*Client, error) { +func newTransport(addr string, tlsConfig *tlsconfig.Options) (transport.Transport, error) { tr := &http.Transport{} if tlsConfig != nil { @@ -45,15 +44,33 @@ func NewClient(addr string, tlsConfig *tlsconfig.Options) (*Client, error) { } scheme := httpScheme(u) - clientTransport := transport.NewHTTPTransport(tr, scheme, socket) - return NewClientWithTransport(clientTransport), nil + return transport.NewHTTPTransport(tr, scheme, socket), nil +} + +// NewClient creates a new plugin client (http). +func NewClient(addr string, tlsConfig *tlsconfig.Options) (*Client, error) { + clientTransport, err := newTransport(addr, tlsConfig) + if err != nil { + return nil, err + } + return newClientWithTransport(clientTransport, 0), nil +} + +// NewClientWithTimeout creates a new plugin client (http). +func NewClientWithTimeout(addr string, tlsConfig *tlsconfig.Options, timeout time.Duration) (*Client, error) { + clientTransport, err := newTransport(addr, tlsConfig) + if err != nil { + return nil, err + } + return newClientWithTransport(clientTransport, timeout), nil } -// NewClientWithTransport creates a new plugin client with a given transport. -func NewClientWithTransport(tr transport.Transport) *Client { +// newClientWithTransport creates a new plugin client with a given transport. +func newClientWithTransport(tr transport.Transport, timeout time.Duration) *Client { return &Client{ http: &http.Client{ Transport: tr, + Timeout: timeout, }, requestFactory: tr, } @@ -112,15 +129,15 @@ func (c *Client) SendFile(serviceMethod string, data io.Reader, ret interface{}) } func (c *Client) callWithRetry(serviceMethod string, data io.Reader, retry bool) (io.ReadCloser, error) { - req, err := c.requestFactory.NewRequest(serviceMethod, data) - if err != nil { - return nil, err - } - var retries int start := time.Now() for { + req, err := c.requestFactory.NewRequest(serviceMethod, data) + if err != nil { + return nil, err + } + resp, err := c.http.Do(req) if err != nil { if !retry { diff --git a/pkg/plugins/client_test.go b/pkg/plugins/client_test.go index 56b3d71007..884a1ec0a0 100644 --- a/pkg/plugins/client_test.go +++ b/pkg/plugins/client_test.go @@ -1,16 +1,19 @@ package plugins import ( + "bytes" + "encoding/json" "io" "net/http" "net/http/httptest" "net/url" - "reflect" + "strings" "testing" "time" "github.com/containers/storage/pkg/plugins/transport" "github.com/docker/go-connections/tlsconfig" + "github.com/stretchr/testify/assert" ) var ( @@ -38,6 +41,26 @@ func TestFailedConnection(t *testing.T) { } } +func TestFailOnce(t *testing.T) { + addr := setupRemotePluginServer() + defer teardownRemotePluginServer() + + failed := false + mux.HandleFunc("/Test.FailOnce", func(w http.ResponseWriter, r *http.Request) { + if !failed { + failed = true + panic("Plugin not ready") + } + }) + + c, _ := NewClient(addr, &tlsconfig.Options{InsecureSkipVerify: true}) + b := strings.NewReader("body") + _, err := c.callWithRetry("Test.FailOnce", b, true) + if err != nil { + t.Fatal(err) + } +} + func TestEchoInputOutput(t *testing.T) { addr := setupRemotePluginServer() defer teardownRemotePluginServer() @@ -62,9 +85,7 @@ func TestEchoInputOutput(t *testing.T) { t.Fatal(err) } - if !reflect.DeepEqual(output, m) { - t.Fatalf("Expected %v, was %v\n", m, output) - } + assert.Equal(t, m, output) err = c.Call("Test.Echo", nil, nil) if err != nil { t.Fatal(err) @@ -132,3 +153,82 @@ func TestClientScheme(t *testing.T) { } } } + +func TestNewClientWithTimeout(t *testing.T) { + addr := setupRemotePluginServer() + defer teardownRemotePluginServer() + + m := Manifest{[]string{"VolumeDriver", "NetworkDriver"}} + + mux.HandleFunc("/Test.Echo", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Duration(600) * time.Millisecond) + io.Copy(w, r.Body) + }) + + // setting timeout of 500ms + timeout := time.Duration(500) * time.Millisecond + c, _ := NewClientWithTimeout(addr, &tlsconfig.Options{InsecureSkipVerify: true}, timeout) + var output Manifest + err := c.Call("Test.Echo", m, &output) + if err == nil { + t.Fatal("Expected timeout error") + } +} + +func TestClientStream(t *testing.T) { + addr := setupRemotePluginServer() + defer teardownRemotePluginServer() + + m := Manifest{[]string{"VolumeDriver", "NetworkDriver"}} + var output Manifest + + mux.HandleFunc("/Test.Echo", func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Fatalf("Expected POST, got %s", r.Method) + } + + header := w.Header() + header.Set("Content-Type", transport.VersionMimetype) + + io.Copy(w, r.Body) + }) + + c, _ := NewClient(addr, &tlsconfig.Options{InsecureSkipVerify: true}) + body, err := c.Stream("Test.Echo", m) + if err != nil { + t.Fatal(err) + } + defer body.Close() + if err := json.NewDecoder(body).Decode(&output); err != nil { + t.Fatalf("Test.Echo: error reading plugin resp: %v", err) + } + assert.Equal(t, m, output) +} + +func TestClientSendFile(t *testing.T) { + addr := setupRemotePluginServer() + defer teardownRemotePluginServer() + + m := Manifest{[]string{"VolumeDriver", "NetworkDriver"}} + var output Manifest + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(m); err != nil { + t.Fatal(err) + } + mux.HandleFunc("/Test.Echo", func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Fatalf("Expected POST, got %s\n", r.Method) + } + + header := w.Header() + header.Set("Content-Type", transport.VersionMimetype) + + io.Copy(w, r.Body) + }) + + c, _ := NewClient(addr, &tlsconfig.Options{InsecureSkipVerify: true}) + if err := c.SendFile("Test.Echo", &buf, &output); err != nil { + t.Fatal(err) + } + assert.Equal(t, m, output) +} diff --git a/pkg/plugins/discovery.go b/pkg/plugins/discovery.go index 4cb5a1a3a7..8d8d2eadf2 100644 --- a/pkg/plugins/discovery.go +++ b/pkg/plugins/discovery.go @@ -15,8 +15,7 @@ import ( var ( // ErrNotFound plugin not found ErrNotFound = errors.New("plugin not found") - socketsPath = "/run/containers/storage/plugins" - specsPaths = []string{"/etc/containers/storage/plugins", "/usr/lib/containers/storage/plugins"} + socketsPath = "/run/container/storage/plugins" ) // localRegistry defines a registry that is local (using unix socket). @@ -116,7 +115,7 @@ func readPluginJSONInfo(name, path string) (*Plugin, error) { return nil, err } p.name = name - if len(p.TLSConfig.CAFile) == 0 { + if p.TLSConfig != nil && len(p.TLSConfig.CAFile) == 0 { p.TLSConfig.InsecureSkipVerify = true } p.activateWait = sync.NewCond(&sync.Mutex{}) diff --git a/pkg/plugins/discovery_test.go b/pkg/plugins/discovery_test.go index b8fa38c1d3..8da0bff4b5 100644 --- a/pkg/plugins/discovery_test.go +++ b/pkg/plugins/discovery_test.go @@ -8,7 +8,7 @@ import ( ) func Setup(t *testing.T) (string, func()) { - tmpdir, err := ioutil.TempDir("", "docker-test") + tmpdir, err := ioutil.TempDir("", "container-storage-test") if err != nil { t.Fatal(err) } @@ -59,14 +59,14 @@ func TestFileSpecPlugin(t *testing.T) { } if p.name != c.name { - t.Fatalf("Expected plugin `%s`, got %s\n", c.name, p.Name) + t.Fatalf("Expected plugin `%s`, got %s\n", c.name, p.name) } if p.Addr != c.addr { t.Fatalf("Expected plugin addr `%s`, got %s\n", c.addr, p.Addr) } - if p.TLSConfig.InsecureSkipVerify != true { + if !p.TLSConfig.InsecureSkipVerify { t.Fatalf("Expected TLS verification to be skipped") } } @@ -79,11 +79,11 @@ func TestFileJSONSpecPlugin(t *testing.T) { p := filepath.Join(tmpdir, "example.json") spec := `{ "Name": "plugin-example", - "Addr": "https://example.com/containers-storage/plugin", + "Addr": "https://example.com/containers/storage/plugin", "TLSConfig": { - "CAFile": "/usr/shared/containers-storage/certs/example-ca.pem", - "CertFile": "/usr/shared/containers-storage/certs/example-cert.pem", - "KeyFile": "/usr/shared/containers-storage/certs/example-key.pem" + "CAFile": "/usr/shared/containers/storage/certs/example-ca.pem", + "CertFile": "/usr/shared/containers/storage/certs/example-cert.pem", + "KeyFile": "/usr/shared/containers/storage/certs/example-key.pem" } }` @@ -97,23 +97,56 @@ func TestFileJSONSpecPlugin(t *testing.T) { t.Fatal(err) } - if plugin.name != "example" { - t.Fatalf("Expected plugin `plugin-example`, got %s\n", plugin.Name) + if expected, actual := "example", plugin.name; expected != actual { + t.Fatalf("Expected plugin %q, got %s\n", expected, actual) } - if plugin.Addr != "https://example.com/containers-storage/plugin" { - t.Fatalf("Expected plugin addr `https://example.com/containers-storage/plugin`, got %s\n", plugin.Addr) + if plugin.Addr != "https://example.com/containers/storage/plugin" { + t.Fatalf("Expected plugin addr `https://example.com/containers/storage/plugin`, got %s\n", plugin.Addr) } - if plugin.TLSConfig.CAFile != "/usr/shared/containers-storage/certs/example-ca.pem" { - t.Fatalf("Expected plugin CA `/usr/shared/containers-storage/certs/example-ca.pem`, got %s\n", plugin.TLSConfig.CAFile) + if plugin.TLSConfig.CAFile != "/usr/shared/containers/storage/certs/example-ca.pem" { + t.Fatalf("Expected plugin CA `/usr/shared/containers/storage/certs/example-ca.pem`, got %s\n", plugin.TLSConfig.CAFile) } - if plugin.TLSConfig.CertFile != "/usr/shared/containers-storage/certs/example-cert.pem" { - t.Fatalf("Expected plugin Certificate `/usr/shared/containers-storage/certs/example-cert.pem`, got %s\n", plugin.TLSConfig.CertFile) + if plugin.TLSConfig.CertFile != "/usr/shared/containers/storage/certs/example-cert.pem" { + t.Fatalf("Expected plugin Certificate `/usr/shared/containers/storage/certs/example-cert.pem`, got %s\n", plugin.TLSConfig.CertFile) } - if plugin.TLSConfig.KeyFile != "/usr/shared/containers-storage/certs/example-key.pem" { - t.Fatalf("Expected plugin Key `/usr/shared/containers-storage/certs/example-key.pem`, got %s\n", plugin.TLSConfig.KeyFile) + if plugin.TLSConfig.KeyFile != "/usr/shared/containers/storage/certs/example-key.pem" { + t.Fatalf("Expected plugin Key `/usr/shared/containers/storage/certs/example-key.pem`, got %s\n", plugin.TLSConfig.KeyFile) + } +} + +func TestFileJSONSpecPluginWithoutTLSConfig(t *testing.T) { + tmpdir, unregister := Setup(t) + defer unregister() + + p := filepath.Join(tmpdir, "example.json") + spec := `{ + "Name": "plugin-example", + "Addr": "https://example.com/containers/storage/plugin" +}` + + if err := ioutil.WriteFile(p, []byte(spec), 0644); err != nil { + t.Fatal(err) + } + + r := newLocalRegistry() + plugin, err := r.Plugin("example") + if err != nil { + t.Fatal(err) + } + + if expected, actual := "example", plugin.name; expected != actual { + t.Fatalf("Expected plugin %q, got %s\n", expected, actual) + } + + if plugin.Addr != "https://example.com/containers/storage/plugin" { + t.Fatalf("Expected plugin addr `https://example.com/containers/storage/plugin`, got %s\n", plugin.Addr) + } + + if plugin.TLSConfig != nil { + t.Fatalf("Expected plugin TLSConfig nil, got %v\n", plugin.TLSConfig) } } diff --git a/pkg/plugins/discovery_unix.go b/pkg/plugins/discovery_unix.go new file mode 100644 index 0000000000..dd01df3919 --- /dev/null +++ b/pkg/plugins/discovery_unix.go @@ -0,0 +1,5 @@ +// +build !windows + +package plugins + +var specsPaths = []string{"/etc/containers/storage/plugins", "/usr/lib/containers/storage/plugins"} diff --git a/pkg/plugins/discovery_unix_test.go b/pkg/plugins/discovery_unix_test.go index 53e02d2858..f6376dc461 100644 --- a/pkg/plugins/discovery_unix_test.go +++ b/pkg/plugins/discovery_unix_test.go @@ -4,6 +4,7 @@ package plugins import ( "fmt" + "io/ioutil" "net" "os" "path/filepath" @@ -46,16 +47,54 @@ func TestLocalSocket(t *testing.T) { } if p.name != "echo" { - t.Fatalf("Expected plugin `echo`, got %s\n", p.Name) + t.Fatalf("Expected plugin `echo`, got %s\n", p.name) } addr := fmt.Sprintf("unix://%s", c) if p.Addr != addr { t.Fatalf("Expected plugin addr `%s`, got %s\n", addr, p.Addr) } - if p.TLSConfig.InsecureSkipVerify != true { + if !p.TLSConfig.InsecureSkipVerify { t.Fatalf("Expected TLS verification to be skipped") } l.Close() } } + +func TestScan(t *testing.T) { + tmpdir, unregister := Setup(t) + defer unregister() + + pluginNames, err := Scan() + if err != nil { + t.Fatal(err) + } + if pluginNames != nil { + t.Fatal("Plugin names should be empty.") + } + + path := filepath.Join(tmpdir, "echo.spec") + addr := "unix://var/lib/containers/storage/plugins/echo.sock" + name := "echo" + + err = os.MkdirAll(filepath.Dir(path), 0755) + if err != nil { + t.Fatal(err) + } + + err = ioutil.WriteFile(path, []byte(addr), 0644) + if err != nil { + t.Fatal(err) + } + + r := newLocalRegistry() + p, err := r.Plugin(name) + + pluginNamesNotEmpty, err := Scan() + if err != nil { + t.Fatal(err) + } + if p.Name() != pluginNamesNotEmpty[0] { + t.Fatalf("Unable to scan plugin with name %s", p.name) + } +} diff --git a/pkg/plugins/discovery_windows.go b/pkg/plugins/discovery_windows.go new file mode 100644 index 0000000000..b4f55302f7 --- /dev/null +++ b/pkg/plugins/discovery_windows.go @@ -0,0 +1,8 @@ +package plugins + +import ( + "os" + "path/filepath" +) + +var specsPaths = []string{filepath.Join(os.Getenv("programdata"), "containers", "storage", "plugins")} diff --git a/pkg/plugins/plugin_test.go b/pkg/plugins/plugin_test.go new file mode 100644 index 0000000000..d32365ab02 --- /dev/null +++ b/pkg/plugins/plugin_test.go @@ -0,0 +1,156 @@ +package plugins + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "io/ioutil" + "net/http" + "path/filepath" + "runtime" + "sync" + "testing" + "time" + + "github.com/containers/storage/pkg/plugins/transport" + "github.com/docker/go-connections/tlsconfig" + "github.com/stretchr/testify/assert" +) + +const ( + fruitPlugin = "fruit" + fruitImplements = "apple" +) + +// regression test for deadlock in handlers +func TestPluginAddHandler(t *testing.T) { + // make a plugin which is pre-activated + p := &Plugin{activateWait: sync.NewCond(&sync.Mutex{})} + p.Manifest = &Manifest{Implements: []string{"bananas"}} + storage.plugins["qwerty"] = p + + testActive(t, p) + Handle("bananas", func(_ string, _ *Client) {}) + testActive(t, p) +} + +func TestPluginWaitBadPlugin(t *testing.T) { + p := &Plugin{activateWait: sync.NewCond(&sync.Mutex{})} + p.activateErr = errors.New("some junk happened") + testActive(t, p) +} + +func testActive(t *testing.T, p *Plugin) { + done := make(chan struct{}) + go func() { + p.waitActive() + close(done) + }() + + select { + case <-time.After(100 * time.Millisecond): + _, f, l, _ := runtime.Caller(1) + t.Fatalf("%s:%d: deadlock in waitActive", filepath.Base(f), l) + case <-done: + } + +} + +func TestGet(t *testing.T) { + p := &Plugin{name: fruitPlugin, activateWait: sync.NewCond(&sync.Mutex{})} + p.Manifest = &Manifest{Implements: []string{fruitImplements}} + storage.plugins[fruitPlugin] = p + + plugin, err := Get(fruitPlugin, fruitImplements) + if err != nil { + t.Fatal(err) + } + if p.Name() != plugin.Name() { + t.Fatalf("No matching plugin with name %s found", plugin.Name()) + } + if plugin.Client() != nil { + t.Fatal("expected nil Client but found one") + } + if !plugin.IsV1() { + t.Fatal("Expected true for V1 plugin") + } + + // check negative case where plugin fruit doesn't implement banana + _, err = Get("fruit", "banana") + assert.Equal(t, err, ErrNotImplements) + + // check negative case where plugin vegetable doesn't exist + _, err = Get("vegetable", "potato") + assert.Equal(t, err, ErrNotFound) + +} + +func TestPluginWithNoManifest(t *testing.T) { + addr := setupRemotePluginServer() + defer teardownRemotePluginServer() + + m := Manifest{[]string{fruitImplements}} + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(m); err != nil { + t.Fatal(err) + } + + mux.HandleFunc("/Plugin.Activate", func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Fatalf("Expected POST, got %s\n", r.Method) + } + + header := w.Header() + header.Set("Content-Type", transport.VersionMimetype) + + io.Copy(w, &buf) + }) + + p := &Plugin{ + name: fruitPlugin, + activateWait: sync.NewCond(&sync.Mutex{}), + Addr: addr, + TLSConfig: &tlsconfig.Options{InsecureSkipVerify: true}, + } + storage.plugins[fruitPlugin] = p + + plugin, err := Get(fruitPlugin, fruitImplements) + if err != nil { + t.Fatal(err) + } + if p.Name() != plugin.Name() { + t.Fatalf("No matching plugin with name %s found", plugin.Name()) + } +} + +func TestGetAll(t *testing.T) { + tmpdir, unregister := Setup(t) + defer unregister() + + p := filepath.Join(tmpdir, "example.json") + spec := `{ + "Name": "example", + "Addr": "https://example.com/containers/storage/plugin" +}` + + if err := ioutil.WriteFile(p, []byte(spec), 0644); err != nil { + t.Fatal(err) + } + + r := newLocalRegistry() + plugin, err := r.Plugin("example") + if err != nil { + t.Fatal(err) + } + plugin.Manifest = &Manifest{Implements: []string{"apple"}} + storage.plugins["example"] = plugin + + fetchedPlugins, err := GetAll("apple") + if err != nil { + t.Fatal(err) + } + if fetchedPlugins[0].Name() != plugin.Name() { + t.Fatalf("Expected to get plugin with name %s", plugin.Name()) + } +} diff --git a/pkg/plugins/pluginrpc-gen/README.md b/pkg/plugins/pluginrpc-gen/README.md new file mode 100644 index 0000000000..5f6a421f19 --- /dev/null +++ b/pkg/plugins/pluginrpc-gen/README.md @@ -0,0 +1,58 @@ +Plugin RPC Generator +==================== + +Generates go code from a Go interface definition for proxying between the plugin +API and the subsystem being extended. + +## Usage + +Given an interface definition: + +```go +type volumeDriver interface { + Create(name string, opts opts) (err error) + Remove(name string) (err error) + Path(name string) (mountpoint string, err error) + Mount(name string) (mountpoint string, err error) + Unmount(name string) (err error) +} +``` + +**Note**: All function options and return values must be named in the definition. + +Run the generator: + +```bash +$ pluginrpc-gen --type volumeDriver --name VolumeDriver -i volumes/drivers/extpoint.go -o volumes/drivers/proxy.go +``` + +Where: +- `--type` is the name of the interface to use +- `--name` is the subsystem that the plugin "Implements" +- `-i` is the input file containing the interface definition +- `-o` is the output file where the generated code should go + +**Note**: The generated code will use the same package name as the one defined in the input file + +Optionally, you can skip functions on the interface that should not be +implemented in the generated proxy code by passing in the function name to `--skip`. +This flag can be specified multiple times. + +You can also add build tags that should be prepended to the generated code by +supplying `--tag`. This flag can be specified multiple times. + +## Known issues + +## go-generate + +You can also use this with go-generate, which is pretty awesome. +To do so, place the code at the top of the file which contains the interface +definition (i.e., the input file): + +```go +//go:generate pluginrpc-gen -i $GOFILE -o proxy.go -type volumeDriver -name VolumeDriver +``` + +Then cd to the package dir and run `go generate` + +**Note**: the `pluginrpc-gen` binary must be within your `$PATH` diff --git a/pkg/plugins/pluginrpc-gen/fixtures/foo.go b/pkg/plugins/pluginrpc-gen/fixtures/foo.go index 642cefe8f6..3f3b29c78a 100644 --- a/pkg/plugins/pluginrpc-gen/fixtures/foo.go +++ b/pkg/plugins/pluginrpc-gen/fixtures/foo.go @@ -1,17 +1,11 @@ package foo import ( - "fmt" - aliasedio "io" "github.com/containers/storage/pkg/plugins/pluginrpc-gen/fixtures/otherfixture" ) -var ( - errFakeImport = fmt.Errorf("just to import fmt for imports tests") -) - type wobble struct { Some string Val string diff --git a/pkg/plugins/pluginrpc-gen/parser_test.go b/pkg/plugins/pluginrpc-gen/parser_test.go index a1b1ac9567..fe7fa5ade6 100644 --- a/pkg/plugins/pluginrpc-gen/parser_test.go +++ b/pkg/plugins/pluginrpc-gen/parser_test.go @@ -136,7 +136,7 @@ func TestParseWithMultipleFuncs(t *testing.T) { } } -func TestParseWithUnamedReturn(t *testing.T) { +func TestParseWithUnnamedReturn(t *testing.T) { _, err := Parse(testFixture, "Fooer4") if !strings.HasSuffix(err.Error(), errBadReturn.Error()) { t.Fatalf("expected ErrBadReturn, got %v", err) diff --git a/pkg/plugins/plugins.go b/pkg/plugins/plugins.go index e197c3fdcf..f7773b0a87 100644 --- a/pkg/plugins/plugins.go +++ b/pkg/plugins/plugins.go @@ -1,19 +1,17 @@ // Package plugins provides structures and helper functions to manage Docker // plugins. // -// Storage discovers plugins by looking for them in the plugin directory whenever +// Docker discovers plugins by looking for them in the plugin directory whenever // a user or container tries to use one by name. UNIX domain socket files must -// be located under /run/containers/storage/plugins, whereas spec files can be -// located either under /etc/containers/storage/plugins or -// /usr/lib/containers/storage/plugins. This is handled by the Registry -// interface, which lets you list all plugins or get a plugin by its name if it -// exists. +// be located under /run/container/storage/plugins, whereas spec files can be located +// either under /etc/container/storage/plugins or /usr/lib/container/storage/plugins. This is handled +// by the Registry interface, which lets you list all plugins or get a plugin by +// its name if it exists. // // The plugins need to implement an HTTP server and bind this to the UNIX socket // or the address specified in the spec files. // A handshake is send at /Plugin.Activate, and plugins are expected to return -// a Manifest with a list of subsystems which this plugin implements. As of -// this writing, the known subsystem is "GraphDriver". +// a Manifest with a list of of Docker subsystems which this plugin implements. // // In order to use a plugins, you can use the ``Get`` with the name of the // plugin and the subsystem it implements. @@ -43,9 +41,14 @@ type plugins struct { plugins map[string]*Plugin } +type extpointHandlers struct { + sync.RWMutex + extpointHandlers map[string][]func(string, *Client) +} + var ( - storage = plugins{plugins: make(map[string]*Plugin)} - extpointHandlers = make(map[string]func(string, *Client)) + storage = plugins{plugins: make(map[string]*Plugin)} + handlers = extpointHandlers{extpointHandlers: make(map[string][]func(string, *Client))} ) // Manifest lists what a plugin implements. @@ -54,7 +57,7 @@ type Manifest struct { Implements []string } -// Plugin is the definition of a storage plugin. +// Plugin is the definition of a container/storage plugin. type Plugin struct { // Name of the plugin name string @@ -67,12 +70,12 @@ type Plugin struct { // Manifest of the plugin (see above) Manifest *Manifest `json:"-"` - // error produced by activation - activateErr error - // specifies if the activation sequence is completed (not if it is successful or not) - activated bool // wait for activation to finish activateWait *sync.Cond + // error produced by activation + activateErr error + // keeps track of callback handlers run against this plugin + handlersRun bool } // Name returns the name of the plugin. @@ -85,6 +88,11 @@ func (p *Plugin) Client() *Client { return p.client } +// IsV1 returns true for V1 plugins and false otherwise. +func (p *Plugin) IsV1() bool { + return true +} + // NewLocalPlugin creates a new local plugin. func NewLocalPlugin(name, addr string) *Plugin { return &Plugin{ @@ -98,19 +106,51 @@ func NewLocalPlugin(name, addr string) *Plugin { func (p *Plugin) activate() error { p.activateWait.L.Lock() - if p.activated { + + if p.activated() { + p.runHandlers() p.activateWait.L.Unlock() return p.activateErr } p.activateErr = p.activateWithLock() - p.activated = true + p.runHandlers() p.activateWait.L.Unlock() p.activateWait.Broadcast() return p.activateErr } +// runHandlers runs the registered handlers for the implemented plugin types +// This should only be run after activation, and while the activation lock is held. +func (p *Plugin) runHandlers() { + if !p.activated() { + return + } + + handlers.RLock() + if !p.handlersRun { + for _, iface := range p.Manifest.Implements { + hdlrs, handled := handlers.extpointHandlers[iface] + if !handled { + continue + } + for _, handler := range hdlrs { + handler(p.name, p.client) + } + } + p.handlersRun = true + } + handlers.RUnlock() + +} + +// activated returns if the plugin has already been activated. +// This should only be called with the activation lock held +func (p *Plugin) activated() bool { + return p.Manifest != nil +} + func (p *Plugin) activateWithLock() error { c, err := NewClient(p.Addr, p.TLSConfig) if err != nil { @@ -124,20 +164,12 @@ func (p *Plugin) activateWithLock() error { } p.Manifest = m - - for _, iface := range m.Implements { - handler, handled := extpointHandlers[iface] - if !handled { - continue - } - handler(p.name, p.client) - } return nil } func (p *Plugin) waitActive() error { p.activateWait.L.Lock() - for !p.activated { + for !p.activated() && p.activateErr == nil { p.activateWait.Wait() } p.activateWait.L.Unlock() @@ -145,7 +177,7 @@ func (p *Plugin) waitActive() error { } func (p *Plugin) implements(kind string) bool { - if err := p.waitActive(); err != nil { + if p.Manifest == nil { return false } for _, driver := range p.Manifest.Implements { @@ -183,6 +215,10 @@ func loadWithRetry(name string, retry bool) (*Plugin, error) { } storage.Lock() + if pl, exists := storage.plugins[name]; exists { + storage.Unlock() + return pl, pl.activate() + } storage.plugins[name] = pl storage.Unlock() @@ -214,7 +250,7 @@ func Get(name, imp string) (*Plugin, error) { if err != nil { return nil, err } - if pl.implements(imp) { + if err := pl.waitActive(); err == nil && pl.implements(imp) { logrus.Debugf("%s implements: %s", name, imp) return pl, nil } @@ -223,7 +259,26 @@ func Get(name, imp string) (*Plugin, error) { // Handle adds the specified function to the extpointHandlers. func Handle(iface string, fn func(string, *Client)) { - extpointHandlers[iface] = fn + handlers.Lock() + hdlrs, ok := handlers.extpointHandlers[iface] + if !ok { + hdlrs = []func(string, *Client){} + } + + hdlrs = append(hdlrs, fn) + handlers.extpointHandlers[iface] = hdlrs + + storage.Lock() + for _, p := range storage.plugins { + p.activateWait.L.Lock() + if p.activated() && p.implements(iface) { + p.handlersRun = false + } + p.activateWait.L.Unlock() + } + storage.Unlock() + + handlers.Unlock() } // GetAll returns all the plugins for the specified implementation @@ -241,7 +296,10 @@ func GetAll(imp string) ([]*Plugin, error) { chPl := make(chan *plLoad, len(pluginNames)) var wg sync.WaitGroup for _, name := range pluginNames { - if pl, ok := storage.plugins[name]; ok { + storage.Lock() + pl, ok := storage.plugins[name] + storage.Unlock() + if ok { chPl <- &plLoad{pl, nil} continue } @@ -263,7 +321,7 @@ func GetAll(imp string) ([]*Plugin, error) { logrus.Error(pl.err) continue } - if pl.pl.implements(imp) { + if err := pl.pl.waitActive(); err == nil && pl.pl.implements(imp) { out = append(out, pl.pl) } } diff --git a/pkg/plugins/plugins_unix.go b/pkg/plugins/plugins_unix.go new file mode 100644 index 0000000000..02f1da69a1 --- /dev/null +++ b/pkg/plugins/plugins_unix.go @@ -0,0 +1,9 @@ +// +build !windows + +package plugins + +// BasePath returns the path to which all paths returned by the plugin are relative to. +// For v1 plugins, this always returns the host's root directory. +func (p *Plugin) BasePath() string { + return "/" +} diff --git a/pkg/plugins/plugins_windows.go b/pkg/plugins/plugins_windows.go new file mode 100644 index 0000000000..3c8d8feb83 --- /dev/null +++ b/pkg/plugins/plugins_windows.go @@ -0,0 +1,8 @@ +package plugins + +// BasePath returns the path to which all paths returned by the plugin are relative to. +// For Windows v1 plugins, this returns an empty string, since the plugin is already aware +// of the absolute path of the mount. +func (p *Plugin) BasePath() string { + return "" +} diff --git a/pkg/plugins/transport/http_test.go b/pkg/plugins/transport/http_test.go new file mode 100644 index 0000000000..b724fd0df0 --- /dev/null +++ b/pkg/plugins/transport/http_test.go @@ -0,0 +1,20 @@ +package transport + +import ( + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHTTPTransport(t *testing.T) { + var r io.Reader + roundTripper := &http.Transport{} + newTransport := NewHTTPTransport(roundTripper, "http", "0.0.0.0") + request, err := newTransport.NewRequest("", r) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "POST", request.Method) +} diff --git a/pkg/promise/promise_test.go b/pkg/promise/promise_test.go new file mode 100644 index 0000000000..287213b504 --- /dev/null +++ b/pkg/promise/promise_test.go @@ -0,0 +1,25 @@ +package promise + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGo(t *testing.T) { + errCh := Go(functionWithError) + er := <-errCh + require.EqualValues(t, "Error Occurred", er.Error()) + + noErrCh := Go(functionWithNoError) + er = <-noErrCh + require.Nil(t, er) +} + +func functionWithError() (err error) { + return errors.New("Error Occurred") +} +func functionWithNoError() (err error) { + return nil +} diff --git a/pkg/random/random.go b/pkg/random/random.go deleted file mode 100644 index 70de4d1304..0000000000 --- a/pkg/random/random.go +++ /dev/null @@ -1,71 +0,0 @@ -package random - -import ( - cryptorand "crypto/rand" - "io" - "math" - "math/big" - "math/rand" - "sync" - "time" -) - -// Rand is a global *rand.Rand instance, which initialized with NewSource() source. -var Rand = rand.New(NewSource()) - -// Reader is a global, shared instance of a pseudorandom bytes generator. -// It doesn't consume entropy. -var Reader io.Reader = &reader{rnd: Rand} - -// copypaste from standard math/rand -type lockedSource struct { - lk sync.Mutex - src rand.Source -} - -func (r *lockedSource) Int63() (n int64) { - r.lk.Lock() - n = r.src.Int63() - r.lk.Unlock() - return -} - -func (r *lockedSource) Seed(seed int64) { - r.lk.Lock() - r.src.Seed(seed) - r.lk.Unlock() -} - -// NewSource returns math/rand.Source safe for concurrent use and initialized -// with current unix-nano timestamp -func NewSource() rand.Source { - var seed int64 - if cryptoseed, err := cryptorand.Int(cryptorand.Reader, big.NewInt(math.MaxInt64)); err != nil { - // This should not happen, but worst-case fallback to time-based seed. - seed = time.Now().UnixNano() - } else { - seed = cryptoseed.Int64() - } - return &lockedSource{ - src: rand.NewSource(seed), - } -} - -type reader struct { - rnd *rand.Rand -} - -func (r *reader) Read(b []byte) (int, error) { - i := 0 - for { - val := r.rnd.Int63() - for val > 0 { - b[i] = byte(val) - i++ - if i == len(b) { - return i, nil - } - val >>= 8 - } - } -} diff --git a/pkg/random/random_test.go b/pkg/random/random_test.go deleted file mode 100644 index cf405f78cb..0000000000 --- a/pkg/random/random_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package random - -import ( - "math/rand" - "sync" - "testing" -) - -// for go test -v -race -func TestConcurrency(t *testing.T) { - rnd := rand.New(NewSource()) - var wg sync.WaitGroup - - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - rnd.Int63() - wg.Done() - }() - } - wg.Wait() -} diff --git a/pkg/reexec/README.md b/pkg/reexec/README.md new file mode 100644 index 0000000000..6658f69b69 --- /dev/null +++ b/pkg/reexec/README.md @@ -0,0 +1,5 @@ +# reexec + +The `reexec` package facilitates the busybox style reexec of the docker binary that we require because +of the forking limitations of using Go. Handlers can be registered with a name and the argv 0 of +the exec of the binary will be used to find and execute custom init paths. diff --git a/pkg/reexec/command_linux.go b/pkg/reexec/command_linux.go index 3c3a73a9d5..05319eacc9 100644 --- a/pkg/reexec/command_linux.go +++ b/pkg/reexec/command_linux.go @@ -5,6 +5,8 @@ package reexec import ( "os/exec" "syscall" + + "golang.org/x/sys/unix" ) // Self returns the path to the current process's binary. @@ -13,7 +15,7 @@ func Self() string { return "/proc/self/exe" } -// Command returns *exec.Cmd which have Path as current binary. Also it setting +// Command returns *exec.Cmd which has Path as current binary. Also it setting // SysProcAttr.Pdeathsig to SIGTERM. // This will use the in-memory version (/proc/self/exe) of the current binary, // it is thus safe to delete or replace the on-disk binary (os.Args[0]). @@ -22,7 +24,7 @@ func Command(args ...string) *exec.Cmd { Path: Self(), Args: args, SysProcAttr: &syscall.SysProcAttr{ - Pdeathsig: syscall.SIGTERM, + Pdeathsig: unix.SIGTERM, }, } } diff --git a/pkg/reexec/command_unix.go b/pkg/reexec/command_unix.go index b70edcb316..778a720e3b 100644 --- a/pkg/reexec/command_unix.go +++ b/pkg/reexec/command_unix.go @@ -1,4 +1,4 @@ -// +build freebsd solaris +// +build freebsd solaris darwin package reexec @@ -12,7 +12,7 @@ func Self() string { return naiveSelf() } -// Command returns *exec.Cmd which have Path as current binary. +// Command returns *exec.Cmd which has Path as current binary. // For example if current binary is "docker" at "/usr/bin/", then cmd.Path will // be set to "/usr/bin/docker". func Command(args ...string) *exec.Cmd { diff --git a/pkg/reexec/command_unsupported.go b/pkg/reexec/command_unsupported.go index 9aed004e86..76edd82427 100644 --- a/pkg/reexec/command_unsupported.go +++ b/pkg/reexec/command_unsupported.go @@ -1,4 +1,4 @@ -// +build !linux,!windows,!freebsd,!solaris +// +build !linux,!windows,!freebsd,!solaris,!darwin package reexec @@ -6,7 +6,7 @@ import ( "os/exec" ) -// Command is unsupported on operating systems apart from Linux and Windows. +// Command is unsupported on operating systems apart from Linux, Windows, Solaris and Darwin. func Command(args ...string) *exec.Cmd { return nil } diff --git a/pkg/reexec/command_windows.go b/pkg/reexec/command_windows.go index 8d65e0ae1a..ca871c4227 100644 --- a/pkg/reexec/command_windows.go +++ b/pkg/reexec/command_windows.go @@ -12,7 +12,7 @@ func Self() string { return naiveSelf() } -// Command returns *exec.Cmd which have Path as current binary. +// Command returns *exec.Cmd which has Path as current binary. // For example if current binary is "docker.exe" at "C:\", then cmd.Path will // be set to "C:\docker.exe". func Command(args ...string) *exec.Cmd { diff --git a/pkg/reexec/reexec_test.go b/pkg/reexec/reexec_test.go new file mode 100644 index 0000000000..39e87a4a27 --- /dev/null +++ b/pkg/reexec/reexec_test.go @@ -0,0 +1,53 @@ +package reexec + +import ( + "os" + "os/exec" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func init() { + Register("reexec", func() { + panic("Return Error") + }) + Init() +} + +func TestRegister(t *testing.T) { + defer func() { + if r := recover(); r != nil { + require.Equal(t, `reexec func already registered under name "reexec"`, r) + } + }() + Register("reexec", func() {}) +} + +func TestCommand(t *testing.T) { + cmd := Command("reexec") + w, err := cmd.StdinPipe() + require.NoError(t, err, "Error on pipe creation: %v", err) + defer w.Close() + + err = cmd.Start() + require.NoError(t, err, "Error on re-exec cmd: %v", err) + err = cmd.Wait() + require.EqualError(t, err, "exit status 2") +} + +func TestNaiveSelf(t *testing.T) { + if os.Getenv("TEST_CHECK") == "1" { + os.Exit(2) + } + cmd := exec.Command(naiveSelf(), "-test.run=TestNaiveSelf") + cmd.Env = append(os.Environ(), "TEST_CHECK=1") + err := cmd.Start() + require.NoError(t, err, "Unable to start command") + err = cmd.Wait() + require.EqualError(t, err, "exit status 2") + + os.Args[0] = "mkdir" + assert.NotEqual(t, naiveSelf(), os.Args[0]) +} diff --git a/pkg/stringid/README.md b/pkg/stringid/README.md new file mode 100644 index 0000000000..37a5098fd9 --- /dev/null +++ b/pkg/stringid/README.md @@ -0,0 +1 @@ +This package provides helper functions for dealing with string identifiers diff --git a/pkg/stringid/stringid.go b/pkg/stringid/stringid.go index 74dfaaaa76..a0c7c42a05 100644 --- a/pkg/stringid/stringid.go +++ b/pkg/stringid/stringid.go @@ -2,19 +2,25 @@ package stringid import ( - "crypto/rand" + cryptorand "crypto/rand" "encoding/hex" + "fmt" "io" + "math" + "math/big" + "math/rand" "regexp" "strconv" "strings" - - "github.com/containers/storage/pkg/random" + "time" ) const shortLen = 12 -var validShortID = regexp.MustCompile("^[a-z0-9]{12}$") +var ( + validShortID = regexp.MustCompile("^[a-f0-9]{12}$") + validHex = regexp.MustCompile(`^[a-f0-9]{64}$`) +) // IsShortID determines if an arbitrary string *looks like* a short ID. func IsShortID(id string) bool { @@ -29,19 +35,14 @@ func TruncateID(id string) string { if i := strings.IndexRune(id, ':'); i >= 0 { id = id[i+1:] } - trimTo := shortLen - if len(id) < shortLen { - trimTo = len(id) + if len(id) > shortLen { + id = id[:shortLen] } - return id[:trimTo] + return id } -func generateID(crypto bool) string { +func generateID(r io.Reader) string { b := make([]byte, 32) - r := random.Reader - if crypto { - r = rand.Reader - } for { if _, err := io.ReadFull(r, b); err != nil { panic(err) // This shouldn't happen @@ -59,13 +60,40 @@ func generateID(crypto bool) string { // GenerateRandomID returns a unique id. func GenerateRandomID() string { - return generateID(true) - + return generateID(cryptorand.Reader) } // GenerateNonCryptoID generates unique id without using cryptographically // secure sources of random. // It helps you to save entropy. func GenerateNonCryptoID() string { - return generateID(false) + return generateID(readerFunc(rand.Read)) +} + +// ValidateID checks whether an ID string is a valid image ID. +func ValidateID(id string) error { + if ok := validHex.MatchString(id); !ok { + return fmt.Errorf("image ID %q is invalid", id) + } + return nil +} + +func init() { + // safely set the seed globally so we generate random ids. Tries to use a + // crypto seed before falling back to time. + var seed int64 + if cryptoseed, err := cryptorand.Int(cryptorand.Reader, big.NewInt(math.MaxInt64)); err != nil { + // This should not happen, but worst-case fallback to time-based seed. + seed = time.Now().UnixNano() + } else { + seed = cryptoseed.Int64() + } + + rand.Seed(seed) +} + +type readerFunc func(p []byte) (int, error) + +func (fn readerFunc) Read(p []byte) (int, error) { + return fn(p) } diff --git a/pkg/stringid/stringid_test.go b/pkg/stringid/stringid_test.go index bcb1365495..8ff6b4383d 100644 --- a/pkg/stringid/stringid_test.go +++ b/pkg/stringid/stringid_test.go @@ -13,10 +13,26 @@ func TestGenerateRandomID(t *testing.T) { } } +func TestGenerateNonCryptoID(t *testing.T) { + id := GenerateNonCryptoID() + + if len(id) != 64 { + t.Fatalf("Id returned is incorrect: %s", id) + } +} + func TestShortenId(t *testing.T) { - id := GenerateRandomID() + id := "90435eec5c4e124e741ef731e118be2fc799a68aba0466ec17717f24ce2ae6a2" + truncID := TruncateID(id) + if truncID != "90435eec5c4e" { + t.Fatalf("Id returned is incorrect: truncate on %s returned %s", id, truncID) + } +} + +func TestShortenSha256Id(t *testing.T) { + id := "sha256:4e38e38c8ce0b8d9041a9c4fefe786631d1416225e13b0bfe8cfa2321aec4bba" truncID := TruncateID(id) - if len(truncID) != 12 { + if truncID != "4e38e38c8ce0" { t.Fatalf("Id returned is incorrect: truncate on %s returned %s", id, truncID) } } diff --git a/pkg/stringutils/README.md b/pkg/stringutils/README.md new file mode 100644 index 0000000000..b3e454573c --- /dev/null +++ b/pkg/stringutils/README.md @@ -0,0 +1 @@ +This package provides helper functions for dealing with strings diff --git a/pkg/stringutils/stringutils.go b/pkg/stringutils/stringutils.go index 078ceaf2f7..8c4c39875e 100644 --- a/pkg/stringutils/stringutils.go +++ b/pkg/stringutils/stringutils.go @@ -5,8 +5,6 @@ import ( "bytes" "math/rand" "strings" - - "github.com/containers/storage/pkg/random" ) // GenerateRandomAlphaOnlyString generates an alphabetical random string with length n. @@ -15,7 +13,7 @@ func GenerateRandomAlphaOnlyString(n int) string { letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") b := make([]byte, n) for i := range b { - b[i] = letters[random.Rand.Intn(len(letters))] + b[i] = letters[rand.Intn(len(letters))] } return string(b) } @@ -32,12 +30,26 @@ func GenerateRandomASCIIString(n int) string { return string(res) } +// Ellipsis truncates a string to fit within maxlen, and appends ellipsis (...). +// For maxlen of 3 and lower, no ellipsis is appended. +func Ellipsis(s string, maxlen int) string { + r := []rune(s) + if len(r) <= maxlen { + return s + } + if maxlen <= 3 { + return string(r[:maxlen]) + } + return string(r[:maxlen-3]) + "..." +} + // Truncate truncates a string to maxlen. func Truncate(s string, maxlen int) string { - if len(s) <= maxlen { + r := []rune(s) + if len(r) <= maxlen { return s } - return s[:maxlen] + return string(r[:maxlen]) } // InSlice tests whether a string is contained in a slice of strings or not. @@ -63,7 +75,7 @@ func quote(word string, buf *bytes.Buffer) { for i := 0; i < len(word); i++ { b := word[i] if b == '\'' { - // Replace literal ' with a close ', a \', and a open ' + // Replace literal ' with a close ', a \', and an open ' buf.WriteString("'\\''") } else { buf.WriteByte(b) diff --git a/pkg/stringutils/stringutils_test.go b/pkg/stringutils/stringutils_test.go index fec59450bc..8af2bdcc0b 100644 --- a/pkg/stringutils/stringutils_test.go +++ b/pkg/stringutils/stringutils_test.go @@ -57,24 +57,40 @@ func TestGenerateRandomAsciiStringIsAscii(t *testing.T) { } } +func TestEllipsis(t *testing.T) { + str := "t🐳ststring" + newstr := Ellipsis(str, 3) + if newstr != "t🐳s" { + t.Fatalf("Expected t🐳s, got %s", newstr) + } + newstr = Ellipsis(str, 8) + if newstr != "t🐳sts..." { + t.Fatalf("Expected tests..., got %s", newstr) + } + newstr = Ellipsis(str, 20) + if newstr != "t🐳ststring" { + t.Fatalf("Expected t🐳ststring, got %s", newstr) + } +} + func TestTruncate(t *testing.T) { - str := "teststring" + str := "t🐳ststring" newstr := Truncate(str, 4) - if newstr != "test" { - t.Fatalf("Expected test, got %s", newstr) + if newstr != "t🐳st" { + t.Fatalf("Expected t🐳st, got %s", newstr) } newstr = Truncate(str, 20) - if newstr != "teststring" { - t.Fatalf("Expected teststring, got %s", newstr) + if newstr != "t🐳ststring" { + t.Fatalf("Expected t🐳ststring, got %s", newstr) } } func TestInSlice(t *testing.T) { - slice := []string{"test", "in", "slice"} + slice := []string{"t🐳st", "in", "slice"} - test := InSlice(slice, "test") + test := InSlice(slice, "t🐳st") if !test { - t.Fatalf("Expected string test to be in slice") + t.Fatalf("Expected string t🐳st to be in slice") } test = InSlice(slice, "SLICE") if !test { diff --git a/pkg/system/chtimes.go b/pkg/system/chtimes.go index 7637f12e1a..056d19954d 100644 --- a/pkg/system/chtimes.go +++ b/pkg/system/chtimes.go @@ -2,26 +2,9 @@ package system import ( "os" - "syscall" "time" - "unsafe" ) -var ( - maxTime time.Time -) - -func init() { - if unsafe.Sizeof(syscall.Timespec{}.Nsec) == 8 { - // This is a 64 bit timespec - // os.Chtimes limits time to the following - maxTime = time.Unix(0, 1<<63-1) - } else { - // This is a 32 bit timespec - maxTime = time.Unix(1<<31-1, 0) - } -} - // Chtimes changes the access time and modified time of a file at the given path func Chtimes(name string, atime time.Time, mtime time.Time) error { unixMinTime := time.Unix(0, 0) diff --git a/pkg/system/chtimes_test.go b/pkg/system/chtimes_test.go index 5c87df32a2..fcd644f5c6 100644 --- a/pkg/system/chtimes_test.go +++ b/pkg/system/chtimes_test.go @@ -10,7 +10,7 @@ import ( // prepareTempFile creates a temporary file in a temporary directory. func prepareTempFile(t *testing.T) (string, string) { - dir, err := ioutil.TempDir("", "docker-system-test") + dir, err := ioutil.TempDir("", "storage-system-test") if err != nil { t.Fatal(err) } diff --git a/pkg/system/chtimes_unix_test.go b/pkg/system/chtimes_unix_test.go index 0aafe1d845..6ec9a7173c 100644 --- a/pkg/system/chtimes_unix_test.go +++ b/pkg/system/chtimes_unix_test.go @@ -9,7 +9,7 @@ import ( "time" ) -// TestChtimes tests Chtimes access time on a tempfile on Linux +// TestChtimesLinux tests Chtimes access time on a tempfile on Linux func TestChtimesLinux(t *testing.T) { file, dir := prepareTempFile(t) defer os.RemoveAll(dir) diff --git a/pkg/system/chtimes_windows.go b/pkg/system/chtimes_windows.go index 2945868465..45428c141c 100644 --- a/pkg/system/chtimes_windows.go +++ b/pkg/system/chtimes_windows.go @@ -3,25 +3,26 @@ package system import ( - "syscall" "time" + + "golang.org/x/sys/windows" ) //setCTime will set the create time on a file. On Windows, this requires //calling SetFileTime and explicitly including the create time. func setCTime(path string, ctime time.Time) error { - ctimespec := syscall.NsecToTimespec(ctime.UnixNano()) - pathp, e := syscall.UTF16PtrFromString(path) + ctimespec := windows.NsecToTimespec(ctime.UnixNano()) + pathp, e := windows.UTF16PtrFromString(path) if e != nil { return e } - h, e := syscall.CreateFile(pathp, - syscall.FILE_WRITE_ATTRIBUTES, syscall.FILE_SHARE_WRITE, nil, - syscall.OPEN_EXISTING, syscall.FILE_FLAG_BACKUP_SEMANTICS, 0) + h, e := windows.CreateFile(pathp, + windows.FILE_WRITE_ATTRIBUTES, windows.FILE_SHARE_WRITE, nil, + windows.OPEN_EXISTING, windows.FILE_FLAG_BACKUP_SEMANTICS, 0) if e != nil { return e } - defer syscall.Close(h) - c := syscall.NsecToFiletime(syscall.TimespecToNsec(ctimespec)) - return syscall.SetFileTime(h, &c, nil, nil) + defer windows.Close(h) + c := windows.NsecToFiletime(windows.TimespecToNsec(ctimespec)) + return windows.SetFileTime(h, &c, nil, nil) } diff --git a/pkg/system/chtimes_windows_test.go b/pkg/system/chtimes_windows_test.go index be57558e1b..72d8a10619 100644 --- a/pkg/system/chtimes_windows_test.go +++ b/pkg/system/chtimes_windows_test.go @@ -9,7 +9,7 @@ import ( "time" ) -// TestChtimes tests Chtimes access time on a tempfile on Windows +// TestChtimesWindows tests Chtimes access time on a tempfile on Windows func TestChtimesWindows(t *testing.T) { file, dir := prepareTempFile(t) defer os.RemoveAll(dir) diff --git a/pkg/system/events_windows.go b/pkg/system/events_windows.go deleted file mode 100644 index 04e2de7871..0000000000 --- a/pkg/system/events_windows.go +++ /dev/null @@ -1,83 +0,0 @@ -package system - -// This file implements syscalls for Win32 events which are not implemented -// in golang. - -import ( - "syscall" - "unsafe" -) - -var ( - procCreateEvent = modkernel32.NewProc("CreateEventW") - procOpenEvent = modkernel32.NewProc("OpenEventW") - procSetEvent = modkernel32.NewProc("SetEvent") - procResetEvent = modkernel32.NewProc("ResetEvent") - procPulseEvent = modkernel32.NewProc("PulseEvent") -) - -// CreateEvent implements win32 CreateEventW func in golang. It will create an event object. -func CreateEvent(eventAttributes *syscall.SecurityAttributes, manualReset bool, initialState bool, name string) (handle syscall.Handle, err error) { - namep, _ := syscall.UTF16PtrFromString(name) - var _p1 uint32 - if manualReset { - _p1 = 1 - } - var _p2 uint32 - if initialState { - _p2 = 1 - } - r0, _, e1 := procCreateEvent.Call(uintptr(unsafe.Pointer(eventAttributes)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(namep))) - use(unsafe.Pointer(namep)) - handle = syscall.Handle(r0) - if handle == syscall.InvalidHandle { - err = e1 - } - return -} - -// OpenEvent implements win32 OpenEventW func in golang. It opens an event object. -func OpenEvent(desiredAccess uint32, inheritHandle bool, name string) (handle syscall.Handle, err error) { - namep, _ := syscall.UTF16PtrFromString(name) - var _p1 uint32 - if inheritHandle { - _p1 = 1 - } - r0, _, e1 := procOpenEvent.Call(uintptr(desiredAccess), uintptr(_p1), uintptr(unsafe.Pointer(namep))) - use(unsafe.Pointer(namep)) - handle = syscall.Handle(r0) - if handle == syscall.InvalidHandle { - err = e1 - } - return -} - -// SetEvent implements win32 SetEvent func in golang. -func SetEvent(handle syscall.Handle) (err error) { - return setResetPulse(handle, procSetEvent) -} - -// ResetEvent implements win32 ResetEvent func in golang. -func ResetEvent(handle syscall.Handle) (err error) { - return setResetPulse(handle, procResetEvent) -} - -// PulseEvent implements win32 PulseEvent func in golang. -func PulseEvent(handle syscall.Handle) (err error) { - return setResetPulse(handle, procPulseEvent) -} - -func setResetPulse(handle syscall.Handle, proc *syscall.LazyProc) (err error) { - r0, _, _ := proc.Call(uintptr(handle)) - if r0 != 0 { - err = syscall.Errno(r0) - } - return -} - -var temp unsafe.Pointer - -// use ensures a variable is kept alive without the GC freeing while still needed -func use(p unsafe.Pointer) { - temp = p -} diff --git a/pkg/system/exitcode.go b/pkg/system/exitcode.go new file mode 100644 index 0000000000..60f0514b1d --- /dev/null +++ b/pkg/system/exitcode.go @@ -0,0 +1,33 @@ +package system + +import ( + "fmt" + "os/exec" + "syscall" +) + +// GetExitCode returns the ExitStatus of the specified error if its type is +// exec.ExitError, returns 0 and an error otherwise. +func GetExitCode(err error) (int, error) { + exitCode := 0 + if exiterr, ok := err.(*exec.ExitError); ok { + if procExit, ok := exiterr.Sys().(syscall.WaitStatus); ok { + return procExit.ExitStatus(), nil + } + } + return exitCode, fmt.Errorf("failed to get exit code") +} + +// ProcessExitCode process the specified error and returns the exit status code +// if the error was of type exec.ExitError, returns nothing otherwise. +func ProcessExitCode(err error) (exitCode int) { + if err != nil { + var exiterr error + if exitCode, exiterr = GetExitCode(err); exiterr != nil { + // TODO: Fix this so we check the error's text. + // we've failed to retrieve exit code, so we set it to 127 + exitCode = 127 + } + } + return +} diff --git a/pkg/system/filesys.go b/pkg/system/filesys.go index c14feb8496..102565f760 100644 --- a/pkg/system/filesys.go +++ b/pkg/system/filesys.go @@ -3,13 +3,19 @@ package system import ( + "io/ioutil" "os" "path/filepath" ) +// MkdirAllWithACL is a wrapper for MkdirAll on unix systems. +func MkdirAllWithACL(path string, perm os.FileMode, sddl string) error { + return MkdirAll(path, perm, sddl) +} + // MkdirAll creates a directory named path along with any necessary parents, // with permission specified by attribute perm for all dir created. -func MkdirAll(path string, perm os.FileMode) error { +func MkdirAll(path string, perm os.FileMode, sddl string) error { return os.MkdirAll(path, perm) } @@ -17,3 +23,45 @@ func MkdirAll(path string, perm os.FileMode) error { func IsAbs(path string) bool { return filepath.IsAbs(path) } + +// The functions below here are wrappers for the equivalents in the os and ioutils packages. +// They are passthrough on Unix platforms, and only relevant on Windows. + +// CreateSequential creates the named file with mode 0666 (before umask), truncating +// it if it already exists. If successful, methods on the returned +// File can be used for I/O; the associated file descriptor has mode +// O_RDWR. +// If there is an error, it will be of type *PathError. +func CreateSequential(name string) (*os.File, error) { + return os.Create(name) +} + +// OpenSequential opens the named file for reading. If successful, methods on +// the returned file can be used for reading; the associated file +// descriptor has mode O_RDONLY. +// If there is an error, it will be of type *PathError. +func OpenSequential(name string) (*os.File, error) { + return os.Open(name) +} + +// OpenFileSequential is the generalized open call; most users will use Open +// or Create instead. It opens the named file with specified flag +// (O_RDONLY etc.) and perm, (0666 etc.) if applicable. If successful, +// methods on the returned File can be used for I/O. +// If there is an error, it will be of type *PathError. +func OpenFileSequential(name string, flag int, perm os.FileMode) (*os.File, error) { + return os.OpenFile(name, flag, perm) +} + +// TempFileSequential creates a new temporary file in the directory dir +// with a name beginning with prefix, opens the file for reading +// and writing, and returns the resulting *os.File. +// If dir is the empty string, TempFile uses the default directory +// for temporary files (see os.TempDir). +// Multiple programs calling TempFile simultaneously +// will not choose the same file. The caller can use f.Name() +// to find the pathname of the file. It is the caller's responsibility +// to remove the file when no longer needed. +func TempFileSequential(dir, prefix string) (f *os.File, err error) { + return ioutil.TempFile(dir, prefix) +} diff --git a/pkg/system/filesys_windows.go b/pkg/system/filesys_windows.go index 16823d5517..a61b53d0ba 100644 --- a/pkg/system/filesys_windows.go +++ b/pkg/system/filesys_windows.go @@ -6,17 +6,44 @@ import ( "os" "path/filepath" "regexp" + "strconv" "strings" + "sync" "syscall" + "time" + "unsafe" + + winio "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" +) + +const ( + // SddlAdministratorsLocalSystem is local administrators plus NT AUTHORITY\System + SddlAdministratorsLocalSystem = "D:P(A;OICI;GA;;;BA)(A;OICI;GA;;;SY)" + // SddlNtvmAdministratorsLocalSystem is NT VIRTUAL MACHINE\Virtual Machines plus local administrators plus NT AUTHORITY\System + SddlNtvmAdministratorsLocalSystem = "D:P(A;OICI;GA;;;S-1-5-83-0)(A;OICI;GA;;;BA)(A;OICI;GA;;;SY)" ) +// MkdirAllWithACL is a wrapper for MkdirAll that creates a directory +// with an appropriate SDDL defined ACL. +func MkdirAllWithACL(path string, perm os.FileMode, sddl string) error { + return mkdirall(path, true, sddl) +} + // MkdirAll implementation that is volume path aware for Windows. -func MkdirAll(path string, perm os.FileMode) error { +func MkdirAll(path string, _ os.FileMode, sddl string) error { + return mkdirall(path, false, sddl) +} + +// mkdirall is a custom version of os.MkdirAll modified for use on Windows +// so that it is both volume path aware, and can create a directory with +// a DACL. +func mkdirall(path string, applyACL bool, sddl string) error { if re := regexp.MustCompile(`^\\\\\?\\Volume{[a-z0-9-]+}$`); re.MatchString(path) { return nil } - // The rest of this method is copied from os.MkdirAll and should be kept + // The rest of this method is largely copied from os.MkdirAll and should be kept // as-is to ensure compatibility. // Fast path: if we can tell whether path is a directory or file, stop with success or error. @@ -45,14 +72,19 @@ func MkdirAll(path string, perm os.FileMode) error { if j > 1 { // Create parent - err = MkdirAll(path[0:j-1], perm) + err = mkdirall(path[0:j-1], false, sddl) if err != nil { return err } } - // Parent now exists; invoke Mkdir and use its result. - err = os.Mkdir(path, perm) + // Parent now exists; invoke os.Mkdir or mkdirWithACL and use its result. + if applyACL { + err = mkdirWithACL(path, sddl) + } else { + err = os.Mkdir(path, 0) + } + if err != nil { // Handle arguments like "foo/." by // double-checking that directory doesn't exist. @@ -65,6 +97,35 @@ func MkdirAll(path string, perm os.FileMode) error { return nil } +// mkdirWithACL creates a new directory. If there is an error, it will be of +// type *PathError. . +// +// This is a modified and combined version of os.Mkdir and windows.Mkdir +// in golang to cater for creating a directory am ACL permitting full +// access, with inheritance, to any subfolder/file for Built-in Administrators +// and Local System. +func mkdirWithACL(name string, sddl string) error { + sa := windows.SecurityAttributes{Length: 0} + sd, err := winio.SddlToSecurityDescriptor(sddl) + if err != nil { + return &os.PathError{Op: "mkdir", Path: name, Err: err} + } + sa.Length = uint32(unsafe.Sizeof(sa)) + sa.InheritHandle = 1 + sa.SecurityDescriptor = uintptr(unsafe.Pointer(&sd[0])) + + namep, err := windows.UTF16PtrFromString(name) + if err != nil { + return &os.PathError{Op: "mkdir", Path: name, Err: err} + } + + e := windows.CreateDirectory(namep, &sa) + if e != nil { + return &os.PathError{Op: "mkdir", Path: name, Err: e} + } + return nil +} + // IsAbs is a platform-specific wrapper for filepath.IsAbs. On Windows, // golang filepath.IsAbs does not consider a path \windows\system32 as absolute // as it doesn't start with a drive-letter/colon combination. However, in @@ -80,3 +141,158 @@ func IsAbs(path string) bool { } return true } + +// The origin of the functions below here are the golang OS and windows packages, +// slightly modified to only cope with files, not directories due to the +// specific use case. +// +// The alteration is to allow a file on Windows to be opened with +// FILE_FLAG_SEQUENTIAL_SCAN (particular for docker load), to avoid eating +// the standby list, particularly when accessing large files such as layer.tar. + +// CreateSequential creates the named file with mode 0666 (before umask), truncating +// it if it already exists. If successful, methods on the returned +// File can be used for I/O; the associated file descriptor has mode +// O_RDWR. +// If there is an error, it will be of type *PathError. +func CreateSequential(name string) (*os.File, error) { + return OpenFileSequential(name, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0) +} + +// OpenSequential opens the named file for reading. If successful, methods on +// the returned file can be used for reading; the associated file +// descriptor has mode O_RDONLY. +// If there is an error, it will be of type *PathError. +func OpenSequential(name string) (*os.File, error) { + return OpenFileSequential(name, os.O_RDONLY, 0) +} + +// OpenFileSequential is the generalized open call; most users will use Open +// or Create instead. +// If there is an error, it will be of type *PathError. +func OpenFileSequential(name string, flag int, _ os.FileMode) (*os.File, error) { + if name == "" { + return nil, &os.PathError{Op: "open", Path: name, Err: syscall.ENOENT} + } + r, errf := windowsOpenFileSequential(name, flag, 0) + if errf == nil { + return r, nil + } + return nil, &os.PathError{Op: "open", Path: name, Err: errf} +} + +func windowsOpenFileSequential(name string, flag int, _ os.FileMode) (file *os.File, err error) { + r, e := windowsOpenSequential(name, flag|windows.O_CLOEXEC, 0) + if e != nil { + return nil, e + } + return os.NewFile(uintptr(r), name), nil +} + +func makeInheritSa() *windows.SecurityAttributes { + var sa windows.SecurityAttributes + sa.Length = uint32(unsafe.Sizeof(sa)) + sa.InheritHandle = 1 + return &sa +} + +func windowsOpenSequential(path string, mode int, _ uint32) (fd windows.Handle, err error) { + if len(path) == 0 { + return windows.InvalidHandle, windows.ERROR_FILE_NOT_FOUND + } + pathp, err := windows.UTF16PtrFromString(path) + if err != nil { + return windows.InvalidHandle, err + } + var access uint32 + switch mode & (windows.O_RDONLY | windows.O_WRONLY | windows.O_RDWR) { + case windows.O_RDONLY: + access = windows.GENERIC_READ + case windows.O_WRONLY: + access = windows.GENERIC_WRITE + case windows.O_RDWR: + access = windows.GENERIC_READ | windows.GENERIC_WRITE + } + if mode&windows.O_CREAT != 0 { + access |= windows.GENERIC_WRITE + } + if mode&windows.O_APPEND != 0 { + access &^= windows.GENERIC_WRITE + access |= windows.FILE_APPEND_DATA + } + sharemode := uint32(windows.FILE_SHARE_READ | windows.FILE_SHARE_WRITE) + var sa *windows.SecurityAttributes + if mode&windows.O_CLOEXEC == 0 { + sa = makeInheritSa() + } + var createmode uint32 + switch { + case mode&(windows.O_CREAT|windows.O_EXCL) == (windows.O_CREAT | windows.O_EXCL): + createmode = windows.CREATE_NEW + case mode&(windows.O_CREAT|windows.O_TRUNC) == (windows.O_CREAT | windows.O_TRUNC): + createmode = windows.CREATE_ALWAYS + case mode&windows.O_CREAT == windows.O_CREAT: + createmode = windows.OPEN_ALWAYS + case mode&windows.O_TRUNC == windows.O_TRUNC: + createmode = windows.TRUNCATE_EXISTING + default: + createmode = windows.OPEN_EXISTING + } + // Use FILE_FLAG_SEQUENTIAL_SCAN rather than FILE_ATTRIBUTE_NORMAL as implemented in golang. + //https://msdn.microsoft.com/en-us/library/windows/desktop/aa363858(v=vs.85).aspx + const fileFlagSequentialScan = 0x08000000 // FILE_FLAG_SEQUENTIAL_SCAN + h, e := windows.CreateFile(pathp, access, sharemode, sa, createmode, fileFlagSequentialScan, 0) + return h, e +} + +// Helpers for TempFileSequential +var rand uint32 +var randmu sync.Mutex + +func reseed() uint32 { + return uint32(time.Now().UnixNano() + int64(os.Getpid())) +} +func nextSuffix() string { + randmu.Lock() + r := rand + if r == 0 { + r = reseed() + } + r = r*1664525 + 1013904223 // constants from Numerical Recipes + rand = r + randmu.Unlock() + return strconv.Itoa(int(1e9 + r%1e9))[1:] +} + +// TempFileSequential is a copy of ioutil.TempFile, modified to use sequential +// file access. Below is the original comment from golang: +// TempFile creates a new temporary file in the directory dir +// with a name beginning with prefix, opens the file for reading +// and writing, and returns the resulting *os.File. +// If dir is the empty string, TempFile uses the default directory +// for temporary files (see os.TempDir). +// Multiple programs calling TempFile simultaneously +// will not choose the same file. The caller can use f.Name() +// to find the pathname of the file. It is the caller's responsibility +// to remove the file when no longer needed. +func TempFileSequential(dir, prefix string) (f *os.File, err error) { + if dir == "" { + dir = os.TempDir() + } + + nconflict := 0 + for i := 0; i < 10000; i++ { + name := filepath.Join(dir, prefix+nextSuffix()) + f, err = OpenFileSequential(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0600) + if os.IsExist(err) { + if nconflict++; nconflict > 10 { + randmu.Lock() + rand = reseed() + randmu.Unlock() + } + continue + } + break + } + return +} diff --git a/pkg/system/init.go b/pkg/system/init.go new file mode 100644 index 0000000000..17935088de --- /dev/null +++ b/pkg/system/init.go @@ -0,0 +1,22 @@ +package system + +import ( + "syscall" + "time" + "unsafe" +) + +// Used by chtimes +var maxTime time.Time + +func init() { + // chtimes initialization + if unsafe.Sizeof(syscall.Timespec{}.Nsec) == 8 { + // This is a 64 bit timespec + // os.Chtimes limits time to the following + maxTime = time.Unix(0, 1<<63-1) + } else { + // This is a 32 bit timespec + maxTime = time.Unix(1<<31-1, 0) + } +} diff --git a/pkg/system/init_windows.go b/pkg/system/init_windows.go new file mode 100644 index 0000000000..019c66441c --- /dev/null +++ b/pkg/system/init_windows.go @@ -0,0 +1,17 @@ +package system + +import "os" + +// LCOWSupported determines if Linux Containers on Windows are supported. +// Note: This feature is in development (06/17) and enabled through an +// environment variable. At a future time, it will be enabled based +// on build number. @jhowardmsft +var lcowSupported = false + +func init() { + // LCOW initialization + if os.Getenv("LCOW_SUPPORTED") != "" { + lcowSupported = true + } + +} diff --git a/pkg/system/lcow_unix.go b/pkg/system/lcow_unix.go new file mode 100644 index 0000000000..cff33bb408 --- /dev/null +++ b/pkg/system/lcow_unix.go @@ -0,0 +1,8 @@ +// +build !windows + +package system + +// LCOWSupported returns true if Linux containers on Windows are supported. +func LCOWSupported() bool { + return false +} diff --git a/pkg/system/lcow_windows.go b/pkg/system/lcow_windows.go new file mode 100644 index 0000000000..e54d01e696 --- /dev/null +++ b/pkg/system/lcow_windows.go @@ -0,0 +1,6 @@ +package system + +// LCOWSupported returns true if Linux containers on Windows are supported. +func LCOWSupported() bool { + return lcowSupported +} diff --git a/pkg/system/lstat.go b/pkg/system/lstat_unix.go similarity index 100% rename from pkg/system/lstat.go rename to pkg/system/lstat_unix.go diff --git a/pkg/system/lstat_windows.go b/pkg/system/lstat_windows.go index 49e87eb40b..e51df0dafe 100644 --- a/pkg/system/lstat_windows.go +++ b/pkg/system/lstat_windows.go @@ -1,25 +1,14 @@ -// +build windows - package system -import ( - "os" -) +import "os" // Lstat calls os.Lstat to get a fileinfo interface back. // This is then copied into our own locally defined structure. -// Note the Linux version uses fromStatT to do the copy back, -// but that not strictly necessary when already in an OS specific module. func Lstat(path string) (*StatT, error) { fi, err := os.Lstat(path) if err != nil { return nil, err } - return &StatT{ - name: fi.Name(), - size: fi.Size(), - mode: fi.Mode(), - modTime: fi.ModTime(), - isDir: fi.IsDir()}, nil + return fromStatT(&fi) } diff --git a/pkg/system/meminfo_solaris.go b/pkg/system/meminfo_solaris.go index 313c601b12..925776e789 100644 --- a/pkg/system/meminfo_solaris.go +++ b/pkg/system/meminfo_solaris.go @@ -7,6 +7,7 @@ import ( "unsafe" ) +// #cgo CFLAGS: -std=c99 // #cgo LDFLAGS: -lkstat // #include // #include @@ -89,7 +90,7 @@ func ReadMemInfo() (*MemInfo, error) { if ppKernel < 0 || MemTotal < 0 || MemFree < 0 || SwapTotal < 0 || SwapFree < 0 { - return nil, fmt.Errorf("Error getting system memory info %v\n", err) + return nil, fmt.Errorf("error getting system memory info %v\n", err) } meminfo := &MemInfo{} diff --git a/pkg/system/meminfo_windows.go b/pkg/system/meminfo_windows.go index d46642598c..883944a4c5 100644 --- a/pkg/system/meminfo_windows.go +++ b/pkg/system/meminfo_windows.go @@ -1,12 +1,13 @@ package system import ( - "syscall" "unsafe" + + "golang.org/x/sys/windows" ) var ( - modkernel32 = syscall.NewLazyDLL("kernel32.dll") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") procGlobalMemoryStatusEx = modkernel32.NewProc("GlobalMemoryStatusEx") ) diff --git a/pkg/system/mknod.go b/pkg/system/mknod.go index 73958182b4..af79a65383 100644 --- a/pkg/system/mknod.go +++ b/pkg/system/mknod.go @@ -3,13 +3,13 @@ package system import ( - "syscall" + "golang.org/x/sys/unix" ) // Mknod creates a filesystem node (file, device special file or named pipe) named path // with attributes specified by mode and dev. func Mknod(path string, mode uint32, dev int) error { - return syscall.Mknod(path, mode, dev) + return unix.Mknod(path, mode, dev) } // Mkdev is used to build the value of linux devices (in /dev/) which specifies major diff --git a/pkg/system/path.go b/pkg/system/path.go new file mode 100644 index 0000000000..f634a6be67 --- /dev/null +++ b/pkg/system/path.go @@ -0,0 +1,21 @@ +package system + +import "runtime" + +const defaultUnixPathEnv = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" + +// DefaultPathEnv is unix style list of directories to search for +// executables. Each directory is separated from the next by a colon +// ':' character . +func DefaultPathEnv(platform string) string { + if runtime.GOOS == "windows" { + if platform != runtime.GOOS && LCOWSupported() { + return defaultUnixPathEnv + } + // Deliberately empty on Windows containers on Windows as the default path will be set by + // the container. Docker has no context of what the default path should be. + return "" + } + return defaultUnixPathEnv + +} diff --git a/pkg/system/path_unix.go b/pkg/system/path_unix.go index c607c4db09..f3762e69d3 100644 --- a/pkg/system/path_unix.go +++ b/pkg/system/path_unix.go @@ -2,11 +2,6 @@ package system -// DefaultPathEnv is unix style list of directories to search for -// executables. Each directory is separated from the next by a colon -// ':' character . -const DefaultPathEnv = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" - // CheckSystemDriveAndRemoveDriveLetter verifies that a path, if it includes a drive letter, // is the system drive. This is a no-op on Linux. func CheckSystemDriveAndRemoveDriveLetter(path string) (string, error) { diff --git a/pkg/system/path_windows.go b/pkg/system/path_windows.go index cbfe2c1576..aab891522d 100644 --- a/pkg/system/path_windows.go +++ b/pkg/system/path_windows.go @@ -8,15 +8,11 @@ import ( "strings" ) -// DefaultPathEnv is deliberately empty on Windows as the default path will be set by -// the container. Docker has no context of what the default path should be. -const DefaultPathEnv = "" - // CheckSystemDriveAndRemoveDriveLetter verifies and manipulates a Windows path. // This is used, for example, when validating a user provided path in docker cp. // If a drive letter is supplied, it must be the system drive. The drive letter // is always removed. Also, it translates it to OS semantics (IOW / to \). We -// need the path in this syntax so that it can ultimately be contatenated with +// need the path in this syntax so that it can ultimately be concatenated with // a Windows long-path which doesn't support drive-letters. Examples: // C: --> Fail // C:\ --> \ diff --git a/pkg/system/process_unix.go b/pkg/system/process_unix.go new file mode 100644 index 0000000000..26c8b42c17 --- /dev/null +++ b/pkg/system/process_unix.go @@ -0,0 +1,24 @@ +// +build linux freebsd solaris darwin + +package system + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +// IsProcessAlive returns true if process with a given pid is running. +func IsProcessAlive(pid int) bool { + err := unix.Kill(pid, syscall.Signal(0)) + if err == nil || err == unix.EPERM { + return true + } + + return false +} + +// KillProcess force-stops a process. +func KillProcess(pid int) { + unix.Kill(pid, unix.SIGKILL) +} diff --git a/pkg/system/rm.go b/pkg/system/rm.go new file mode 100644 index 0000000000..fc03c3e6b6 --- /dev/null +++ b/pkg/system/rm.go @@ -0,0 +1,80 @@ +package system + +import ( + "os" + "syscall" + "time" + + "github.com/containers/storage/pkg/mount" + "github.com/pkg/errors" +) + +// EnsureRemoveAll wraps `os.RemoveAll` to check for specific errors that can +// often be remedied. +// Only use `EnsureRemoveAll` if you really want to make every effort to remove +// a directory. +// +// Because of the way `os.Remove` (and by extension `os.RemoveAll`) works, there +// can be a race between reading directory entries and then actually attempting +// to remove everything in the directory. +// These types of errors do not need to be returned since it's ok for the dir to +// be gone we can just retry the remove operation. +// +// This should not return a `os.ErrNotExist` kind of error under any circumstances +func EnsureRemoveAll(dir string) error { + notExistErr := make(map[string]bool) + + // track retries + exitOnErr := make(map[string]int) + maxRetry := 5 + + // Attempt to unmount anything beneath this dir first + mount.RecursiveUnmount(dir) + + for { + err := os.RemoveAll(dir) + if err == nil { + return err + } + + pe, ok := err.(*os.PathError) + if !ok { + return err + } + + if os.IsNotExist(err) { + if notExistErr[pe.Path] { + return err + } + notExistErr[pe.Path] = true + + // There is a race where some subdir can be removed but after the parent + // dir entries have been read. + // So the path could be from `os.Remove(subdir)` + // If the reported non-existent path is not the passed in `dir` we + // should just retry, but otherwise return with no error. + if pe.Path == dir { + return nil + } + continue + } + + if pe.Err != syscall.EBUSY { + return err + } + + if mounted, _ := mount.Mounted(pe.Path); mounted { + if e := mount.Unmount(pe.Path); e != nil { + if mounted, _ := mount.Mounted(pe.Path); mounted { + return errors.Wrapf(e, "error while removing %s", dir) + } + } + } + + if exitOnErr[pe.Path] == maxRetry { + return err + } + exitOnErr[pe.Path]++ + time.Sleep(100 * time.Millisecond) + } +} diff --git a/pkg/system/rm_test.go b/pkg/system/rm_test.go new file mode 100644 index 0000000000..98371d4a64 --- /dev/null +++ b/pkg/system/rm_test.go @@ -0,0 +1,84 @@ +package system + +import ( + "io/ioutil" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/containers/storage/pkg/mount" +) + +func TestEnsureRemoveAllNotExist(t *testing.T) { + // should never return an error for a non-existent path + if err := EnsureRemoveAll("/non/existent/path"); err != nil { + t.Fatal(err) + } +} + +func TestEnsureRemoveAllWithDir(t *testing.T) { + dir, err := ioutil.TempDir("", "test-ensure-removeall-with-dir") + if err != nil { + t.Fatal(err) + } + if err := EnsureRemoveAll(dir); err != nil { + t.Fatal(err) + } +} + +func TestEnsureRemoveAllWithFile(t *testing.T) { + tmp, err := ioutil.TempFile("", "test-ensure-removeall-with-dir") + if err != nil { + t.Fatal(err) + } + tmp.Close() + if err := EnsureRemoveAll(tmp.Name()); err != nil { + t.Fatal(err) + } +} + +func TestEnsureRemoveAllWithMount(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("mount not supported on Windows") + } + + dir1, err := ioutil.TempDir("", "test-ensure-removeall-with-dir1") + if err != nil { + t.Fatal(err) + } + dir2, err := ioutil.TempDir("", "test-ensure-removeall-with-dir2") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir2) + + bindDir := filepath.Join(dir1, "bind") + if err := os.MkdirAll(bindDir, 0755); err != nil { + t.Fatal(err) + } + + if err := mount.Mount(dir2, bindDir, "none", "bind"); err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + go func() { + err = EnsureRemoveAll(dir1) + close(done) + }() + + select { + case <-done: + if err != nil { + t.Fatal(err) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for EnsureRemoveAll to finish") + } + + if _, err := os.Stat(dir1); !os.IsNotExist(err) { + t.Fatalf("expected %q to not exist", dir1) + } +} diff --git a/pkg/system/stat_unsupported.go b/pkg/system/stat_darwin.go similarity index 59% rename from pkg/system/stat_unsupported.go rename to pkg/system/stat_darwin.go index f53e9de4d1..715f05b938 100644 --- a/pkg/system/stat_unsupported.go +++ b/pkg/system/stat_darwin.go @@ -1,12 +1,8 @@ -// +build !linux,!windows,!freebsd,!solaris,!openbsd - package system -import ( - "syscall" -) +import "syscall" -// fromStatT creates a system.StatT type from a syscall.Stat_t type +// fromStatT converts a syscall.Stat_t type to a system.Stat_t type func fromStatT(s *syscall.Stat_t) (*StatT, error) { return &StatT{size: s.Size, mode: uint32(s.Mode), diff --git a/pkg/system/stat_freebsd.go b/pkg/system/stat_freebsd.go index d0fb6f1519..715f05b938 100644 --- a/pkg/system/stat_freebsd.go +++ b/pkg/system/stat_freebsd.go @@ -1,8 +1,6 @@ package system -import ( - "syscall" -) +import "syscall" // fromStatT converts a syscall.Stat_t type to a system.Stat_t type func fromStatT(s *syscall.Stat_t) (*StatT, error) { @@ -13,15 +11,3 @@ func fromStatT(s *syscall.Stat_t) (*StatT, error) { rdev: uint64(s.Rdev), mtim: s.Mtimespec}, nil } - -// Stat takes a path to a file and returns -// a system.Stat_t type pertaining to that file. -// -// Throws an error if the file does not exist -func Stat(path string) (*StatT, error) { - s := &syscall.Stat_t{} - if err := syscall.Stat(path, s); err != nil { - return nil, err - } - return fromStatT(s) -} diff --git a/pkg/system/stat_linux.go b/pkg/system/stat_linux.go index 8b1eded138..1939f95181 100644 --- a/pkg/system/stat_linux.go +++ b/pkg/system/stat_linux.go @@ -1,8 +1,6 @@ package system -import ( - "syscall" -) +import "syscall" // fromStatT converts a syscall.Stat_t type to a system.Stat_t type func fromStatT(s *syscall.Stat_t) (*StatT, error) { @@ -14,20 +12,8 @@ func fromStatT(s *syscall.Stat_t) (*StatT, error) { mtim: s.Mtim}, nil } -// FromStatT exists only on linux, and loads a system.StatT from a -// syscal.Stat_t. +// FromStatT converts a syscall.Stat_t type to a system.Stat_t type +// This is exposed on Linux as pkg/archive/changes uses it. func FromStatT(s *syscall.Stat_t) (*StatT, error) { return fromStatT(s) } - -// Stat takes a path to a file and returns -// a system.StatT type pertaining to that file. -// -// Throws an error if the file does not exist -func Stat(path string) (*StatT, error) { - s := &syscall.Stat_t{} - if err := syscall.Stat(path, s); err != nil { - return nil, err - } - return fromStatT(s) -} diff --git a/pkg/system/stat_openbsd.go b/pkg/system/stat_openbsd.go index 3c3b71fb21..b607dea946 100644 --- a/pkg/system/stat_openbsd.go +++ b/pkg/system/stat_openbsd.go @@ -1,10 +1,8 @@ package system -import ( - "syscall" -) +import "syscall" -// fromStatT creates a system.StatT type from a syscall.Stat_t type +// fromStatT converts a syscall.Stat_t type to a system.Stat_t type func fromStatT(s *syscall.Stat_t) (*StatT, error) { return &StatT{size: s.Size, mode: uint32(s.Mode), diff --git a/pkg/system/stat_solaris.go b/pkg/system/stat_solaris.go index 0216985a25..b607dea946 100644 --- a/pkg/system/stat_solaris.go +++ b/pkg/system/stat_solaris.go @@ -1,12 +1,8 @@ -// +build solaris - package system -import ( - "syscall" -) +import "syscall" -// fromStatT creates a system.StatT type from a syscall.Stat_t type +// fromStatT converts a syscall.Stat_t type to a system.Stat_t type func fromStatT(s *syscall.Stat_t) (*StatT, error) { return &StatT{size: s.Size, mode: uint32(s.Mode), @@ -15,20 +11,3 @@ func fromStatT(s *syscall.Stat_t) (*StatT, error) { rdev: uint64(s.Rdev), mtim: s.Mtim}, nil } - -// FromStatT loads a system.StatT from a syscal.Stat_t. -func FromStatT(s *syscall.Stat_t) (*StatT, error) { - return fromStatT(s) -} - -// Stat takes a path to a file and returns -// a system.StatT type pertaining to that file. -// -// Throws an error if the file does not exist -func Stat(path string) (*StatT, error) { - s := &syscall.Stat_t{} - if err := syscall.Stat(path, s); err != nil { - return nil, err - } - return fromStatT(s) -} diff --git a/pkg/system/stat.go b/pkg/system/stat_unix.go similarity index 74% rename from pkg/system/stat.go rename to pkg/system/stat_unix.go index 087034c5ec..91c7d121cc 100644 --- a/pkg/system/stat.go +++ b/pkg/system/stat_unix.go @@ -47,7 +47,14 @@ func (s StatT) Mtim() syscall.Timespec { return s.mtim } -// GetLastModification returns file's last modification time. -func (s StatT) GetLastModification() syscall.Timespec { - return s.Mtim() +// Stat takes a path to a file and returns +// a system.StatT type pertaining to that file. +// +// Throws an error if the file does not exist +func Stat(path string) (*StatT, error) { + s := &syscall.Stat_t{} + if err := syscall.Stat(path, s); err != nil { + return nil, err + } + return fromStatT(s) } diff --git a/pkg/system/stat_windows.go b/pkg/system/stat_windows.go index 39490c625c..6c63972682 100644 --- a/pkg/system/stat_windows.go +++ b/pkg/system/stat_windows.go @@ -1,5 +1,3 @@ -// +build windows - package system import ( @@ -8,18 +6,11 @@ import ( ) // StatT type contains status of a file. It contains metadata -// like name, permission, size, etc about a file. +// like permission, size, etc about a file. type StatT struct { - name string - size int64 - mode os.FileMode - modTime time.Time - isDir bool -} - -// Name returns file's name. -func (s StatT) Name() string { - return s.name + mode os.FileMode + size int64 + mtim time.Time } // Size returns file's size. @@ -29,15 +20,30 @@ func (s StatT) Size() int64 { // Mode returns file's permission mode. func (s StatT) Mode() os.FileMode { - return s.mode + return os.FileMode(s.mode) +} + +// Mtim returns file's last modification time. +func (s StatT) Mtim() time.Time { + return time.Time(s.mtim) } -// ModTime returns file's last modification time. -func (s StatT) ModTime() time.Time { - return s.modTime +// Stat takes a path to a file and returns +// a system.StatT type pertaining to that file. +// +// Throws an error if the file does not exist +func Stat(path string) (*StatT, error) { + fi, err := os.Stat(path) + if err != nil { + return nil, err + } + return fromStatT(&fi) } -// IsDir returns whether file is actually a directory. -func (s StatT) IsDir() bool { - return s.isDir +// fromStatT converts a os.FileInfo type to a system.StatT type +func fromStatT(fi *os.FileInfo) (*StatT, error) { + return &StatT{ + size: (*fi).Size(), + mode: (*fi).Mode(), + mtim: (*fi).ModTime()}, nil } diff --git a/pkg/system/syscall_unix.go b/pkg/system/syscall_unix.go index 3ae9128468..49dbdd3781 100644 --- a/pkg/system/syscall_unix.go +++ b/pkg/system/syscall_unix.go @@ -2,12 +2,12 @@ package system -import "syscall" +import "golang.org/x/sys/unix" // Unmount is a platform-specific helper function to call // the unmount syscall. func Unmount(dest string) error { - return syscall.Unmount(dest, 0) + return unix.Unmount(dest, 0) } // CommandLineToArgv should not be used on Unix. diff --git a/pkg/system/syscall_windows.go b/pkg/system/syscall_windows.go index 7aaab7e7fb..23e9b207c7 100644 --- a/pkg/system/syscall_windows.go +++ b/pkg/system/syscall_windows.go @@ -1,15 +1,16 @@ package system import ( - "syscall" "unsafe" "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" ) var ( - ntuserApiset = syscall.NewLazyDLL("ext-ms-win-ntuser-window-l1-1-0") - procGetVersionExW = modkernel32.NewProc("GetVersionExW") + ntuserApiset = windows.NewLazyDLL("ext-ms-win-ntuser-window-l1-1-0") + procGetVersionExW = modkernel32.NewProc("GetVersionExW") + procGetProductInfo = modkernel32.NewProc("GetProductInfo") ) // OSVersion is a wrapper for Windows version information @@ -41,7 +42,7 @@ type osVersionInfoEx struct { func GetOSVersion() OSVersion { var err error osv := OSVersion{} - osv.Version, err = syscall.GetVersion() + osv.Version, err = windows.GetVersion() if err != nil { // GetVersion never fails. panic(err) @@ -53,6 +54,8 @@ func GetOSVersion() OSVersion { } // IsWindowsClient returns true if the SKU is client +// @engine maintainers - this function should not be removed or modified as it +// is used to enforce licensing restrictions on Windows. func IsWindowsClient() bool { osviex := &osVersionInfoEx{OSVersionInfoSize: 284} r1, _, err := procGetVersionExW.Call(uintptr(unsafe.Pointer(osviex))) @@ -64,6 +67,22 @@ func IsWindowsClient() bool { return osviex.ProductType == verNTWorkstation } +// IsIoTCore returns true if the currently running image is based off of +// Windows 10 IoT Core. +// @engine maintainers - this function should not be removed or modified as it +// is used to enforce licensing restrictions on Windows. +func IsIoTCore() bool { + var returnedProductType uint32 + r1, _, err := procGetProductInfo.Call(6, 1, 0, 0, uintptr(unsafe.Pointer(&returnedProductType))) + if r1 == 0 { + logrus.Warnf("GetProductInfo failed - assuming this is not IoT: %v", err) + return false + } + const productIoTUAP = 0x0000007B + const productIoTUAPCommercial = 0x00000083 + return returnedProductType == productIoTUAP || returnedProductType == productIoTUAPCommercial +} + // Unmount is a platform-specific helper function to call // the unmount syscall. Not supported on Windows func Unmount(dest string) error { @@ -74,20 +93,20 @@ func Unmount(dest string) error { func CommandLineToArgv(commandLine string) ([]string, error) { var argc int32 - argsPtr, err := syscall.UTF16PtrFromString(commandLine) + argsPtr, err := windows.UTF16PtrFromString(commandLine) if err != nil { return nil, err } - argv, err := syscall.CommandLineToArgv(argsPtr, &argc) + argv, err := windows.CommandLineToArgv(argsPtr, &argc) if err != nil { return nil, err } - defer syscall.LocalFree(syscall.Handle(uintptr(unsafe.Pointer(argv)))) + defer windows.LocalFree(windows.Handle(uintptr(unsafe.Pointer(argv)))) newArgs := make([]string, argc) for i, v := range (*argv)[:argc] { - newArgs[i] = string(syscall.UTF16ToString((*v)[:])) + newArgs[i] = string(windows.UTF16ToString((*v)[:])) } return newArgs, nil diff --git a/pkg/system/umask.go b/pkg/system/umask.go index 3d0146b01a..5a10eda5af 100644 --- a/pkg/system/umask.go +++ b/pkg/system/umask.go @@ -3,11 +3,11 @@ package system import ( - "syscall" + "golang.org/x/sys/unix" ) // Umask sets current process's file mode creation mask to newmask // and returns oldmask. func Umask(newmask int) (oldmask int, err error) { - return syscall.Umask(newmask), nil + return unix.Umask(newmask), nil } diff --git a/pkg/system/utimes_darwin.go b/pkg/system/utimes_darwin.go deleted file mode 100644 index 0a16197544..0000000000 --- a/pkg/system/utimes_darwin.go +++ /dev/null @@ -1,8 +0,0 @@ -package system - -import "syscall" - -// LUtimesNano is not supported by darwin platform. -func LUtimesNano(path string, ts []syscall.Timespec) error { - return ErrNotSupportedPlatform -} diff --git a/pkg/system/utimes_freebsd.go b/pkg/system/utimes_freebsd.go index e2eac3b553..6a77524376 100644 --- a/pkg/system/utimes_freebsd.go +++ b/pkg/system/utimes_freebsd.go @@ -3,18 +3,20 @@ package system import ( "syscall" "unsafe" + + "golang.org/x/sys/unix" ) // LUtimesNano is used to change access and modification time of the specified path. -// It's used for symbol link file because syscall.UtimesNano doesn't support a NOFOLLOW flag atm. +// It's used for symbol link file because unix.UtimesNano doesn't support a NOFOLLOW flag atm. func LUtimesNano(path string, ts []syscall.Timespec) error { var _path *byte - _path, err := syscall.BytePtrFromString(path) + _path, err := unix.BytePtrFromString(path) if err != nil { return err } - if _, _, err := syscall.Syscall(syscall.SYS_LUTIMES, uintptr(unsafe.Pointer(_path)), uintptr(unsafe.Pointer(&ts[0])), 0); err != 0 && err != syscall.ENOSYS { + if _, _, err := unix.Syscall(unix.SYS_LUTIMES, uintptr(unsafe.Pointer(_path)), uintptr(unsafe.Pointer(&ts[0])), 0); err != 0 && err != unix.ENOSYS { return err } diff --git a/pkg/system/utimes_linux.go b/pkg/system/utimes_linux.go index fc8a1aba95..edc588a63f 100644 --- a/pkg/system/utimes_linux.go +++ b/pkg/system/utimes_linux.go @@ -3,22 +3,21 @@ package system import ( "syscall" "unsafe" + + "golang.org/x/sys/unix" ) // LUtimesNano is used to change access and modification time of the specified path. -// It's used for symbol link file because syscall.UtimesNano doesn't support a NOFOLLOW flag atm. +// It's used for symbol link file because unix.UtimesNano doesn't support a NOFOLLOW flag atm. func LUtimesNano(path string, ts []syscall.Timespec) error { - // These are not currently available in syscall - atFdCwd := -100 - atSymLinkNoFollow := 0x100 + atFdCwd := unix.AT_FDCWD var _path *byte - _path, err := syscall.BytePtrFromString(path) + _path, err := unix.BytePtrFromString(path) if err != nil { return err } - - if _, _, err := syscall.Syscall6(syscall.SYS_UTIMENSAT, uintptr(atFdCwd), uintptr(unsafe.Pointer(_path)), uintptr(unsafe.Pointer(&ts[0])), uintptr(atSymLinkNoFollow), 0, 0); err != 0 && err != syscall.ENOSYS { + if _, _, err := unix.Syscall6(unix.SYS_UTIMENSAT, uintptr(atFdCwd), uintptr(unsafe.Pointer(_path)), uintptr(unsafe.Pointer(&ts[0])), unix.AT_SYMLINK_NOFOLLOW, 0, 0); err != 0 && err != unix.ENOSYS { return err } diff --git a/pkg/system/utimes_unix_test.go b/pkg/system/utimes_unix_test.go index 1ee0d099f9..af6a544245 100644 --- a/pkg/system/utimes_unix_test.go +++ b/pkg/system/utimes_unix_test.go @@ -12,7 +12,7 @@ import ( // prepareFiles creates files for testing in the temp directory func prepareFiles(t *testing.T) (string, string, string, string) { - dir, err := ioutil.TempDir("", "docker-system-test") + dir, err := ioutil.TempDir("", "storage-system-test") if err != nil { t.Fatal(err) } @@ -41,7 +41,7 @@ func TestLUtimesNano(t *testing.T) { t.Fatal(err) } - ts := []syscall.Timespec{{0, 0}, {0, 0}} + ts := []syscall.Timespec{{Sec: 0, Nsec: 0}, {Sec: 0, Nsec: 0}} if err := LUtimesNano(symlink, ts); err != nil { t.Fatal(err) } diff --git a/pkg/system/utimes_unsupported.go b/pkg/system/utimes_unsupported.go index 50c3a04364..139714544d 100644 --- a/pkg/system/utimes_unsupported.go +++ b/pkg/system/utimes_unsupported.go @@ -1,10 +1,10 @@ -// +build !linux,!freebsd,!darwin +// +build !linux,!freebsd package system import "syscall" -// LUtimesNano is not supported on platforms other than linux, freebsd and darwin. +// LUtimesNano is only supported on linux and freebsd. func LUtimesNano(path string, ts []syscall.Timespec) error { return ErrNotSupportedPlatform } diff --git a/pkg/system/xattrs_linux.go b/pkg/system/xattrs_linux.go index d2e2c05799..98b111be42 100644 --- a/pkg/system/xattrs_linux.go +++ b/pkg/system/xattrs_linux.go @@ -1,63 +1,29 @@ package system -import ( - "syscall" - "unsafe" -) +import "golang.org/x/sys/unix" // Lgetxattr retrieves the value of the extended attribute identified by attr // and associated with the given path in the file system. // It will returns a nil slice and nil error if the xattr is not set. func Lgetxattr(path string, attr string) ([]byte, error) { - pathBytes, err := syscall.BytePtrFromString(path) - if err != nil { - return nil, err - } - attrBytes, err := syscall.BytePtrFromString(attr) - if err != nil { - return nil, err - } - dest := make([]byte, 128) - destBytes := unsafe.Pointer(&dest[0]) - sz, _, errno := syscall.Syscall6(syscall.SYS_LGETXATTR, uintptr(unsafe.Pointer(pathBytes)), uintptr(unsafe.Pointer(attrBytes)), uintptr(destBytes), uintptr(len(dest)), 0, 0) - if errno == syscall.ENODATA { + sz, errno := unix.Lgetxattr(path, attr, dest) + if errno == unix.ENODATA { return nil, nil } - if errno == syscall.ERANGE { + if errno == unix.ERANGE { dest = make([]byte, sz) - destBytes := unsafe.Pointer(&dest[0]) - sz, _, errno = syscall.Syscall6(syscall.SYS_LGETXATTR, uintptr(unsafe.Pointer(pathBytes)), uintptr(unsafe.Pointer(attrBytes)), uintptr(destBytes), uintptr(len(dest)), 0, 0) + sz, errno = unix.Lgetxattr(path, attr, dest) } - if errno != 0 { + if errno != nil { return nil, errno } return dest[:sz], nil } -var _zero uintptr - // Lsetxattr sets the value of the extended attribute identified by attr // and associated with the given path in the file system. func Lsetxattr(path string, attr string, data []byte, flags int) error { - pathBytes, err := syscall.BytePtrFromString(path) - if err != nil { - return err - } - attrBytes, err := syscall.BytePtrFromString(attr) - if err != nil { - return err - } - var dataBytes unsafe.Pointer - if len(data) > 0 { - dataBytes = unsafe.Pointer(&data[0]) - } else { - dataBytes = unsafe.Pointer(&_zero) - } - _, _, errno := syscall.Syscall6(syscall.SYS_LSETXATTR, uintptr(unsafe.Pointer(pathBytes)), uintptr(unsafe.Pointer(attrBytes)), uintptr(dataBytes), uintptr(len(data)), uintptr(flags), 0) - if errno != 0 { - return errno - } - return nil + return unix.Lsetxattr(path, attr, data, flags) } diff --git a/pkg/testutil/assert/assert.go b/pkg/testutil/assert/assert.go deleted file mode 100644 index 5b0dcce67a..0000000000 --- a/pkg/testutil/assert/assert.go +++ /dev/null @@ -1,70 +0,0 @@ -// Package assert contains functions for making assertions in unit tests -package assert - -import ( - "fmt" - "path/filepath" - "runtime" - "strings" -) - -// TestingT is an interface which defines the methods of testing.T that are -// required by this package -type TestingT interface { - Fatalf(string, ...interface{}) -} - -// Equal compare the actual value to the expected value and fails the test if -// they are not equal. -func Equal(t TestingT, actual, expected interface{}) { - if expected != actual { - fatal(t, fmt.Sprintf("Expected '%v' (%T) got '%v' (%T)", expected, expected, actual, actual)) - } -} - -//EqualStringSlice compares two slices and fails the test if they do not contain -// the same items. -func EqualStringSlice(t TestingT, actual, expected []string) { - if len(actual) != len(expected) { - t.Fatalf("Expected (length %d): %q\nActual (length %d): %q", - len(expected), expected, len(actual), actual) - } - for i, item := range actual { - if item != expected[i] { - t.Fatalf("Slices differ at element %d, expected %q got %q", - i, expected[i], item) - } - } -} - -// NilError asserts that the error is nil, otherwise it fails the test. -func NilError(t TestingT, err error) { - if err != nil { - fatal(t, fmt.Sprintf("Expected no error, got: %s", err.Error())) - } -} - -// Error asserts that error is not nil, and contains the expected text, -// otherwise it fails the test. -func Error(t TestingT, err error, contains string) { - if err == nil { - fatal(t, "Expected an error, but error was nil") - } - - if !strings.Contains(err.Error(), contains) { - fatal(t, fmt.Sprintf("Expected error to contain '%s', got '%s'", contains, err.Error())) - } -} - -// Contains asserts that the string contains a substring, otherwise it fails the -// test. -func Contains(t TestingT, actual, contains string) { - if !strings.Contains(actual, contains) { - fatal(t, fmt.Sprintf("Expected '%s' to contain '%s'", actual, contains)) - } -} - -func fatal(t TestingT, msg string) { - _, file, line, _ := runtime.Caller(2) - t.Fatalf("%s:%d: %s", filepath.Base(file), line, msg) -} diff --git a/store.go b/store.go index 7635e667c4..0baabb6f8d 100644 --- a/store.go +++ b/store.go @@ -170,7 +170,7 @@ type Store interface { // if reexec.Init { // return // } - PutLayer(id, parent string, names []string, mountLabel string, writeable bool, diff archive.Reader) (*Layer, int64, error) + PutLayer(id, parent string, names []string, mountLabel string, writeable bool, diff io.Reader) (*Layer, int64, error) // CreateImage creates a new image, optionally with the specified ID // (one will be assigned if none is specified), with optional names, @@ -285,7 +285,7 @@ type Store interface { // if reexec.Init { // return // } - ApplyDiff(to string, diff archive.Reader) (int64, error) + ApplyDiff(to string, diff io.Reader) (int64, error) // LayersByCompressedDigest returns a slice of the layers with the // specified compressed digest value recorded for them. @@ -716,7 +716,7 @@ func (s *store) ContainerStore() (ContainerStore, error) { return nil, ErrLoadError } -func (s *store) PutLayer(id, parent string, names []string, mountLabel string, writeable bool, diff archive.Reader) (*Layer, int64, error) { +func (s *store) PutLayer(id, parent string, names []string, mountLabel string, writeable bool, diff io.Reader) (*Layer, int64, error) { rlstore, err := s.LayerStore() if err != nil { return nil, -1, err @@ -1786,7 +1786,7 @@ func (s *store) Diff(from, to string, options *DiffOptions) (io.ReadCloser, erro return nil, ErrLayerUnknown } -func (s *store) ApplyDiff(to string, diff archive.Reader) (int64, error) { +func (s *store) ApplyDiff(to string, diff io.Reader) (int64, error) { rlstore, err := s.LayerStore() if err != nil { return -1, err diff --git a/vendor.conf b/vendor.conf index d09e9bbc6d..27eeee2cc6 100644 --- a/vendor.conf +++ b/vendor.conf @@ -17,5 +17,5 @@ github.com/sirupsen/logrus v1.0.0 github.com/stretchr/testify 4d4bfba8f1d1027c4fdbe371823030df51419987 github.com/tchap/go-patricia v2.2.6 github.com/vbatts/tar-split bd4c5d64c3e9297f410025a3b1bd0c58f659e721 -golang.org/x/net f2499483f923065a842d38eb4c7f1927e6fc6e6d +golang.org/x/net 7dcfb8076726a3fdd9353b6b8a1f1b6be6811bd6 golang.org/x/sys 07c182904dbd53199946ba614a412c61d3c548f5 diff --git a/vendor/golang.org/x/net/README b/vendor/golang.org/x/net/README new file mode 100644 index 0000000000..6b13d8e505 --- /dev/null +++ b/vendor/golang.org/x/net/README @@ -0,0 +1,3 @@ +This repository holds supplementary Go networking libraries. + +To submit changes to this repository, see http://golang.org/doc/contribute.html. diff --git a/vendor/golang.org/x/net/context/context.go b/vendor/golang.org/x/net/context/context.go new file mode 100644 index 0000000000..f143ed6a1e --- /dev/null +++ b/vendor/golang.org/x/net/context/context.go @@ -0,0 +1,156 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package context defines the Context type, which carries deadlines, +// cancelation signals, and other request-scoped values across API boundaries +// and between processes. +// +// Incoming requests to a server should create a Context, and outgoing calls to +// servers should accept a Context. The chain of function calls between must +// propagate the Context, optionally replacing it with a modified copy created +// using WithDeadline, WithTimeout, WithCancel, or WithValue. +// +// Programs that use Contexts should follow these rules to keep interfaces +// consistent across packages and enable static analysis tools to check context +// propagation: +// +// Do not store Contexts inside a struct type; instead, pass a Context +// explicitly to each function that needs it. The Context should be the first +// parameter, typically named ctx: +// +// func DoSomething(ctx context.Context, arg Arg) error { +// // ... use ctx ... +// } +// +// Do not pass a nil Context, even if a function permits it. Pass context.TODO +// if you are unsure about which Context to use. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +// +// The same Context may be passed to functions running in different goroutines; +// Contexts are safe for simultaneous use by multiple goroutines. +// +// See http://blog.golang.org/context for example code for a server that uses +// Contexts. +package context // import "golang.org/x/net/context" + +import "time" + +// A Context carries a deadline, a cancelation signal, and other values across +// API boundaries. +// +// Context's methods may be called by multiple goroutines simultaneously. +type Context interface { + // Deadline returns the time when work done on behalf of this context + // should be canceled. Deadline returns ok==false when no deadline is + // set. Successive calls to Deadline return the same results. + Deadline() (deadline time.Time, ok bool) + + // Done returns a channel that's closed when work done on behalf of this + // context should be canceled. Done may return nil if this context can + // never be canceled. Successive calls to Done return the same value. + // + // WithCancel arranges for Done to be closed when cancel is called; + // WithDeadline arranges for Done to be closed when the deadline + // expires; WithTimeout arranges for Done to be closed when the timeout + // elapses. + // + // Done is provided for use in select statements: + // + // // Stream generates values with DoSomething and sends them to out + // // until DoSomething returns an error or ctx.Done is closed. + // func Stream(ctx context.Context, out chan<- Value) error { + // for { + // v, err := DoSomething(ctx) + // if err != nil { + // return err + // } + // select { + // case <-ctx.Done(): + // return ctx.Err() + // case out <- v: + // } + // } + // } + // + // See http://blog.golang.org/pipelines for more examples of how to use + // a Done channel for cancelation. + Done() <-chan struct{} + + // Err returns a non-nil error value after Done is closed. Err returns + // Canceled if the context was canceled or DeadlineExceeded if the + // context's deadline passed. No other values for Err are defined. + // After Done is closed, successive calls to Err return the same value. + Err() error + + // Value returns the value associated with this context for key, or nil + // if no value is associated with key. Successive calls to Value with + // the same key returns the same result. + // + // Use context values only for request-scoped data that transits + // processes and API boundaries, not for passing optional parameters to + // functions. + // + // A key identifies a specific value in a Context. Functions that wish + // to store values in Context typically allocate a key in a global + // variable then use that key as the argument to context.WithValue and + // Context.Value. A key can be any type that supports equality; + // packages should define keys as an unexported type to avoid + // collisions. + // + // Packages that define a Context key should provide type-safe accessors + // for the values stores using that key: + // + // // Package user defines a User type that's stored in Contexts. + // package user + // + // import "golang.org/x/net/context" + // + // // User is the type of value stored in the Contexts. + // type User struct {...} + // + // // key is an unexported type for keys defined in this package. + // // This prevents collisions with keys defined in other packages. + // type key int + // + // // userKey is the key for user.User values in Contexts. It is + // // unexported; clients use user.NewContext and user.FromContext + // // instead of using this key directly. + // var userKey key = 0 + // + // // NewContext returns a new Context that carries value u. + // func NewContext(ctx context.Context, u *User) context.Context { + // return context.WithValue(ctx, userKey, u) + // } + // + // // FromContext returns the User value stored in ctx, if any. + // func FromContext(ctx context.Context) (*User, bool) { + // u, ok := ctx.Value(userKey).(*User) + // return u, ok + // } + Value(key interface{}) interface{} +} + +// Background returns a non-nil, empty Context. It is never canceled, has no +// values, and has no deadline. It is typically used by the main function, +// initialization, and tests, and as the top-level Context for incoming +// requests. +func Background() Context { + return background +} + +// TODO returns a non-nil, empty Context. Code should use context.TODO when +// it's unclear which Context to use or it is not yet available (because the +// surrounding function has not yet been extended to accept a Context +// parameter). TODO is recognized by static analysis tools that determine +// whether Contexts are propagated correctly in a program. +func TODO() Context { + return todo +} + +// A CancelFunc tells an operation to abandon its work. +// A CancelFunc does not wait for the work to stop. +// After the first call, subsequent calls to a CancelFunc do nothing. +type CancelFunc func() diff --git a/vendor/golang.org/x/net/context/go17.go b/vendor/golang.org/x/net/context/go17.go new file mode 100644 index 0000000000..d20f52b7de --- /dev/null +++ b/vendor/golang.org/x/net/context/go17.go @@ -0,0 +1,72 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.7 + +package context + +import ( + "context" // standard library's context, as of Go 1.7 + "time" +) + +var ( + todo = context.TODO() + background = context.Background() +) + +// Canceled is the error returned by Context.Err when the context is canceled. +var Canceled = context.Canceled + +// DeadlineExceeded is the error returned by Context.Err when the context's +// deadline passes. +var DeadlineExceeded = context.DeadlineExceeded + +// WithCancel returns a copy of parent with a new Done channel. The returned +// context's Done channel is closed when the returned cancel function is called +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + ctx, f := context.WithCancel(parent) + return ctx, CancelFunc(f) +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted +// to be no later than d. If the parent's deadline is already earlier than d, +// WithDeadline(parent, d) is semantically equivalent to parent. The returned +// context's Done channel is closed when the deadline expires, when the returned +// cancel function is called, or when the parent context's Done channel is +// closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { + ctx, f := context.WithDeadline(parent, deadline) + return ctx, CancelFunc(f) +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +func WithValue(parent Context, key interface{}, val interface{}) Context { + return context.WithValue(parent, key, val) +} diff --git a/vendor/golang.org/x/net/context/pre_go17.go b/vendor/golang.org/x/net/context/pre_go17.go new file mode 100644 index 0000000000..0f35592df5 --- /dev/null +++ b/vendor/golang.org/x/net/context/pre_go17.go @@ -0,0 +1,300 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.7 + +package context + +import ( + "errors" + "fmt" + "sync" + "time" +) + +// An emptyCtx is never canceled, has no values, and has no deadline. It is not +// struct{}, since vars of this type must have distinct addresses. +type emptyCtx int + +func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*emptyCtx) Done() <-chan struct{} { + return nil +} + +func (*emptyCtx) Err() error { + return nil +} + +func (*emptyCtx) Value(key interface{}) interface{} { + return nil +} + +func (e *emptyCtx) String() string { + switch e { + case background: + return "context.Background" + case todo: + return "context.TODO" + } + return "unknown empty Context" +} + +var ( + background = new(emptyCtx) + todo = new(emptyCtx) +) + +// Canceled is the error returned by Context.Err when the context is canceled. +var Canceled = errors.New("context canceled") + +// DeadlineExceeded is the error returned by Context.Err when the context's +// deadline passes. +var DeadlineExceeded = errors.New("context deadline exceeded") + +// WithCancel returns a copy of parent with a new Done channel. The returned +// context's Done channel is closed when the returned cancel function is called +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + c := newCancelCtx(parent) + propagateCancel(parent, c) + return c, func() { c.cancel(true, Canceled) } +} + +// newCancelCtx returns an initialized cancelCtx. +func newCancelCtx(parent Context) *cancelCtx { + return &cancelCtx{ + Context: parent, + done: make(chan struct{}), + } +} + +// propagateCancel arranges for child to be canceled when parent is. +func propagateCancel(parent Context, child canceler) { + if parent.Done() == nil { + return // parent is never canceled + } + if p, ok := parentCancelCtx(parent); ok { + p.mu.Lock() + if p.err != nil { + // parent has already been canceled + child.cancel(false, p.err) + } else { + if p.children == nil { + p.children = make(map[canceler]bool) + } + p.children[child] = true + } + p.mu.Unlock() + } else { + go func() { + select { + case <-parent.Done(): + child.cancel(false, parent.Err()) + case <-child.Done(): + } + }() + } +} + +// parentCancelCtx follows a chain of parent references until it finds a +// *cancelCtx. This function understands how each of the concrete types in this +// package represents its parent. +func parentCancelCtx(parent Context) (*cancelCtx, bool) { + for { + switch c := parent.(type) { + case *cancelCtx: + return c, true + case *timerCtx: + return c.cancelCtx, true + case *valueCtx: + parent = c.Context + default: + return nil, false + } + } +} + +// removeChild removes a context from its parent. +func removeChild(parent Context, child canceler) { + p, ok := parentCancelCtx(parent) + if !ok { + return + } + p.mu.Lock() + if p.children != nil { + delete(p.children, child) + } + p.mu.Unlock() +} + +// A canceler is a context type that can be canceled directly. The +// implementations are *cancelCtx and *timerCtx. +type canceler interface { + cancel(removeFromParent bool, err error) + Done() <-chan struct{} +} + +// A cancelCtx can be canceled. When canceled, it also cancels any children +// that implement canceler. +type cancelCtx struct { + Context + + done chan struct{} // closed by the first cancel call. + + mu sync.Mutex + children map[canceler]bool // set to nil by the first cancel call + err error // set to non-nil by the first cancel call +} + +func (c *cancelCtx) Done() <-chan struct{} { + return c.done +} + +func (c *cancelCtx) Err() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.err +} + +func (c *cancelCtx) String() string { + return fmt.Sprintf("%v.WithCancel", c.Context) +} + +// cancel closes c.done, cancels each of c's children, and, if +// removeFromParent is true, removes c from its parent's children. +func (c *cancelCtx) cancel(removeFromParent bool, err error) { + if err == nil { + panic("context: internal error: missing cancel error") + } + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return // already canceled + } + c.err = err + close(c.done) + for child := range c.children { + // NOTE: acquiring the child's lock while holding parent's lock. + child.cancel(false, err) + } + c.children = nil + c.mu.Unlock() + + if removeFromParent { + removeChild(c.Context, c) + } +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted +// to be no later than d. If the parent's deadline is already earlier than d, +// WithDeadline(parent, d) is semantically equivalent to parent. The returned +// context's Done channel is closed when the deadline expires, when the returned +// cancel function is called, or when the parent context's Done channel is +// closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { + if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { + // The current deadline is already sooner than the new one. + return WithCancel(parent) + } + c := &timerCtx{ + cancelCtx: newCancelCtx(parent), + deadline: deadline, + } + propagateCancel(parent, c) + d := deadline.Sub(time.Now()) + if d <= 0 { + c.cancel(true, DeadlineExceeded) // deadline has already passed + return c, func() { c.cancel(true, Canceled) } + } + c.mu.Lock() + defer c.mu.Unlock() + if c.err == nil { + c.timer = time.AfterFunc(d, func() { + c.cancel(true, DeadlineExceeded) + }) + } + return c, func() { c.cancel(true, Canceled) } +} + +// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to +// implement Done and Err. It implements cancel by stopping its timer then +// delegating to cancelCtx.cancel. +type timerCtx struct { + *cancelCtx + timer *time.Timer // Under cancelCtx.mu. + + deadline time.Time +} + +func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { + return c.deadline, true +} + +func (c *timerCtx) String() string { + return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) +} + +func (c *timerCtx) cancel(removeFromParent bool, err error) { + c.cancelCtx.cancel(false, err) + if removeFromParent { + // Remove this timerCtx from its parent cancelCtx's children. + removeChild(c.cancelCtx.Context, c) + } + c.mu.Lock() + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } + c.mu.Unlock() +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +func WithValue(parent Context, key interface{}, val interface{}) Context { + return &valueCtx{parent, key, val} +} + +// A valueCtx carries a key-value pair. It implements Value for that key and +// delegates all other calls to the embedded Context. +type valueCtx struct { + Context + key, val interface{} +} + +func (c *valueCtx) String() string { + return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) +} + +func (c *valueCtx) Value(key interface{}) interface{} { + if c.key == key { + return c.val + } + return c.Context.Value(key) +} diff --git a/vendor/golang.org/x/net/proxy/socks5.go b/vendor/golang.org/x/net/proxy/socks5.go index 9b9628239a..973f57f197 100644 --- a/vendor/golang.org/x/net/proxy/socks5.go +++ b/vendor/golang.org/x/net/proxy/socks5.go @@ -72,24 +72,28 @@ func (s *socks5) Dial(network, addr string) (net.Conn, error) { if err != nil { return nil, err } - closeConn := &conn - defer func() { - if closeConn != nil { - (*closeConn).Close() - } - }() + if err := s.connect(conn, addr); err != nil { + conn.Close() + return nil, err + } + return conn, nil +} - host, portStr, err := net.SplitHostPort(addr) +// connect takes an existing connection to a socks5 proxy server, +// and commands the server to extend that connection to target, +// which must be a canonical address with a host and port. +func (s *socks5) connect(conn net.Conn, target string) error { + host, portStr, err := net.SplitHostPort(target) if err != nil { - return nil, err + return err } port, err := strconv.Atoi(portStr) if err != nil { - return nil, errors.New("proxy: failed to parse port number: " + portStr) + return errors.New("proxy: failed to parse port number: " + portStr) } if port < 1 || port > 0xffff { - return nil, errors.New("proxy: port number out of range: " + portStr) + return errors.New("proxy: port number out of range: " + portStr) } // the size here is just an estimate @@ -103,17 +107,17 @@ func (s *socks5) Dial(network, addr string) (net.Conn, error) { } if _, err := conn.Write(buf); err != nil { - return nil, errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error()) } if _, err := io.ReadFull(conn, buf[:2]); err != nil { - return nil, errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error()) } if buf[0] != 5 { - return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0]))) + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0]))) } if buf[1] == 0xff { - return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication") + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication") } if buf[1] == socks5AuthPassword { @@ -125,15 +129,15 @@ func (s *socks5) Dial(network, addr string) (net.Conn, error) { buf = append(buf, s.password...) if _, err := conn.Write(buf); err != nil { - return nil, errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) } if _, err := io.ReadFull(conn, buf[:2]); err != nil { - return nil, errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) } if buf[1] != 0 { - return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password") + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password") } } @@ -150,7 +154,7 @@ func (s *socks5) Dial(network, addr string) (net.Conn, error) { buf = append(buf, ip...) } else { if len(host) > 255 { - return nil, errors.New("proxy: destination hostname too long: " + host) + return errors.New("proxy: destination hostname too long: " + host) } buf = append(buf, socks5Domain) buf = append(buf, byte(len(host))) @@ -159,11 +163,11 @@ func (s *socks5) Dial(network, addr string) (net.Conn, error) { buf = append(buf, byte(port>>8), byte(port)) if _, err := conn.Write(buf); err != nil { - return nil, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) } if _, err := io.ReadFull(conn, buf[:4]); err != nil { - return nil, errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) } failure := "unknown error" @@ -172,7 +176,7 @@ func (s *socks5) Dial(network, addr string) (net.Conn, error) { } if len(failure) > 0 { - return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure) + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure) } bytesToDiscard := 0 @@ -184,11 +188,11 @@ func (s *socks5) Dial(network, addr string) (net.Conn, error) { case socks5Domain: _, err := io.ReadFull(conn, buf[:1]) if err != nil { - return nil, errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error()) } bytesToDiscard = int(buf[0]) default: - return nil, errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr) + return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr) } if cap(buf) < bytesToDiscard { @@ -197,14 +201,13 @@ func (s *socks5) Dial(network, addr string) (net.Conn, error) { buf = buf[:bytesToDiscard] } if _, err := io.ReadFull(conn, buf); err != nil { - return nil, errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error()) } // Also need to discard the port number if _, err := io.ReadFull(conn, buf[:2]); err != nil { - return nil, errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error()) } - closeConn = nil - return conn, nil + return nil }