Skip to content

Commit

Permalink
Remote: Fix a race that AsyncTaskCache#Execution could be reused afte…
Browse files Browse the repository at this point in the history
…r disposed which results in CancellationException("disposed") propagated to downstream.

Also added a test case to verify the fix.

PiperOrigin-RevId: 364699975
  • Loading branch information
coeuvre authored and philwo committed Apr 20, 2021
1 parent 8e56b94 commit 6b55ea2
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.concurrent.GuardedBy;
Expand Down Expand Up @@ -54,7 +55,7 @@ public final class AsyncTaskCache<KeyT, ValueT> {
private final Map<KeyT, ValueT> finished;

@GuardedBy("lock")
private final Map<KeyT, Execution> inProgress;
private final Map<KeyT, Execution<ValueT>> inProgress;

public static <KeyT, ValueT> AsyncTaskCache<KeyT, ValueT> create() {
return new AsyncTaskCache<>();
Expand Down Expand Up @@ -90,18 +91,22 @@ public Single<ValueT> executeIfNot(KeyT key, Single<ValueT> task) {
return execute(key, task, false);
}

private class Execution {
private static class Execution<ValueT> {
private final AtomicBoolean isTaskDisposed = new AtomicBoolean(false);
private final Single<ValueT> task;
private final AsyncSubject<ValueT> asyncSubject = AsyncSubject.create();
private final AtomicInteger subscriberCount = new AtomicInteger(0);
private final AtomicInteger referenceCount = new AtomicInteger(0);
private final AtomicReference<Disposable> taskDisposable = new AtomicReference<>(null);

Execution(Single<ValueT> task) {
this.task = task;
}

public Single<ValueT> start() {
if (taskDisposable.get() == null) {
Single<ValueT> executeIfNot() {
checkState(!isTaskDisposed(), "disposed");

int subscribed = referenceCount.getAndIncrement();
if (taskDisposable.get() == null && subscribed == 0) {
task.subscribe(
new SingleObserver<ValueT>() {
@Override
Expand All @@ -122,27 +127,39 @@ public void onError(@NonNull Throwable e) {
});
}

return Single.fromObservable(asyncSubject)
.doOnSubscribe(d -> subscriberCount.incrementAndGet())
.doOnDispose(
() -> {
if (subscriberCount.decrementAndGet() == 0) {
Disposable d = taskDisposable.get();
if (d != null) {
d.dispose();
}
asyncSubject.onError(new CancellationException("disposed"));
}
});
return Single.fromObservable(asyncSubject);
}

boolean isTaskTerminated() {
return asyncSubject.hasComplete() || asyncSubject.hasThrowable();
}

boolean isTaskDisposed() {
return isTaskDisposed.get();
}

void tryDisposeTask() {
checkState(!isTaskDisposed(), "disposed");
checkState(!isTaskTerminated(), "terminated");

if (referenceCount.decrementAndGet() == 0) {
isTaskDisposed.set(true);
asyncSubject.onError(new CancellationException("disposed"));

Disposable d = taskDisposable.get();
if (d != null) {
d.dispose();
}
}
}
}

/** Returns count of subscribers for a task. */
public int getSubscriberCount(KeyT key) {
synchronized (lock) {
Execution execution = inProgress.get(key);
Execution<ValueT> execution = inProgress.get(key);
if (execution != null) {
return execution.subscriberCount.get();
return execution.referenceCount.get();
}
}

Expand All @@ -158,49 +175,72 @@ public int getSubscriberCount(KeyT key) {
* error if any.
*/
public Single<ValueT> execute(KeyT key, Single<ValueT> task, boolean force) {
return Single.defer(
() -> {
return Single.create(
emitter -> {
synchronized (lock) {
if (!force && finished.containsKey(key)) {
return Single.just(finished.get(key));
emitter.onSuccess(finished.get(key));
return;
}

finished.remove(key);

Execution execution =
Execution<ValueT> execution =
inProgress.computeIfAbsent(
key,
missingKey -> {
ignoredKey -> {
AtomicInteger subscribeTimes = new AtomicInteger(0);
return new Execution(
return new Execution<>(
Single.defer(
() -> {
int times = subscribeTimes.incrementAndGet();
checkState(times == 1, "Subscribed more than once to the task");
return task;
})
.doOnSuccess(
value -> {
synchronized (lock) {
finished.put(key, value);
inProgress.remove(key);
}
})
.doOnError(
error -> {
synchronized (lock) {
inProgress.remove(key);
}
})
.doOnDispose(
() -> {
synchronized (lock) {
inProgress.remove(key);
}
}));
() -> {
int times = subscribeTimes.incrementAndGet();
checkState(times == 1, "Subscribed more than once to the task");
return task;
}));
});

return execution.start();
execution
.executeIfNot()
.subscribe(
new SingleObserver<ValueT>() {
@Override
public void onSubscribe(@NonNull Disposable d) {
emitter.setCancellable(
() -> {
d.dispose();

if (!execution.isTaskTerminated()) {
synchronized (lock) {
execution.tryDisposeTask();
if (execution.isTaskDisposed()) {
inProgress.remove(key);
}
}
}
});
}

@Override
public void onSuccess(@NonNull ValueT value) {
synchronized (lock) {
finished.put(key, value);
inProgress.remove(key);
}

emitter.onSuccess(value);
}

@Override
public void onError(@NonNull Throwable e) {
synchronized (lock) {
inProgress.remove(key);
}

if (!emitter.isDisposed()) {
emitter.onError(e);
}
}
});
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@
package com.google.devtools.build.lib.remote.util;

import static com.google.common.truth.Truth.assertThat;
import static java.util.concurrent.TimeUnit.SECONDS;

import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.core.SingleEmitter;
import io.reactivex.rxjava3.observers.TestObserver;
import io.reactivex.rxjava3.plugins.RxJavaPlugins;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand All @@ -32,21 +31,7 @@
@RunWith(JUnit4.class)
public class AsyncTaskCacheTest {

private final AtomicReference<Throwable> rxGlobalThrowable = new AtomicReference<>(null);

@Before
public void setUp() {
RxJavaPlugins.setErrorHandler(rxGlobalThrowable::set);
}

@After
public void tearDown() throws Throwable {
// Make sure rxjava didn't receive global errors
Throwable t = rxGlobalThrowable.getAndSet(null);
if (t != null) {
throw t;
}
}
@Rule public final RxNoGlobalErrorsRule rxNoGlobalErrorsRule = new RxNoGlobalErrorsRule();

@Test
public void execute_noSubscription_noExecution() {
Expand Down Expand Up @@ -296,4 +281,44 @@ public void execute_multipleTasks_completeOne() {
assertThat(cache.getInProgressTasks()).containsExactly("key2");
assertThat(cache.getFinishedTasks()).containsExactly("key1");
}

@Test
public void execute_executeAndDisposeLoop_noErrors() throws InterruptedException {
AsyncTaskCache<String, Long> cache = AsyncTaskCache.create();
Single<Long> task = Single.timer(1, SECONDS);
AtomicReference<Throwable> error = new AtomicReference<>(null);
AtomicInteger errorCount = new AtomicInteger(0);
int executionCount = 100;
Runnable runnable =
() -> {
try {
for (int i = 0; i < executionCount; ++i) {
TestObserver<Long> observer = cache.execute("key1", task, true).test();
observer.assertNoErrors();
observer.dispose();
}
} catch (Throwable t) {
errorCount.incrementAndGet();
error.set(t);
}
};
int threadCount = 10;
Thread[] threads = new Thread[threadCount];
for (int i = 0; i < threadCount; ++i) {
Thread thread = new Thread(runnable);
threads[i] = thread;
}

for (Thread thread : threads) {
thread.start();
}
for (Thread thread : threads) {
thread.join();
}

if (error.get() != null) {
throw new IllegalStateException(
String.format("%s/%s errors", errorCount.get(), threadCount), error.get());
}
}
}

0 comments on commit 6b55ea2

Please sign in to comment.