Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the email configuration in the ACME issuer to "pin" an account to a key #283

Merged
merged 3 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions account.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,18 @@ func (*ACMEIssuer) newAccount(email string) (acme.Account, error) {
// If it does not exist in storage, it will be retrieved from the ACME server and added to storage.
// The account must already exist; it does not create a new account.
func (am *ACMEIssuer) GetAccount(ctx context.Context, privateKeyPEM []byte) (acme.Account, error) {
account, err := am.loadAccountByKey(ctx, privateKeyPEM)
if errors.Is(err, fs.ErrNotExist) {
account, err = am.lookUpAccount(ctx, privateKeyPEM)
email := am.getEmail()
if email == "" {
if account, err := am.loadAccountByKey(ctx, privateKeyPEM); err == nil {
return account, nil
}
} else {
keyBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserPrivateKey(am.CA, email))
if err == nil && bytes.Equal(bytes.TrimSpace(keyBytes), bytes.TrimSpace(privateKeyPEM)) {
return am.loadAccount(ctx, am.CA, email)
}
}
return account, err
return am.lookUpAccount(ctx, privateKeyPEM)
}

// loadAccountByKey loads the account with the given private key from storage, if it exists.
Expand All @@ -107,9 +114,14 @@ func (am *ACMEIssuer) loadAccountByKey(ctx context.Context, privateKeyPEM []byte
email := path.Base(accountFolderKey)
keyBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserPrivateKey(am.CA, email))
if err != nil {
return acme.Account{}, err
// Try the next account: This one is missing its private key, if it turns out to be the one we're looking
// for we will try to save it again after confirming with the ACME server.
continue
}
if bytes.Equal(bytes.TrimSpace(keyBytes), bytes.TrimSpace(privateKeyPEM)) {
// Found the account with the correct private key, try loading it. If this fails we we will follow
// the same procedure as if the private key was not found and confirm with the ACME server before saving
// it again.
return am.loadAccount(ctx, am.CA, email)
}
}
Expand Down
236 changes: 236 additions & 0 deletions account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package certmagic
import (
"bytes"
"context"
"io/fs"
"os"
"path/filepath"
"reflect"
Expand All @@ -26,6 +27,131 @@ import (
"time"
)

// memoryStorage is an in-memory storage implementation with known contents *and* fixed iteration order for List.
type memoryStorage struct {
contents []memoryStorageItem
}
Comment on lines +30 to +33
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW: Apart from needing this for controlling the iteration order of List, this might also be faster in general than using an actual file system for unit tests. And, there's good references that say that I/O in unit tests should be banned. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved 💯


type memoryStorageItem struct {
key string
data []byte
}

func (m *memoryStorage) lookup(_ context.Context, key string) *memoryStorageItem {
for _, item := range m.contents {
if item.key == key {
return &item
}
}
return nil
}
func (m *memoryStorage) Delete(ctx context.Context, key string) error {
for i, item := range m.contents {
if item.key == key {
m.contents = append(m.contents[:i], m.contents[i+1:]...)
return nil
}
}
return fs.ErrNotExist
}
func (m *memoryStorage) Store(ctx context.Context, key string, value []byte) error {
m.contents = append(m.contents, memoryStorageItem{key: key, data: value})
return nil
}
func (m *memoryStorage) Exists(ctx context.Context, key string) bool {
return m.lookup(ctx, key) != nil
}
func (m *memoryStorage) List(ctx context.Context, path string, recursive bool) ([]string, error) {
if recursive {
panic("unimplemented")
}

result := []string{}
nextitem:
for _, item := range m.contents {
if !strings.HasPrefix(item.key, path+"/") {
continue
}
name := strings.TrimPrefix(item.key, path+"/")
if i := strings.Index(name, "/"); i >= 0 {
name = name[:i]
}

for _, existing := range result {
if existing == name {
continue nextitem
}
}
result = append(result, name)
}
return result, nil
}
func (m *memoryStorage) Load(ctx context.Context, key string) ([]byte, error) {
if item := m.lookup(ctx, key); item != nil {
return item.data, nil
}
return nil, fs.ErrNotExist
}
func (m *memoryStorage) Stat(ctx context.Context, key string) (KeyInfo, error) {
if item := m.lookup(ctx, key); item != nil {
return KeyInfo{Key: key, Size: int64(len(item.data))}, nil
}
return KeyInfo{}, fs.ErrNotExist
}
func (m *memoryStorage) Lock(ctx context.Context, name string) error { panic("unimplemented") }
func (m *memoryStorage) Unlock(ctx context.Context, name string) error { panic("unimplemented") }

var _ Storage = (*memoryStorage)(nil)

type recordingStorage struct {
Storage
calls []recordedCall
}

func (r *recordingStorage) Delete(ctx context.Context, key string) error {
r.record("Delete", key)
return r.Storage.Delete(ctx, key)
}
func (r *recordingStorage) Exists(ctx context.Context, key string) bool {
r.record("Exists", key)
return r.Storage.Exists(ctx, key)
}
func (r *recordingStorage) List(ctx context.Context, path string, recursive bool) ([]string, error) {
r.record("List", path, recursive)
return r.Storage.List(ctx, path, recursive)
}
func (r *recordingStorage) Load(ctx context.Context, key string) ([]byte, error) {
r.record("Load", key)
return r.Storage.Load(ctx, key)
}
func (r *recordingStorage) Lock(ctx context.Context, name string) error {
r.record("Lock", name)
return r.Storage.Lock(ctx, name)
}
func (r *recordingStorage) Stat(ctx context.Context, key string) (KeyInfo, error) {
r.record("Stat", key)
return r.Storage.Stat(ctx, key)
}
func (r *recordingStorage) Store(ctx context.Context, key string, value []byte) error {
r.record("Store", key)
return r.Storage.Store(ctx, key, value)
}
func (r *recordingStorage) Unlock(ctx context.Context, name string) error {
r.record("Unlock", name)
return r.Storage.Unlock(ctx, name)
}

type recordedCall struct {
name string
args []interface{}
}

func (r *recordingStorage) record(name string, args ...interface{}) {
r.calls = append(r.calls, recordedCall{name: name, args: args})
}

var _ Storage = (*recordingStorage)(nil)

func TestNewAccount(t *testing.T) {
am := &ACMEIssuer{CA: dummyCA, mu: new(sync.Mutex)}
testConfig := &Config{
Expand Down Expand Up @@ -159,6 +285,116 @@ func TestGetAccountAlreadyExists(t *testing.T) {
}
}

func TestGetAccountAlreadyExistsSkipsBroken(t *testing.T) {
ctx := context.Background()

am := &ACMEIssuer{CA: dummyCA, mu: new(sync.Mutex)}
testConfig := &Config{
Issuers: []Issuer{am},
Storage: &memoryStorage{},
Logger: defaultTestLogger,
certCache: new(Cache),
}
am.config = testConfig

email := "me@foobar.com"

// Create a "corrupted" account
am.config.Storage.Store(ctx, am.storageKeyUserReg(am.CA, "notmeatall@foobar.com"), []byte("this is not a valid account"))

// Create the actual account
account, err := am.newAccount(email)
if err != nil {
t.Fatalf("Error creating account: %v", err)
}
err = am.saveAccount(ctx, am.CA, account)
if err != nil {
t.Fatalf("Error saving account: %v", err)
}

// Expect to load account from disk
keyBytes, err := PEMEncodePrivateKey(account.PrivateKey)
if err != nil {
t.Fatalf("Error encoding private key: %v", err)
}

loadedAccount, err := am.GetAccount(ctx, keyBytes)
if err != nil {
t.Fatalf("Error getting account: %v", err)
}

// Assert keys are the same
if !privateKeysSame(account.PrivateKey, loadedAccount.PrivateKey) {
t.Error("Expected private key to be the same after loading, but it wasn't")
}

// Assert emails are the same
if !reflect.DeepEqual(account.Contact, loadedAccount.Contact) {
t.Errorf("Expected contacts to be equal, but was '%s' before and '%s' after loading", account.Contact, loadedAccount.Contact)
}
}

func TestGetAccountWithEmailAlreadyExists(t *testing.T) {
ctx := context.Background()

am := &ACMEIssuer{CA: dummyCA, mu: new(sync.Mutex)}
testConfig := &Config{
Issuers: []Issuer{am},
Storage: &recordingStorage{Storage: &memoryStorage{}},
Logger: defaultTestLogger,
certCache: new(Cache),
}
am.config = testConfig

email := "me@foobar.com"

// Set up test
account, err := am.newAccount(email)
if err != nil {
t.Fatalf("Error creating account: %v", err)
}
err = am.saveAccount(ctx, am.CA, account)
if err != nil {
t.Fatalf("Error saving account: %v", err)
}

// Set the expected email:
am.Email = email
err = am.setEmail(ctx, true)
if err != nil {
t.Fatalf("setEmail error: %v", err)
}

// Expect to load account from disk
keyBytes, err := PEMEncodePrivateKey(account.PrivateKey)
if err != nil {
t.Fatalf("Error encoding private key: %v", err)
}

loadedAccount, err := am.GetAccount(ctx, keyBytes)
if err != nil {
t.Fatalf("Error getting account: %v", err)
}

// Assert keys are the same
if !privateKeysSame(account.PrivateKey, loadedAccount.PrivateKey) {
t.Error("Expected private key to be the same after loading, but it wasn't")
}

// Assert emails are the same
if !reflect.DeepEqual(account.Contact, loadedAccount.Contact) {
t.Errorf("Expected contacts to be equal, but was '%s' before and '%s' after loading", account.Contact, loadedAccount.Contact)
}

// Assert that this was found without listing all accounts
rs := testConfig.Storage.(*recordingStorage)
for _, call := range rs.calls {
if call.name == "List" {
t.Error("Unexpected List call")
}
}
}

func TestGetEmailFromPackageDefault(t *testing.T) {
ctx := context.Background()

Expand Down
Loading