Skip to content

Commit

Permalink
aead: Use Get* methods for chained proto field access.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 584652165
Change-Id: I1a1f8b853265a8aaf6def2a34c8dc7667a85ddcc
  • Loading branch information
chuckx authored and copybara-github committed Nov 22, 2023
1 parent b0b39f1 commit 70fd12e
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 46 deletions.
3 changes: 3 additions & 0 deletions aead/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ go_test(
"//keyset",
"//mac",
"//monitoring",
"//proto/aes_ctr_go_proto",
"//proto/aes_ctr_hmac_aead_go_proto",
"//proto/aes_gcm_go_proto",
"//proto/aes_gcm_siv_go_proto",
"//proto/chacha20_poly1305_go_proto",
"//proto/common_go_proto",
"//proto/hmac_go_proto",
"//proto/kms_envelope_go_proto",
"//proto/tink_go_proto",
"//proto/xchacha20_poly1305_go_proto",
Expand Down
46 changes: 23 additions & 23 deletions aead/aes_ctr_hmac_aead_key_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,18 @@ func (km *aesCTRHMACAEADKeyManager) Primitive(serializedKey []byte) (interface{}
return nil, err
}

ctr, err := subtle.NewAESCTR(key.AesCtrKey.KeyValue, int(key.AesCtrKey.Params.IvSize))
ctr, err := subtle.NewAESCTR(key.GetAesCtrKey().GetKeyValue(), int(key.GetAesCtrKey().GetParams().GetIvSize()))
if err != nil {
return nil, fmt.Errorf("aes_ctr_hmac_aead_key_manager: cannot create new primitive: %v", err)
}

hmacKey := key.HmacKey
hmac, err := subtleMac.NewHMAC(hmacKey.Params.Hash.String(), hmacKey.KeyValue, hmacKey.Params.TagSize)
hmacKey := key.GetHmacKey()
hmac, err := subtleMac.NewHMAC(hmacKey.GetParams().GetHash().String(), hmacKey.GetKeyValue(), hmacKey.GetParams().GetTagSize())
if err != nil {
return nil, fmt.Errorf("aes_ctr_hmac_aead_key_manager: cannot create mac primitive, error: %v", err)
}

aead, err := subtle.NewEncryptThenAuthenticate(ctr, hmac, int(hmacKey.Params.TagSize))
aead, err := subtle.NewEncryptThenAuthenticate(ctr, hmac, int(hmacKey.GetParams().GetTagSize()))
if err != nil {
return nil, fmt.Errorf("aes_ctr_hmac_aead_key_manager: cannot create encrypt then authenticate primitive, error: %v", err)
}
Expand All @@ -94,13 +94,13 @@ func (km *aesCTRHMACAEADKeyManager) NewKey(serializedKeyFormat []byte) (proto.Me
Version: aesCTRHMACAEADKeyVersion,
AesCtrKey: &ctrpb.AesCtrKey{
Version: aesCTRHMACAEADKeyVersion,
KeyValue: random.GetRandomBytes(keyFormat.AesCtrKeyFormat.KeySize),
Params: keyFormat.AesCtrKeyFormat.Params,
KeyValue: random.GetRandomBytes(keyFormat.GetAesCtrKeyFormat().GetKeySize()),
Params: keyFormat.GetAesCtrKeyFormat().GetParams(),
},
HmacKey: &hmacpb.HmacKey{
Version: aesCTRHMACAEADKeyVersion,
KeyValue: random.GetRandomBytes(keyFormat.HmacKeyFormat.KeySize),
Params: keyFormat.HmacKeyFormat.Params,
KeyValue: random.GetRandomBytes(keyFormat.GetHmacKeyFormat().GetKeySize()),
Params: keyFormat.GetHmacKeyFormat().GetParams(),
},
}, nil
}
Expand Down Expand Up @@ -139,19 +139,19 @@ func (km *aesCTRHMACAEADKeyManager) validateKey(key *aeadpb.AesCtrHmacAeadKey) e
if err := keyset.ValidateKeyVersion(key.Version, aesCTRHMACAEADKeyVersion); err != nil {
return fmt.Errorf("aes_ctr_hmac_aead_key_manager: %v", err)
}
if err := keyset.ValidateKeyVersion(key.AesCtrKey.Version, aesCTRHMACAEADKeyVersion); err != nil {
if err := keyset.ValidateKeyVersion(key.GetAesCtrKey().GetVersion(), aesCTRHMACAEADKeyVersion); err != nil {
return fmt.Errorf("aes_ctr_hmac_aead_key_manager: %v", err)
}
if err := keyset.ValidateKeyVersion(key.HmacKey.Version, aesCTRHMACAEADKeyVersion); err != nil {
if err := keyset.ValidateKeyVersion(key.GetHmacKey().GetVersion(), aesCTRHMACAEADKeyVersion); err != nil {
return fmt.Errorf("aes_ctr_hmac_aead_key_manager: %v", err)
}
// Validate AesCtrKey.
keySize := uint32(len(key.AesCtrKey.KeyValue))
if err := subtle.ValidateAESKeySize(keySize); err != nil {
return fmt.Errorf("aes_ctr_hmac_aead_key_manager: %v", err)
}
params := key.AesCtrKey.Params
if params.IvSize < subtle.AESCTRMinIVSize || params.IvSize > 16 {
params := key.AesCtrKey.GetParams()
if params.GetIvSize() < subtle.AESCTRMinIVSize || params.GetIvSize() > 16 {
return errors.New("aes_ctr_hmac_aead_key_manager: invalid AesCtrHmacAeadKey: IV size out of range")
}
return nil
Expand All @@ -160,37 +160,37 @@ func (km *aesCTRHMACAEADKeyManager) validateKey(key *aeadpb.AesCtrHmacAeadKey) e
// validateKeyFormat validates the given AesCtrHmacAeadKeyFormat proto.
func (km *aesCTRHMACAEADKeyManager) validateKeyFormat(format *aeadpb.AesCtrHmacAeadKeyFormat) error {
// Validate AesCtrKeyFormat.
if err := subtle.ValidateAESKeySize(format.AesCtrKeyFormat.KeySize); err != nil {
if err := subtle.ValidateAESKeySize(format.GetAesCtrKeyFormat().GetKeySize()); err != nil {
return fmt.Errorf("aes_ctr_hmac_aead_key_manager: %s", err)
}
if format.AesCtrKeyFormat.Params.IvSize < subtle.AESCTRMinIVSize || format.AesCtrKeyFormat.Params.IvSize > 16 {
if format.GetAesCtrKeyFormat().GetParams().GetIvSize() < subtle.AESCTRMinIVSize || format.GetAesCtrKeyFormat().GetParams().GetIvSize() > 16 {
return errors.New("aes_ctr_hmac_aead_key_manager: invalid AesCtrHmacAeadKeyFormat: IV size out of range")
}

// Validate HmacKeyFormat.
hmacKeyFormat := format.HmacKeyFormat
if hmacKeyFormat.KeySize < minHMACKeySizeInBytes {
hmacKeyFormat := format.GetHmacKeyFormat()
if hmacKeyFormat.GetKeySize() < minHMACKeySizeInBytes {
return errors.New("aes_ctr_hmac_aead_key_manager: HMAC KeySize is too small")
}
if hmacKeyFormat.Params.TagSize < minTagSizeInBytes {
return fmt.Errorf("aes_ctr_hmac_aead_key_manager: invalid HmacParams: TagSize %d is too small", hmacKeyFormat.Params.TagSize)
if hmacKeyFormat.GetParams().GetTagSize() < minTagSizeInBytes {
return fmt.Errorf("aes_ctr_hmac_aead_key_manager: invalid HmacParams: TagSize %d is too small", hmacKeyFormat.GetParams().GetTagSize())
}

maxTagSize := map[commonpb.HashType]uint32{
maxTagSizes := map[commonpb.HashType]uint32{
commonpb.HashType_SHA1: 20,
commonpb.HashType_SHA224: 28,
commonpb.HashType_SHA256: 32,
commonpb.HashType_SHA384: 48,
commonpb.HashType_SHA512: 64}

tagSize, ok := maxTagSize[hmacKeyFormat.Params.Hash]
maxTagSize, ok := maxTagSizes[hmacKeyFormat.GetParams().GetHash()]
if !ok {
return fmt.Errorf("aes_ctr_hmac_aead_key_manager: invalid HmacParams: HashType %q not supported",
hmacKeyFormat.Params.Hash)
hmacKeyFormat.GetParams().GetHash())
}
if hmacKeyFormat.Params.TagSize > tagSize {
if hmacKeyFormat.GetParams().GetTagSize() > maxTagSize {
return fmt.Errorf("aes_ctr_hmac_aead_key_manager: invalid HmacParams: tagSize %d is too big for HashType %q",
hmacKeyFormat.Params.TagSize, hmacKeyFormat.Params.Hash)
hmacKeyFormat.GetParams().GetTagSize(), hmacKeyFormat.GetParams().GetHash())
}

return nil
Expand Down
145 changes: 133 additions & 12 deletions aead/aes_ctr_hmac_aead_key_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ import (
"github.com/tink-crypto/tink-go/v2/aead"
"github.com/tink-crypto/tink-go/v2/core/registry"
"github.com/tink-crypto/tink-go/v2/testutil"
ctrpb "github.com/tink-crypto/tink-go/v2/proto/aes_ctr_go_proto"
aeadpb "github.com/tink-crypto/tink-go/v2/proto/aes_ctr_hmac_aead_go_proto"
ctrhmacpb "github.com/tink-crypto/tink-go/v2/proto/aes_ctr_hmac_aead_go_proto"
tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto"
commonpb "github.com/tink-crypto/tink-go/v2/proto/common_go_proto"
hmacpb "github.com/tink-crypto/tink-go/v2/proto/hmac_go_proto"
)

func TestNewKeyMultipleTimes(t *testing.T) {
func TestAESCTRHMACNewKeyMultipleTimes(t *testing.T) {
keyTemplate := aead.AES128CTRHMACSHA256KeyTemplate()
aeadKeyFormat := new(ctrhmacpb.AesCtrHmacAeadKeyFormat)
if err := proto.Unmarshal(keyTemplate.Value, aeadKeyFormat); err != nil {
Expand Down Expand Up @@ -65,24 +68,142 @@ func TestNewKeyMultipleTimes(t *testing.T) {
}
}

func TestNewKeyWithCorruptedFormat(t *testing.T) {
keyTemplate := new(tinkpb.KeyTemplate)
func TestAESCTRHMACNewKeyWithInvalidSerializedKeyFormat(t *testing.T) {
keyManager, err := registry.GetKeyManager(testutil.AESCTRHMACAEADTypeURL)
if err != nil {
t.Errorf("cannot obtain AES-CTR-HMAC-AEAD key manager: %s", err)
}

testcases := []struct {
name string
serializedKeyFormat []byte
keyFormat *ctrhmacpb.AesCtrHmacAeadKeyFormat
}{
{
name: "empty",
serializedKeyFormat: make([]byte, 128),
},
{
name: "params_unset",
keyFormat: &ctrhmacpb.AesCtrHmacAeadKeyFormat{
AesCtrKeyFormat: &ctrpb.AesCtrKeyFormat{
Params: nil,
KeySize: 32,
},
HmacKeyFormat: &hmacpb.HmacKeyFormat{
Params: nil,
KeySize: 32,
},
},
},
{
name: "nested_key_formats_unset",
keyFormat: &ctrhmacpb.AesCtrHmacAeadKeyFormat{
AesCtrKeyFormat: nil,
HmacKeyFormat: nil,
},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
serializedKeyFormat := tc.serializedKeyFormat
if serializedKeyFormat == nil {
var err error
serializedKeyFormat, err = proto.Marshal(tc.keyFormat)
if err != nil {
t.Fatalf("failed to marshal key format: %s", err)
}
}

_, err = keyManager.NewKey(serializedKeyFormat)
if err == nil {
t.Error("NewKey() err = nil, want not error")
}

_, err = keyManager.NewKeyData(serializedKeyFormat)
if err == nil {
t.Error("NewKeyData() err = nil, want error")
}
})
}
}

func TestAESCTRHMACPrimitive(t *testing.T) {
keyManager, err := registry.GetKeyManager(testutil.AESCTRHMACAEADTypeURL)
if err != nil {
t.Errorf("cannot obtain AES-CTR-HMAC-AEAD key manager: %s", err)
}

keyTemplate.TypeUrl = testutil.AESCTRHMACAEADTypeURL
keyTemplate.Value = make([]byte, 128)
key := &aeadpb.AesCtrHmacAeadKey{
Version: 0,
AesCtrKey: &ctrpb.AesCtrKey{
Version: 0,
KeyValue: make([]byte, 32),
Params: &ctrpb.AesCtrParams{IvSize: 16},
},
HmacKey: &hmacpb.HmacKey{
Version: 0,
KeyValue: make([]byte, 32),
Params: &hmacpb.HmacParams{Hash: commonpb.HashType_SHA256, TagSize: 32},
},
}
serializedKey, err := proto.Marshal(key)
if err != nil {
t.Fatalf("failed to marshal key: %s", err)
}

_, err = keyManager.Primitive(serializedKey)
if err != nil {
t.Errorf("Primitive() err = %v, want nil", err)
}
}

func TestAESCTRHMACPrimitiveWithInvalidKey(t *testing.T) {
keyManager, err := registry.GetKeyManager(testutil.AESCTRHMACAEADTypeURL)
if err != nil {
t.Errorf("cannot obtain AES-CTR-HMAC-AEAD key manager: %s", err)
}

_, err = keyManager.NewKey(keyTemplate.Value)
if err == nil {
t.Error("NewKey got: success, want: error due to corrupted format")
testcases := []struct {
name string
key *ctrhmacpb.AesCtrHmacAeadKey
}{
{
name: "nil_nested_keys",
key: &aeadpb.AesCtrHmacAeadKey{
Version: 0,
AesCtrKey: nil,
HmacKey: nil,
},
},
{
name: "nil_key_params",
key: &aeadpb.AesCtrHmacAeadKey{
Version: 0,
AesCtrKey: &ctrpb.AesCtrKey{
Version: 0,
KeyValue: make([]byte, 32),
Params: nil,
},
HmacKey: &hmacpb.HmacKey{
Version: 0,
KeyValue: make([]byte, 32),
Params: nil,
},
},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
serializedKey, err := proto.Marshal(tc.key)
if err != nil {
t.Fatalf("failed to marshal key: %s", err)
}

_, err = keyManager.NewKeyData(keyTemplate.Value)
if err == nil {
t.Error("NewKeyData got: success, want: error due to corrupted format")
_, err = keyManager.Primitive(serializedKey)
if err == nil {
t.Error("Primitive() err = nil, want error")
}
})
}
}
26 changes: 15 additions & 11 deletions aead/kms_envelope_aead_key_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func (km *kmsEnvelopeAEADKeyManager) Primitive(serializedKey []byte) (interface{
return nil, errors.New("kms_envelope_aead_key_manager: invalid key")
}
if err := km.validateKey(key); err != nil {
return nil, err
return nil, fmt.Errorf("kms_envelope_aead_key_manager: %v", err)
}
uri := key.Params.KekUri
uri := key.GetParams().GetKekUri()
kmsClient, err := registry.GetKMSClient(uri)
if err != nil {
return nil, err
Expand All @@ -58,7 +58,7 @@ func (km *kmsEnvelopeAEADKeyManager) Primitive(serializedKey []byte) (interface{
return nil, errors.New("kms_envelope_aead_key_manager: invalid aead backend")
}

return NewKMSEnvelopeAEAD2(key.Params.DekTemplate, backend), nil
return NewKMSEnvelopeAEAD2(key.GetParams().GetDekTemplate(), backend), nil
}

// NewKey creates a new key according to specification the given serialized KMSEnvelopeAEADKeyFormat.
Expand All @@ -70,11 +70,9 @@ func (km *kmsEnvelopeAEADKeyManager) NewKey(serializedKeyFormat []byte) (proto.M
if err := proto.Unmarshal(serializedKeyFormat, keyFormat); err != nil {
return nil, errors.New("kms_envelope_aead_key_manager: invalid key format")
}
dekKeyType := keyFormat.GetDekTemplate().GetTypeUrl()
if !isSupporedKMSEnvelopeDEK(dekKeyType) {
return nil, fmt.Errorf("unsupported DEK key type %s. Only Tink AEAD key types are supported with KMSEnvelopeAEAD", dekKeyType)
if err := km.validateKeyFormat(keyFormat); err != nil {
return nil, fmt.Errorf("kms_envelope_aead_key_manager: %v", err)
}

return &kmsepb.KmsEnvelopeAeadKey{
Version: kmsEnvelopeAEADKeyVersion,
Params: keyFormat,
Expand Down Expand Up @@ -112,11 +110,17 @@ func (km *kmsEnvelopeAEADKeyManager) TypeURL() string {

// validateKey validates the given KmsEnvelopeAeadKey.
func (km *kmsEnvelopeAEADKeyManager) validateKey(key *kmsepb.KmsEnvelopeAeadKey) error {
err := keyset.ValidateKeyVersion(key.Version, kmsEnvelopeAEADKeyVersion)
if err != nil {
return fmt.Errorf("kms_envelope_aead_key_manager: %s", err)
if err := keyset.ValidateKeyVersion(key.Version, kmsEnvelopeAEADKeyVersion); err != nil {
return err
}
dekKeyType := key.GetParams().GetDekTemplate().GetTypeUrl()
if err := km.validateKeyFormat(key.GetParams()); err != nil {
return err
}
return nil
}

func (km *kmsEnvelopeAEADKeyManager) validateKeyFormat(keyFormat *kmsepb.KmsEnvelopeAeadKeyFormat) error {
dekKeyType := keyFormat.GetDekTemplate().GetTypeUrl()
if !isSupporedKMSEnvelopeDEK(dekKeyType) {
return fmt.Errorf("unsupported DEK key type %s. Only Tink AEAD key types are supported with KMSEnvelopeAEAD", dekKeyType)
}
Expand Down
Loading

0 comments on commit 70fd12e

Please sign in to comment.