Skip to content

Commit

Permalink
Merge pull request #7088 from liortamari/s3_backup_server_side_encryp…
Browse files Browse the repository at this point in the history
…tion_with_customer_provided_key

Add s3 server-side encryption and decryption with customer provided key
  • Loading branch information
deepthi authored Dec 3, 2020
2 parents bebcd07 + d25a7c1 commit 6454b7d
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 9 deletions.
69 changes: 61 additions & 8 deletions go/vt/mysqlctl/s3backupstorage/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ limitations under the License.
package s3backupstorage

import (
"crypto/md5"
"crypto/tls"
"encoding/base64"
"flag"
"fmt"
"io"
"io/ioutil"
"math"
"net/http"
"sort"
Expand Down Expand Up @@ -74,7 +77,7 @@ var (
requiredLogLevel = flag.String("s3_backup_log_level", "LogOff", "determine the S3 loglevel to use from LogOff, LogDebug, LogDebugWithSigning, LogDebugWithHTTPBody, LogDebugWithRequestRetries, LogDebugWithRequestErrors")

// sse is the server-side encryption algorithm used when storing this object in S3
sse = flag.String("s3_backup_server_side_encryption", "", "server-side encryption algorithm (e.g., AES256, aws:kms)")
sse = flag.String("s3_backup_server_side_encryption", "", "server-side encryption algorithm (e.g., AES256, aws:kms, sse_c:/path/to/key/file)")

// path component delimiter
delimiter = "/"
Expand All @@ -84,6 +87,8 @@ type logNameToLogLevel map[string]aws.LogLevelType

var logNameMap logNameToLogLevel

const sseCustomerPrefix = "sse_c:"

// S3BackupHandle implements the backupstorage.BackupHandle interface.
type S3BackupHandle struct {
client s3iface.S3API
Expand Down Expand Up @@ -147,15 +152,14 @@ func (bh *S3BackupHandle) AddFile(ctx context.Context, filename string, filesize
})
object := objName(bh.dir, bh.name, filename)

var sseOption *string
if *sse != "" {
sseOption = sse
}
_, err := uploader.Upload(&s3manager.UploadInput{
Bucket: bucket,
Key: object,
Body: reader,
ServerSideEncryption: sseOption,
ServerSideEncryption: bh.bs.s3SSE.awsAlg,
SSECustomerAlgorithm: bh.bs.s3SSE.customerAlg,
SSECustomerKey: bh.bs.s3SSE.customerKey,
SSECustomerKeyMD5: bh.bs.s3SSE.customerMd5,
})
if err != nil {
reader.CloseWithError(err)
Expand Down Expand Up @@ -190,8 +194,11 @@ func (bh *S3BackupHandle) ReadFile(ctx context.Context, filename string) (io.Rea
}
object := objName(bh.dir, bh.name, filename)
out, err := bh.client.GetObject(&s3.GetObjectInput{
Bucket: bucket,
Key: object,
Bucket: bucket,
Key: object,
SSECustomerAlgorithm: bh.bs.s3SSE.customerAlg,
SSECustomerKey: bh.bs.s3SSE.customerKey,
SSECustomerKeyMD5: bh.bs.s3SSE.customerMd5,
})
if err != nil {
return nil, err
Expand All @@ -201,10 +208,51 @@ func (bh *S3BackupHandle) ReadFile(ctx context.Context, filename string) (io.Rea

var _ backupstorage.BackupHandle = (*S3BackupHandle)(nil)

type S3ServerSideEncryption struct {
awsAlg *string
customerAlg *string
customerKey *string
customerMd5 *string
}

func (s3ServerSideEncryption *S3ServerSideEncryption) init() error {
s3ServerSideEncryption.reset()

if strings.HasPrefix(*sse, sseCustomerPrefix) {
sseCustomerKeyFile := strings.TrimPrefix(*sse, sseCustomerPrefix)
base64CodedKey, err := ioutil.ReadFile(sseCustomerKeyFile)
if err != nil {
log.Errorf(err.Error())
return err
}

decodedKey, err := base64.StdEncoding.DecodeString(string(base64CodedKey))
if err != nil {
decodedKey = base64CodedKey
}

md5Hash := md5.Sum(decodedKey)
s3ServerSideEncryption.customerAlg = aws.String("AES256")
s3ServerSideEncryption.customerKey = aws.String(string(decodedKey))
s3ServerSideEncryption.customerMd5 = aws.String(base64.StdEncoding.EncodeToString(md5Hash[:]))
} else if *sse != "" {
s3ServerSideEncryption.awsAlg = sse
}
return nil
}

func (s3ServerSideEncryption *S3ServerSideEncryption) reset() {
s3ServerSideEncryption.awsAlg = nil
s3ServerSideEncryption.customerAlg = nil
s3ServerSideEncryption.customerKey = nil
s3ServerSideEncryption.customerMd5 = nil
}

// S3BackupStorage implements the backupstorage.BackupStorage interface.
type S3BackupStorage struct {
_client *s3.S3
mu sync.Mutex
s3SSE S3ServerSideEncryption
}

// ListBackups is part of the backupstorage.BackupStorage interface.
Expand Down Expand Up @@ -339,6 +387,7 @@ func (bs *S3BackupStorage) Close() error {
bs.mu.Lock()
defer bs.mu.Unlock()
bs._client = nil
bs.s3SSE.reset()
return nil
}

Expand Down Expand Up @@ -394,6 +443,10 @@ func (bs *S3BackupStorage) client() (*s3.S3, error) {
if _, err := bs._client.HeadBucket(&s3.HeadBucketInput{Bucket: bucket}); err != nil {
return nil, err
}

if err := bs.s3SSE.init(); err != nil {
return nil, err
}
}
return bs._client, nil
}
Expand Down
126 changes: 125 additions & 1 deletion go/vt/mysqlctl/s3backupstorage/s3_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package s3backupstorage

import (
"crypto/md5"
"crypto/rand"
"encoding/base64"
"errors"
"io/ioutil"
"net/http"
"os"
"testing"

"github.com/aws/aws-sdk-go/aws"
Expand All @@ -25,7 +30,7 @@ func (s3errclient *s3ErrorClient) PutObjectRequest(in *s3.PutObjectInput) (*requ
}

func TestAddFileError(t *testing.T) {
bh := &S3BackupHandle{client: &s3ErrorClient{}, readOnly: false}
bh := &S3BackupHandle{client: &s3ErrorClient{}, bs: &S3BackupStorage{}, readOnly: false}

wc, err := bh.AddFile(aws.BackgroundContext(), "somefile", 100000)
require.NoErrorf(t, err, "AddFile() expected no error, got %s", err)
Expand All @@ -41,3 +46,122 @@ func TestAddFileError(t *testing.T) {

require.Equal(t, bh.HasErrors(), true, "AddFile() expected bh to record async error but did not")
}

func TestNoSSE(t *testing.T) {
sseData := S3ServerSideEncryption{}
err := sseData.init()
require.NoErrorf(t, err, "init() expected to succeed")

assert.Nil(t, sseData.awsAlg, "awsAlg expected to be nil")
assert.Nil(t, sseData.customerAlg, "customerAlg expected to be nil")
assert.Nil(t, sseData.customerKey, "customerKey expected to be nil")
assert.Nil(t, sseData.customerMd5, "customerMd5 expected to be nil")

sseData.reset()
require.NoErrorf(t, err, "reset() expected to succeed")
}

func TestSSEAws(t *testing.T) {
sse = aws.String("aws:kms")
sseData := S3ServerSideEncryption{}
err := sseData.init()
require.NoErrorf(t, err, "init() expected to succeed")

assert.Equal(t, aws.String("aws:kms"), sseData.awsAlg, "awsAlg expected to be aws:kms")
assert.Nil(t, sseData.customerAlg, "customerAlg expected to be nil")
assert.Nil(t, sseData.customerKey, "customerKey expected to be nil")
assert.Nil(t, sseData.customerMd5, "customerMd5 expected to be nil")

sseData.reset()
require.NoErrorf(t, err, "reset() expected to succeed")

assert.Nil(t, sseData.awsAlg, "awsAlg expected to be nil")
assert.Nil(t, sseData.customerAlg, "customerAlg expected to be nil")
assert.Nil(t, sseData.customerKey, "customerKey expected to be nil")
assert.Nil(t, sseData.customerMd5, "customerMd5 expected to be nil")
}

func TestSSECustomerFileNotFound(t *testing.T) {
tempFile, err := ioutil.TempFile("", "filename")
require.NoErrorf(t, err, "TempFile() expected to succeed")
defer os.Remove(tempFile.Name())

err = tempFile.Close()
require.NoErrorf(t, err, "Close() expected to succeed")

err = os.Remove(tempFile.Name())
require.NoErrorf(t, err, "Remove() expected to succeed")

sse = aws.String(sseCustomerPrefix + tempFile.Name())
sseData := S3ServerSideEncryption{}
err = sseData.init()
require.Errorf(t, err, "init() expected to fail")
}

func TestSSECustomerFileBinaryKey(t *testing.T) {
tempFile, err := ioutil.TempFile("", "filename")
require.NoErrorf(t, err, "TempFile() expected to succeed")
defer os.Remove(tempFile.Name())

randomKey := make([]byte, 32)
_, err = rand.Read(randomKey)
require.NoErrorf(t, err, "Read() expected to succeed")
_, err = tempFile.Write(randomKey)
require.NoErrorf(t, err, "Write() expected to succeed")
err = tempFile.Close()
require.NoErrorf(t, err, "Close() expected to succeed")

sse = aws.String(sseCustomerPrefix + tempFile.Name())
sseData := S3ServerSideEncryption{}
err = sseData.init()
require.NoErrorf(t, err, "init() expected to succeed")

assert.Nil(t, sseData.awsAlg, "awsAlg expected to be nil")
assert.Equal(t, aws.String("AES256"), sseData.customerAlg, "customerAlg expected to be AES256")
assert.Equal(t, aws.String(string(randomKey)), sseData.customerKey, "customerKey expected to be equal to the generated randomKey")
md5Hash := md5.Sum(randomKey)
assert.Equal(t, aws.String(base64.StdEncoding.EncodeToString(md5Hash[:])), sseData.customerMd5, "customerMd5 expected to be equal to the customerMd5 hash of the generated randomKey")

sseData.reset()
require.NoErrorf(t, err, "reset() expected to succeed")

assert.Nil(t, sseData.awsAlg, "awsAlg expected to be nil")
assert.Nil(t, sseData.customerAlg, "customerAlg expected to be nil")
assert.Nil(t, sseData.customerKey, "customerKey expected to be nil")
assert.Nil(t, sseData.customerMd5, "customerMd5 expected to be nil")
}

func TestSSECustomerFileBase64Key(t *testing.T) {
tempFile, err := ioutil.TempFile("", "filename")
require.NoErrorf(t, err, "TempFile() expected to succeed")
defer os.Remove(tempFile.Name())

randomKey := make([]byte, 32)
_, err = rand.Read(randomKey)
require.NoErrorf(t, err, "Read() expected to succeed")

base64Key := base64.StdEncoding.EncodeToString(randomKey[:])
_, err = tempFile.WriteString(base64Key)
require.NoErrorf(t, err, "WriteString() expected to succeed")
err = tempFile.Close()
require.NoErrorf(t, err, "Close() expected to succeed")

sse = aws.String(sseCustomerPrefix + tempFile.Name())
sseData := S3ServerSideEncryption{}
err = sseData.init()
require.NoErrorf(t, err, "init() expected to succeed")

assert.Nil(t, sseData.awsAlg, "awsAlg expected to be nil")
assert.Equal(t, aws.String("AES256"), sseData.customerAlg, "customerAlg expected to be AES256")
assert.Equal(t, aws.String(string(randomKey)), sseData.customerKey, "customerKey expected to be equal to the generated randomKey")
md5Hash := md5.Sum(randomKey)
assert.Equal(t, aws.String(base64.StdEncoding.EncodeToString(md5Hash[:])), sseData.customerMd5, "customerMd5 expected to be equal to the customerMd5 hash of the generated randomKey")

sseData.reset()
require.NoErrorf(t, err, "reset() expected to succeed")

assert.Nil(t, sseData.awsAlg, "awsAlg expected to be nil")
assert.Nil(t, sseData.customerAlg, "customerAlg expected to be nil")
assert.Nil(t, sseData.customerKey, "customerKey expected to be nil")
assert.Nil(t, sseData.customerMd5, "customerMd5 expected to be nil")
}

0 comments on commit 6454b7d

Please sign in to comment.