Skip to content

Commit

Permalink
[fix][ml] Fix race conditions in RangeCache (apache#22789)
Browse files Browse the repository at this point in the history
(cherry picked from commit c39f9f8)
(cherry picked from commit 9a99e45)
lhotari authored and srinath-ctds committed Jun 7, 2024
1 parent 0ee3687 commit 6812af7
Showing 3 changed files with 254 additions and 94 deletions.
Original file line number Diff line number Diff line change
@@ -27,9 +27,10 @@
import org.apache.bookkeeper.client.api.LedgerEntry;
import org.apache.bookkeeper.mledger.Entry;
import org.apache.bookkeeper.mledger.util.AbstractCASReferenceCounted;
import org.apache.bookkeeper.mledger.util.RangeCache;

public final class EntryImpl extends AbstractCASReferenceCounted implements Entry, Comparable<EntryImpl>,
ReferenceCounted {
RangeCache.ValueWithKeyValidation<PositionImpl> {

private static final Recycler<EntryImpl> RECYCLER = new Recycler<EntryImpl>() {
@Override
@@ -200,4 +201,8 @@ protected void deallocate() {
recyclerHandle.recycle(this);
}

@Override
public boolean matchesKey(PositionImpl key) {
return key.compareTo(ledgerId, entryId) == 0;
}
}
Original file line number Diff line number Diff line change
@@ -19,31 +19,134 @@
package org.apache.bookkeeper.mledger.util;

import static com.google.common.base.Preconditions.checkArgument;
import com.google.common.base.Predicate;
import io.netty.util.IllegalReferenceCountException;
import io.netty.util.Recycler;
import io.netty.util.Recycler.Handle;
import io.netty.util.ReferenceCounted;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentNavigableMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.bookkeeper.mledger.util.RangeCache.ValueWithKeyValidation;
import org.apache.commons.lang3.tuple.Pair;

/**
* Special type of cache where get() and delete() operations can be done over a range of keys.
* The implementation avoids locks and synchronization and relies on ConcurrentSkipListMap for storing the entries.
* Since there is no locks, there is a need to have a way to ensure that a single entry in the cache is removed
* exactly once. Removing an entry multiple times would result in the entries of the cache getting released too
* while they could still be in use.
*
* @param <Key>
* Cache key. Needs to be Comparable
* @param <Value>
* Cache value
*/
public class RangeCache<Key extends Comparable<Key>, Value extends ReferenceCounted> {
public class RangeCache<Key extends Comparable<Key>, Value extends ValueWithKeyValidation<Key>> {
public interface ValueWithKeyValidation<T> extends ReferenceCounted {
boolean matchesKey(T key);
}

// Map from key to nodes inside the linked list
private final ConcurrentNavigableMap<Key, Value> entries;
private final ConcurrentNavigableMap<Key, IdentityWrapper<Key, Value>> entries;
private AtomicLong size; // Total size of values stored in cache
private final Weighter<Value> weighter; // Weighter object used to extract the size from values
private final TimestampExtractor<Value> timestampExtractor; // Extract the timestamp associated with a value

/**
* Wrapper around the value to store in Map. This is needed to ensure that a specific instance can be removed from
* the map by calling the {@link Map#remove(Object, Object)} method. Certain race conditions could result in the
* wrong value being removed from the map. The instances of this class are recycled to avoid creating new objects.
*/
private static class IdentityWrapper<K, V> {
private final Handle<IdentityWrapper> recyclerHandle;
private static final Recycler<IdentityWrapper> RECYCLER = new Recycler<IdentityWrapper>() {
@Override
protected IdentityWrapper newObject(Handle<IdentityWrapper> recyclerHandle) {
return new IdentityWrapper(recyclerHandle);
}
};
private K key;
private V value;

private IdentityWrapper(Handle<IdentityWrapper> recyclerHandle) {
this.recyclerHandle = recyclerHandle;
}

static <K, V> IdentityWrapper<K, V> create(K key, V value) {
IdentityWrapper<K, V> identityWrapper = RECYCLER.get();
identityWrapper.key = key;
identityWrapper.value = value;
return identityWrapper;
}

K getKey() {
return key;
}

V getValue() {
return value;
}

void recycle() {
value = null;
recyclerHandle.recycle(this);
}

@Override
public boolean equals(Object o) {
// only match exact identity of the value
return this == o;
}

@Override
public int hashCode() {
return Objects.hashCode(key);
}
}

/**
* Mutable object to store the number of entries and the total size removed from the cache. The instances
* are recycled to avoid creating new instances.
*/
private static class RemovalCounters {
private final Handle<RemovalCounters> recyclerHandle;
private static final Recycler<RemovalCounters> RECYCLER = new Recycler<RemovalCounters>() {
@Override
protected RemovalCounters newObject(Handle<RemovalCounters> recyclerHandle) {
return new RemovalCounters(recyclerHandle);
}
};
int removedEntries;
long removedSize;
private RemovalCounters(Handle<RemovalCounters> recyclerHandle) {
this.recyclerHandle = recyclerHandle;
}

static <T> RemovalCounters create() {
RemovalCounters results = RECYCLER.get();
results.removedEntries = 0;
results.removedSize = 0;
return results;
}

void recycle() {
removedEntries = 0;
removedSize = 0;
recyclerHandle.recycle(this);
}

public void entryRemoved(long size) {
removedSize += size;
removedEntries++;
}
}

/**
* Construct a new RangeLruCache with default Weighter.
*/
@@ -68,18 +171,23 @@ public RangeCache(Weighter<Value> weighter, TimestampExtractor<Value> timestampE
* Insert.
*
* @param key
* @param value
* ref counted value with at least 1 ref to pass on the cache
* @param value ref counted value with at least 1 ref to pass on the cache
* @return whether the entry was inserted in the cache
*/
public boolean put(Key key, Value value) {
// retain value so that it's not released before we put it in the cache and calculate the weight
value.retain();
try {
if (entries.putIfAbsent(key, value) == null) {
if (!value.matchesKey(key)) {
throw new IllegalArgumentException("Value '" + value + "' does not match key '" + key + "'");
}
IdentityWrapper<Key, Value> newWrapper = IdentityWrapper.create(key, value);
if (entries.putIfAbsent(key, newWrapper) == null) {
size.addAndGet(weighter.getSize(value));
return true;
} else {
// recycle the new wrapper as it was not used
newWrapper.recycle();
return false;
}
} finally {
@@ -91,16 +199,37 @@ public boolean exists(Key key) {
return key != null ? entries.containsKey(key) : true;
}

/**
* Get the value associated with the key and increment the reference count of it.
* The caller is responsible for releasing the reference.
*/
public Value get(Key key) {
Value value = entries.get(key);
if (value == null) {
return getValue(key, entries.get(key));
}

private Value getValue(Key key, IdentityWrapper<Key, Value> valueWrapper) {
if (valueWrapper == null) {
return null;
} else {
if (valueWrapper.getKey() != key) {
// the wrapper has been recycled and contains another key
return null;
}
Value value = valueWrapper.getValue();
try {
value.retain();
} catch (IllegalReferenceCountException e) {
// Value was already deallocated
return null;
}
// check that the value matches the key and that there's at least 2 references to it since
// the cache should be holding one reference and a new reference was just added in this method
if (value.refCnt() > 1 && value.matchesKey(key)) {
return value;
} catch (Throwable t) {
// Value was already destroyed between get() and retain()
} else {
// Value or IdentityWrapper was recycled and already contains another value
// release the reference added in this method
value.release();
return null;
}
}
@@ -118,12 +247,10 @@ public Collection<Value> getRange(Key first, Key last) {
List<Value> values = new ArrayList();

// Return the values of the entries found in cache
for (Value value : entries.subMap(first, true, last, true).values()) {
try {
value.retain();
for (Map.Entry<Key, IdentityWrapper<Key, Value>> entry : entries.subMap(first, true, last, true).entrySet()) {
Value value = getValue(entry.getKey(), entry.getValue());
if (value != null) {
values.add(value);
} catch (Throwable t) {
// Value was already destroyed between get() and retain()
}
}

@@ -138,25 +265,65 @@ public Collection<Value> getRange(Key first, Key last) {
* @return an pair of ints, containing the number of removed entries and the total size
*/
public Pair<Integer, Long> removeRange(Key first, Key last, boolean lastInclusive) {
Map<Key, Value> subMap = entries.subMap(first, true, last, lastInclusive);
RemovalCounters counters = RemovalCounters.create();
Map<Key, IdentityWrapper<Key, Value>> subMap = entries.subMap(first, true, last, lastInclusive);
for (Map.Entry<Key, IdentityWrapper<Key, Value>> entry : subMap.entrySet()) {
removeEntry(entry, counters);
}
return handleRemovalResult(counters);
}

int removedEntries = 0;
long removedSize = 0;
enum RemoveEntryResult {
ENTRY_REMOVED,
CONTINUE_LOOP,
BREAK_LOOP;
}

for (Key key : subMap.keySet()) {
Value value = entries.remove(key);
if (value == null) {
continue;
}
private RemoveEntryResult removeEntry(Map.Entry<Key, IdentityWrapper<Key, Value>> entry, RemovalCounters counters) {
return removeEntry(entry, counters, (x) -> true);
}

removedSize += weighter.getSize(value);
private RemoveEntryResult removeEntry(Map.Entry<Key, IdentityWrapper<Key, Value>> entry, RemovalCounters counters,
Predicate<Value> removeCondition) {
Key key = entry.getKey();
IdentityWrapper<Key, Value> identityWrapper = entry.getValue();
if (identityWrapper.getKey() != key) {
// the wrapper has been recycled and contains another key
return RemoveEntryResult.CONTINUE_LOOP;
}
Value value = identityWrapper.getValue();
try {
// add extra retain to avoid value being released while we are removing it
value.retain();
} catch (IllegalReferenceCountException e) {
// Value was already released
return RemoveEntryResult.CONTINUE_LOOP;
}
try {
if (!removeCondition.test(value)) {
return RemoveEntryResult.BREAK_LOOP;
}
// check that the value hasn't been recycled in between
// there should be at least 2 references since this method adds one and the cache should have one
// it is valid that the value contains references even after the key has been removed from the cache
if (value.refCnt() > 1 && value.matchesKey(key) && entries.remove(key, identityWrapper)) {
identityWrapper.recycle();
counters.entryRemoved(weighter.getSize(value));
// remove the cache reference
value.release();
}
} finally {
// remove the extra retain
value.release();
++removedEntries;
}
return RemoveEntryResult.ENTRY_REMOVED;
}

size.addAndGet(-removedSize);

return Pair.of(removedEntries, removedSize);
private Pair<Integer, Long> handleRemovalResult(RemovalCounters counters) {
size.addAndGet(-counters.removedSize);
Pair<Integer, Long> result = Pair.of(counters.removedEntries, counters.removedSize);
counters.recycle();
return result;
}

/**
@@ -166,24 +333,15 @@ public Pair<Integer, Long> removeRange(Key first, Key last, boolean lastInclusiv
*/
public Pair<Integer, Long> evictLeastAccessedEntries(long minSize) {
checkArgument(minSize > 0);

long removedSize = 0;
int removedEntries = 0;

while (removedSize < minSize) {
Map.Entry<Key, Value> entry = entries.pollFirstEntry();
RemovalCounters counters = RemovalCounters.create();
while (counters.removedSize < minSize) {
Map.Entry<Key, IdentityWrapper<Key, Value>> entry = entries.firstEntry();
if (entry == null) {
break;
}

Value value = entry.getValue();
++removedEntries;
removedSize += weighter.getSize(value);
value.release();
removeEntry(entry, counters);
}

size.addAndGet(-removedSize);
return Pair.of(removedEntries, removedSize);
return handleRemovalResult(counters);
}

/**
@@ -192,27 +350,18 @@ public Pair<Integer, Long> evictLeastAccessedEntries(long minSize) {
* @return the tota
*/
public Pair<Integer, Long> evictLEntriesBeforeTimestamp(long maxTimestamp) {
long removedSize = 0;
int removedCount = 0;

RemovalCounters counters = RemovalCounters.create();
while (true) {
Map.Entry<Key, Value> entry = entries.firstEntry();
if (entry == null || timestampExtractor.getTimestamp(entry.getValue()) > maxTimestamp) {
Map.Entry<Key, IdentityWrapper<Key, Value>> entry = entries.firstEntry();
if (entry == null) {
break;
}
Value value = entry.getValue();
boolean removeHits = entries.remove(entry.getKey(), value);
if (!removeHits) {
if (removeEntry(entry, counters, value -> timestampExtractor.getTimestamp(value) <= maxTimestamp)
== RemoveEntryResult.BREAK_LOOP) {
break;
}

removedSize += weighter.getSize(value);
removedCount++;
value.release();
}

size.addAndGet(-removedSize);
return Pair.of(removedCount, removedSize);
return handleRemovalResult(counters);
}

/**
@@ -231,23 +380,16 @@ public long getSize() {
*
* @return size of removed entries
*/
public synchronized Pair<Integer, Long> clear() {
long removedSize = 0;
int removedCount = 0;

public Pair<Integer, Long> clear() {
RemovalCounters counters = RemovalCounters.create();
while (true) {
Map.Entry<Key, Value> entry = entries.pollFirstEntry();
Map.Entry<Key, IdentityWrapper<Key, Value>> entry = entries.firstEntry();
if (entry == null) {
break;
}
Value value = entry.getValue();
removedSize += weighter.getSize(value);
removedCount++;
value.release();
removeEntry(entry, counters);
}

size.getAndAdd(-removedSize);
return Pair.of(removedCount, removedSize);
return handleRemovalResult(counters);
}

/**
Original file line number Diff line number Diff line change
@@ -23,25 +23,30 @@
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;

import com.google.common.collect.Lists;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCounted;
import org.apache.commons.lang3.tuple.Pair;
import org.testng.annotations.Test;
import java.util.UUID;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import lombok.Cleanup;
import org.apache.commons.lang3.tuple.Pair;
import org.testng.annotations.Test;

public class RangeCacheTest {

class RefString extends AbstractReferenceCounted implements ReferenceCounted {
class RefString extends AbstractReferenceCounted implements RangeCache.ValueWithKeyValidation<Integer> {
String s;
Integer matchingKey;

RefString(String s) {
this(s, null);
}

RefString(String s, Integer matchingKey) {
super();
this.s = s;
this.matchingKey = matchingKey != null ? matchingKey : Integer.parseInt(s);
setRefCnt(1);
}

@@ -65,6 +70,11 @@ public boolean equals(Object obj) {

return false;
}

@Override
public boolean matchesKey(Integer key) {
return matchingKey.equals(key);
}
}

@Test
@@ -119,8 +129,8 @@ public void simple() {
public void customWeighter() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> 0);

cache.put(0, new RefString("zero"));
cache.put(1, new RefString("one"));
cache.put(0, new RefString("zero", 0));
cache.put(1, new RefString("one", 1));

assertEquals(cache.getSize(), 7);
assertEquals(cache.getNumberOfEntries(), 2);
@@ -132,9 +142,9 @@ public void customTimeExtraction() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> x.s.length());

cache.put(1, new RefString("1"));
cache.put(2, new RefString("22"));
cache.put(3, new RefString("333"));
cache.put(4, new RefString("4444"));
cache.put(22, new RefString("22"));
cache.put(333, new RefString("333"));
cache.put(4444, new RefString("4444"));

assertEquals(cache.getSize(), 10);
assertEquals(cache.getNumberOfEntries(), 4);
@@ -151,20 +161,20 @@ public void customTimeExtraction() {
public void doubleInsert() {
RangeCache<Integer, RefString> cache = new RangeCache<>();

RefString s0 = new RefString("zero");
RefString s0 = new RefString("zero", 0);
assertEquals(s0.refCnt(), 1);
assertTrue(cache.put(0, s0));
assertEquals(s0.refCnt(), 1);

cache.put(1, new RefString("one"));
cache.put(1, new RefString("one", 1));

assertEquals(cache.getSize(), 2);
assertEquals(cache.getNumberOfEntries(), 2);
RefString s = cache.get(1);
assertEquals(s.s, "one");
assertEquals(s.refCnt(), 2);

RefString s1 = new RefString("uno");
RefString s1 = new RefString("uno", 1);
assertEquals(s1.refCnt(), 1);
assertFalse(cache.put(1, s1));
assertEquals(s1.refCnt(), 1);
@@ -201,10 +211,10 @@ public void getRange() {
public void eviction() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> 0);

cache.put(0, new RefString("zero"));
cache.put(1, new RefString("one"));
cache.put(2, new RefString("two"));
cache.put(3, new RefString("three"));
cache.put(0, new RefString("zero", 0));
cache.put(1, new RefString("one", 1));
cache.put(2, new RefString("two", 2));
cache.put(3, new RefString("three", 3));

// This should remove the LRU entries: 0, 1 whose combined size is 7
assertEquals(cache.evictLeastAccessedEntries(5), Pair.of(2, (long) 7));
@@ -276,20 +286,23 @@ public void evictions() {
}

@Test
public void testInParallel() {
RangeCache<String, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> 0);
ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
executor.scheduleWithFixedDelay(cache::clear, 10, 10, TimeUnit.MILLISECONDS);
for (int i = 0; i < 1000; i++) {
cache.put(UUID.randomUUID().toString(), new RefString("zero"));
public void testPutWhileClearIsCalledConcurrently() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> 0);
int numberOfThreads = 4;
@Cleanup("shutdownNow")
ScheduledExecutorService executor = Executors.newScheduledThreadPool(numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
executor.scheduleWithFixedDelay(cache::clear, 0, 1, TimeUnit.MILLISECONDS);
}
for (int i = 0; i < 100000; i++) {
cache.put(i, new RefString(String.valueOf(i)));
}
executor.shutdown();
}

@Test
public void testPutSameObj() {
RangeCache<Integer, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> 0);
RefString s0 = new RefString("zero");
RefString s0 = new RefString("zero", 0);
assertEquals(s0.refCnt(), 1);
assertTrue(cache.put(0, s0));
assertFalse(cache.put(0, s0));

0 comments on commit 6812af7

Please sign in to comment.