Skip to content

Commit

Permalink
Add support for different types of private keys for the account
Browse files Browse the repository at this point in the history
This change the data associated with the account so all accounts will
need to be destroyed and recreated before requesting a new certificate:

    $ vault delete acme/accounts/lenstra
    $ vault write acme/accounts/lenstra contact=remi@lenstra.fr ...

Closes #12
  • Loading branch information
remilapeyre committed May 18, 2020
1 parent e116a1f commit e4fea55
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 86 deletions.
14 changes: 8 additions & 6 deletions acme/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package acme
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/x509"
"encoding/pem"

Expand All @@ -15,7 +14,8 @@ import (
type account struct {
Email string
Registration *registration.Resource
Key *ecdsa.PrivateKey
Key crypto.PrivateKey
KeyType string
ServerURL string
Provider string
EnableHTTP01 bool
Expand Down Expand Up @@ -59,14 +59,15 @@ func getAccount(ctx context.Context, storage logical.Storage, path string) (*acc
}

block, _ := pem.Decode([]byte(d["private_key"].(string)))
privateKey, err := x509.ParseECPrivateKey(block.Bytes)
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, err
}

return &account{
Email: d["contact"].(string),
Key: privateKey,
Email: d["contact"].(string),
Key: privateKey,
KeyType: d["key_type"].(string),
Registration: &registration.Resource{
URI: d["registration_uri"].(string),
},
Expand All @@ -79,7 +80,7 @@ func getAccount(ctx context.Context, storage logical.Storage, path string) (*acc
}

func (a *account) save(ctx context.Context, storage logical.Storage, path string, serverURL string) error {
x509Encoded, err := x509.MarshalECPrivateKey(a.Key)
x509Encoded, err := x509.MarshalPKCS8PrivateKey(a.Key)
if err != nil {
return err
}
Expand All @@ -91,6 +92,7 @@ func (a *account) save(ctx context.Context, storage logical.Storage, path string
"contact": a.GetEmail(),
"terms_of_service_agreed": a.TermsOfServiceAgreed,
"private_key": string(pemEncoded),
"key_type": a.KeyType,
"provider": a.Provider,
"enable_http_01": a.EnableHTTP01,
"enable_tls_alpn_01": a.EnableTLSALPN01,
Expand Down
80 changes: 4 additions & 76 deletions acme/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import (
"net"
"net/http"
"os"
"reflect"
"strings"
"testing"
"time"

"github.com/hashicorp/vault/sdk/logical"
"github.com/remilapeyre/vault-acme/acme/sidecar"
"github.com/stretchr/testify/require"
)

var serverURL string
Expand Down Expand Up @@ -277,72 +277,6 @@ func TestTLSALPN01Challenge(t *testing.T) {
makeRequest(t, b, req, "")
}

func TestAccounts(t *testing.T) {
config, b := getTestConfig(t)

// Create account
req := &logical.Request{
Operation: logical.CreateOperation,
Path: "accounts/lenstra",
Storage: config.StorageView,
Data: map[string]interface{}{
"server_url": serverURL,
"contact": "remi@lenstra.fr",
"terms_of_service_agreed": true,
"provider": "exec",
},
}
resp := makeRequest(t, b, req, "")

account := map[string]interface{}{}
for k, v := range resp.Data {
account[k] = v
}
delete(resp.Data, "registration_uri")

expected := map[string]interface{}{
"contact": "remi@lenstra.fr",
"server_url": serverURL,
"terms_of_service_agreed": true,
"provider": "exec",
"enable_http_01": false,
"enable_tls_alpn_01": false,
}
assertEqual(t, expected, resp.Data)

// Read account
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "accounts/lenstra",
Storage: config.StorageView,
}
resp = makeRequest(t, b, req, "")
assertEqual(t, account, resp.Data)

// Read unknown account
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "accounts/foobar",
Storage: config.StorageView,
}
resp = makeRequest(t, b, req, "This account does not exists")

// Delete account
req = &logical.Request{
Operation: logical.DeleteOperation,
Path: "accounts/lenstra",
Storage: config.StorageView,
}
makeRequest(t, b, req, "")

req = &logical.Request{
Operation: logical.DeleteOperation,
Path: "accounts/foobar",
Storage: config.StorageView,
}
makeRequest(t, b, req, "This account does not exists")
}

func TestRoles(t *testing.T) {
config, b := getTestConfig(t)
createAccount(t, b, config.StorageView)
Expand Down Expand Up @@ -378,7 +312,7 @@ func TestRoles(t *testing.T) {
Data: tcase.RequestData,
}
resp := makeRequest(t, b, req, "")
assertEqual(t, tcase.ExpectedResponse, resp.Data)
require.Equal(t, tcase.ExpectedResponse, resp.Data)
}

req := &logical.Request{
Expand All @@ -400,7 +334,7 @@ func TestRoles(t *testing.T) {
Storage: config.StorageView,
}
resp := makeRequest(t, b, req, "")
assertEqual(
require.Equal(
t,
resp.Data,
map[string]interface{}{
Expand Down Expand Up @@ -470,7 +404,7 @@ func checkCreatingCerts(t *testing.T, b logical.Backend, storage logical.Storage

// Since caching is enabled, we should get the same cert when calling the
// endoint twice
assertEqual(t, first.Data, second.Data)
require.Equal(t, first.Data, second.Data)

return first, second
}
Expand Down Expand Up @@ -626,12 +560,6 @@ func makeRequest(t *testing.T, b logical.Backend, req *logical.Request, expected
return resp
}

func assertEqual(t *testing.T, expected, data map[string]interface{}) {
if !reflect.DeepEqual(expected, data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, data)
}
}

type testCase struct {
Path string
RequestData map[string]interface{}
Expand Down
44 changes: 40 additions & 4 deletions acme/path_accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@ package acme

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"fmt"

"github.com/go-acme/lego/v3/certcrypto"
"github.com/go-acme/lego/v3/registration"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)

var keyTypes = []interface{}{
"EC256",
"EC384",
"RSA2048",
"RSA4096",
"RSA8192",
}

func pathAccounts(b *backend) *framework.Path {
return &framework.Path{
Pattern: "accounts/" + framework.GenericNameRegex("account"),
Expand All @@ -29,6 +36,11 @@ func pathAccounts(b *backend) *framework.Path {
Type: framework.TypeBool,
Default: false,
},
"key_type": &framework.FieldSchema{
Type: framework.TypeString,
Default: "EC256",
AllowedValues: keyTypes,
},
// TODO(remi): We should have a list of those so we can request certs
// for domains registred to different providers
"provider": &framework.FieldSchema{
Expand Down Expand Up @@ -57,6 +69,23 @@ func pathAccounts(b *backend) *framework.Path {
}
}

func getKeyType(t string) (certcrypto.KeyType, error) {
switch t {
case "EC256":
return certcrypto.EC256, nil
case "EC384":
return certcrypto.EC384, nil
case "RSA2048":
return certcrypto.RSA2048, nil
case "RSA4096":
return certcrypto.RSA4096, nil
case "RSA8192":
return certcrypto.RSA8192, nil
default:
return certcrypto.KeyType(""), fmt.Errorf("%q is not a supported key type", t)
}
}

func (b *backend) accountCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if err := data.Validate(); err != nil {
return nil, err
Expand All @@ -68,15 +97,21 @@ func (b *backend) accountCreate(ctx context.Context, req *logical.Request, data
enableHTTP01 := data.Get("enable_http_01").(bool)
enableTLSALPN01 := data.Get("enable_tls_alpn_01").(bool)

keyType, err := getKeyType(data.Get("key_type").(string))
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}

b.Logger().Info("Generating key pair for new account")
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
privateKey, err := certcrypto.GeneratePrivateKey(keyType)
if err != nil {
return nil, errwrap.Wrapf("Failed to generate account key pair: {{err}}", err)
}

user := account{
Email: contact,
Key: privateKey,
KeyType: data.Get("key_type").(string),
ServerURL: serverURL,
Provider: provider,
EnableHTTP01: enableHTTP01,
Expand Down Expand Up @@ -125,6 +160,7 @@ func (b *backend) accountRead(ctx context.Context, req *logical.Request, data *f
"registration_uri": a.Registration.URI,
"contact": a.GetEmail(),
"terms_of_service_agreed": a.TermsOfServiceAgreed,
"key_type": a.KeyType,
"provider": a.Provider,
"enable_http_01": a.EnableHTTP01,
"enable_tls_alpn_01": a.EnableTLSALPN01,
Expand Down
123 changes: 123 additions & 0 deletions acme/path_accounts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package acme

import (
"testing"

"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)

func TestAccounts(t *testing.T) {
config, b := getTestConfig(t)

data := map[string]interface{}{
"server_url": serverURL,
"contact": "remi@lenstra.fr",
"terms_of_service_agreed": true,
"provider": "exec",
}
expected := map[string]interface{}{
"contact": "remi@lenstra.fr",
"server_url": serverURL,
"terms_of_service_agreed": true,
"provider": "exec",
"key_type": "EC256",
"enable_http_01": false,
"enable_tls_alpn_01": false,
}

testCases := []struct {
keyTypeIn string
keyTypeOut string
}{
{
"",
"EC256",
},
{
"EC256",
"EC256",
},
{
"EC384",
"EC384",
},
{
"RSA2048",
"RSA2048",
},
{
"RSA4096",
"RSA4096",
},
{
"RSA8192",
"RSA8192",
},
}

for _, tc := range testCases {
if tc.keyTypeIn == "" {
delete(data, "key_type")
} else {
data["key_type"] = tc.keyTypeIn
}
expected["key_type"] = tc.keyTypeOut

// Create account
req := &logical.Request{
Operation: logical.CreateOperation,
Path: "accounts/lenstra",
Storage: config.StorageView,
Data: data,
}
resp := makeRequest(t, b, req, "")

delete(resp.Data, "registration_uri")
require.Equal(t, expected, resp.Data)

// Read account
req.Operation = logical.ReadOperation
resp = makeRequest(t, b, req, "")
delete(resp.Data, "registration_uri")
require.Equal(t, expected, resp.Data)

// Delete account
req.Operation = logical.ReadOperation
makeRequest(t, b, req, "")
}

// Unsupported key type
data["key_type"] = "foo"
req := &logical.Request{
Operation: logical.CreateOperation,
Path: "accounts/lenstra",
Storage: config.StorageView,
Data: data,
}
makeRequest(t, b, req, `"foo" is not a supported key type`)
}

func TestDeleteAccount(t *testing.T) {
config, b := getTestConfig(t)

req := &logical.Request{
Operation: logical.CreateOperation,
Path: "accounts/lenstra",
Storage: config.StorageView,
Data: map[string]interface{}{
"server_url": serverURL,
"contact": "remi@lenstra.fr",
"terms_of_service_agreed": true,
"provider": "exec",
},
}
makeRequest(t, b, req, "")

req.Operation = logical.DeleteOperation
makeRequest(t, b, req, "")
makeRequest(t, b, req, "This account does not exists")

req.Operation = logical.ReadOperation
makeRequest(t, b, req, "This account does not exists")
}
Loading

0 comments on commit e4fea55

Please sign in to comment.