From 8b2931a161b1e1aff5eecff18f2dd92db6217d18 Mon Sep 17 00:00:00 2001 From: Victor Conner Date: Mon, 9 Dec 2024 09:20:21 +0100 Subject: [PATCH] fix: Partial responses from the distributed storage could result in faulty missing records (#21) --- distribution.go | 7 +- distribution_test.go | 177 ++++++++++++++++++++++++++++++++++++++++++- errors.go | 4 + fetch.go | 5 +- inflight.go | 32 +++++--- refresh.go | 7 +- safe.go | 5 +- 7 files changed, 216 insertions(+), 21 deletions(-) diff --git a/distribution.go b/distribution.go index ec67fd5..8d49397 100644 --- a/distribution.go +++ b/distribution.go @@ -191,7 +191,7 @@ func distributedBatchFetch[V, T any](c *Client[T], keyFn KeyFn, fetchFn BatchFet continue } - // If distributedStaleStorage isn't enabled it means all records are fresh, otherwise checked the CreatedAt time. + // If early refreshes isn't enabled it means all records are fresh, otherwise we'll check the CreatedAt time. if !c.distributedEarlyRefreshes || c.clock.Since(record.CreatedAt) < c.distributedRefreshAfterDuration { // We never want to return missing records. if !record.IsMissingRecord { @@ -219,13 +219,16 @@ func distributedBatchFetch[V, T any](c *Client[T], keyFn KeyFn, fetchFn BatchFet dataSourceResponses, err := fetchFn(ctx, idsToRefresh) // In case of an error, we'll proceed with the ones we got from the distributed storage. + // NOTE: It's important that we return a specific error here, otherwise we'll potentially + // end up caching the IDs that we weren't able to retrieve from the underlying data source + // as missing records. if err != nil { for i := 0; i < len(stale); i++ { c.reportDistributedStaleFallback() } c.log.Error(fmt.Sprintf("sturdyc: error fetching records from the underlying data source. %v", err)) maps.Copy(stale, fresh) - return stale, nil + return stale, errOnlyDistributedRecords } // Next, we'll want to check if we should change any of the records to be missing or perform deletions. diff --git a/distribution_test.go b/distribution_test.go index 5b97aa3..b6fa9da 100644 --- a/distribution_test.go +++ b/distribution_test.go @@ -3,6 +3,7 @@ package sturdyc_test import ( "context" "errors" + "strconv" "sync" "testing" "time" @@ -498,8 +499,8 @@ func TestDistributedStaleStorageBatch(t *testing.T) { fetchObserver.Err(errors.New("error")) res, err := sturdyc.GetOrFetchBatch(ctx, c, firstBatchOfIDs, keyFn, fetchObserver.FetchBatch) - if err != nil { - t.Fatalf("expected no error, got %v", err) + if !errors.Is(err, sturdyc.ErrOnlyCachedRecords) { + t.Fatalf("expected ErrOnlyCachedRecords, got %v", err) } for id, value := range res { if value != "value"+id { @@ -696,3 +697,175 @@ func TestDistributedStorageBatchConvertsToMissingRecord(t *testing.T) { time.Sleep(50 * time.Millisecond) fetchObserver.AssertFetchCount(t, 3) } + +func TestDistributedStorageDoesNotCachePartialResponseAsMissingRecords(t *testing.T) { + t.Parallel() + + refreshAfter := time.Minute + clock := sturdyc.NewTestClock(time.Now()) + ctx := context.Background() + ttl := time.Second * 30 + distributedStorage := &mockStorage{} + c := sturdyc.New[string](1000, 10, ttl, 30, + sturdyc.WithNoContinuousEvictions(), + sturdyc.WithClock(clock), + sturdyc.WithMissingRecordStorage(), + sturdyc.WithDistributedStorageEarlyRefreshes(distributedStorage, refreshAfter), + ) + fetchObserver := NewFetchObserver(1) + + keyFn := c.BatchKeyFn("item") + batchOfIDs := []string{"1", "2", "3"} + fetchObserver.BatchResponse(batchOfIDs) + _, err := sturdyc.GetOrFetchBatch(ctx, c, batchOfIDs, keyFn, fetchObserver.FetchBatch) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + <-fetchObserver.FetchCompleted + fetchObserver.AssertRequestedRecords(t, batchOfIDs) + fetchObserver.AssertFetchCount(t, 1) + fetchObserver.Clear() + + // The keys are written asynchonously to the distributed storage. + time.Sleep(100 * time.Millisecond) + distributedStorage.assertRecords(t, batchOfIDs, keyFn) + distributedStorage.assertGetCount(t, 1) + distributedStorage.assertSetCount(t, 1) + + // Next, we'll delete the records from the in-memory cache to simulate that they were evicted. + for _, id := range batchOfIDs { + c.Delete(keyFn(id)) + } + if c.Size() != 0 { + t.Fatalf("expected cache size to be 0, got %d", c.Size()) + } + + // Now we'll add a new id to the batch that we're going to fetch. Next, we'll + // set up the fetch observer so that it errors. We should still be able to + // retrieve the records that we have in the distributed cache, and assert + // that the remaining ID should not be stored as a missing record. + secondBatchOfIDs := []string{"1", "2", "3", "4"} + fetchObserver.Err(errors.New("boom")) + res, err := sturdyc.GetOrFetchBatch(ctx, c, secondBatchOfIDs, keyFn, fetchObserver.FetchBatch) + if !errors.Is(err, sturdyc.ErrOnlyCachedRecords) { + t.Fatalf("expected ErrOnlyCachedRecords, got %v", err) + } + if len(res) != 3 { + t.Fatalf("expected 3 records, got %d", len(res)) + } + + <-fetchObserver.FetchCompleted + fetchObserver.AssertRequestedRecords(t, []string{"4"}) + fetchObserver.AssertFetchCount(t, 2) + fetchObserver.Clear() + + // The 3 records we had in the distributed cache should have been synced to + // the in-memory cache. If we had 4 records here, it would mean that we + // cached the last record as a missing record when the fetch observer errored + // out. That is not the behaviour we want. + if c.Size() != 3 { + t.Fatalf("expected 3 records, got %d", c.Size()) + } +} + +func TestPartialResponseForRefreshesDoesNotResultInMissingRecords(t *testing.T) { + t.Parallel() + + ctx := context.Background() + capacity := 1000 + numShards := 50 + evictionPercentage := 10 + ttl := time.Hour + minRefreshDelay := time.Minute * 5 + maxRefreshDelay := time.Minute * 10 + refreshRetryInterval := time.Millisecond * 10 + batchSize := 10 + batchBufferTimeout := time.Minute + refreshAfter := minRefreshDelay + distributedStorage := &mockStorage{} + clock := sturdyc.NewTestClock(time.Now()) + + c := sturdyc.New[string](capacity, numShards, ttl, evictionPercentage, + sturdyc.WithNoContinuousEvictions(), + sturdyc.WithEarlyRefreshes(minRefreshDelay, maxRefreshDelay, refreshRetryInterval), + sturdyc.WithMissingRecordStorage(), + sturdyc.WithRefreshCoalescing(batchSize, batchBufferTimeout), + sturdyc.WithDistributedStorageEarlyRefreshes(distributedStorage, refreshAfter), + sturdyc.WithClock(clock), + ) + + keyFn := c.BatchKeyFn("item") + ids := make([]string, 0, 100) + for i := 1; i <= 100; i++ { + ids = append(ids, strconv.Itoa(i)) + } + + fetchObserver := NewFetchObserver(11) + fetchObserver.BatchResponse(ids) + res, err := sturdyc.GetOrFetchBatch(ctx, c, ids, keyFn, fetchObserver.FetchBatch) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(res) != 100 { + t.Fatalf("expected 100 records, got %d", len(res)) + } + + <-fetchObserver.FetchCompleted + fetchObserver.AssertFetchCount(t, 1) + fetchObserver.AssertRequestedRecords(t, ids) + fetchObserver.Clear() + + // We need to add a sleep because the keys are written asynchonously to the + // distributed storage. We expect that the distributed storage was queried + // for the ids before we went to the underlying data source, and then written + // to when it resulted in a cache miss and the data was in fact fetched. + time.Sleep(100 * time.Millisecond) + distributedStorage.assertRecords(t, ids, keyFn) + distributedStorage.assertGetCount(t, 1) + distributedStorage.assertSetCount(t, 1) + + // Next, we'll move the clock past the maxRefreshDelay. This should guarantee + // that the next records we request gets scheduled for a refresh. We're also + // going to add two more ids to the batch and make the fetchObserver error. + // What should happen is the following: The cache queries the distributed + // storage and sees that ID 1-100 are due for a refresh, and that ID 101 and + // 102 are missing. Hence, it queries the underlying data source for all of + // them. When the underlying data source returns an error, we should get the + // records we have in the distributed cache back. The cache should still only + // contain 100 records because we can't tell if ID 101 and 102 are missing or + // not. + clock.Add(maxRefreshDelay + time.Second) + fetchObserver.Err(errors.New("boom")) + secondBatchOfIDs := make([]string, 0, 102) + for i := 1; i <= 102; i++ { + secondBatchOfIDs = append(secondBatchOfIDs, strconv.Itoa(i)) + } + res, err = sturdyc.GetOrFetchBatch(ctx, c, secondBatchOfIDs, keyFn, fetchObserver.FetchBatch) + if !errors.Is(err, sturdyc.ErrOnlyCachedRecords) { + t.Fatalf("expected ErrOnlyCachedRecords, got %v", err) + } + if len(res) != 100 { + t.Fatalf("expected 100 records, got %d", len(res)) + } + + // The fetch observer should be called 11 times. 10 times for the batches of + // ids that we tried to refresh, and once for id 101 and 102 which we didn't + // have in the cache. + for i := 0; i < 11; i++ { + <-fetchObserver.FetchCompleted + } + fetchObserver.AssertFetchCount(t, 12) + + // Assert that the distributed storage was queried when we + // tried to refresh the records we had in the memory cache. + distributedStorage.assertGetCount(t, 12) + distributedStorage.assertSetCount(t, 1) + + // The in-memory cache should only have 100 records because we can't tell if + // ID 101 and 102 are missing or not because the fetch observer errored out + // when we tried to fetch them. + if c.Size() != 100 { + t.Fatalf("expected cache size to be 100, got %d", c.Size()) + } +} diff --git a/errors.go b/errors.go index 9f022c1..160c610 100644 --- a/errors.go +++ b/errors.go @@ -3,6 +3,10 @@ package sturdyc import "errors" var ( + // errOnlyDistributedRecords is an internal error that the cache uses to not + // store records as missing if it's unable to get part of the batch from the + // underlying data source. + errOnlyDistributedRecords = errors.New("sturdyc: we were only able to get records from the distributed storage") // ErrNotFound should be returned from a FetchFn to indicate that a record is // missing at the underlying data source. This helps the cache to determine // if a record should be deleted or stored as a missing record if you have diff --git a/fetch.go b/fetch.go index 09f24fb..2ce31ca 100644 --- a/fetch.go +++ b/fetch.go @@ -2,6 +2,7 @@ package sturdyc import ( "context" + "errors" "maps" ) @@ -118,7 +119,7 @@ func getFetchBatch[V, T any](ctx context.Context, c *Client[T], ids []string, ke callBatchOpts := callBatchOpts[T, T]{ids: cacheMisses, keyFn: keyFn, fn: wrappedFetch} response, err := callAndCacheBatch(ctx, c, callBatchOpts) - if err != nil { + if err != nil && !errors.Is(err, ErrOnlyCachedRecords) { if len(cachedRecords) > 0 { return cachedRecords, ErrOnlyCachedRecords } @@ -126,7 +127,7 @@ func getFetchBatch[V, T any](ctx context.Context, c *Client[T], ids []string, ke } maps.Copy(cachedRecords, response) - return cachedRecords, nil + return cachedRecords, err } // GetOrFetchBatch attempts to retrieve the specified ids from the cache. If diff --git a/inflight.go b/inflight.go index 098813d..5368aa4 100644 --- a/inflight.go +++ b/inflight.go @@ -98,13 +98,21 @@ type makeBatchCallOpts[T, V any] struct { func makeBatchCall[T, V any](ctx context.Context, c *Client[T], opts makeBatchCallOpts[T, V]) { response, err := opts.fn(ctx, opts.ids) - if err != nil { + if err != nil && !errors.Is(err, errOnlyDistributedRecords) { opts.call.err = err return } - // Check if we should store any of these IDs as a missing record. - if c.storeMissingRecords && len(response) < len(opts.ids) { + if errors.Is(err, errOnlyDistributedRecords) { + opts.call.err = ErrOnlyCachedRecords + } + + // Check if we should store any of these IDs as a missing record. However, we + // don't want to do this if we only received records from the distributed + // storage. That means that the underlying data source errored for the ID's + // that we didn't have in our distributed storage, and we don't know wether + // these records are missing or not. + if c.storeMissingRecords && len(response) < len(opts.ids) && !errors.Is(err, errOnlyDistributedRecords) { for _, id := range opts.ids { if _, ok := response[id]; !ok { c.StoreMissingRecord(opts.keyFn(id)) @@ -153,24 +161,26 @@ func callAndCacheBatch[V, T any](ctx context.Context, c *Client[T], opts callBat } c.endBatchFlight(uniqueIDs, opts.keyFn, call) }() - batchCallOpts := makeBatchCallOpts[T, V]{ - ids: uniqueIDs, - fn: opts.fn, - keyFn: opts.keyFn, - call: call, - } + batchCallOpts := makeBatchCallOpts[T, V]{ids: uniqueIDs, fn: opts.fn, keyFn: opts.keyFn, call: call} makeBatchCall(ctx, c, batchCallOpts) }() } c.inFlightBatchMutex.Unlock() + var err error response := make(map[string]V, len(opts.ids)) for call, callIDs := range callIDs { call.Wait() - if call.err != nil { + // It could be only cached records here, if we we're able + // to get some of the IDs from the distributed storage. + if call.err != nil && !errors.Is(call.err, ErrOnlyCachedRecords) { return response, call.err } + if errors.Is(call.err, ErrOnlyCachedRecords) { + err = ErrOnlyCachedRecords + } + // We need to iterate through the values that we want from this call. The // batch could contain a hundred IDs, but we might only want a few of them. for _, id := range callIDs { @@ -187,5 +197,5 @@ func callAndCacheBatch[V, T any](ctx context.Context, c *Client[T], opts callBat } } - return response, nil + return response, err } diff --git a/refresh.go b/refresh.go index 00abb6d..a4e3eba 100644 --- a/refresh.go +++ b/refresh.go @@ -22,7 +22,7 @@ func (c *Client[T]) refresh(key string, fetchFn FetchFn[T]) { func (c *Client[T]) refreshBatch(ids []string, keyFn KeyFn, fetchFn BatchFetchFn[T]) { c.reportBatchRefreshSize(len(ids)) response, err := fetchFn(context.Background(), ids) - if err != nil { + if err != nil && !errors.Is(err, errOnlyDistributedRecords) { return } @@ -39,7 +39,10 @@ func (c *Client[T]) refreshBatch(ids []string, keyFn KeyFn, fetchFn BatchFetchFn c.Delete(keyFn(id)) } - if c.storeMissingRecords && !okResponse { + // If we're only getting records from the distributed storage, it means that we weren't able to get + // the remaining IDs for the batch from the underlying data source. We don't want to store these + // as missing records because we don't know if they're missing or not. + if c.storeMissingRecords && !okResponse && !errors.Is(err, errOnlyDistributedRecords) { c.StoreMissingRecord(keyFn(id)) } } diff --git a/safe.go b/safe.go index a6ece0d..f5d86a4 100644 --- a/safe.go +++ b/safe.go @@ -2,6 +2,7 @@ package sturdyc import ( "context" + "errors" "fmt" ) @@ -51,7 +52,7 @@ func unwrap[V, T any](val T, err error) (V, error) { func wrapBatch[T, V any](fetchFn BatchFetchFn[V]) BatchFetchFn[T] { return func(ctx context.Context, ids []string) (map[string]T, error) { resV, err := fetchFn(ctx, ids) - if err != nil { + if err != nil && !errors.Is(err, errOnlyDistributedRecords) { return map[string]T{}, err } @@ -64,7 +65,7 @@ func wrapBatch[T, V any](fetchFn BatchFetchFn[V]) BatchFetchFn[T] { resT[id] = val } - return resT, nil + return resT, err } }