diff --git a/pkg/executor/internal/applycache/apply_cache.go b/pkg/executor/internal/applycache/apply_cache.go index cb5ee694f3bce..e3740b7601536 100644 --- a/pkg/executor/internal/applycache/apply_cache.go +++ b/pkg/executor/internal/applycache/apply_cache.go @@ -67,6 +67,12 @@ func (c *ApplyCache) put(key applyCacheKey, val kvcache.Value) { c.cache.Put(key, val) } +func (c *ApplyCache) removeOldest() (kvcache.Key, kvcache.Value, bool) { + c.lock.Lock() + defer c.lock.Unlock() + return c.cache.RemoveOldest() +} + // Get gets a cache item according to cache key. It's thread-safe. func (c *ApplyCache) Get(key applyCacheKey) (*chunk.List, error) { value, hit := c.get(key) @@ -84,7 +90,7 @@ func (c *ApplyCache) Set(key applyCacheKey, value *chunk.List) (bool, error) { return false, nil } for mem+c.memTracker.BytesConsumed() > c.memCapacity { - evictedKey, evictedValue, evicted := c.cache.RemoveOldest() + evictedKey, evictedValue, evicted := c.removeOldest() if !evicted { return false, nil } diff --git a/pkg/executor/internal/applycache/apply_cache_test.go b/pkg/executor/internal/applycache/apply_cache_test.go index b20495a2a7369..e40148fcc6528 100644 --- a/pkg/executor/internal/applycache/apply_cache_test.go +++ b/pkg/executor/internal/applycache/apply_cache_test.go @@ -17,6 +17,7 @@ package applycache import ( "strconv" "strings" + "sync" "testing" "github.com/pingcap/tidb/pkg/parser/mysql" @@ -77,3 +78,61 @@ func TestApplyCache(t *testing.T) { require.NoError(t, err) require.Nil(t, result) } + +func TestApplyCacheConcurrent(t *testing.T) { + ctx := mock.NewContext() + ctx.GetSessionVars().MemQuotaApplyCache = 100 + applyCache, err := NewApplyCache(ctx) + require.NoError(t, err) + + fields := []*types.FieldType{types.NewFieldType(mysql.TypeLonglong)} + value := make([]*chunk.List, 2) + key := make([][]byte, 2) + for i := 0; i < 2; i++ { + value[i] = chunk.NewList(fields, 1, 1) + srcChunk := chunk.NewChunkWithCapacity(fields, 1) + srcChunk.AppendInt64(0, int64(i)) + srcRow := srcChunk.GetRow(0) + value[i].AppendRow(srcRow) + key[i] = []byte(strings.Repeat(strconv.Itoa(i), 100)) + + // TODO: *chunk.List.GetMemTracker().BytesConsumed() is not accurate, fix it later. + require.Equal(t, int64(100), applyCacheKVMem(key[i], value[i])) + } + + applyCache.Set(key[0], value[0]) + var wg sync.WaitGroup + wg.Add(2) + var func1 = func() { + for i := 0; i < 100; i++ { + for { + result, err := applyCache.Get(key[0]) + require.NoError(t, err) + if result != nil { + applyCache.Set(key[1], value[1]) + break + } + } + } + wg.Done() + } + var func2 = func() { + for i := 0; i < 100; i++ { + for { + result, err := applyCache.Get(key[1]) + require.NoError(t, err) + if result != nil { + applyCache.Set(key[0], value[0]) + break + } + } + } + wg.Done() + } + go func1() + go func2() + wg.Wait() + result, err := applyCache.Get(key[0]) + require.NoError(t, err) + require.NotNil(t, result) +}