diff --git a/go/vt/mysqlctl/s3backupstorage/s3.go b/go/vt/mysqlctl/s3backupstorage/s3.go index e6347a4a9fc..8d566351e9d 100644 --- a/go/vt/mysqlctl/s3backupstorage/s3.go +++ b/go/vt/mysqlctl/s3backupstorage/s3.go @@ -24,10 +24,13 @@ limitations under the License. package s3backupstorage import ( + "crypto/md5" "crypto/tls" + "encoding/base64" "flag" "fmt" "io" + "io/ioutil" "math" "net/http" "sort" @@ -74,7 +77,7 @@ var ( requiredLogLevel = flag.String("s3_backup_log_level", "LogOff", "determine the S3 loglevel to use from LogOff, LogDebug, LogDebugWithSigning, LogDebugWithHTTPBody, LogDebugWithRequestRetries, LogDebugWithRequestErrors") // sse is the server-side encryption algorithm used when storing this object in S3 - sse = flag.String("s3_backup_server_side_encryption", "", "server-side encryption algorithm (e.g., AES256, aws:kms)") + sse = flag.String("s3_backup_server_side_encryption", "", "server-side encryption algorithm (e.g., AES256, aws:kms, sse_c:/path/to/key/file)") // path component delimiter delimiter = "/" @@ -84,6 +87,8 @@ type logNameToLogLevel map[string]aws.LogLevelType var logNameMap logNameToLogLevel +const sseCustomerPrefix = "sse_c:" + // S3BackupHandle implements the backupstorage.BackupHandle interface. type S3BackupHandle struct { client s3iface.S3API @@ -147,15 +152,14 @@ func (bh *S3BackupHandle) AddFile(ctx context.Context, filename string, filesize }) object := objName(bh.dir, bh.name, filename) - var sseOption *string - if *sse != "" { - sseOption = sse - } _, err := uploader.Upload(&s3manager.UploadInput{ Bucket: bucket, Key: object, Body: reader, - ServerSideEncryption: sseOption, + ServerSideEncryption: bh.bs.s3SSE.awsAlg, + SSECustomerAlgorithm: bh.bs.s3SSE.customerAlg, + SSECustomerKey: bh.bs.s3SSE.customerKey, + SSECustomerKeyMD5: bh.bs.s3SSE.customerMd5, }) if err != nil { reader.CloseWithError(err) @@ -190,8 +194,11 @@ func (bh *S3BackupHandle) ReadFile(ctx context.Context, filename string) (io.Rea } object := objName(bh.dir, bh.name, filename) out, err := bh.client.GetObject(&s3.GetObjectInput{ - Bucket: bucket, - Key: object, + Bucket: bucket, + Key: object, + SSECustomerAlgorithm: bh.bs.s3SSE.customerAlg, + SSECustomerKey: bh.bs.s3SSE.customerKey, + SSECustomerKeyMD5: bh.bs.s3SSE.customerMd5, }) if err != nil { return nil, err @@ -201,10 +208,51 @@ func (bh *S3BackupHandle) ReadFile(ctx context.Context, filename string) (io.Rea var _ backupstorage.BackupHandle = (*S3BackupHandle)(nil) +type S3ServerSideEncryption struct { + awsAlg *string + customerAlg *string + customerKey *string + customerMd5 *string +} + +func (s3ServerSideEncryption *S3ServerSideEncryption) init() error { + s3ServerSideEncryption.reset() + + if strings.HasPrefix(*sse, sseCustomerPrefix) { + sseCustomerKeyFile := strings.TrimPrefix(*sse, sseCustomerPrefix) + base64CodedKey, err := ioutil.ReadFile(sseCustomerKeyFile) + if err != nil { + log.Errorf(err.Error()) + return err + } + + decodedKey, err := base64.StdEncoding.DecodeString(string(base64CodedKey)) + if err != nil { + decodedKey = base64CodedKey + } + + md5Hash := md5.Sum(decodedKey) + s3ServerSideEncryption.customerAlg = aws.String("AES256") + s3ServerSideEncryption.customerKey = aws.String(string(decodedKey)) + s3ServerSideEncryption.customerMd5 = aws.String(base64.StdEncoding.EncodeToString(md5Hash[:])) + } else if *sse != "" { + s3ServerSideEncryption.awsAlg = sse + } + return nil +} + +func (s3ServerSideEncryption *S3ServerSideEncryption) reset() { + s3ServerSideEncryption.awsAlg = nil + s3ServerSideEncryption.customerAlg = nil + s3ServerSideEncryption.customerKey = nil + s3ServerSideEncryption.customerMd5 = nil +} + // S3BackupStorage implements the backupstorage.BackupStorage interface. type S3BackupStorage struct { _client *s3.S3 mu sync.Mutex + s3SSE S3ServerSideEncryption } // ListBackups is part of the backupstorage.BackupStorage interface. @@ -339,6 +387,7 @@ func (bs *S3BackupStorage) Close() error { bs.mu.Lock() defer bs.mu.Unlock() bs._client = nil + bs.s3SSE.reset() return nil } @@ -394,6 +443,10 @@ func (bs *S3BackupStorage) client() (*s3.S3, error) { if _, err := bs._client.HeadBucket(&s3.HeadBucketInput{Bucket: bucket}); err != nil { return nil, err } + + if err := bs.s3SSE.init(); err != nil { + return nil, err + } } return bs._client, nil } diff --git a/go/vt/mysqlctl/s3backupstorage/s3_test.go b/go/vt/mysqlctl/s3backupstorage/s3_test.go index 25d958934f7..98bac3a59ba 100644 --- a/go/vt/mysqlctl/s3backupstorage/s3_test.go +++ b/go/vt/mysqlctl/s3backupstorage/s3_test.go @@ -1,8 +1,13 @@ package s3backupstorage import ( + "crypto/md5" + "crypto/rand" + "encoding/base64" "errors" + "io/ioutil" "net/http" + "os" "testing" "github.com/aws/aws-sdk-go/aws" @@ -25,7 +30,7 @@ func (s3errclient *s3ErrorClient) PutObjectRequest(in *s3.PutObjectInput) (*requ } func TestAddFileError(t *testing.T) { - bh := &S3BackupHandle{client: &s3ErrorClient{}, readOnly: false} + bh := &S3BackupHandle{client: &s3ErrorClient{}, bs: &S3BackupStorage{}, readOnly: false} wc, err := bh.AddFile(aws.BackgroundContext(), "somefile", 100000) require.NoErrorf(t, err, "AddFile() expected no error, got %s", err) @@ -41,3 +46,122 @@ func TestAddFileError(t *testing.T) { require.Equal(t, bh.HasErrors(), true, "AddFile() expected bh to record async error but did not") } + +func TestNoSSE(t *testing.T) { + sseData := S3ServerSideEncryption{} + err := sseData.init() + require.NoErrorf(t, err, "init() expected to succeed") + + assert.Nil(t, sseData.awsAlg, "awsAlg expected to be nil") + assert.Nil(t, sseData.customerAlg, "customerAlg expected to be nil") + assert.Nil(t, sseData.customerKey, "customerKey expected to be nil") + assert.Nil(t, sseData.customerMd5, "customerMd5 expected to be nil") + + sseData.reset() + require.NoErrorf(t, err, "reset() expected to succeed") +} + +func TestSSEAws(t *testing.T) { + sse = aws.String("aws:kms") + sseData := S3ServerSideEncryption{} + err := sseData.init() + require.NoErrorf(t, err, "init() expected to succeed") + + assert.Equal(t, aws.String("aws:kms"), sseData.awsAlg, "awsAlg expected to be aws:kms") + assert.Nil(t, sseData.customerAlg, "customerAlg expected to be nil") + assert.Nil(t, sseData.customerKey, "customerKey expected to be nil") + assert.Nil(t, sseData.customerMd5, "customerMd5 expected to be nil") + + sseData.reset() + require.NoErrorf(t, err, "reset() expected to succeed") + + assert.Nil(t, sseData.awsAlg, "awsAlg expected to be nil") + assert.Nil(t, sseData.customerAlg, "customerAlg expected to be nil") + assert.Nil(t, sseData.customerKey, "customerKey expected to be nil") + assert.Nil(t, sseData.customerMd5, "customerMd5 expected to be nil") +} + +func TestSSECustomerFileNotFound(t *testing.T) { + tempFile, err := ioutil.TempFile("", "filename") + require.NoErrorf(t, err, "TempFile() expected to succeed") + defer os.Remove(tempFile.Name()) + + err = tempFile.Close() + require.NoErrorf(t, err, "Close() expected to succeed") + + err = os.Remove(tempFile.Name()) + require.NoErrorf(t, err, "Remove() expected to succeed") + + sse = aws.String(sseCustomerPrefix + tempFile.Name()) + sseData := S3ServerSideEncryption{} + err = sseData.init() + require.Errorf(t, err, "init() expected to fail") +} + +func TestSSECustomerFileBinaryKey(t *testing.T) { + tempFile, err := ioutil.TempFile("", "filename") + require.NoErrorf(t, err, "TempFile() expected to succeed") + defer os.Remove(tempFile.Name()) + + randomKey := make([]byte, 32) + _, err = rand.Read(randomKey) + require.NoErrorf(t, err, "Read() expected to succeed") + _, err = tempFile.Write(randomKey) + require.NoErrorf(t, err, "Write() expected to succeed") + err = tempFile.Close() + require.NoErrorf(t, err, "Close() expected to succeed") + + sse = aws.String(sseCustomerPrefix + tempFile.Name()) + sseData := S3ServerSideEncryption{} + err = sseData.init() + require.NoErrorf(t, err, "init() expected to succeed") + + assert.Nil(t, sseData.awsAlg, "awsAlg expected to be nil") + assert.Equal(t, aws.String("AES256"), sseData.customerAlg, "customerAlg expected to be AES256") + assert.Equal(t, aws.String(string(randomKey)), sseData.customerKey, "customerKey expected to be equal to the generated randomKey") + md5Hash := md5.Sum(randomKey) + assert.Equal(t, aws.String(base64.StdEncoding.EncodeToString(md5Hash[:])), sseData.customerMd5, "customerMd5 expected to be equal to the customerMd5 hash of the generated randomKey") + + sseData.reset() + require.NoErrorf(t, err, "reset() expected to succeed") + + assert.Nil(t, sseData.awsAlg, "awsAlg expected to be nil") + assert.Nil(t, sseData.customerAlg, "customerAlg expected to be nil") + assert.Nil(t, sseData.customerKey, "customerKey expected to be nil") + assert.Nil(t, sseData.customerMd5, "customerMd5 expected to be nil") +} + +func TestSSECustomerFileBase64Key(t *testing.T) { + tempFile, err := ioutil.TempFile("", "filename") + require.NoErrorf(t, err, "TempFile() expected to succeed") + defer os.Remove(tempFile.Name()) + + randomKey := make([]byte, 32) + _, err = rand.Read(randomKey) + require.NoErrorf(t, err, "Read() expected to succeed") + + base64Key := base64.StdEncoding.EncodeToString(randomKey[:]) + _, err = tempFile.WriteString(base64Key) + require.NoErrorf(t, err, "WriteString() expected to succeed") + err = tempFile.Close() + require.NoErrorf(t, err, "Close() expected to succeed") + + sse = aws.String(sseCustomerPrefix + tempFile.Name()) + sseData := S3ServerSideEncryption{} + err = sseData.init() + require.NoErrorf(t, err, "init() expected to succeed") + + assert.Nil(t, sseData.awsAlg, "awsAlg expected to be nil") + assert.Equal(t, aws.String("AES256"), sseData.customerAlg, "customerAlg expected to be AES256") + assert.Equal(t, aws.String(string(randomKey)), sseData.customerKey, "customerKey expected to be equal to the generated randomKey") + md5Hash := md5.Sum(randomKey) + assert.Equal(t, aws.String(base64.StdEncoding.EncodeToString(md5Hash[:])), sseData.customerMd5, "customerMd5 expected to be equal to the customerMd5 hash of the generated randomKey") + + sseData.reset() + require.NoErrorf(t, err, "reset() expected to succeed") + + assert.Nil(t, sseData.awsAlg, "awsAlg expected to be nil") + assert.Nil(t, sseData.customerAlg, "customerAlg expected to be nil") + assert.Nil(t, sseData.customerKey, "customerKey expected to be nil") + assert.Nil(t, sseData.customerMd5, "customerMd5 expected to be nil") +}