diff --git a/account.go b/account.go index 2853691c..f3b8d44d 100644 --- a/account.go +++ b/account.go @@ -88,11 +88,18 @@ func (*ACMEIssuer) newAccount(email string) (acme.Account, error) { // If it does not exist in storage, it will be retrieved from the ACME server and added to storage. // The account must already exist; it does not create a new account. func (am *ACMEIssuer) GetAccount(ctx context.Context, privateKeyPEM []byte) (acme.Account, error) { - account, err := am.loadAccountByKey(ctx, privateKeyPEM) - if errors.Is(err, fs.ErrNotExist) { - account, err = am.lookUpAccount(ctx, privateKeyPEM) + email := am.getEmail() + if email == "" { + if account, err := am.loadAccountByKey(ctx, privateKeyPEM); err == nil { + return account, nil + } + } else { + keyBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserPrivateKey(am.CA, email)) + if err == nil && bytes.Equal(bytes.TrimSpace(keyBytes), bytes.TrimSpace(privateKeyPEM)) { + return am.loadAccount(ctx, am.CA, email) + } } - return account, err + return am.lookUpAccount(ctx, privateKeyPEM) } // loadAccountByKey loads the account with the given private key from storage, if it exists. @@ -107,9 +114,14 @@ func (am *ACMEIssuer) loadAccountByKey(ctx context.Context, privateKeyPEM []byte email := path.Base(accountFolderKey) keyBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserPrivateKey(am.CA, email)) if err != nil { - return acme.Account{}, err + // Try the next account: This one is missing its private key, if it turns out to be the one we're looking + // for we will try to save it again after confirming with the ACME server. + continue } if bytes.Equal(bytes.TrimSpace(keyBytes), bytes.TrimSpace(privateKeyPEM)) { + // Found the account with the correct private key, try loading it. If this fails we we will follow + // the same procedure as if the private key was not found and confirm with the ACME server before saving + // it again. return am.loadAccount(ctx, am.CA, email) } } diff --git a/account_test.go b/account_test.go index 2fd72f93..24389c07 100644 --- a/account_test.go +++ b/account_test.go @@ -17,6 +17,7 @@ package certmagic import ( "bytes" "context" + "io/fs" "os" "path/filepath" "reflect" @@ -26,6 +27,131 @@ import ( "time" ) +// memoryStorage is an in-memory storage implementation with known contents *and* fixed iteration order for List. +type memoryStorage struct { + contents []memoryStorageItem +} + +type memoryStorageItem struct { + key string + data []byte +} + +func (m *memoryStorage) lookup(_ context.Context, key string) *memoryStorageItem { + for _, item := range m.contents { + if item.key == key { + return &item + } + } + return nil +} +func (m *memoryStorage) Delete(ctx context.Context, key string) error { + for i, item := range m.contents { + if item.key == key { + m.contents = append(m.contents[:i], m.contents[i+1:]...) + return nil + } + } + return fs.ErrNotExist +} +func (m *memoryStorage) Store(ctx context.Context, key string, value []byte) error { + m.contents = append(m.contents, memoryStorageItem{key: key, data: value}) + return nil +} +func (m *memoryStorage) Exists(ctx context.Context, key string) bool { + return m.lookup(ctx, key) != nil +} +func (m *memoryStorage) List(ctx context.Context, path string, recursive bool) ([]string, error) { + if recursive { + panic("unimplemented") + } + + result := []string{} +nextitem: + for _, item := range m.contents { + if !strings.HasPrefix(item.key, path+"/") { + continue + } + name := strings.TrimPrefix(item.key, path+"/") + if i := strings.Index(name, "/"); i >= 0 { + name = name[:i] + } + + for _, existing := range result { + if existing == name { + continue nextitem + } + } + result = append(result, name) + } + return result, nil +} +func (m *memoryStorage) Load(ctx context.Context, key string) ([]byte, error) { + if item := m.lookup(ctx, key); item != nil { + return item.data, nil + } + return nil, fs.ErrNotExist +} +func (m *memoryStorage) Stat(ctx context.Context, key string) (KeyInfo, error) { + if item := m.lookup(ctx, key); item != nil { + return KeyInfo{Key: key, Size: int64(len(item.data))}, nil + } + return KeyInfo{}, fs.ErrNotExist +} +func (m *memoryStorage) Lock(ctx context.Context, name string) error { panic("unimplemented") } +func (m *memoryStorage) Unlock(ctx context.Context, name string) error { panic("unimplemented") } + +var _ Storage = (*memoryStorage)(nil) + +type recordingStorage struct { + Storage + calls []recordedCall +} + +func (r *recordingStorage) Delete(ctx context.Context, key string) error { + r.record("Delete", key) + return r.Storage.Delete(ctx, key) +} +func (r *recordingStorage) Exists(ctx context.Context, key string) bool { + r.record("Exists", key) + return r.Storage.Exists(ctx, key) +} +func (r *recordingStorage) List(ctx context.Context, path string, recursive bool) ([]string, error) { + r.record("List", path, recursive) + return r.Storage.List(ctx, path, recursive) +} +func (r *recordingStorage) Load(ctx context.Context, key string) ([]byte, error) { + r.record("Load", key) + return r.Storage.Load(ctx, key) +} +func (r *recordingStorage) Lock(ctx context.Context, name string) error { + r.record("Lock", name) + return r.Storage.Lock(ctx, name) +} +func (r *recordingStorage) Stat(ctx context.Context, key string) (KeyInfo, error) { + r.record("Stat", key) + return r.Storage.Stat(ctx, key) +} +func (r *recordingStorage) Store(ctx context.Context, key string, value []byte) error { + r.record("Store", key) + return r.Storage.Store(ctx, key, value) +} +func (r *recordingStorage) Unlock(ctx context.Context, name string) error { + r.record("Unlock", name) + return r.Storage.Unlock(ctx, name) +} + +type recordedCall struct { + name string + args []interface{} +} + +func (r *recordingStorage) record(name string, args ...interface{}) { + r.calls = append(r.calls, recordedCall{name: name, args: args}) +} + +var _ Storage = (*recordingStorage)(nil) + func TestNewAccount(t *testing.T) { am := &ACMEIssuer{CA: dummyCA, mu: new(sync.Mutex)} testConfig := &Config{ @@ -159,6 +285,116 @@ func TestGetAccountAlreadyExists(t *testing.T) { } } +func TestGetAccountAlreadyExistsSkipsBroken(t *testing.T) { + ctx := context.Background() + + am := &ACMEIssuer{CA: dummyCA, mu: new(sync.Mutex)} + testConfig := &Config{ + Issuers: []Issuer{am}, + Storage: &memoryStorage{}, + Logger: defaultTestLogger, + certCache: new(Cache), + } + am.config = testConfig + + email := "me@foobar.com" + + // Create a "corrupted" account + am.config.Storage.Store(ctx, am.storageKeyUserReg(am.CA, "notmeatall@foobar.com"), []byte("this is not a valid account")) + + // Create the actual account + account, err := am.newAccount(email) + if err != nil { + t.Fatalf("Error creating account: %v", err) + } + err = am.saveAccount(ctx, am.CA, account) + if err != nil { + t.Fatalf("Error saving account: %v", err) + } + + // Expect to load account from disk + keyBytes, err := PEMEncodePrivateKey(account.PrivateKey) + if err != nil { + t.Fatalf("Error encoding private key: %v", err) + } + + loadedAccount, err := am.GetAccount(ctx, keyBytes) + if err != nil { + t.Fatalf("Error getting account: %v", err) + } + + // Assert keys are the same + if !privateKeysSame(account.PrivateKey, loadedAccount.PrivateKey) { + t.Error("Expected private key to be the same after loading, but it wasn't") + } + + // Assert emails are the same + if !reflect.DeepEqual(account.Contact, loadedAccount.Contact) { + t.Errorf("Expected contacts to be equal, but was '%s' before and '%s' after loading", account.Contact, loadedAccount.Contact) + } +} + +func TestGetAccountWithEmailAlreadyExists(t *testing.T) { + ctx := context.Background() + + am := &ACMEIssuer{CA: dummyCA, mu: new(sync.Mutex)} + testConfig := &Config{ + Issuers: []Issuer{am}, + Storage: &recordingStorage{Storage: &memoryStorage{}}, + Logger: defaultTestLogger, + certCache: new(Cache), + } + am.config = testConfig + + email := "me@foobar.com" + + // Set up test + account, err := am.newAccount(email) + if err != nil { + t.Fatalf("Error creating account: %v", err) + } + err = am.saveAccount(ctx, am.CA, account) + if err != nil { + t.Fatalf("Error saving account: %v", err) + } + + // Set the expected email: + am.Email = email + err = am.setEmail(ctx, true) + if err != nil { + t.Fatalf("setEmail error: %v", err) + } + + // Expect to load account from disk + keyBytes, err := PEMEncodePrivateKey(account.PrivateKey) + if err != nil { + t.Fatalf("Error encoding private key: %v", err) + } + + loadedAccount, err := am.GetAccount(ctx, keyBytes) + if err != nil { + t.Fatalf("Error getting account: %v", err) + } + + // Assert keys are the same + if !privateKeysSame(account.PrivateKey, loadedAccount.PrivateKey) { + t.Error("Expected private key to be the same after loading, but it wasn't") + } + + // Assert emails are the same + if !reflect.DeepEqual(account.Contact, loadedAccount.Contact) { + t.Errorf("Expected contacts to be equal, but was '%s' before and '%s' after loading", account.Contact, loadedAccount.Contact) + } + + // Assert that this was found without listing all accounts + rs := testConfig.Storage.(*recordingStorage) + for _, call := range rs.calls { + if call.name == "List" { + t.Error("Unexpected List call") + } + } +} + func TestGetEmailFromPackageDefault(t *testing.T) { ctx := context.Background()