Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify zeroDevice to just zero first block #1672

Merged
merged 3 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 14 additions & 23 deletions internal/guest/storage/crypt/crypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@ import (

// Test dependencies
var (
_cryptsetupClose = cryptsetupClose
_cryptsetupFormat = cryptsetupFormat
_cryptsetupOpen = cryptsetupOpen
_generateKeyFile = generateKeyFile
_getBlockDeviceSize = getBlockDeviceSize
_osMkdirTemp = os.MkdirTemp
_mkfsXfs = mkfsXfs
_zeroDevice = zeroDevice
_osRemoveAll = os.RemoveAll
_cryptsetupClose = cryptsetupClose
_cryptsetupFormat = cryptsetupFormat
_cryptsetupOpen = cryptsetupOpen
_generateKeyFile = generateKeyFile
_osMkdirTemp = os.MkdirTemp
_mkfsXfs = mkfsXfs
_osRemoveAll = os.RemoveAll
_zeroFirstBlock = zeroFirstBlock
)

// cryptsetupCommand runs cryptsetup with the provided arguments
Expand Down Expand Up @@ -165,20 +164,12 @@ func EncryptDevice(ctx context.Context, source string, dmCryptName string) (path
}()

deviceNamePath := "/dev/mapper/" + dmCryptName

// Get actual size of the scratch device
deviceSize, err := _getBlockDeviceSize(ctx, deviceNamePath)
if err != nil {
return "", fmt.Errorf("error getting size of: %s: %w", deviceNamePath, err)
}

if deviceSize == 0 {
return "", fmt.Errorf("invalid size obtained for: %s", deviceNamePath)
}

// 4.1. Zero the first block. It appears that mkfs.xfs reads this before formatting.
if err = _zeroDevice(deviceNamePath, 4096, 1); err != nil {
return "", fmt.Errorf("failed zero'ing start of device %s: %w", deviceNamePath, err)
// 4.1. Zero the first block.
// In the xfs mkfs case it appears to attempt to read the first block of the device.
// This results in an integrity error. This function zeros out the start of the device,
// so we are sure that when it is read it has already been hashed so matches.
if err := _zeroFirstBlock(deviceNamePath, 4096); err != nil {
return "", fmt.Errorf("failed to zero first block: %w", err)
}

// 4.2. Format it as xfs
Expand Down
76 changes: 5 additions & 71 deletions internal/guest/storage/crypt/crypt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ package crypt

import (
"context"
"fmt"
"testing"

"github.com/Microsoft/hcsshim/internal/memory"
"github.com/pkg/errors"
)

Expand All @@ -23,9 +21,9 @@ func clearCryptTestDependencies() {
_cryptsetupFormat = nil
_cryptsetupOpen = nil
_generateKeyFile = nil
_getBlockDeviceSize = nil
_osMkdirTemp = osMkdirTempTest
_osRemoveAll = nil
_zeroFirstBlock = nil
}

func Test_Encrypt_Generate_Key_Error(t *testing.T) {
Expand Down Expand Up @@ -127,63 +125,11 @@ func Test_Encrypt_Cryptsetup_Open_Error(t *testing.T) {
}
}

func Test_Encrypt_Get_Device_Size_Error(t *testing.T) {
clearCryptTestDependencies()

// Test what happens when cryptsetup fails to get the size of the
// unencrypted block device.

_generateKeyFile = func(path string, size int64) error {
return nil
}
_osRemoveAll = func(path string) error {
return nil
}
_cryptsetupFormat = func(source string, keyFilePath string) error {
return nil
}
_cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error {
return nil
}
_cryptsetupClose = func(deviceName string) error {
return nil
}

source := "/dev/sda"
dmCryptName := "dm-crypt-target"
deviceNamePath := "/dev/mapper/" + dmCryptName

expectedErr := errors.New("expected error message")
_getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) {
return 0, expectedErr
}

_, err := EncryptDevice(context.Background(), source, dmCryptName)
if errors.Unwrap(err) != expectedErr {
t.Fatalf("expected err: '%v' got: '%v'", expectedErr, err)
}

// Check that it fails when the size of the block device is zero

expectedErr = fmt.Errorf("invalid size obtained for: %s", deviceNamePath)
_getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) {
return 0, nil
}

_, err = EncryptDevice(context.Background(), source, dmCryptName)
if err.Error() != expectedErr.Error() {
t.Fatalf("expected err: '%v' got: '%v'", expectedErr, err)
}
}

func Test_Encrypt_Mkfs_Error(t *testing.T) {
clearCryptTestDependencies()

// Test what happens when mkfs fails to format the unencrypted device.
// Verify that the arguments passed to it are the right ones.

blockDeviceSize := int64(memory.GiB)

_generateKeyFile = func(path string, size int64) error {
return nil
}
Expand All @@ -199,14 +145,7 @@ func Test_Encrypt_Mkfs_Error(t *testing.T) {
_cryptsetupClose = func(deviceName string) error {
return nil
}
_getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) {
// Return a non-zero size
return blockDeviceSize, nil
}
_mkfsXfs = func(string) error {
return nil
}
_zeroDevice = func(string, int64, int64) error {
_zeroFirstBlock = func(_ string, _ int) error {
return nil
}

Expand All @@ -221,8 +160,7 @@ func Test_Encrypt_Mkfs_Error(t *testing.T) {
return expectedErr
}

_, err := EncryptDevice(context.Background(), source, "dm-crypt-name")
if errors.Unwrap(err) != expectedErr {
if _, err := EncryptDevice(context.Background(), source, "dm-crypt-name"); errors.Unwrap(err) != expectedErr {
t.Fatalf("expected err: '%v' got: '%v'", expectedErr, err)
}
}
Expand All @@ -231,9 +169,6 @@ func Test_Encrypt_Success(t *testing.T) {
clearCryptTestDependencies()

// Test what happens when everything goes right.

blockDeviceSize := int64(memory.GiB)

_generateKeyFile = func(path string, size int64) error {
return nil
}
Expand All @@ -246,9 +181,8 @@ func Test_Encrypt_Success(t *testing.T) {
_cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error {
return nil
}
_getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) {
// Return a non-zero size
return blockDeviceSize, nil
_zeroFirstBlock = func(_ string, _ int) error {
return nil
}
_mkfsXfs = func(arg string) error {
return nil
Expand Down
65 changes: 13 additions & 52 deletions internal/guest/storage/crypt/utilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,75 +4,36 @@
package crypt

import (
"context"
"bytes"
"crypto/rand"
"fmt"
"io"
"os"

"github.com/Microsoft/hcsshim/internal/log"
)

// getBlockDeviceSize returns the size of the specified block device.
func getBlockDeviceSize(ctx context.Context, path string) (int64, error) {
file, err := os.Open(path)
if err != nil {
return 0, fmt.Errorf("error opening %s: %w", path, err)
}

defer func() {
if err := file.Close(); err != nil {
log.G(ctx).WithError(err).Debug("error closing: " + path)
}
}()

pos, err := file.Seek(0, io.SeekEnd)
if err != nil {
return 0, fmt.Errorf("error seeking end of %s: %w", path, err)
}
return pos, nil
}

// In the xfs mkfs case it appears to attempt to read the first block of the device.
// This results in an integrity error. This function zeros out the start of the device
// so we are sure that when it is read it has already been hashed so matches.
func zeroDevice(devicePath string, blockSize int64, numberOfBlocks int64) error {
fout, err := os.OpenFile(devicePath, os.O_WRONLY, 0)
func zeroFirstBlock(path string, blockSize int) error {
fout, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return fmt.Errorf("failed to open device file %s: %w", devicePath, err)
return fmt.Errorf("failed to open file for zero'ing: %w", err)
}
defer fout.Close()

zeros := make([]byte, blockSize)
for i := range zeros {
zeros[i] = 0
}

// get the size so we don't overrun the end of the device
foutSize, err := fout.Seek(0, io.SeekEnd)
size, err := fout.Seek(0, io.SeekEnd)
if err != nil {
return fmt.Errorf("zeroDevice: failed to seek to end, device file %s: %w", devicePath, err)
return fmt.Errorf("error seeking end of %s: %w", path, err)
}
if size < int64(blockSize) {
return fmt.Errorf("file size is smaller than minimum expected: %d < %d", size, blockSize)
}

// move back to the front.
_, err = fout.Seek(0, io.SeekStart)
if err != nil {
return fmt.Errorf("zeroDevice: failed to seek to start, device file %s: %w", devicePath, err)
return fmt.Errorf("error seeking start of %s: %w", path, err)
}

var offset int64 = 0
var which int64
for which = 0; which < numberOfBlocks; which++ {
// Exit when the end of the file is reached
if offset >= foutSize {
break
}
// Write data to destination file
written, err := fout.Write(zeros)
if err != nil {
return fmt.Errorf("failed to write destination file %s offset %d: %w", devicePath, offset, err)
}
offset += int64(written)
zeros := bytes.Repeat([]byte{0}, blockSize)
if _, err := fout.Write(zeros); err != nil {
return fmt.Errorf("failed to zero-out bytes: %w", err)
}
return nil
}
Expand Down