diff --git a/internal/discover/char_devices.go b/internal/discover/char_devices.go index b7503fcc..83f9d843 100644 --- a/internal/discover/char_devices.go +++ b/internal/discover/char_devices.go @@ -30,16 +30,14 @@ var _ Discover = (*charDevices)(nil) func NewCharDeviceDiscoverer(logger *logrus.Logger, devices []string, root string) Discover { locator := lookup.NewCharDeviceLocator(logger, root) - return NewDeviceDiscoverer(logger, locator, devices) + return NewDeviceDiscoverer(logger, locator, root, devices) } // NewDeviceDiscoverer creates a discoverer which locates the specified set of device nodes using the specified locator. -func NewDeviceDiscoverer(logger *logrus.Logger, locator lookup.Locator, devices []string) Discover { - return &charDevices{ - logger: logger, - lookup: locator, - required: devices, - } +func NewDeviceDiscoverer(logger *logrus.Logger, locator lookup.Locator, root string, devices []string) Discover { + m := NewMounts(logger, locator, root, devices).(*mounts) + + return (*charDevices)(m) } // Mounts returns the discovered mounts for the charDevices. diff --git a/internal/discover/csv.go b/internal/discover/csv.go index 64e5d25c..6235f44f 100644 --- a/internal/discover/csv.go +++ b/internal/discover/csv.go @@ -52,7 +52,7 @@ func NewFromCSVFiles(logger *logrus.Logger, files []string, root string) (Discov mountSpecs = append(mountSpecs, targets...) } - return newFromMountSpecs(logger, locators, mountSpecs) + return newFromMountSpecs(logger, locators, root, mountSpecs) } // loadCSVFile loads the specified CSV file and returns the list of mount specs @@ -71,7 +71,7 @@ func loadCSVFile(logger *logrus.Logger, filename string) ([]*csv.MountSpec, erro // newFromMountSpecs creates a discoverer for the CSV file. A logger is also supplied. // A list of csvDiscoverers is returned, with each being associated with a single MountSpecType. -func newFromMountSpecs(logger *logrus.Logger, locators map[csv.MountSpecType]lookup.Locator, targets []*csv.MountSpec) (Discover, error) { +func newFromMountSpecs(logger *logrus.Logger, locators map[csv.MountSpecType]lookup.Locator, root string, targets []*csv.MountSpec) (Discover, error) { if len(targets) == 0 { return &None{}, nil } @@ -95,13 +95,9 @@ func newFromMountSpecs(logger *logrus.Logger, locators map[csv.MountSpecType]loo var m Discover switch t { case csv.MountSpecDev: - m = NewDeviceDiscoverer(logger, locator, candidatesByType[t]) + m = NewDeviceDiscoverer(logger, locator, root, candidatesByType[t]) default: - m = &mounts{ - logger: logger, - lookup: locator, - required: candidatesByType[t], - } + m = NewMounts(logger, locator, root, candidatesByType[t]) } discoverers = append(discoverers, m) diff --git a/internal/discover/csv_test.go b/internal/discover/csv_test.go index b46f9509..f3c6dfeb 100644 --- a/internal/discover/csv_test.go +++ b/internal/discover/csv_test.go @@ -36,6 +36,7 @@ func TestNewFromMountSpec(t *testing.T) { testCases := []struct { description string + root string targets []*csv.MountSpec expectedError error expectedDiscoverer Discover @@ -76,12 +77,50 @@ func TestNewFromMountSpec(t *testing.T) { &mounts{ logger: logger, lookup: locators["dev"], + root: "/", required: []string{"dev0", "dev1"}, }, ), &mounts{ logger: logger, lookup: locators["lib"], + root: "/", + required: []string{"lib0"}, + }, + }, + }, + }, + { + description: "sets root", + targets: []*csv.MountSpec{ + { + Type: "dev", + Path: "dev0", + }, + { + Type: "lib", + Path: "lib0", + }, + { + Type: "dev", + Path: "dev1", + }, + }, + root: "/some/root", + expectedDiscoverer: &list{ + discoverers: []Discover{ + (*charDevices)( + &mounts{ + logger: logger, + lookup: locators["dev"], + root: "/some/root", + required: []string{"dev0", "dev1"}, + }, + ), + &mounts{ + logger: logger, + lookup: locators["lib"], + root: "/some/root", required: []string{"lib0"}, }, }, @@ -91,7 +130,7 @@ func TestNewFromMountSpec(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - discoverer, err := newFromMountSpecs(logger, locators, tc.targets) + discoverer, err := newFromMountSpecs(logger, locators, tc.root, tc.targets) if tc.expectedError != nil { require.Error(t, err) return diff --git a/internal/discover/gds.go b/internal/discover/gds.go index fe11aa6d..dab2f735 100644 --- a/internal/discover/gds.go +++ b/internal/discover/gds.go @@ -36,17 +36,19 @@ func NewGDSDiscoverer(logger *logrus.Logger, root string) (Discover, error) { root, ) - udev := &mounts{ - logger: logger, - lookup: lookup.NewDirectoryLocator(logger, root), - required: []string{"/run/udev"}, - } + udev := NewMounts( + logger, + lookup.NewDirectoryLocator(logger, root), + root, + []string{"/run/udev"}, + ) - cufile := &mounts{ - logger: logger, - lookup: lookup.NewFileLocator(logger, root), - required: []string{"/etc/cufile.json"}, - } + cufile := NewMounts( + logger, + lookup.NewFileLocator(logger, root), + root, + []string{"/etc/cufile.json"}, + ) d := gdsDeviceDiscoverer{ logger: logger, diff --git a/internal/discover/mounts.go b/internal/discover/mounts.go index 006c7d56..3c026aea 100644 --- a/internal/discover/mounts.go +++ b/internal/discover/mounts.go @@ -18,6 +18,8 @@ package discover import ( "fmt" + "path/filepath" + "strings" "sync" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" @@ -31,6 +33,7 @@ type mounts struct { None logger *logrus.Logger lookup lookup.Locator + root string required []string sync.Mutex cache []Mount @@ -38,6 +41,16 @@ type mounts struct { var _ Discover = (*mounts)(nil) +// NewMounts creates a discoverer for the required mounts using the specified locator. +func NewMounts(logger *logrus.Logger, lookup lookup.Locator, root string, required []string) Discover { + return &mounts{ + logger: logger, + lookup: lookup, + root: filepath.Join("/", root), + required: required, + } +} + func (d *mounts) Mounts() ([]Mount, error) { if d.lookup == nil { return nil, fmt.Errorf("no lookup defined") @@ -71,11 +84,7 @@ func (d *mounts) Mounts() ([]Mount, error) { continue } - r, err := d.lookup.Relative(p) - if err != nil { - d.logger.Warnf("Failed to get relative path of %v: %v", p, err) - continue - } + r := d.relativeTo(p) if r == "" { r = p } @@ -97,3 +106,12 @@ func (d *mounts) Mounts() ([]Mount, error) { return d.cache, nil } + +// relativeTo returns the path relative to the root for the file locator +func (d *mounts) relativeTo(path string) string { + if d.root == "/" { + return path + } + + return strings.TrimPrefix(path, d.root) +} diff --git a/internal/discover/mounts_test.go b/internal/discover/mounts_test.go index 6d8e3d20..ee5254ac 100644 --- a/internal/discover/mounts_test.go +++ b/internal/discover/mounts_test.go @@ -140,37 +140,14 @@ func TestMounts(t *testing.T) { input: &mounts{ lookup: &lookup.LocatorMock{ LocateFunc: func(s string) ([]string, error) { - return []string{"located"}, nil - }, - RelativeFunc: func(s string) (string, error) { - return "relative", nil + return []string{"/some/root/located"}, nil }, }, + root: "/some/root", required: []string{"required0", "multiple", "required1"}, }, expectedMounts: []Mount{ - {Path: "relative", HostPath: "located"}, - }, - }, - { - description: "mounts skips relative error", - input: &mounts{ - lookup: &lookup.LocatorMock{ - LocateFunc: func(s string) ([]string, error) { - return []string{s}, nil - }, - RelativeFunc: func(s string) (string, error) { - if s == "error" { - return "", fmt.Errorf("no relative path") - } - return "relative" + s, nil - }, - }, - required: []string{"required0", "error", "required1"}, - }, - expectedMounts: []Mount{ - {Path: "relativerequired0", HostPath: "required0"}, - {Path: "relativerequired1", HostPath: "required1"}, + {Path: "/located", HostPath: "/some/root/located"}, }, }, } diff --git a/internal/lookup/file.go b/internal/lookup/file.go index 1283f0ae..ab52d03d 100644 --- a/internal/lookup/file.go +++ b/internal/lookup/file.go @@ -28,7 +28,6 @@ import ( // prefixes. The validity of a file is determined by a filter function. type file struct { logger *log.Logger - root string prefixes []string filter func(string) error } @@ -78,15 +77,6 @@ func (p file) Locate(pattern string) ([]string, error) { return filenames, nil } -// Relative returns the path relative to the root for the file locator -func (p file) Relative(path string) (string, error) { - if p.root == "" || p.root == "/" { - return path, nil - } - - return filepath.Rel(p.root, path) -} - // assertFile checks whether the specified path is a regular file func assertFile(filename string) error { info, err := os.Stat(filename) diff --git a/internal/lookup/locator.go b/internal/lookup/locator.go index 76ade332..871e1b02 100644 --- a/internal/lookup/locator.go +++ b/internal/lookup/locator.go @@ -21,5 +21,4 @@ package lookup // Locator defines the interface for locating files on a system. type Locator interface { Locate(string) ([]string, error) - Relative(string) (string, error) }