Skip to content

Commit

Permalink
simplify zeroDevice to just zero first block (microsoft#1672)
Browse files Browse the repository at this point in the history
Signed-off-by: Maksim An <maksiman@microsoft.com>
  • Loading branch information
anmaxvl authored Feb 28, 2023
1 parent 411a183 commit 38a2b19
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 146 deletions.
37 changes: 14 additions & 23 deletions internal/guest/storage/crypt/crypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@ import (

// Test dependencies
var (
_cryptsetupClose = cryptsetupClose
_cryptsetupFormat = cryptsetupFormat
_cryptsetupOpen = cryptsetupOpen
_generateKeyFile = generateKeyFile
_getBlockDeviceSize = getBlockDeviceSize
_osMkdirTemp = os.MkdirTemp
_mkfsXfs = mkfsXfs
_zeroDevice = zeroDevice
_osRemoveAll = os.RemoveAll
_cryptsetupClose = cryptsetupClose
_cryptsetupFormat = cryptsetupFormat
_cryptsetupOpen = cryptsetupOpen
_generateKeyFile = generateKeyFile
_osMkdirTemp = os.MkdirTemp
_mkfsXfs = mkfsXfs
_osRemoveAll = os.RemoveAll
_zeroFirstBlock = zeroFirstBlock
)

// cryptsetupCommand runs cryptsetup with the provided arguments
Expand Down Expand Up @@ -165,20 +164,12 @@ func EncryptDevice(ctx context.Context, source string, dmCryptName string) (path
}()

deviceNamePath := "/dev/mapper/" + dmCryptName

// Get actual size of the scratch device
deviceSize, err := _getBlockDeviceSize(ctx, deviceNamePath)
if err != nil {
return "", fmt.Errorf("error getting size of: %s: %w", deviceNamePath, err)
}

if deviceSize == 0 {
return "", fmt.Errorf("invalid size obtained for: %s", deviceNamePath)
}

// 4.1. Zero the first block. It appears that mkfs.xfs reads this before formatting.
if err = _zeroDevice(deviceNamePath, 4096, 1); err != nil {
return "", fmt.Errorf("failed zero'ing start of device %s: %w", deviceNamePath, err)
// 4.1. Zero the first block.
// In the xfs mkfs case it appears to attempt to read the first block of the device.
// This results in an integrity error. This function zeros out the start of the device,
// so we are sure that when it is read it has already been hashed so matches.
if err := _zeroFirstBlock(deviceNamePath, 4096); err != nil {
return "", fmt.Errorf("failed to zero first block: %w", err)
}

// 4.2. Format it as xfs
Expand Down
76 changes: 5 additions & 71 deletions internal/guest/storage/crypt/crypt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ package crypt

import (
"context"
"fmt"
"testing"

"github.com/Microsoft/hcsshim/internal/memory"
"github.com/pkg/errors"
)

Expand All @@ -23,9 +21,9 @@ func clearCryptTestDependencies() {
_cryptsetupFormat = nil
_cryptsetupOpen = nil
_generateKeyFile = nil
_getBlockDeviceSize = nil
_osMkdirTemp = osMkdirTempTest
_osRemoveAll = nil
_zeroFirstBlock = nil
}

func Test_Encrypt_Generate_Key_Error(t *testing.T) {
Expand Down Expand Up @@ -127,63 +125,11 @@ func Test_Encrypt_Cryptsetup_Open_Error(t *testing.T) {
}
}

func Test_Encrypt_Get_Device_Size_Error(t *testing.T) {
clearCryptTestDependencies()

// Test what happens when cryptsetup fails to get the size of the
// unencrypted block device.

_generateKeyFile = func(path string, size int64) error {
return nil
}
_osRemoveAll = func(path string) error {
return nil
}
_cryptsetupFormat = func(source string, keyFilePath string) error {
return nil
}
_cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error {
return nil
}
_cryptsetupClose = func(deviceName string) error {
return nil
}

source := "/dev/sda"
dmCryptName := "dm-crypt-target"
deviceNamePath := "/dev/mapper/" + dmCryptName

expectedErr := errors.New("expected error message")
_getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) {
return 0, expectedErr
}

_, err := EncryptDevice(context.Background(), source, dmCryptName)
if errors.Unwrap(err) != expectedErr {
t.Fatalf("expected err: '%v' got: '%v'", expectedErr, err)
}

// Check that it fails when the size of the block device is zero

expectedErr = fmt.Errorf("invalid size obtained for: %s", deviceNamePath)
_getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) {
return 0, nil
}

_, err = EncryptDevice(context.Background(), source, dmCryptName)
if err.Error() != expectedErr.Error() {
t.Fatalf("expected err: '%v' got: '%v'", expectedErr, err)
}
}

func Test_Encrypt_Mkfs_Error(t *testing.T) {
clearCryptTestDependencies()

// Test what happens when mkfs fails to format the unencrypted device.
// Verify that the arguments passed to it are the right ones.

blockDeviceSize := int64(memory.GiB)

_generateKeyFile = func(path string, size int64) error {
return nil
}
Expand All @@ -199,14 +145,7 @@ func Test_Encrypt_Mkfs_Error(t *testing.T) {
_cryptsetupClose = func(deviceName string) error {
return nil
}
_getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) {
// Return a non-zero size
return blockDeviceSize, nil
}
_mkfsXfs = func(string) error {
return nil
}
_zeroDevice = func(string, int64, int64) error {
_zeroFirstBlock = func(_ string, _ int) error {
return nil
}

Expand All @@ -221,8 +160,7 @@ func Test_Encrypt_Mkfs_Error(t *testing.T) {
return expectedErr
}

_, err := EncryptDevice(context.Background(), source, "dm-crypt-name")
if errors.Unwrap(err) != expectedErr {
if _, err := EncryptDevice(context.Background(), source, "dm-crypt-name"); errors.Unwrap(err) != expectedErr {
t.Fatalf("expected err: '%v' got: '%v'", expectedErr, err)
}
}
Expand All @@ -231,9 +169,6 @@ func Test_Encrypt_Success(t *testing.T) {
clearCryptTestDependencies()

// Test what happens when everything goes right.

blockDeviceSize := int64(memory.GiB)

_generateKeyFile = func(path string, size int64) error {
return nil
}
Expand All @@ -246,9 +181,8 @@ func Test_Encrypt_Success(t *testing.T) {
_cryptsetupOpen = func(source string, deviceName string, keyFilePath string) error {
return nil
}
_getBlockDeviceSize = func(ctx context.Context, path string) (int64, error) {
// Return a non-zero size
return blockDeviceSize, nil
_zeroFirstBlock = func(_ string, _ int) error {
return nil
}
_mkfsXfs = func(arg string) error {
return nil
Expand Down
65 changes: 13 additions & 52 deletions internal/guest/storage/crypt/utilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,75 +4,36 @@
package crypt

import (
"context"
"bytes"
"crypto/rand"
"fmt"
"io"
"os"

"github.com/Microsoft/hcsshim/internal/log"
)

// getBlockDeviceSize returns the size of the specified block device.
func getBlockDeviceSize(ctx context.Context, path string) (int64, error) {
file, err := os.Open(path)
if err != nil {
return 0, fmt.Errorf("error opening %s: %w", path, err)
}

defer func() {
if err := file.Close(); err != nil {
log.G(ctx).WithError(err).Debug("error closing: " + path)
}
}()

pos, err := file.Seek(0, io.SeekEnd)
if err != nil {
return 0, fmt.Errorf("error seeking end of %s: %w", path, err)
}
return pos, nil
}

// In the xfs mkfs case it appears to attempt to read the first block of the device.
// This results in an integrity error. This function zeros out the start of the device
// so we are sure that when it is read it has already been hashed so matches.
func zeroDevice(devicePath string, blockSize int64, numberOfBlocks int64) error {
fout, err := os.OpenFile(devicePath, os.O_WRONLY, 0)
func zeroFirstBlock(path string, blockSize int) error {
fout, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return fmt.Errorf("failed to open device file %s: %w", devicePath, err)
return fmt.Errorf("failed to open file for zero'ing: %w", err)
}
defer fout.Close()

zeros := make([]byte, blockSize)
for i := range zeros {
zeros[i] = 0
}

// get the size so we don't overrun the end of the device
foutSize, err := fout.Seek(0, io.SeekEnd)
size, err := fout.Seek(0, io.SeekEnd)
if err != nil {
return fmt.Errorf("zeroDevice: failed to seek to end, device file %s: %w", devicePath, err)
return fmt.Errorf("error seeking end of %s: %w", path, err)
}
if size < int64(blockSize) {
return fmt.Errorf("file size is smaller than minimum expected: %d < %d", size, blockSize)
}

// move back to the front.
_, err = fout.Seek(0, io.SeekStart)
if err != nil {
return fmt.Errorf("zeroDevice: failed to seek to start, device file %s: %w", devicePath, err)
return fmt.Errorf("error seeking start of %s: %w", path, err)
}

var offset int64 = 0
var which int64
for which = 0; which < numberOfBlocks; which++ {
// Exit when the end of the file is reached
if offset >= foutSize {
break
}
// Write data to destination file
written, err := fout.Write(zeros)
if err != nil {
return fmt.Errorf("failed to write destination file %s offset %d: %w", devicePath, offset, err)
}
offset += int64(written)
zeros := bytes.Repeat([]byte{0}, blockSize)
if _, err := fout.Write(zeros); err != nil {
return fmt.Errorf("failed to zero-out bytes: %w", err)
}
return nil
}
Expand Down

0 comments on commit 38a2b19

Please sign in to comment.