Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OnSubscribeRedo - fix race conditions #2930

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 224 additions & 36 deletions src/main/java/rx/internal/operators/OnSubscribeRedo.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

import rx.Notification;
import rx.Observable;
Expand All @@ -47,13 +46,14 @@
import rx.functions.Action0;
import rx.functions.Func1;
import rx.functions.Func2;
import rx.observers.Subscribers;
import rx.schedulers.Schedulers;
import rx.subjects.PublishSubject;
import rx.subjects.BehaviorSubject;
import rx.subscriptions.SerialSubscription;

public final class OnSubscribeRedo<T> implements OnSubscribe<T> {

static final Func1<Observable<? extends Notification<?>>, Observable<?>> REDO_INIFINITE = new Func1<Observable<? extends Notification<?>>, Observable<?>>() {
static final Func1<Observable<? extends Notification<?>>, Observable<?>> REDO_INFINITE = new Func1<Observable<? extends Notification<?>>, Observable<?>>() {
@Override
public Observable<?> call(Observable<? extends Notification<?>> ts) {
return ts.map(new Func1<Notification<?>, Notification<?>>() {
Expand Down Expand Up @@ -120,7 +120,7 @@ public Notification<Integer> call(Notification<Integer> n, Notification<?> term)
}

public static <T> Observable<T> retry(Observable<T> source) {
return retry(source, REDO_INIFINITE);
return retry(source, REDO_INFINITE);
}

public static <T> Observable<T> retry(Observable<T> source, final long count) {
Expand All @@ -144,7 +144,7 @@ public static <T> Observable<T> repeat(Observable<T> source) {
}

public static <T> Observable<T> repeat(Observable<T> source, Scheduler scheduler) {
return repeat(source, REDO_INIFINITE, scheduler);
return repeat(source, REDO_INFINITE, scheduler);
}

public static <T> Observable<T> repeat(Observable<T> source, final long count) {
Expand Down Expand Up @@ -172,10 +172,10 @@ public static <T> Observable<T> redo(Observable<T> source, Func1<? super Observa
return create(new OnSubscribeRedo<T>(source, notificationHandler, false, false, scheduler));
}

private Observable<T> source;
private final Observable<T> source;
private final Func1<? super Observable<? extends Notification<?>>, ? extends Observable<?>> controlHandlerFunction;
private boolean stopOnComplete;
private boolean stopOnError;
private final boolean stopOnComplete;
private final boolean stopOnError;
private final Scheduler scheduler;

private OnSubscribeRedo(Observable<T> source, Func1<? super Observable<? extends Notification<?>>, ? extends Observable<?>> f, boolean stopOnComplete, boolean stopOnError,
Expand All @@ -189,20 +189,31 @@ private OnSubscribeRedo(Observable<T> source, Func1<? super Observable<? extends

@Override
public void call(final Subscriber<? super T> child) {
final AtomicBoolean isLocked = new AtomicBoolean(true);

// when true is a marker to say we are ready to resubscribe to source
final AtomicBoolean resumeBoundary = new AtomicBoolean(true);

// incremented when requests are made, decremented when requests are fulfilled
final AtomicLong consumerCapacity = new AtomicLong(0l);
final AtomicReference<Producer> currentProducer = new AtomicReference<Producer>();

final Scheduler.Worker worker = scheduler.createWorker();
child.add(worker);

final SerialSubscription sourceSubscriptions = new SerialSubscription();
child.add(sourceSubscriptions);

final PublishSubject<Notification<?>> terminals = PublishSubject.create();

// use a subject to receive terminals (onCompleted and onError signals) from
// the source observable. We use a BehaviorSubject because subscribeToSource
// may emit a terminal before the restarts observable (transformed terminals)
// is subscribed
final BehaviorSubject<Notification<?>> terminals = BehaviorSubject.create();
final Subscriber<Notification<?>> dummySubscriber = Subscribers.empty();
// subscribe immediately so the last emission will be replayed to the next
// subscriber (which is the one we care about)
terminals.subscribe(dummySubscriber);

final ProducerArbiter arbiter = new ProducerArbiter();

final Action0 subscribeToSource = new Action0() {
@Override
public void call() {
Expand All @@ -212,11 +223,11 @@ public void call() {

Subscriber<T> terminalDelegatingSubscriber = new Subscriber<T>() {
boolean done;

@Override
public void onCompleted() {
if (!done) {
done = true;
currentProducer.set(null);
unsubscribe();
terminals.onNext(Notification.createOnCompleted());
}
Expand All @@ -226,7 +237,6 @@ public void onCompleted() {
public void onError(Throwable e) {
if (!done) {
done = true;
currentProducer.set(null);
unsubscribe();
terminals.onNext(Notification.createOnError(e));
}
Expand All @@ -235,20 +245,30 @@ public void onError(Throwable e) {
@Override
public void onNext(T v) {
if (!done) {
if (consumerCapacity.get() != Long.MAX_VALUE) {
consumerCapacity.decrementAndGet();
}
child.onNext(v);
decrementConsumerCapacity();
arbiter.produced(1);
}
}

private void decrementConsumerCapacity() {
// use a CAS loop because we don't want to decrement the
// value if it is Long.MAX_VALUE
while (true) {
long cc = consumerCapacity.get();
if (cc != Long.MAX_VALUE) {
if (consumerCapacity.compareAndSet(cc, cc - 1)) {
break;
}
} else {
break;
}
}
}

@Override
public void setProducer(Producer producer) {
currentProducer.set(producer);
long c = consumerCapacity.get();
if (c > 0) {
producer.request(c);
}
arbiter.setProducer(producer);
}
};
// new subscription each time so if it unsubscribes itself it does not prevent retries
Expand Down Expand Up @@ -278,12 +298,11 @@ public void onError(Throwable e) {

@Override
public void onNext(Notification<?> t) {
if (t.isOnCompleted() && stopOnComplete)
child.onCompleted();
else if (t.isOnError() && stopOnError)
child.onError(t.getThrowable());
else {
isLocked.set(false);
if (t.isOnCompleted() && stopOnComplete) {
filteredTerminals.onCompleted();
} else if (t.isOnError() && stopOnError) {
filteredTerminals.onError(t.getThrowable());
} else {
filteredTerminals.onNext(t);
}
}
Expand Down Expand Up @@ -313,10 +332,15 @@ public void onError(Throwable e) {

@Override
public void onNext(Object t) {
if (!isLocked.get() && !child.isUnsubscribed()) {
if (!child.isUnsubscribed()) {
// perform a best endeavours check on consumerCapacity
// with the intent of only resubscribing immediately
// if there is outstanding capacity
if (consumerCapacity.get() > 0) {
worker.schedule(subscribeToSource);
} else {
// set this to true so that on next request
// subscribeToSource will be scheduled
resumeBoundary.compareAndSet(false, true);
}
}
Expand All @@ -334,16 +358,180 @@ public void setProducer(Producer producer) {

@Override
public void request(final long n) {
long c = BackpressureUtils.getAndAddRequest(consumerCapacity, n);
Producer producer = currentProducer.get();
if (producer != null) {
producer.request(n);
} else
if (c == 0 && resumeBoundary.compareAndSet(true, false)) {
worker.schedule(subscribeToSource);
if (n > 0) {
BackpressureUtils.getAndAddRequest(consumerCapacity, n);
arbiter.request(n);
if (resumeBoundary.compareAndSet(true, false))
worker.schedule(subscribeToSource);
}
}
});

}

/**
* Between when the source subscription finishes and the next subscription starts requests may arrive.
* ProducerArbiter keeps track of all requests made and all arriving emissions so that when setProducer
* is called for a new subscription the appropriate number of requests are made to the new producer.
*/
static final class ProducerArbiter implements Producer {
/** Guarded by this. */
boolean emitting;
/** The current producer. Accessed while emitting. */
Producer currentProducer;
/** The current requested count. */
long requested;

long missedRequested;
Producer missedProducer;
long missedProd;

@Override
public void request(long n) {
if (n <= 0) {
return;
}
Producer mp;
long mprod;
synchronized (this) {
if (emitting) {
missedRequested += n;
return;
}
emitting = true;
mp = missedProducer;
mprod = missedProd;

missedProducer = null;
missedProd = 0L;
}

boolean skipFinal = false;
try {
emit(n, mp, mprod);
drainLoop();
skipFinal = true;
} finally {
if (!skipFinal) {
synchronized (this) {
emitting = false;
}
}
}
}
void setProducer(Producer p) {
if (p == null) {
throw new NullPointerException();
}

long mreq;
long mprod;
synchronized (this) {
if (emitting) {
missedProducer = p;
return;
}
emitting = true;
mreq = missedRequested;
mprod = missedProd;

missedRequested = 0L;
missedProd = 0L;
}

boolean skipFinal = false;
try {
emit(mreq, p, mprod);
drainLoop();
skipFinal = true;
} finally {
if (!skipFinal) {
synchronized (this) {
emitting = false;
}
}
}
}
void produced(long n) {
if (n <= 0) {
throw new IllegalArgumentException(n + " produced?!");
}

long mreq;
Producer mp;
synchronized (this) {
if (emitting) {
missedProd += n;
return;
}
emitting = true;
mreq = missedRequested;
mp = missedProducer;

missedRequested = 0L;
missedProducer = null;
}

boolean skipFinal = false;
try {
emit(mreq, mp, n);
drainLoop();
skipFinal = true;
} finally {
if (!skipFinal) {
synchronized (this) {
emitting = false;
}
}
}
}
void drainLoop() {
for (;;) {
long mreq;
long mprod;
Producer mp;
synchronized (this) {
mreq = missedRequested;
mprod = missedProd;
mp = missedProducer;
if (mreq == 0L && mp == null && mprod == 0L) {
emitting = false;
return;
}
missedRequested = 0L;
missedProd = 0L;
missedProducer = null;
}
emit(mreq, mp, mprod);
}
}
void emit(long mreq, Producer mp, long mprod) {
boolean newMp = false;
if (mp != null) {
newMp = true;
currentProducer = mp;
} else {
mp = currentProducer;
}

long u = requested + mreq;
if (u < 0) {
u = Long.MAX_VALUE;
} else
if (u != Long.MAX_VALUE) {
u -= mprod;
if (u < 0) {
throw new IllegalStateException("More produced than requested");
}
}
requested = u;

if (mreq > 0 && mp != null) {
mp.request(mreq);
} else
if (newMp && u > 0) {
mp.request(u);
}
}
}
}
Loading