-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Simple starting PKI storage api for CA rotation * Add key and issuer storage apis * Add listKeys and listIssuers storage implementations * Add simple keys and issuers configuration storage api methods
- Loading branch information
1 parent
e23ff1b
commit 7a19beb
Showing
4 changed files
with
355 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
package pki | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
|
||
"github.com/hashicorp/vault/sdk/helper/certutil" | ||
"github.com/hashicorp/vault/sdk/helper/errutil" | ||
"github.com/hashicorp/vault/sdk/logical" | ||
) | ||
|
||
const ( | ||
storageKeyConfig = "/config/keys" | ||
storeageIssuerConfig = "/config/issuers" | ||
keyPrefix = "/config/key/" | ||
issuerPrefix = "/config/issuer/" | ||
) | ||
|
||
type keyId string | ||
|
||
func (p keyId) String() string { | ||
return string(p) | ||
} | ||
|
||
type issuerId string | ||
|
||
func (p issuerId) String() string { | ||
return string(p) | ||
} | ||
|
||
type key struct { | ||
ID keyId `json:"id" structs:"id" mapstructure:"id"` | ||
PrivateKeyType certutil.PrivateKeyType `json:"private_key_type" structs:"private_key_type" mapstructure:"private_key_type"` | ||
PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` | ||
} | ||
|
||
type issuer struct { | ||
ID issuerId `json:"id" structs:"id" mapstructure:"id"` | ||
Name string `json:"name" structs:"name" mapstructure:"name"` | ||
KeyID keyId `json:"key_id" structs:"key_id" mapstructure:"key_id"` | ||
Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` | ||
CAChain []string `json:"ca_chain" structs:"ca_chain" mapstructure:"ca_chain"` | ||
SerialNumber string `json:"serial_number" structs:"serial_number" mapstructure:"serial_number"` | ||
} | ||
|
||
type keyConfig struct { | ||
DefaultKeyId keyId `json:"default" structs:"default" mapstructure:"default"` | ||
} | ||
|
||
type issuerConfig struct { | ||
DefaultIssuerId issuerId `json:"default" structs:"default" mapstructure:"default"` | ||
} | ||
|
||
func listKeys(ctx context.Context, s logical.Storage) ([]keyId, error) { | ||
strList, err := s.List(ctx, keyPrefix) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
keyIds := make([]keyId, 0, len(strList)) | ||
for _, entry := range strList { | ||
keyIds = append(keyIds, keyId(entry)) | ||
} | ||
|
||
return keyIds, nil | ||
} | ||
|
||
func fetchKeyById(ctx context.Context, s logical.Storage, keyId keyId) (*key, error) { | ||
keyEntry, err := s.Get(ctx, keyPrefix+keyId.String()) | ||
if err != nil { | ||
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch pki key: %v", err)} | ||
} | ||
if keyEntry == nil { | ||
// FIXME: Dedicated/specific error for this? | ||
return nil, errutil.UserError{Err: fmt.Sprintf("pki key id %s does not exist", keyId.String())} | ||
} | ||
|
||
var key key | ||
if err := keyEntry.DecodeJSON(&key); err != nil { | ||
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode pki key with id %s: %v", keyId.String(), err)} | ||
} | ||
|
||
return &key, nil | ||
} | ||
|
||
func writeKey(ctx context.Context, s logical.Storage, key key) error { | ||
keyId := key.ID | ||
|
||
json, err := logical.StorageEntryJSON(keyPrefix+keyId.String(), key) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
return s.Put(ctx, json) | ||
} | ||
|
||
func listIssuers(ctx context.Context, s logical.Storage) ([]issuerId, error) { | ||
strList, err := s.List(ctx, issuerPrefix) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
issuerIds := make([]issuerId, 0, len(strList)) | ||
for _, entry := range strList { | ||
issuerIds = append(issuerIds, issuerId(entry)) | ||
} | ||
|
||
return issuerIds, nil | ||
} | ||
|
||
func fetchIssuerById(ctx context.Context, s logical.Storage, issuerId issuerId) (*issuer, error) { | ||
issuerEntry, err := s.Get(ctx, issuerPrefix+issuerId.String()) | ||
if err != nil { | ||
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch pki issuer: %v", err)} | ||
} | ||
if issuerEntry == nil { | ||
// FIXME: Dedicated/specific error for this? | ||
return nil, errutil.UserError{Err: fmt.Sprintf("pki issuer id %s does not exist", issuerId.String())} | ||
} | ||
|
||
var issuer issuer | ||
if err := issuerEntry.DecodeJSON(&issuer); err != nil { | ||
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode pki issuer with id %s: %v", issuerId.String(), err)} | ||
} | ||
|
||
return &issuer, nil | ||
} | ||
|
||
func writeIssuer(ctx context.Context, s logical.Storage, issuer issuer) error { | ||
issuerId := issuer.ID | ||
|
||
json, err := logical.StorageEntryJSON(issuerPrefix+issuerId.String(), issuer) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
return s.Put(ctx, json) | ||
} | ||
|
||
func setKeysConfig(ctx context.Context, s logical.Storage, config *keyConfig) error { | ||
json, err := logical.StorageEntryJSON(storageKeyConfig, config) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
return s.Put(ctx, json) | ||
} | ||
|
||
func getKeysConfig(ctx context.Context, s logical.Storage) (*keyConfig, error) { | ||
keyConfigEntry, err := s.Get(ctx, storageKeyConfig) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
keyConfig := &keyConfig{} | ||
if keyConfigEntry != nil { | ||
if err := keyConfigEntry.DecodeJSON(keyConfig); err != nil { | ||
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode key configuration: %v", err)} | ||
} | ||
} | ||
|
||
return keyConfig, nil | ||
} | ||
|
||
func setIssuersConfig(ctx context.Context, s logical.Storage, config *issuerConfig) error { | ||
json, err := logical.StorageEntryJSON(storeageIssuerConfig, config) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
return s.Put(ctx, json) | ||
} | ||
|
||
func getIssuersConfig(ctx context.Context, s logical.Storage) (*issuerConfig, error) { | ||
issuerConfigEntry, err := s.Get(ctx, storeageIssuerConfig) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
issuerConfig := &issuerConfig{} | ||
if issuerConfigEntry != nil { | ||
if err := issuerConfigEntry.DecodeJSON(issuerConfig); err != nil { | ||
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode issuer configuration: %v", err)} | ||
} | ||
} | ||
|
||
return issuerConfig, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
package pki | ||
|
||
import ( | ||
"context" | ||
"crypto/rand" | ||
"testing" | ||
|
||
"github.com/hashicorp/go-uuid" | ||
"github.com/hashicorp/vault/sdk/framework" | ||
"github.com/hashicorp/vault/sdk/helper/certutil" | ||
"github.com/hashicorp/vault/sdk/logical" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
var ctx = context.Background() | ||
|
||
func Test_ConfigsRoundTrip(t *testing.T) { | ||
_, s := createBackendWithStorage(t) | ||
|
||
// Verify we handle nothing stored properly | ||
keyConfigEmpty, err := getKeysConfig(ctx, s) | ||
require.NoError(t, err) | ||
require.Equal(t, &keyConfig{}, keyConfigEmpty) | ||
|
||
issuerConfigEmpty, err := getIssuersConfig(ctx, s) | ||
require.NoError(t, err) | ||
require.Equal(t, &issuerConfig{}, issuerConfigEmpty) | ||
|
||
// Now attempt to store and reload properly | ||
origKeyConfig := &keyConfig{ | ||
DefaultKeyId: genKeyId(t), | ||
} | ||
origIssuerConfig := &issuerConfig{ | ||
DefaultIssuerId: genIssuerId(t), | ||
} | ||
|
||
err = setKeysConfig(ctx, s, origKeyConfig) | ||
require.NoError(t, err) | ||
err = setIssuersConfig(ctx, s, origIssuerConfig) | ||
require.NoError(t, err) | ||
|
||
keyConfig, err := getKeysConfig(ctx, s) | ||
require.NoError(t, err) | ||
require.Equal(t, origKeyConfig, keyConfig) | ||
|
||
issuerConfig, err := getIssuersConfig(ctx, s) | ||
require.NoError(t, err) | ||
require.Equal(t, origIssuerConfig, issuerConfig) | ||
} | ||
|
||
func Test_IssuerRoundTrip(t *testing.T) { | ||
b, s := createBackendWithStorage(t) | ||
issuer1, key1 := genIssuerAndKey(t, b) | ||
issuer2, key2 := genIssuerAndKey(t, b) | ||
|
||
// We get an error when issuer id not found | ||
_, err := fetchIssuerById(ctx, s, issuer1.ID) | ||
require.Error(t, err) | ||
|
||
// We get an error when key id not found | ||
_, err = fetchKeyById(ctx, s, key1.ID) | ||
require.Error(t, err) | ||
|
||
// Now write out our issuers and keys | ||
err = writeKey(ctx, s, key1) | ||
require.NoError(t, err) | ||
err = writeIssuer(ctx, s, issuer1) | ||
require.NoError(t, err) | ||
|
||
err = writeKey(ctx, s, key2) | ||
require.NoError(t, err) | ||
err = writeIssuer(ctx, s, issuer2) | ||
require.NoError(t, err) | ||
|
||
fetchedKey1, err := fetchKeyById(ctx, s, key1.ID) | ||
require.NoError(t, err) | ||
|
||
fetchedIssuer1, err := fetchIssuerById(ctx, s, issuer1.ID) | ||
require.NoError(t, err) | ||
|
||
require.Equal(t, &key1, fetchedKey1) | ||
require.Equal(t, &issuer1, fetchedIssuer1) | ||
|
||
keys, err := listKeys(ctx, s) | ||
require.NoError(t, err) | ||
|
||
require.ElementsMatch(t, []keyId{key1.ID, key2.ID}, keys) | ||
|
||
issuers, err := listIssuers(ctx, s) | ||
require.NoError(t, err) | ||
|
||
require.ElementsMatch(t, []issuerId{issuer1.ID, issuer2.ID}, issuers) | ||
} | ||
|
||
func genIssuerAndKey(t *testing.T, b *backend) (issuer, key) { | ||
certBundle, err := genCertBundle(t, b) | ||
require.NoError(t, err) | ||
|
||
keyId := genKeyId(t) | ||
|
||
pkiKey := key{ | ||
ID: keyId, | ||
PrivateKeyType: certBundle.PrivateKeyType, | ||
PrivateKey: certBundle.PrivateKey, | ||
} | ||
|
||
issuerId := genIssuerId(t) | ||
|
||
pkiIssuer := issuer{ | ||
ID: issuerId, | ||
KeyID: keyId, | ||
Certificate: certBundle.Certificate, | ||
CAChain: certBundle.CAChain, | ||
SerialNumber: certBundle.SerialNumber, | ||
} | ||
|
||
return pkiIssuer, pkiKey | ||
} | ||
|
||
func genIssuerId(t *testing.T) issuerId { | ||
issuerIdStr, err := uuid.GenerateUUID() | ||
require.NoError(t, err) | ||
return issuerId(issuerIdStr) | ||
} | ||
|
||
func genKeyId(t *testing.T) keyId { | ||
keyIdStr, err := uuid.GenerateUUID() | ||
require.NoError(t, err) | ||
return keyId(keyIdStr) | ||
} | ||
|
||
func genCertBundle(t *testing.T, b *backend) (*certutil.CertBundle, error) { | ||
// Pretty gross just to generate a cert bundle, but | ||
fields := addCACommonFields(map[string]*framework.FieldSchema{}) | ||
fields = addCAKeyGenerationFields(fields) | ||
fields = addCAIssueFields(fields) | ||
apiData := &framework.FieldData{ | ||
Schema: fields, | ||
Raw: map[string]interface{}{ | ||
"exported": "internal", | ||
"cn": "example.com", | ||
"ttl": 3600, | ||
}, | ||
} | ||
_, _, role, respErr := b.getGenerationParams(ctx, apiData, "/pki") | ||
require.Nil(t, respErr) | ||
|
||
input := &inputBundle{ | ||
req: &logical.Request{ | ||
Operation: logical.UpdateOperation, | ||
Path: "issue/testrole", | ||
Storage: b.storage, | ||
}, | ||
apiData: apiData, | ||
role: role, | ||
} | ||
parsedCertBundle, err := generateCert(ctx, b, input, nil, true, rand.Reader) | ||
|
||
require.NoError(t, err) | ||
certBundle, err := parsedCertBundle.ToCertBundle() | ||
require.NoError(t, err) | ||
return certBundle, err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters