Skip to content

Commit

Permalink
br: have better crypter key error msg (#56589)
Browse files Browse the repository at this point in the history
close #56388
  • Loading branch information
Tristan1900 authored Oct 17, 2024
1 parent 2e51209 commit 65fcaec
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 54 deletions.
2 changes: 1 addition & 1 deletion br/pkg/task/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ go_test(
],
embed = [":task"],
flaky = True,
shard_count = 38,
shard_count = 39,
deps = [
"//br/pkg/backup",
"//br/pkg/config",
Expand Down
59 changes: 38 additions & 21 deletions br/pkg/task/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/pingcap/tidb/br/pkg/metautil"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/br/pkg/utils"
"github.com/pingcap/tidb/pkg/meta/model"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
filter "github.com/pingcap/tidb/pkg/util/table-filter"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -115,9 +114,7 @@ const (
)

const (
// Once TableInfoVersion updated. BR need to check compatibility with
// new TableInfoVersion. both snapshot restore and pitr need to be checked.
CURRENT_BACKUP_SUPPORT_TABLE_INFO_VERSION = model.TableInfoVersion5
cipherKeyNonHexErrorMsg = "cipher key must be a valid hexadecimal string"
)

// FullBackupType type when doing full backup or restore
Expand Down Expand Up @@ -464,34 +461,52 @@ func GetCipherKeyContent(cipherKey, cipherKeyFile string) ([]byte, error) {
return nil, errors.Trace(err)
}

// if cipher-key is valid, convert the hexadecimal string to bytes
var hexString string

// Check if cipher-key is provided directly
if len(cipherKey) > 0 {
return hex.DecodeString(cipherKey)
hexString = cipherKey
} else {
// Read content from cipher-file
content, err := os.ReadFile(cipherKeyFile)
if err != nil {
return nil, errors.Annotate(err, "failed to read cipher file")
}
hexString = string(bytes.TrimSuffix(content, []byte("\n")))
}

// convert the content(as hexadecimal string) from cipher-file to bytes
content, err := os.ReadFile(cipherKeyFile)
// Attempt to decode the hex string
decodedKey, err := hex.DecodeString(hexString)
if err != nil {
return nil, errors.Annotate(err, "failed to read cipher file")
return nil, errors.Annotate(berrors.ErrInvalidArgument, cipherKeyNonHexErrorMsg)
}

content = bytes.TrimSuffix(content, []byte("\n"))
return hex.DecodeString(string(content))
return decodedKey, nil
}

func checkCipherKeyMatch(cipher *backuppb.CipherInfo) bool {
func checkCipherKeyMatch(cipher *backuppb.CipherInfo) error {
switch cipher.CipherType {
case encryptionpb.EncryptionMethod_PLAINTEXT:
return true
return nil
case encryptionpb.EncryptionMethod_AES128_CTR:
return len(cipher.CipherKey) == crypterAES128KeyLen
if len(cipher.CipherKey) != crypterAES128KeyLen {
return errors.Annotatef(berrors.ErrInvalidArgument, "AES-128 key length mismatch: expected %d, got %d",
crypterAES128KeyLen, len(cipher.CipherKey))
}
case encryptionpb.EncryptionMethod_AES192_CTR:
return len(cipher.CipherKey) == crypterAES192KeyLen
if len(cipher.CipherKey) != crypterAES192KeyLen {
return errors.Annotatef(berrors.ErrInvalidArgument, "AES-192 key length mismatch: expected %d, got %d",
crypterAES192KeyLen, len(cipher.CipherKey))
}
case encryptionpb.EncryptionMethod_AES256_CTR:
return len(cipher.CipherKey) == crypterAES256KeyLen
if len(cipher.CipherKey) != crypterAES256KeyLen {
return errors.Annotatef(berrors.ErrInvalidArgument, "AES-256 key length mismatch: expected %d, got %d",
crypterAES256KeyLen, len(cipher.CipherKey))
}
default:
return false
return errors.Errorf("Unknown encryption method: %v", cipher.CipherType)
}
return nil
}

func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error {
Expand Down Expand Up @@ -524,8 +539,9 @@ func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error {
return errors.Trace(err)
}

if !checkCipherKeyMatch(&cfg.CipherInfo) {
return errors.Annotate(berrors.ErrInvalidArgument, "crypter method and key length not match")
err = checkCipherKeyMatch(&cfg.CipherInfo)
if err != nil {
return errors.Trace(err)
}

return nil
Expand Down Expand Up @@ -561,8 +577,9 @@ func (cfg *Config) parseLogBackupCipherInfo(flags *pflag.FlagSet) (bool, error)
return false, errors.Trace(err)
}

if !checkCipherKeyMatch(&cfg.CipherInfo) {
return false, errors.Annotate(berrors.ErrInvalidArgument, "log backup encryption method and key length not match")
err = checkCipherKeyMatch(&cfg.CipherInfo)
if err != nil {
return false, errors.Trace(err)
}

return true, nil
Expand Down
102 changes: 70 additions & 32 deletions br/pkg/task/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package task

import (
"encoding/hex"
"fmt"
"testing"

Expand Down Expand Up @@ -70,57 +69,89 @@ func TestStripingPDURL(t *testing.T) {

func TestCheckCipherKeyMatch(t *testing.T) {
cases := []struct {
CipherType encryptionpb.EncryptionMethod
CipherKey string
ok bool
name string
cipherInfo *backup.CipherInfo
expectErr bool
errMsg string
}{
{
CipherType: encryptionpb.EncryptionMethod_PLAINTEXT,
ok: true,
name: "PLAINTEXT",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_PLAINTEXT,
},
expectErr: false,
},
{
CipherType: encryptionpb.EncryptionMethod_UNKNOWN,
ok: false,
name: "UNKNOWN",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_UNKNOWN,
},
expectErr: true,
errMsg: "Unknown encryption method: UNKNOWN",
},
{
CipherType: encryptionpb.EncryptionMethod_AES128_CTR,
CipherKey: "0123456789abcdef0123456789abcdef",
ok: true,
name: "AES128_CTR valid",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_AES128_CTR,
CipherKey: make([]byte, crypterAES128KeyLen),
},
expectErr: false,
},
{
CipherType: encryptionpb.EncryptionMethod_AES128_CTR,
CipherKey: "0123456789abcdef0123456789abcd",
ok: false,
name: "AES128_CTR invalid length",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_AES128_CTR,
CipherKey: make([]byte, crypterAES128KeyLen-1),
},
expectErr: true,
errMsg: fmt.Sprintf("AES-128 key length mismatch: expected %d, got %d", crypterAES128KeyLen, crypterAES128KeyLen-1),
},
{
CipherType: encryptionpb.EncryptionMethod_AES192_CTR,
CipherKey: "0123456789abcdef0123456789abcdef0123456789abcdef",
ok: true,
name: "AES192_CTR valid",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_AES192_CTR,
CipherKey: make([]byte, crypterAES192KeyLen),
},
expectErr: false,
},
{
CipherType: encryptionpb.EncryptionMethod_AES192_CTR,
CipherKey: "0123456789abcdef0123456789abcdef0123456789abcdefff",
ok: false,
name: "AES192_CTR invalid length",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_AES192_CTR,
CipherKey: make([]byte, crypterAES192KeyLen+1),
},
expectErr: true,
errMsg: fmt.Sprintf("AES-192 key length mismatch: expected %d, got %d", crypterAES192KeyLen, crypterAES192KeyLen+1),
},
{
CipherType: encryptionpb.EncryptionMethod_AES256_CTR,
CipherKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
ok: true,
name: "AES256_CTR valid",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_AES256_CTR,
CipherKey: make([]byte, crypterAES256KeyLen),
},
expectErr: false,
},
{
CipherType: encryptionpb.EncryptionMethod_AES256_CTR,
CipherKey: "",
ok: false,
name: "AES256_CTR invalid length",
cipherInfo: &backup.CipherInfo{
CipherType: encryptionpb.EncryptionMethod_AES256_CTR,
CipherKey: make([]byte, 0),
},
expectErr: true,
errMsg: fmt.Sprintf("AES-256 key length mismatch: expected %d, got %d", crypterAES256KeyLen, 0),
},
}

for _, c := range cases {
cipherKey, err := hex.DecodeString(c.CipherKey)
require.NoError(t, err)
require.Equal(t, c.ok, checkCipherKeyMatch(&backup.CipherInfo{
CipherType: c.CipherType,
CipherKey: cipherKey,
}))
t.Run(c.name, func(t *testing.T) {
err := checkCipherKeyMatch(c.cipherInfo)
if c.expectErr {
require.Error(t, err)
require.Contains(t, err.Error(), c.errMsg)
} else {
require.NoError(t, err)
}
})
}
}

Expand Down Expand Up @@ -162,6 +193,13 @@ func TestCheckCipherKey(t *testing.T) {
}
}

func TestGetCipherKey(t *testing.T) {
nonHexKey := "this is not a hex string"
_, err := GetCipherKeyContent(nonHexKey, "")
require.Error(t, err)
require.Contains(t, err.Error(), cipherKeyNonHexErrorMsg)
}

func must[T any](t T, err error) T {
if err != nil {
panic(err)
Expand Down

0 comments on commit 65fcaec

Please sign in to comment.