Skip to content

Commit

Permalink
Fix race with eviction when reading from FileCache
Browse files Browse the repository at this point in the history
The previous implementation had an inherent race condition where a
zero-reference count IndexInput read from the cache could be evicted
before the IndexInput was cloned (and therefore had its reference count
incremented). Since the IndexInputs are stateful this is very bad. The
least-recently-used semantics meant that in a properly-configured system
this would be unlikely since accessing a zero-reference count item would
move it to be most-recently used and therefore least likely to be
evicted. However, there was still a latent bug that was possible to
encounter (see issue opensearch-project#6295).

The only way to fix this, as far as I can see, is to change the cache
behavior so that fetching an item from the cache atomically
increments its reference count. This also led to a change to
TransferManager to ensure that all requests for an item ultimately read
through the cache to eliminate any possibility of a race. I have
implement some concurrent unit tests that put the cache into a
worst-case thrashing scenario to ensure that concurrent access never
closes an IndexInput while it is still being used.

Signed-off-by: Andrew Ross <andrross@amazon.com>
  • Loading branch information
andrross committed Mar 9, 2023
1 parent 8c0e5d0 commit c958df2
Show file tree
Hide file tree
Showing 13 changed files with 326 additions and 483 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.opensearch.index.store.remote.utils.cache.SegmentedCache;
import org.opensearch.index.store.remote.utils.cache.stats.CacheStats;
import java.nio.file.Path;
import java.util.Map;
import java.util.function.BiFunction;

/**
Expand Down Expand Up @@ -45,15 +44,11 @@ public long capacity() {
return theCache.capacity();
}

@Override
public CachedIndexInput put(Path filePath, CachedIndexInput indexInput) {
return theCache.put(filePath, indexInput);
}

@Override
public void putAll(Map<? extends Path, ? extends CachedIndexInput> m) {
theCache.putAll(m);
}

@Override
public CachedIndexInput computeIfPresent(
Path key,
Expand Down Expand Up @@ -84,11 +79,6 @@ public void remove(final Path filePath) {
theCache.remove(filePath);
}

@Override
public void removeAll(Iterable<? extends Path> keys) {
theCache.removeAll(keys);
}

@Override
public void clear() {
theCache.clear();
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Objects;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;

/**
* This acts as entry point to fetch {@link BlobFetchRequest} and return actual {@link IndexInput}. Utilizes the BlobContainer interface to
Expand All @@ -33,13 +36,12 @@
public class TransferManager {
private static final Logger logger = LogManager.getLogger(TransferManager.class);

private final ConcurrentMap<Path, CountDownLatch> latchMap = new ConcurrentHashMap<>();
private final BlobContainer blobContainer;
private final ConcurrentInvocationLinearizer<Path, IndexInput> invocationLinearizer;
private final FileCache fileCache;

public TransferManager(final BlobContainer blobContainer, final FileCache fileCache) {
this.blobContainer = blobContainer;
this.invocationLinearizer = new ConcurrentInvocationLinearizer<>();
this.fileCache = fileCache;
}

Expand All @@ -49,22 +51,13 @@ public TransferManager(final BlobContainer blobContainer, final FileCache fileCa
* @return future of IndexInput augmented with internal caching maintenance tasks
*/
public IndexInput fetchBlob(BlobFetchRequest blobFetchRequest) throws InterruptedException, IOException {
final IndexInput indexInput = invocationLinearizer.linearize(
blobFetchRequest.getFilePath(),
p -> fetchOriginBlob(blobFetchRequest)
);
return indexInput.clone();
return fetchBlobInternal(blobFetchRequest);
}

/**
* Fetches the "origin" IndexInput from the cache, downloading it first if it is
* not already cached. This instance must be cloned before using. This method is
* accessed through the ConcurrentInvocationLinearizer so read-check-write is
* acceptable here
*/
private IndexInput fetchOriginBlob(BlobFetchRequest blobFetchRequest) throws IOException {
private IndexInput fetchBlobInternal(BlobFetchRequest blobFetchRequest) throws InterruptedException, IOException {
final Path key = blobFetchRequest.getFilePath();
// check if the origin is already in block cache
IndexInput origin = fileCache.computeIfPresent(blobFetchRequest.getFilePath(), (path, cachedIndexInput) -> {
IndexInput origin = fileCache.computeIfPresent(key, (path, cachedIndexInput) -> {
if (cachedIndexInput.isClosed()) {
// if it's already in the file cache, but closed, open it and replace the original one
try {
Expand All @@ -81,20 +74,38 @@ private IndexInput fetchOriginBlob(BlobFetchRequest blobFetchRequest) throws IOE
return cachedIndexInput;
});

if (Objects.isNull(origin)) {
// origin is not in file cache, download origin

// open new origin
IndexInput downloaded = downloadBlockLocally(blobFetchRequest);

// refcount = 0 at the beginning
FileCachedIndexInput newOrigin = new FileCachedIndexInput(fileCache, blobFetchRequest.getFilePath(), downloaded);
if (origin == null) {
final CountDownLatch existingLatch = latchMap.putIfAbsent(key, new CountDownLatch(1));
if (existingLatch != null) {
// Another thread is downloading the same resource. Wait for it
// to complete then make a recursive call to fetch it from the
// cache.
existingLatch.await();
return fetchBlobInternal(blobFetchRequest);
} else {
// Origin is not in file cache, download origin and put in cache
// We've effectively taken a lock for this key by inserting a
// latch into the concurrent map, so we must be sure to remove it
// and count it down before leaving.
try {
IndexInput downloaded = downloadBlockLocally(blobFetchRequest);
FileCachedIndexInput newOrigin = new FileCachedIndexInput(fileCache, blobFetchRequest.getFilePath(), downloaded);
fileCache.put(key, newOrigin);
origin = newOrigin;
} finally {
latchMap.remove(key).countDown();
}
}
}

// put origin into file cache
fileCache.put(blobFetchRequest.getFilePath(), newOrigin);
origin = newOrigin;
// Origin was either retrieved from the cache or newly added, either
// way the reference count has been incremented by one. We can only
// decrement this reference _after_ creating the clone to be returned.
try {
return origin.clone();
} finally {
fileCache.decRef(key);
}
return origin;
}

private IndexInput downloadBlockLocally(BlobFetchRequest blobFetchRequest) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.opensearch.index.store.remote.utils.cache.stats.StatsCounter;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.BiFunction;
Expand Down Expand Up @@ -137,9 +136,7 @@ public V get(K key) {
return null;
}
// hit
if (node.evictable()) {
lru.moveToBack(node);
}
incRef(key);
statsCounter.recordHits(key, 1);
return node.value;
} finally {
Expand Down Expand Up @@ -172,20 +169,18 @@ public V put(K key, V value) {
if (node.refCount > 0) {
activeUsage += weightDiff;
}
if (node.evictable()) {
lru.moveToBack(node);
}
usage += weightDiff;
// call listeners
statsCounter.recordReplacement();
listener.onRemoval(new RemovalNotification<>(key, oldValue, RemovalReason.REPLACED));
incRef(key);
evict();
return oldValue;
} else {
Node<K, V> newNode = new Node<>(key, value, weight);
data.put(key, newNode);
lru.add(newNode);
usage += weight;
incRef(key);
evict();
return null;
}
Expand All @@ -194,12 +189,6 @@ public V put(K key, V value) {
}
}

@Override
public void putAll(Map<? extends K, ? extends V> m) {
for (Map.Entry<? extends K, ? extends V> e : m.entrySet())
put(e.getKey(), e.getValue());
}

@Override
public V computeIfPresent(K key, BiFunction<? super K, ? super V, ? extends V> remappingFunction) {
Objects.requireNonNull(key);
Expand All @@ -219,9 +208,6 @@ public V computeIfPresent(K key, BiFunction<? super K, ? super V, ? extends V> r

// update usage
final long weightDiff = weight - oldWeight;
if (node.evictable()) {
lru.moveToBack(node);
}

if (node.refCount > 0) {
activeUsage += weightDiff;
Expand All @@ -233,6 +219,7 @@ public V computeIfPresent(K key, BiFunction<? super K, ? super V, ? extends V> r
statsCounter.recordReplacement();
listener.onRemoval(new RemovalNotification<>(node.key, oldValue, RemovalReason.REPLACED));
}
incRef(key);
evict();
return v;
} else {
Expand Down Expand Up @@ -280,13 +267,6 @@ public void remove(K key) {
}
}

@Override
public void removeAll(Iterable<? extends K> keys) {
for (K key : keys) {
remove(key);
}
}

@Override
public void clear() {
final ReentrantLock lock = this.lock;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,22 @@

import org.opensearch.index.store.remote.utils.cache.stats.CacheStats;

import java.util.Map;
import java.util.function.BiFunction;

/**
* Custom Cache which support typical cache operations (put, get, ...) and it support reference counting per individual key which might
* change eviction behavior
* @param <K> type of the key
* @param <V> type of th value
* @param <V> type of the value
*
* @opensearch.internal
*/
public interface RefCountedCache<K, V> {

/**
* Returns the value associated with {@code key} in this cache, or {@code null} if there is no
* cached value for {@code key}.
* cached value for {@code key}. Retrieving an item automatically increases its reference
* count.
*/
V get(K key);

Expand All @@ -35,17 +35,12 @@ public interface RefCountedCache<K, V> {
*/
V put(K key, V value);

/**
* Copies all the mappings from the specified map to the cache. The effect of this call is
* equivalent to that of calling {@code put(k, v)} on this map once for each mapping from key
* {@code k} to value {@code v} in the specified map. The behavior of this operation is undefined
* if the specified map is modified while the operation is in progress.
*/
void putAll(Map<? extends K, ? extends V> m);

/**
* If the specified key is already associated with a value, attempts to update its value using the given mapping
* function and enters the new value into this map unless null.
* function and enters the new value. If the mapping function returns null the item is removed from the
* cache, regardless of its reference count. If the mapping function returns non-null the value is updated.
* The new entry will have the reference count of the previous entry plus one, as this method automatically
* increases the reference count by one when it returns the newly mapped value.
*
* If the specified key is NOT already associated with a value, return null without applying the mapping function.
*
Expand All @@ -54,17 +49,12 @@ public interface RefCountedCache<K, V> {
V computeIfPresent(K key, BiFunction<? super K, ? super V, ? extends V> remappingFunction);

/**
* Discards any cached value for key {@code key}.
* Discards any cached value for key {@code key}, regardless of reference count.
*/
void remove(K key);

/**
* Discards any cached values for keys {@code keys}.
*/
void removeAll(Iterable<? extends K> keys);

/**
* Discards all entries in the cache.
* Discards all entries in the cache, regardless of reference count.
*/
void clear();

Expand All @@ -83,6 +73,11 @@ public interface RefCountedCache<K, V> {
*/
void decRef(K key);

/**
* Removes all cache entries with a reference count of zero, regardless of current capacity.
*
* @return The total weight of all removed entries.
*/
long prune();

/**
Expand Down
Loading

0 comments on commit c958df2

Please sign in to comment.