Skip to content

Commit

Permalink
Add nvpci.Interface to the nvmdev struct to aid in unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
  • Loading branch information
cdesiniotis committed Jul 15, 2024
1 parent 44a5440 commit 7374add
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 11 deletions.
28 changes: 25 additions & 3 deletions pkg/nvmdev/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
// MockNvmdev mock mdev device.
type MockNvmdev struct {
*nvmdev
pciDevicesRoot string
}

var _ Interface = (*MockNvmdev)(nil)
Expand All @@ -53,8 +54,24 @@ func NewMock() (mock *MockNvmdev, rerr error) {
}
}()

pciRootDir, err := os.MkdirTemp(os.TempDir(), "")
if err != nil {
return nil, err
}
defer func() {
if rerr != nil {
os.RemoveAll(pciRootDir)
}
}()

nvpciLib := nvpci.New(nvpci.WithPCIDevicesRoot(pciRootDir))
mock = &MockNvmdev{
&nvmdev{mdevParentsRootDir, mdevDevicesRootDir},
nvmdev: &nvmdev{
mdevParentsRoot: mdevParentsRootDir,
mdevDevicesRoot: mdevDevicesRootDir,
nvpci: nvpciLib,
},
pciDevicesRoot: pciRootDir,
}

return mock, nil
Expand All @@ -64,16 +81,21 @@ func NewMock() (mock *MockNvmdev, rerr error) {
func (m *MockNvmdev) Cleanup() {
os.RemoveAll(m.mdevParentsRoot)
os.RemoveAll(m.mdevDevicesRoot)
os.RemoveAll(m.pciDevicesRoot)
}

// AddMockA100Parent creates an A100 like parent GPU mock device.
func (m *MockNvmdev) AddMockA100Parent(address string, numaNode int) error {
deviceDir := filepath.Join(m.mdevParentsRoot, address)
err := os.MkdirAll(deviceDir, 0755)
pciDeviceDir := filepath.Join(m.pciDevicesRoot, address)
err := os.MkdirAll(pciDeviceDir, 0755)
if err != nil {
return err
}

// /sys/class/mdev_bus/<address> is a symlink to /sys/bus/pci/devices/<address>
deviceDir := filepath.Join(m.mdevParentsRoot, address)
os.Symlink(pciDeviceDir, deviceDir)

Check failure on line 97 in pkg/nvmdev/mock.go

View workflow job for this annotation

GitHub Actions / check

Error return value of `os.Symlink` is not checked (errcheck)

vendor, err := os.Create(filepath.Join(deviceDir, "vendor"))
if err != nil {
return err
Expand Down
34 changes: 26 additions & 8 deletions pkg/nvmdev/nvmdev.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type Interface interface {
type nvmdev struct {
mdevParentsRoot string
mdevDevicesRoot string
nvpci nvpci.Interface
}

var _ Interface = (*nvmdev)(nil)
Expand All @@ -63,8 +64,25 @@ type Device struct {
}

// New interface that allows us to get a list of all NVIDIA parent and MDEV (vGPU) devices.
func New() Interface {
return &nvmdev{mdevParentsRoot, mdevDevicesRoot}
func New(opts ...Option) Interface {
n := &nvmdev{mdevParentsRoot: mdevParentsRoot, mdevDevicesRoot: mdevDevicesRoot}
for _, opt := range opts {
opt(n)
}
if n.nvpci == nil {
n.nvpci = nvpci.New()
}
return n
}

// Option defines a function for passing options to the New() call.
type Option func(*nvmdev)

// WithNvpciLib provides an Option to set the nvpci library.
func WithNvpciLib(nvpciLib nvpci.Interface) Option {
return func(n *nvmdev) {
n.nvpci = nvpciLib
}
}

// GetAllParentDevices returns all NVIDIA Parent PCI devices on the system.
Expand All @@ -77,7 +95,7 @@ func (m *nvmdev) GetAllParentDevices() ([]*ParentDevice, error) {
var nvdevices []*ParentDevice
for _, deviceDir := range deviceDirs {
devicePath := path.Join(m.mdevParentsRoot, deviceDir.Name())
nvdevice, err := NewParentDevice(devicePath)
nvdevice, err := m.NewParentDevice(devicePath)
if err != nil {
return nil, fmt.Errorf("error constructing NVIDIA parent device: %v", err)
}
Expand Down Expand Up @@ -110,7 +128,7 @@ func (m *nvmdev) GetAllDevices() ([]*Device, error) {

var nvdevices []*Device
for _, deviceDir := range deviceDirs {
nvdevice, err := NewDevice(m.mdevDevicesRoot, deviceDir.Name())
nvdevice, err := m.NewDevice(m.mdevDevicesRoot, deviceDir.Name())
if err != nil {
return nil, fmt.Errorf("error constructing MDEV device: %v", err)
}
Expand All @@ -124,15 +142,15 @@ func (m *nvmdev) GetAllDevices() ([]*Device, error) {
}

// NewDevice constructs a Device, which represents an NVIDIA mdev (vGPU) device.
func NewDevice(root string, uuid string) (*Device, error) {
func (n *nvmdev) NewDevice(root string, uuid string) (*Device, error) {
path := path.Join(root, uuid)

m, err := newMdev(path)
if err != nil {
return nil, err
}

parent, err := NewParentDevice(m.parentDevicePath())
parent, err := n.NewParentDevice(m.parentDevicePath())
if err != nil {
return nil, fmt.Errorf("error constructing NVIDIA PCI device: %v", err)
}
Expand Down Expand Up @@ -241,9 +259,9 @@ func (m mdev) iommuGroup() (int, error) {
}

// NewParentDevice constructs a ParentDevice.
func NewParentDevice(devicePath string) (*ParentDevice, error) {
func (m *nvmdev) NewParentDevice(devicePath string) (*ParentDevice, error) {
address := filepath.Base(devicePath)
nvdevice, err := nvpci.New().GetGPUByPciBusID(address)
nvdevice, err := m.nvpci.GetGPUByPciBusID(address)
if err != nil {
return nil, fmt.Errorf("failed to construct NVIDIA PCI device: %v", err)
}
Expand Down

0 comments on commit 7374add

Please sign in to comment.