diff --git a/account_test.go b/account_test.go index 45697ece..2504d1b0 100644 --- a/account_test.go +++ b/account_test.go @@ -29,17 +29,17 @@ import ( "go.uber.org/zap" ) -// memoryStorage is an in-memory storage implementation with known contents *and* fixed iteration order for List. -type memoryStorage struct { - contents []memoryStorageItem +// testingMemoryStorage is an in-memory storage implementation with known contents *and* fixed iteration order for List. +type testingMemoryStorage struct { + contents []testingMemoryStorageItem } -type memoryStorageItem struct { +type testingMemoryStorageItem struct { key string data []byte } -func (m *memoryStorage) lookup(_ context.Context, key string) *memoryStorageItem { +func (m *testingMemoryStorage) lookup(_ context.Context, key string) *testingMemoryStorageItem { for _, item := range m.contents { if item.key == key { return &item @@ -47,7 +47,7 @@ func (m *memoryStorage) lookup(_ context.Context, key string) *memoryStorageItem } return nil } -func (m *memoryStorage) Delete(ctx context.Context, key string) error { +func (m *testingMemoryStorage) Delete(ctx context.Context, key string) error { for i, item := range m.contents { if item.key == key { m.contents = append(m.contents[:i], m.contents[i+1:]...) @@ -56,14 +56,14 @@ func (m *memoryStorage) Delete(ctx context.Context, key string) error { } return fs.ErrNotExist } -func (m *memoryStorage) Store(ctx context.Context, key string, value []byte) error { - m.contents = append(m.contents, memoryStorageItem{key: key, data: value}) +func (m *testingMemoryStorage) Store(ctx context.Context, key string, value []byte) error { + m.contents = append(m.contents, testingMemoryStorageItem{key: key, data: value}) return nil } -func (m *memoryStorage) Exists(ctx context.Context, key string) bool { +func (m *testingMemoryStorage) Exists(ctx context.Context, key string) bool { return m.lookup(ctx, key) != nil } -func (m *memoryStorage) List(ctx context.Context, path string, recursive bool) ([]string, error) { +func (m *testingMemoryStorage) List(ctx context.Context, path string, recursive bool) ([]string, error) { if recursive { panic("unimplemented") } @@ -88,22 +88,22 @@ nextitem: } return result, nil } -func (m *memoryStorage) Load(ctx context.Context, key string) ([]byte, error) { +func (m *testingMemoryStorage) Load(ctx context.Context, key string) ([]byte, error) { if item := m.lookup(ctx, key); item != nil { return item.data, nil } return nil, fs.ErrNotExist } -func (m *memoryStorage) Stat(ctx context.Context, key string) (KeyInfo, error) { +func (m *testingMemoryStorage) Stat(ctx context.Context, key string) (KeyInfo, error) { if item := m.lookup(ctx, key); item != nil { return KeyInfo{Key: key, Size: int64(len(item.data))}, nil } return KeyInfo{}, fs.ErrNotExist } -func (m *memoryStorage) Lock(ctx context.Context, name string) error { panic("unimplemented") } -func (m *memoryStorage) Unlock(ctx context.Context, name string) error { panic("unimplemented") } +func (m *testingMemoryStorage) Lock(ctx context.Context, name string) error { panic("unimplemented") } +func (m *testingMemoryStorage) Unlock(ctx context.Context, name string) error { panic("unimplemented") } -var _ Storage = (*memoryStorage)(nil) +var _ Storage = (*testingMemoryStorage)(nil) type recordingStorage struct { Storage @@ -293,7 +293,7 @@ func TestGetAccountAlreadyExistsSkipsBroken(t *testing.T) { am := &ACMEIssuer{CA: dummyCA, Logger: zap.NewNop(), mu: new(sync.Mutex)} testConfig := &Config{ Issuers: []Issuer{am}, - Storage: &memoryStorage{}, + Storage: &testingMemoryStorage{}, Logger: defaultTestLogger, certCache: new(Cache), } @@ -342,7 +342,7 @@ func TestGetAccountWithEmailAlreadyExists(t *testing.T) { am := &ACMEIssuer{CA: dummyCA, Logger: zap.NewNop(), mu: new(sync.Mutex)} testConfig := &Config{ Issuers: []Issuer{am}, - Storage: &recordingStorage{Storage: &memoryStorage{}}, + Storage: &recordingStorage{Storage: &testingMemoryStorage{}}, Logger: defaultTestLogger, certCache: new(Cache), } diff --git a/memorystorage.go b/memorystorage.go new file mode 100644 index 00000000..1b04cc56 --- /dev/null +++ b/memorystorage.go @@ -0,0 +1,200 @@ +// Copyright 2015 Matthew Holt +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package certmagic + +import ( + "context" + "os" + "path" + "strings" + "sync" + "time" + + "golang.org/x/sync/semaphore" +) + +type storageEntry struct { + i KeyInfo + d []byte +} + +// memoryStorage is a Storage implemention that exists only in memory +// it is intended for testing and one-time-deploys where no persistence is needed +type memoryStorage struct { + m map[string]storageEntry + mu sync.RWMutex + + kmu *keyMutex +} + +func NewMemoryStorage() Storage { + return &memoryStorage{ + m: map[string]storageEntry{}, + kmu: newKeyMutex(), + } +} + +// Exists returns true if key exists in s. +func (s *memoryStorage) Exists(ctx context.Context, key string) bool { + ans, err := s.List(ctx, key, true) + if err != nil { + return false + } + return len(ans) != 0 +} + +// Store saves value at key. +func (s *memoryStorage) Store(_ context.Context, key string, value []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + s.m[key] = storageEntry{ + i: KeyInfo{ + Key: key, + Modified: time.Now(), + Size: int64(len(value)), + IsTerminal: true, + }, + d: value, + } + return nil +} + +// Load retrieves the value at key. +func (s *memoryStorage) Load(_ context.Context, key string) ([]byte, error) { + s.mu.Lock() + defer s.mu.Unlock() + val, ok := s.m[key] + if !ok { + return nil, os.ErrNotExist + } + return val.d, nil +} + +// Delete deletes the value at key. +func (s *memoryStorage) Delete(_ context.Context, key string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.m, key) + return nil +} + +// List returns all keys that match prefix. +func (s *memoryStorage) List(ctx context.Context, prefix string, recursive bool) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.list(ctx, prefix, recursive) +} + +func (s *memoryStorage) list(ctx context.Context, prefix string, recursive bool) ([]string, error) { + var keyList []string + + keys := make([]string, 0, len(s.m)) + for k := range s.m { + if strings.HasPrefix(k, prefix) { + keys = append(keys, k) + } + } + // adapted from https://github.com/pberkel/caddy-storage-redis/blob/main/storage.go#L369 + // Iterate over each child key + for _, k := range keys { + // Directory keys will have a "/" suffix + trimmedKey := strings.TrimSuffix(k, "/") + // Reconstruct the full path of child key + fullPathKey := path.Join(prefix, trimmedKey) + // If current key is a directory + if recursive && k != trimmedKey { + // Recursively traverse all child directories + childKeys, err := s.list(ctx, fullPathKey, recursive) + if err != nil { + return keyList, err + } + keyList = append(keyList, childKeys...) + } else { + keyList = append(keyList, fullPathKey) + } + } + + return keys, nil +} + +// Stat returns information about key. +func (s *memoryStorage) Stat(_ context.Context, key string) (KeyInfo, error) { + s.mu.Lock() + defer s.mu.Unlock() + val, ok := s.m[key] + if !ok { + return KeyInfo{}, os.ErrNotExist + } + return val.i, nil +} + +// Lock obtains a lock named by the given name. It blocks +// until the lock can be obtained or an error is returned. +func (s *memoryStorage) Lock(ctx context.Context, name string) error { + return s.kmu.LockKey(ctx, name) +} + +// Unlock releases the lock for name. +func (s *memoryStorage) Unlock(_ context.Context, name string) error { + return s.kmu.UnlockKey(name) +} + +func (s *memoryStorage) String() string { + return "memoryStorage" +} + +// Interface guard +var _ Storage = (*memoryStorage)(nil) + +type keyMutex struct { + m map[string]*semaphore.Weighted + mu sync.Mutex +} + +func newKeyMutex() *keyMutex { + return &keyMutex{ + m: map[string]*semaphore.Weighted{}, + } +} + +func (km *keyMutex) LockKey(ctx context.Context, id string) error { + select { + case <-ctx.Done(): + // as a special case, caddy allows for the cancelled context to be used for a trylock. + if km.mutex(id).TryAcquire(1) { + return nil + } + return ctx.Err() + default: + return km.mutex(id).Acquire(ctx, 1) + } +} + +// Releases the lock associated with the specified ID. +func (km *keyMutex) UnlockKey(id string) error { + km.mutex(id).Release(1) + return nil +} + +func (km *keyMutex) mutex(id string) *semaphore.Weighted { + km.mu.Lock() + defer km.mu.Unlock() + val, ok := km.m[id] + if !ok { + val = semaphore.NewWeighted(1) + km.m[id] = val + } + return val +} diff --git a/memorystorage_test.go b/memorystorage_test.go new file mode 100644 index 00000000..a12e8f77 --- /dev/null +++ b/memorystorage_test.go @@ -0,0 +1,72 @@ +package certmagic_test + +import ( + "bytes" + "context" + "os" + "testing" + + "github.com/caddyserver/certmagic" + "github.com/caddyserver/certmagic/internal/testutil" +) + +func TestMemoryStorageStoreLoad(t *testing.T) { + ctx := context.Background() + tmpDir, err := os.MkdirTemp(os.TempDir(), "certmagic*") + testutil.RequireNoError(t, err, "allocating tmp dir") + defer os.RemoveAll(tmpDir) + s := certmagic.NewMemoryStorage() + err = s.Store(ctx, "foo", []byte("bar")) + testutil.RequireNoError(t, err) + dat, err := s.Load(ctx, "foo") + testutil.RequireNoError(t, err) + testutil.RequireEqualValues(t, dat, []byte("bar")) +} + +func TestMemoryStorageStoreLoadRace(t *testing.T) { + ctx := context.Background() + tmpDir, err := os.MkdirTemp(os.TempDir(), "certmagic*") + testutil.RequireNoError(t, err, "allocating tmp dir") + defer os.RemoveAll(tmpDir) + s := certmagic.NewMemoryStorage() + a := bytes.Repeat([]byte("a"), 4096*1024) + b := bytes.Repeat([]byte("b"), 4096*1024) + err = s.Store(ctx, "foo", a) + testutil.RequireNoError(t, err) + done := make(chan struct{}) + go func() { + err := s.Store(ctx, "foo", b) + testutil.RequireNoError(t, err) + close(done) + }() + dat, err := s.Load(ctx, "foo") + <-done + testutil.RequireNoError(t, err) + testutil.RequireEqualValues(t, 4096*1024, len(dat)) +} + +func TestMemoryStorageWriteLock(t *testing.T) { + ctx := context.Background() + tmpDir, err := os.MkdirTemp(os.TempDir(), "certmagic*") + testutil.RequireNoError(t, err, "allocating tmp dir") + defer os.RemoveAll(tmpDir) + s := certmagic.NewMemoryStorage() + // cctx is a cancelled ctx. so if we can't immediately get the lock, it will fail + cctx, cn := context.WithCancel(ctx) + cn() + // should success + err = s.Lock(cctx, "foo") + testutil.RequireNoError(t, err) + // should fail + err = s.Lock(cctx, "foo") + testutil.RequireError(t, err) + + err = s.Unlock(cctx, "foo") + testutil.RequireNoError(t, err) + // shouldn't fail + err = s.Lock(cctx, "foo") + testutil.RequireNoError(t, err) + + err = s.Unlock(cctx, "foo") + testutil.RequireNoError(t, err) +}