Skip to content

Commit

Permalink
Starter PKI CA Storage API (#14796)
Browse files Browse the repository at this point in the history
* Simple starting PKI storage api for CA rotation
* Add key and issuer storage apis
* Add listKeys and listIssuers storage implementations
* Add simple keys and issuers configuration storage api methods
  • Loading branch information
stevendpclark authored and cipherboy committed Apr 5, 2022
1 parent e23ff1b commit 7a19beb
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 4 deletions.
2 changes: 1 addition & 1 deletion builtin/logical/pki/managed_key_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ func parseCABundle(_ context.Context, _ *backend, _ *logical.Request, bundle *ce
return bundle.ToParsedCertBundle()
}

func withManagedPKIKey(_ context.Context, _ *backend, _ keyId, _ string, _ logical.ManagedSigningKeyConsumer) error {
func withManagedPKIKey(_ context.Context, _ *backend, _ managedKeyId, _ string, _ logical.ManagedSigningKeyConsumer) error {
return errEntOnly
}
188 changes: 188 additions & 0 deletions builtin/logical/pki/storage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
package pki

import (
"context"
"fmt"

"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
)

const (
storageKeyConfig = "/config/keys"
storeageIssuerConfig = "/config/issuers"
keyPrefix = "/config/key/"
issuerPrefix = "/config/issuer/"
)

type keyId string

func (p keyId) String() string {
return string(p)
}

type issuerId string

func (p issuerId) String() string {
return string(p)
}

type key struct {
ID keyId `json:"id" structs:"id" mapstructure:"id"`
PrivateKeyType certutil.PrivateKeyType `json:"private_key_type" structs:"private_key_type" mapstructure:"private_key_type"`
PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"`
}

type issuer struct {
ID issuerId `json:"id" structs:"id" mapstructure:"id"`
Name string `json:"name" structs:"name" mapstructure:"name"`
KeyID keyId `json:"key_id" structs:"key_id" mapstructure:"key_id"`
Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"`
CAChain []string `json:"ca_chain" structs:"ca_chain" mapstructure:"ca_chain"`
SerialNumber string `json:"serial_number" structs:"serial_number" mapstructure:"serial_number"`
}

type keyConfig struct {
DefaultKeyId keyId `json:"default" structs:"default" mapstructure:"default"`
}

type issuerConfig struct {
DefaultIssuerId issuerId `json:"default" structs:"default" mapstructure:"default"`
}

func listKeys(ctx context.Context, s logical.Storage) ([]keyId, error) {
strList, err := s.List(ctx, keyPrefix)
if err != nil {
return nil, err
}

keyIds := make([]keyId, 0, len(strList))
for _, entry := range strList {
keyIds = append(keyIds, keyId(entry))
}

return keyIds, nil
}

func fetchKeyById(ctx context.Context, s logical.Storage, keyId keyId) (*key, error) {
keyEntry, err := s.Get(ctx, keyPrefix+keyId.String())
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch pki key: %v", err)}
}
if keyEntry == nil {
// FIXME: Dedicated/specific error for this?
return nil, errutil.UserError{Err: fmt.Sprintf("pki key id %s does not exist", keyId.String())}
}

var key key
if err := keyEntry.DecodeJSON(&key); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode pki key with id %s: %v", keyId.String(), err)}
}

return &key, nil
}

func writeKey(ctx context.Context, s logical.Storage, key key) error {
keyId := key.ID

json, err := logical.StorageEntryJSON(keyPrefix+keyId.String(), key)
if err != nil {
return err
}

return s.Put(ctx, json)
}

func listIssuers(ctx context.Context, s logical.Storage) ([]issuerId, error) {
strList, err := s.List(ctx, issuerPrefix)
if err != nil {
return nil, err
}

issuerIds := make([]issuerId, 0, len(strList))
for _, entry := range strList {
issuerIds = append(issuerIds, issuerId(entry))
}

return issuerIds, nil
}

func fetchIssuerById(ctx context.Context, s logical.Storage, issuerId issuerId) (*issuer, error) {
issuerEntry, err := s.Get(ctx, issuerPrefix+issuerId.String())
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch pki issuer: %v", err)}
}
if issuerEntry == nil {
// FIXME: Dedicated/specific error for this?
return nil, errutil.UserError{Err: fmt.Sprintf("pki issuer id %s does not exist", issuerId.String())}
}

var issuer issuer
if err := issuerEntry.DecodeJSON(&issuer); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode pki issuer with id %s: %v", issuerId.String(), err)}
}

return &issuer, nil
}

func writeIssuer(ctx context.Context, s logical.Storage, issuer issuer) error {
issuerId := issuer.ID

json, err := logical.StorageEntryJSON(issuerPrefix+issuerId.String(), issuer)
if err != nil {
return err
}

return s.Put(ctx, json)
}

func setKeysConfig(ctx context.Context, s logical.Storage, config *keyConfig) error {
json, err := logical.StorageEntryJSON(storageKeyConfig, config)
if err != nil {
return err
}

return s.Put(ctx, json)
}

func getKeysConfig(ctx context.Context, s logical.Storage) (*keyConfig, error) {
keyConfigEntry, err := s.Get(ctx, storageKeyConfig)
if err != nil {
return nil, err
}

keyConfig := &keyConfig{}
if keyConfigEntry != nil {
if err := keyConfigEntry.DecodeJSON(keyConfig); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode key configuration: %v", err)}
}
}

return keyConfig, nil
}

func setIssuersConfig(ctx context.Context, s logical.Storage, config *issuerConfig) error {
json, err := logical.StorageEntryJSON(storeageIssuerConfig, config)
if err != nil {
return err
}

return s.Put(ctx, json)
}

func getIssuersConfig(ctx context.Context, s logical.Storage) (*issuerConfig, error) {
issuerConfigEntry, err := s.Get(ctx, storeageIssuerConfig)
if err != nil {
return nil, err
}

issuerConfig := &issuerConfig{}
if issuerConfigEntry != nil {
if err := issuerConfigEntry.DecodeJSON(issuerConfig); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode issuer configuration: %v", err)}
}
}

return issuerConfig, nil
}
163 changes: 163 additions & 0 deletions builtin/logical/pki/storage_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package pki

import (
"context"
"crypto/rand"
"testing"

"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)

var ctx = context.Background()

func Test_ConfigsRoundTrip(t *testing.T) {
_, s := createBackendWithStorage(t)

// Verify we handle nothing stored properly
keyConfigEmpty, err := getKeysConfig(ctx, s)
require.NoError(t, err)
require.Equal(t, &keyConfig{}, keyConfigEmpty)

issuerConfigEmpty, err := getIssuersConfig(ctx, s)
require.NoError(t, err)
require.Equal(t, &issuerConfig{}, issuerConfigEmpty)

// Now attempt to store and reload properly
origKeyConfig := &keyConfig{
DefaultKeyId: genKeyId(t),
}
origIssuerConfig := &issuerConfig{
DefaultIssuerId: genIssuerId(t),
}

err = setKeysConfig(ctx, s, origKeyConfig)
require.NoError(t, err)
err = setIssuersConfig(ctx, s, origIssuerConfig)
require.NoError(t, err)

keyConfig, err := getKeysConfig(ctx, s)
require.NoError(t, err)
require.Equal(t, origKeyConfig, keyConfig)

issuerConfig, err := getIssuersConfig(ctx, s)
require.NoError(t, err)
require.Equal(t, origIssuerConfig, issuerConfig)
}

func Test_IssuerRoundTrip(t *testing.T) {
b, s := createBackendWithStorage(t)
issuer1, key1 := genIssuerAndKey(t, b)
issuer2, key2 := genIssuerAndKey(t, b)

// We get an error when issuer id not found
_, err := fetchIssuerById(ctx, s, issuer1.ID)
require.Error(t, err)

// We get an error when key id not found
_, err = fetchKeyById(ctx, s, key1.ID)
require.Error(t, err)

// Now write out our issuers and keys
err = writeKey(ctx, s, key1)
require.NoError(t, err)
err = writeIssuer(ctx, s, issuer1)
require.NoError(t, err)

err = writeKey(ctx, s, key2)
require.NoError(t, err)
err = writeIssuer(ctx, s, issuer2)
require.NoError(t, err)

fetchedKey1, err := fetchKeyById(ctx, s, key1.ID)
require.NoError(t, err)

fetchedIssuer1, err := fetchIssuerById(ctx, s, issuer1.ID)
require.NoError(t, err)

require.Equal(t, &key1, fetchedKey1)
require.Equal(t, &issuer1, fetchedIssuer1)

keys, err := listKeys(ctx, s)
require.NoError(t, err)

require.ElementsMatch(t, []keyId{key1.ID, key2.ID}, keys)

issuers, err := listIssuers(ctx, s)
require.NoError(t, err)

require.ElementsMatch(t, []issuerId{issuer1.ID, issuer2.ID}, issuers)
}

func genIssuerAndKey(t *testing.T, b *backend) (issuer, key) {
certBundle, err := genCertBundle(t, b)
require.NoError(t, err)

keyId := genKeyId(t)

pkiKey := key{
ID: keyId,
PrivateKeyType: certBundle.PrivateKeyType,
PrivateKey: certBundle.PrivateKey,
}

issuerId := genIssuerId(t)

pkiIssuer := issuer{
ID: issuerId,
KeyID: keyId,
Certificate: certBundle.Certificate,
CAChain: certBundle.CAChain,
SerialNumber: certBundle.SerialNumber,
}

return pkiIssuer, pkiKey
}

func genIssuerId(t *testing.T) issuerId {
issuerIdStr, err := uuid.GenerateUUID()
require.NoError(t, err)
return issuerId(issuerIdStr)
}

func genKeyId(t *testing.T) keyId {
keyIdStr, err := uuid.GenerateUUID()
require.NoError(t, err)
return keyId(keyIdStr)
}

func genCertBundle(t *testing.T, b *backend) (*certutil.CertBundle, error) {
// Pretty gross just to generate a cert bundle, but
fields := addCACommonFields(map[string]*framework.FieldSchema{})
fields = addCAKeyGenerationFields(fields)
fields = addCAIssueFields(fields)
apiData := &framework.FieldData{
Schema: fields,
Raw: map[string]interface{}{
"exported": "internal",
"cn": "example.com",
"ttl": 3600,
},
}
_, _, role, respErr := b.getGenerationParams(ctx, apiData, "/pki")
require.Nil(t, respErr)

input := &inputBundle{
req: &logical.Request{
Operation: logical.UpdateOperation,
Path: "issue/testrole",
Storage: b.storage,
},
apiData: apiData,
role: role,
}
parsedCertBundle, err := generateCert(ctx, b, input, nil, true, rand.Reader)

require.NoError(t, err)
certBundle, err := parsedCertBundle.ToCertBundle()
require.NoError(t, err)
return certBundle, err
}
6 changes: 3 additions & 3 deletions builtin/logical/pki/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func kmsRequested(input *inputBundle) bool {
return exportedStr.(string) == "kms"
}

type keyId interface {
type managedKeyId interface {
String() string
}

Expand All @@ -49,13 +49,13 @@ func (n NameKey) String() string {

// getManagedKeyId returns a NameKey or a UUIDKey, whichever was specified in the
// request API data.
func getManagedKeyId(data *framework.FieldData) (keyId, error) {
func getManagedKeyId(data *framework.FieldData) (managedKeyId, error) {
name, UUID, err := getManagedKeyNameOrUUID(data)
if err != nil {
return nil, err
}

var keyId keyId = NameKey(name)
var keyId managedKeyId = NameKey(name)
if len(UUID) > 0 {
keyId = UUIDKey(UUID)
}
Expand Down

0 comments on commit 7a19beb

Please sign in to comment.