Skip to content

Commit

Permalink
Guest agent support for partitions on SCSI devices
Browse files Browse the repository at this point in the history
* Update `ControllerLunToName` to `GetDevicePath` and take in partition
as an additional param
* Wait for partition subdirectory to appear for the devices
* Update device encryption and verity device names with partition index
* Update device encryption and verity device tests
* Add new unit tests for `GetDevicePath`

Signed-off-by: Kathryn Baldauf <kabaldau@microsoft.com>
  • Loading branch information
katiewasnothere committed May 1, 2023
1 parent 3d744da commit b21c405
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 47 deletions.
7 changes: 4 additions & 3 deletions internal/guest/runtime/hcsv2/uvm.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req *
if !mvd.ReadOnly {
localCtx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
source, err := scsi.ControllerLunToName(localCtx, mvd.Controller, mvd.Lun)
source, err := scsi.GetDevicePath(localCtx, mvd.Controller, mvd.Lun, mvd.Partition)
if err != nil {
return err
}
Expand Down Expand Up @@ -980,7 +980,7 @@ func modifyMappedVirtualDisk(
}
}

return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.MountPath,
return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath,
mvd.ReadOnly, mvd.Encrypted, mvd.Options, mvd.VerityInfo)
}
return nil
Expand All @@ -992,7 +992,8 @@ func modifyMappedVirtualDisk(
}
}

if err := scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.MountPath, mvd.Encrypted, mvd.VerityInfo); err != nil {
if err := scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.Partition,
mvd.MountPath, mvd.Encrypted, mvd.VerityInfo); err != nil {
return err
}
}
Expand Down
77 changes: 59 additions & 18 deletions internal/guest/storage/scsi/scsi.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ var (
osRemoveAll = os.RemoveAll
unixMount = unix.Mount

// controllerLunToName is stubbed to make testing `Mount` easier.
controllerLunToName = ControllerLunToName
// mock functions for testing getDevicePath
osReadDir = os.ReadDir
osStat = os.Stat

// getDevicePath is stubbed to make testing `Mount` easier.
getDevicePath = GetDevicePath
// createVerityTarget is stubbed for unit testing `Mount`.
createVerityTarget = dm.CreateVerityTarget
// removeDevice is stubbed for unit testing `Mount`.
Expand All @@ -49,8 +53,8 @@ var (
const (
scsiDevicesPath = "/sys/bus/scsi/devices"
vmbusDevicesPath = "/sys/bus/vmbus/devices"
verityDeviceFmt = "dm-verity-scsi-contr%d-lun%d-%s"
cryptDeviceFmt = "dm-crypt-scsi-contr%d-lun%d"
verityDeviceFmt = "dm-verity-scsi-contr%d-lun%d-p%d-%s"
cryptDeviceFmt = "dm-crypt-scsi-contr%d-lun%d-p%d"
)

// ActualControllerNumber retrieves the actual controller number assigned to a SCSI controller
Expand Down Expand Up @@ -98,6 +102,7 @@ func Mount(
ctx context.Context,
controller,
lun uint8,
partition uint64,
target string,
readonly bool,
encrypted bool,
Expand All @@ -109,9 +114,11 @@ func Mount(

span.AddAttributes(
trace.Int64Attribute("controller", int64(controller)),
trace.Int64Attribute("lun", int64(lun)))
trace.Int64Attribute("lun", int64(lun)),
trace.Int64Attribute("partition", int64(partition)),
)

source, err := controllerLunToName(spnCtx, controller, lun)
source, err := getDevicePath(spnCtx, controller, lun, partition)
if err != nil {
return err
}
Expand All @@ -123,7 +130,7 @@ func Mount(
}

if verityInfo != nil {
dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, deviceHash)
dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, partition, deviceHash)
if source, err = createVerityTarget(spnCtx, source, dmVerityName, verityInfo); err != nil {
return err
}
Expand Down Expand Up @@ -156,7 +163,7 @@ func Mount(

mountType := "ext4"
if encrypted {
cryptDeviceName := fmt.Sprintf(cryptDeviceFmt, controller, lun)
cryptDeviceName := fmt.Sprintf(cryptDeviceFmt, controller, lun, partition)
encryptedSource, err := encryptDevice(spnCtx, source, cryptDeviceName)
if err != nil {
// todo (maksiman): add better retry logic, similar to how SCSI device mounts are
Expand All @@ -173,7 +180,7 @@ func Mount(

for {
if err := unixMount(source, target, mountType, flags, data); err != nil {
// The `source` found by controllerLunToName can take some time
// The `source` found by GetDevicePath can take some time
// before its actually available under `/dev/sd*`. Retry while we
// wait for `source` to show up.
if errors.Is(err, unix.ENOENT) || errors.Is(err, unix.ENXIO) {
Expand Down Expand Up @@ -210,6 +217,7 @@ func Unmount(
ctx context.Context,
controller,
lun uint8,
partition uint64,
target string,
encrypted bool,
verityInfo *guestresource.DeviceVerityInfo,
Expand All @@ -221,6 +229,7 @@ func Unmount(
span.AddAttributes(
trace.Int64Attribute("controller", int64(controller)),
trace.Int64Attribute("lun", int64(lun)),
trace.Int64Attribute("partition", int64(partition)),
trace.StringAttribute("target", target))

// unmount target
Expand All @@ -229,15 +238,15 @@ func Unmount(
}

if verityInfo != nil {
dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, verityInfo.RootDigest)
dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, partition, verityInfo.RootDigest)
if err := removeDevice(dmVerityName); err != nil {
// Ignore failures, since the path has been unmounted at this point.
log.G(ctx).WithError(err).Debugf("failed to remove dm verity target: %s", dmVerityName)
}
}

if encrypted {
dmCryptName := fmt.Sprintf(cryptDeviceFmt, controller, lun)
dmCryptName := fmt.Sprintf(cryptDeviceFmt, controller, lun, partition)
if err := cleanupCryptDevice(dmCryptName); err != nil {
return fmt.Errorf("failed to cleanup dm-crypt target %s: %w", dmCryptName, err)
}
Expand All @@ -246,24 +255,26 @@ func Unmount(
return nil
}

// ControllerLunToName finds the `/dev/sd*` path to the SCSI device on
// `controller` index `lun`.
func ControllerLunToName(ctx context.Context, controller, lun uint8) (_ string, err error) {
ctx, span := oc.StartSpan(ctx, "scsi::ControllerLunToName")
// GetDevicePath finds the `/dev/sd*` path to the SCSI device on `controller`
// index `lun` with partition index `partition`.
func GetDevicePath(ctx context.Context, controller, lun uint8, partition uint64) (_ string, err error) {
ctx, span := oc.StartSpan(ctx, "scsi::GetDevicePath")
defer span.End()
defer func() { oc.SetSpanStatus(span, err) }()

span.AddAttributes(
trace.Int64Attribute("controller", int64(controller)),
trace.Int64Attribute("lun", int64(lun)))
trace.Int64Attribute("lun", int64(lun)),
trace.Int64Attribute("partition", int64(partition)),
)

scsiID := fmt.Sprintf("%d:0:0:%d", controller, lun)
// Devices matching the given SCSI code should each have a subdirectory
// under /sys/bus/scsi/devices/<scsiID>/block.
blockPath := filepath.Join(scsiDevicesPath, scsiID, "block")
var deviceNames []os.DirEntry
for {
deviceNames, err = os.ReadDir(blockPath)
deviceNames, err = osReadDir(blockPath)
if err != nil && !os.IsNotExist(err) {
return "", err
}
Expand All @@ -282,8 +293,38 @@ func ControllerLunToName(ctx context.Context, controller, lun uint8) (_ string,
if len(deviceNames) > 1 {
return "", errors.Errorf("more than one block device could match SCSI ID \"%s\"", scsiID)
}
deviceName := deviceNames[0].Name()

// devices that have partitions have a subdirectory under
// /sys/bus/scsi/devices/<scsiID>/block/<deviceName> for each partition.
// Partitions use 1-based indexing, so if `partition` is 0, then we should
// return the device name without a partition index.
if partition != 0 {
partitionName := deviceName + fmt.Sprintf("%d", partition)
partitionPath := filepath.Join(blockPath, deviceName, partitionName)

// Wait for the device partition to show up
for {
fi, err := osStat(partitionPath)
if err != nil && !os.IsNotExist(err) {
return "", err
} else if fi == nil {
// if the fileinfo is nil that means we didn't find the device, keep
// trying until the context is done or the device path shows up
select {
case <-ctx.Done():
return "", ctx.Err()
default:
time.Sleep(time.Millisecond * 10)
continue
}
}
break
}
deviceName = partitionName
}

devicePath := filepath.Join("/dev", deviceNames[0].Name())
devicePath := filepath.Join("/dev", deviceName)
log.G(ctx).WithField("devicePath", devicePath).Debug("found device path")
return devicePath, nil
}
Expand Down
Loading

0 comments on commit b21c405

Please sign in to comment.