Skip to content

Commit

Permalink
fix: Partial responses from the distributed storage could result in f…
Browse files Browse the repository at this point in the history
…aulty missing records (#21)
  • Loading branch information
viccon authored Dec 9, 2024
1 parent f2a0c00 commit 8b2931a
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 21 deletions.
7 changes: 5 additions & 2 deletions distribution.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func distributedBatchFetch[V, T any](c *Client[T], keyFn KeyFn, fetchFn BatchFet
continue
}

// If distributedStaleStorage isn't enabled it means all records are fresh, otherwise checked the CreatedAt time.
// If early refreshes isn't enabled it means all records are fresh, otherwise we'll check the CreatedAt time.
if !c.distributedEarlyRefreshes || c.clock.Since(record.CreatedAt) < c.distributedRefreshAfterDuration {
// We never want to return missing records.
if !record.IsMissingRecord {
Expand Down Expand Up @@ -219,13 +219,16 @@ func distributedBatchFetch[V, T any](c *Client[T], keyFn KeyFn, fetchFn BatchFet

dataSourceResponses, err := fetchFn(ctx, idsToRefresh)
// In case of an error, we'll proceed with the ones we got from the distributed storage.
// NOTE: It's important that we return a specific error here, otherwise we'll potentially
// end up caching the IDs that we weren't able to retrieve from the underlying data source
// as missing records.
if err != nil {
for i := 0; i < len(stale); i++ {
c.reportDistributedStaleFallback()
}
c.log.Error(fmt.Sprintf("sturdyc: error fetching records from the underlying data source. %v", err))
maps.Copy(stale, fresh)
return stale, nil
return stale, errOnlyDistributedRecords
}

// Next, we'll want to check if we should change any of the records to be missing or perform deletions.
Expand Down
177 changes: 175 additions & 2 deletions distribution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sturdyc_test
import (
"context"
"errors"
"strconv"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -498,8 +499,8 @@ func TestDistributedStaleStorageBatch(t *testing.T) {
fetchObserver.Err(errors.New("error"))

res, err := sturdyc.GetOrFetchBatch(ctx, c, firstBatchOfIDs, keyFn, fetchObserver.FetchBatch)
if err != nil {
t.Fatalf("expected no error, got %v", err)
if !errors.Is(err, sturdyc.ErrOnlyCachedRecords) {
t.Fatalf("expected ErrOnlyCachedRecords, got %v", err)
}
for id, value := range res {
if value != "value"+id {
Expand Down Expand Up @@ -696,3 +697,175 @@ func TestDistributedStorageBatchConvertsToMissingRecord(t *testing.T) {
time.Sleep(50 * time.Millisecond)
fetchObserver.AssertFetchCount(t, 3)
}

func TestDistributedStorageDoesNotCachePartialResponseAsMissingRecords(t *testing.T) {
t.Parallel()

refreshAfter := time.Minute
clock := sturdyc.NewTestClock(time.Now())
ctx := context.Background()
ttl := time.Second * 30
distributedStorage := &mockStorage{}
c := sturdyc.New[string](1000, 10, ttl, 30,
sturdyc.WithNoContinuousEvictions(),
sturdyc.WithClock(clock),
sturdyc.WithMissingRecordStorage(),
sturdyc.WithDistributedStorageEarlyRefreshes(distributedStorage, refreshAfter),
)
fetchObserver := NewFetchObserver(1)

keyFn := c.BatchKeyFn("item")
batchOfIDs := []string{"1", "2", "3"}
fetchObserver.BatchResponse(batchOfIDs)
_, err := sturdyc.GetOrFetchBatch(ctx, c, batchOfIDs, keyFn, fetchObserver.FetchBatch)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

<-fetchObserver.FetchCompleted
fetchObserver.AssertRequestedRecords(t, batchOfIDs)
fetchObserver.AssertFetchCount(t, 1)
fetchObserver.Clear()

// The keys are written asynchonously to the distributed storage.
time.Sleep(100 * time.Millisecond)
distributedStorage.assertRecords(t, batchOfIDs, keyFn)
distributedStorage.assertGetCount(t, 1)
distributedStorage.assertSetCount(t, 1)

// Next, we'll delete the records from the in-memory cache to simulate that they were evicted.
for _, id := range batchOfIDs {
c.Delete(keyFn(id))
}
if c.Size() != 0 {
t.Fatalf("expected cache size to be 0, got %d", c.Size())
}

// Now we'll add a new id to the batch that we're going to fetch. Next, we'll
// set up the fetch observer so that it errors. We should still be able to
// retrieve the records that we have in the distributed cache, and assert
// that the remaining ID should not be stored as a missing record.
secondBatchOfIDs := []string{"1", "2", "3", "4"}
fetchObserver.Err(errors.New("boom"))
res, err := sturdyc.GetOrFetchBatch(ctx, c, secondBatchOfIDs, keyFn, fetchObserver.FetchBatch)
if !errors.Is(err, sturdyc.ErrOnlyCachedRecords) {
t.Fatalf("expected ErrOnlyCachedRecords, got %v", err)
}
if len(res) != 3 {
t.Fatalf("expected 3 records, got %d", len(res))
}

<-fetchObserver.FetchCompleted
fetchObserver.AssertRequestedRecords(t, []string{"4"})
fetchObserver.AssertFetchCount(t, 2)
fetchObserver.Clear()

// The 3 records we had in the distributed cache should have been synced to
// the in-memory cache. If we had 4 records here, it would mean that we
// cached the last record as a missing record when the fetch observer errored
// out. That is not the behaviour we want.
if c.Size() != 3 {
t.Fatalf("expected 3 records, got %d", c.Size())
}
}

func TestPartialResponseForRefreshesDoesNotResultInMissingRecords(t *testing.T) {
t.Parallel()

ctx := context.Background()
capacity := 1000
numShards := 50
evictionPercentage := 10
ttl := time.Hour
minRefreshDelay := time.Minute * 5
maxRefreshDelay := time.Minute * 10
refreshRetryInterval := time.Millisecond * 10
batchSize := 10
batchBufferTimeout := time.Minute
refreshAfter := minRefreshDelay
distributedStorage := &mockStorage{}
clock := sturdyc.NewTestClock(time.Now())

c := sturdyc.New[string](capacity, numShards, ttl, evictionPercentage,
sturdyc.WithNoContinuousEvictions(),
sturdyc.WithEarlyRefreshes(minRefreshDelay, maxRefreshDelay, refreshRetryInterval),
sturdyc.WithMissingRecordStorage(),
sturdyc.WithRefreshCoalescing(batchSize, batchBufferTimeout),
sturdyc.WithDistributedStorageEarlyRefreshes(distributedStorage, refreshAfter),
sturdyc.WithClock(clock),
)

keyFn := c.BatchKeyFn("item")
ids := make([]string, 0, 100)
for i := 1; i <= 100; i++ {
ids = append(ids, strconv.Itoa(i))
}

fetchObserver := NewFetchObserver(11)
fetchObserver.BatchResponse(ids)
res, err := sturdyc.GetOrFetchBatch(ctx, c, ids, keyFn, fetchObserver.FetchBatch)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(res) != 100 {
t.Fatalf("expected 100 records, got %d", len(res))
}

<-fetchObserver.FetchCompleted
fetchObserver.AssertFetchCount(t, 1)
fetchObserver.AssertRequestedRecords(t, ids)
fetchObserver.Clear()

// We need to add a sleep because the keys are written asynchonously to the
// distributed storage. We expect that the distributed storage was queried
// for the ids before we went to the underlying data source, and then written
// to when it resulted in a cache miss and the data was in fact fetched.
time.Sleep(100 * time.Millisecond)
distributedStorage.assertRecords(t, ids, keyFn)
distributedStorage.assertGetCount(t, 1)
distributedStorage.assertSetCount(t, 1)

// Next, we'll move the clock past the maxRefreshDelay. This should guarantee
// that the next records we request gets scheduled for a refresh. We're also
// going to add two more ids to the batch and make the fetchObserver error.
// What should happen is the following: The cache queries the distributed
// storage and sees that ID 1-100 are due for a refresh, and that ID 101 and
// 102 are missing. Hence, it queries the underlying data source for all of
// them. When the underlying data source returns an error, we should get the
// records we have in the distributed cache back. The cache should still only
// contain 100 records because we can't tell if ID 101 and 102 are missing or
// not.
clock.Add(maxRefreshDelay + time.Second)
fetchObserver.Err(errors.New("boom"))
secondBatchOfIDs := make([]string, 0, 102)
for i := 1; i <= 102; i++ {
secondBatchOfIDs = append(secondBatchOfIDs, strconv.Itoa(i))
}
res, err = sturdyc.GetOrFetchBatch(ctx, c, secondBatchOfIDs, keyFn, fetchObserver.FetchBatch)
if !errors.Is(err, sturdyc.ErrOnlyCachedRecords) {
t.Fatalf("expected ErrOnlyCachedRecords, got %v", err)
}
if len(res) != 100 {
t.Fatalf("expected 100 records, got %d", len(res))
}

// The fetch observer should be called 11 times. 10 times for the batches of
// ids that we tried to refresh, and once for id 101 and 102 which we didn't
// have in the cache.
for i := 0; i < 11; i++ {
<-fetchObserver.FetchCompleted
}
fetchObserver.AssertFetchCount(t, 12)

// Assert that the distributed storage was queried when we
// tried to refresh the records we had in the memory cache.
distributedStorage.assertGetCount(t, 12)
distributedStorage.assertSetCount(t, 1)

// The in-memory cache should only have 100 records because we can't tell if
// ID 101 and 102 are missing or not because the fetch observer errored out
// when we tried to fetch them.
if c.Size() != 100 {
t.Fatalf("expected cache size to be 100, got %d", c.Size())
}
}
4 changes: 4 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package sturdyc
import "errors"

var (
// errOnlyDistributedRecords is an internal error that the cache uses to not
// store records as missing if it's unable to get part of the batch from the
// underlying data source.
errOnlyDistributedRecords = errors.New("sturdyc: we were only able to get records from the distributed storage")
// ErrNotFound should be returned from a FetchFn to indicate that a record is
// missing at the underlying data source. This helps the cache to determine
// if a record should be deleted or stored as a missing record if you have
Expand Down
5 changes: 3 additions & 2 deletions fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sturdyc

import (
"context"
"errors"
"maps"
)

Expand Down Expand Up @@ -118,15 +119,15 @@ func getFetchBatch[V, T any](ctx context.Context, c *Client[T], ids []string, ke

callBatchOpts := callBatchOpts[T, T]{ids: cacheMisses, keyFn: keyFn, fn: wrappedFetch}
response, err := callAndCacheBatch(ctx, c, callBatchOpts)
if err != nil {
if err != nil && !errors.Is(err, ErrOnlyCachedRecords) {
if len(cachedRecords) > 0 {
return cachedRecords, ErrOnlyCachedRecords
}
return cachedRecords, err
}

maps.Copy(cachedRecords, response)
return cachedRecords, nil
return cachedRecords, err
}

// GetOrFetchBatch attempts to retrieve the specified ids from the cache. If
Expand Down
32 changes: 21 additions & 11 deletions inflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,21 @@ type makeBatchCallOpts[T, V any] struct {

func makeBatchCall[T, V any](ctx context.Context, c *Client[T], opts makeBatchCallOpts[T, V]) {
response, err := opts.fn(ctx, opts.ids)
if err != nil {
if err != nil && !errors.Is(err, errOnlyDistributedRecords) {
opts.call.err = err
return
}

// Check if we should store any of these IDs as a missing record.
if c.storeMissingRecords && len(response) < len(opts.ids) {
if errors.Is(err, errOnlyDistributedRecords) {
opts.call.err = ErrOnlyCachedRecords
}

// Check if we should store any of these IDs as a missing record. However, we
// don't want to do this if we only received records from the distributed
// storage. That means that the underlying data source errored for the ID's
// that we didn't have in our distributed storage, and we don't know wether
// these records are missing or not.
if c.storeMissingRecords && len(response) < len(opts.ids) && !errors.Is(err, errOnlyDistributedRecords) {
for _, id := range opts.ids {
if _, ok := response[id]; !ok {
c.StoreMissingRecord(opts.keyFn(id))
Expand Down Expand Up @@ -153,24 +161,26 @@ func callAndCacheBatch[V, T any](ctx context.Context, c *Client[T], opts callBat
}
c.endBatchFlight(uniqueIDs, opts.keyFn, call)
}()
batchCallOpts := makeBatchCallOpts[T, V]{
ids: uniqueIDs,
fn: opts.fn,
keyFn: opts.keyFn,
call: call,
}
batchCallOpts := makeBatchCallOpts[T, V]{ids: uniqueIDs, fn: opts.fn, keyFn: opts.keyFn, call: call}
makeBatchCall(ctx, c, batchCallOpts)
}()
}
c.inFlightBatchMutex.Unlock()

var err error
response := make(map[string]V, len(opts.ids))
for call, callIDs := range callIDs {
call.Wait()
if call.err != nil {
// It could be only cached records here, if we we're able
// to get some of the IDs from the distributed storage.
if call.err != nil && !errors.Is(call.err, ErrOnlyCachedRecords) {
return response, call.err
}

if errors.Is(call.err, ErrOnlyCachedRecords) {
err = ErrOnlyCachedRecords
}

// We need to iterate through the values that we want from this call. The
// batch could contain a hundred IDs, but we might only want a few of them.
for _, id := range callIDs {
Expand All @@ -187,5 +197,5 @@ func callAndCacheBatch[V, T any](ctx context.Context, c *Client[T], opts callBat
}
}

return response, nil
return response, err
}
7 changes: 5 additions & 2 deletions refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (c *Client[T]) refresh(key string, fetchFn FetchFn[T]) {
func (c *Client[T]) refreshBatch(ids []string, keyFn KeyFn, fetchFn BatchFetchFn[T]) {
c.reportBatchRefreshSize(len(ids))
response, err := fetchFn(context.Background(), ids)
if err != nil {
if err != nil && !errors.Is(err, errOnlyDistributedRecords) {
return
}

Expand All @@ -39,7 +39,10 @@ func (c *Client[T]) refreshBatch(ids []string, keyFn KeyFn, fetchFn BatchFetchFn
c.Delete(keyFn(id))
}

if c.storeMissingRecords && !okResponse {
// If we're only getting records from the distributed storage, it means that we weren't able to get
// the remaining IDs for the batch from the underlying data source. We don't want to store these
// as missing records because we don't know if they're missing or not.
if c.storeMissingRecords && !okResponse && !errors.Is(err, errOnlyDistributedRecords) {
c.StoreMissingRecord(keyFn(id))
}
}
Expand Down
5 changes: 3 additions & 2 deletions safe.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sturdyc

import (
"context"
"errors"
"fmt"
)

Expand Down Expand Up @@ -51,7 +52,7 @@ func unwrap[V, T any](val T, err error) (V, error) {
func wrapBatch[T, V any](fetchFn BatchFetchFn[V]) BatchFetchFn[T] {
return func(ctx context.Context, ids []string) (map[string]T, error) {
resV, err := fetchFn(ctx, ids)
if err != nil {
if err != nil && !errors.Is(err, errOnlyDistributedRecords) {
return map[string]T{}, err
}

Expand All @@ -64,7 +65,7 @@ func wrapBatch[T, V any](fetchFn BatchFetchFn[V]) BatchFetchFn[T] {
resT[id] = val
}

return resT, nil
return resT, err
}
}

Expand Down

0 comments on commit 8b2931a

Please sign in to comment.