From bb463f95eea0f3cf0f9d87da3cad110c156ecdc4 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 8 Aug 2014 13:07:57 +0800 Subject: [PATCH] SwitchOnNext with backpressure support --- .../rx/internal/operators/OperatorSwitch.java | 403 +++++++++++------- .../operators/OperatorSwitchTest.java | 117 ++++- 2 files changed, 352 insertions(+), 168 deletions(-) diff --git a/rxjava-core/src/main/java/rx/internal/operators/OperatorSwitch.java b/rxjava-core/src/main/java/rx/internal/operators/OperatorSwitch.java index 1c0b143b1c..4c5c788053 100644 --- a/rxjava-core/src/main/java/rx/internal/operators/OperatorSwitch.java +++ b/rxjava-core/src/main/java/rx/internal/operators/OperatorSwitch.java @@ -19,6 +19,7 @@ import java.util.List; import rx.Observable; import rx.Observable.Operator; +import rx.Producer; import rx.Subscriber; import rx.observers.SerializedSubscriber; import rx.subscriptions.SerialSubscription; @@ -35,198 +36,272 @@ public final class OperatorSwitch implements Operator> call(final Subscriber child) { - final SerializedSubscriber s = new SerializedSubscriber(child); - final SerialSubscription ssub = new SerialSubscription(); - child.add(ssub); - - return new Subscriber>(child) { - final Object guard = new Object(); - final NotificationLite nl = NotificationLite.instance(); - /** Guarded by guard. */ - int index; - /** Guarded by guard. */ - boolean active; - /** Guarded by guard. */ - boolean mainDone; - /** Guarded by guard. */ - List queue; - /** Guarded by guard. */ - boolean emitting; - @Override - public void onNext(Observable t) { - final int id; - synchronized (guard) { - id = ++index; - active = true; - } - - Subscriber sub = new Subscriber() { + return new SwitchSubscriber(child); + } - @Override - public void onNext(T t) { - emit(t, id); - } + private static final class SwitchSubscriber extends Subscriber> { + final SerializedSubscriber s; + final SerialSubscription ssub; + final Object guard = new Object(); + final NotificationLite nl = NotificationLite.instance(); + /** Guarded by guard. */ + int index; + /** Guarded by guard. */ + boolean active; + /** Guarded by guard. */ + boolean mainDone; + /** Guarded by guard. */ + List queue; + /** Guarded by guard. */ + boolean emitting; + /** Guarded by guard. */ + InnerSubscriber currentSubscriber; + /** Guarded by guard. */ + long initialRequested; - @Override - public void onError(Throwable e) { - error(e, id); - } + volatile boolean infinite = false; - @Override - public void onCompleted() { - complete(id); - } - - }; - ssub.set(sub); - - t.unsafeSubscribe(sub); - } - - @Override - public void onError(Throwable e) { - s.onError(e); - unsubscribe(); - } + public SwitchSubscriber(Subscriber child) { + s = new SerializedSubscriber(child); + ssub = new SerialSubscription(); + child.add(ssub); + child.setProducer(new Producer(){ - @Override - public void onCompleted() { - List localQueue; - synchronized (guard) { - mainDone = true; - if (active) { + @Override + public void request(long n) { + if (infinite) { return; } - if (emitting) { - if (queue == null) { - queue = new ArrayList(); + if(n == Long.MAX_VALUE) { + infinite = true; + } + InnerSubscriber localSubscriber; + synchronized (guard) { + localSubscriber = currentSubscriber; + if (currentSubscriber == null) { + initialRequested = n; + } else { + // If n == Long.MAX_VALUE, infinite will become true. Then currentSubscriber.requested won't be used. + // Therefore we don't need to worry about overflow. + currentSubscriber.requested += n; } - queue.add(nl.completed()); - return; } - localQueue = queue; - queue = null; - emitting = true; + if (localSubscriber != null) { + localSubscriber.requestMore(n); + } } - drain(localQueue); - s.onCompleted(); - unsubscribe(); + }); + } + + @Override + public void onNext(Observable t) { + final int id; + long remainingRequest; + synchronized (guard) { + id = ++index; + active = true; + if (infinite) { + remainingRequest = Long.MAX_VALUE; + } else { + remainingRequest = currentSubscriber == null ? initialRequested : currentSubscriber.requested; + } + currentSubscriber = new InnerSubscriber(id, remainingRequest); + currentSubscriber.requested = remainingRequest; } - void emit(T value, int id) { - List localQueue; - synchronized (guard) { - if (id != index) { - return; + ssub.set(currentSubscriber); + + t.unsafeSubscribe(currentSubscriber); + } + + @Override + public void onError(Throwable e) { + s.onError(e); + unsubscribe(); + } + + @Override + public void onCompleted() { + List localQueue; + synchronized (guard) { + mainDone = true; + if (active) { + return; + } + if (emitting) { + if (queue == null) { + queue = new ArrayList(); } - if (emitting) { - if (queue == null) { - queue = new ArrayList(); - } - queue.add(value); - return; + queue.add(nl.completed()); + return; + } + localQueue = queue; + queue = null; + emitting = true; + } + drain(localQueue); + s.onCompleted(); + unsubscribe(); + } + void emit(T value, int id, InnerSubscriber innerSubscriber) { + List localQueue; + synchronized (guard) { + if (id != index) { + return; + } + if (emitting) { + if (queue == null) { + queue = new ArrayList(); } - localQueue = queue; - queue = null; - emitting = true; + innerSubscriber.requested--; + queue.add(value); + return; } - boolean once = true; - boolean skipFinal = false; - try { - do { - drain(localQueue); - if (once) { - once = false; - s.onNext(value); - } + localQueue = queue; + queue = null; + emitting = true; + } + boolean once = true; + boolean skipFinal = false; + try { + do { + drain(localQueue); + if (once) { + once = false; synchronized (guard) { - localQueue = queue; - queue = null; - if (localQueue == null) { - emitting = false; - skipFinal = true; - break; - } + innerSubscriber.requested--; } - } while (!s.isUnsubscribed()); - } finally { - if (!skipFinal) { - synchronized (guard) { + s.onNext(value); + } + synchronized (guard) { + localQueue = queue; + queue = null; + if (localQueue == null) { emitting = false; + skipFinal = true; + break; } } + } while (!s.isUnsubscribed()); + } finally { + if (!skipFinal) { + synchronized (guard) { + emitting = false; + } } } - void drain(List localQueue) { - if (localQueue == null) { + } + void drain(List localQueue) { + if (localQueue == null) { + return; + } + for (Object o : localQueue) { + if (nl.isCompleted(o)) { + s.onCompleted(); + break; + } else + if (nl.isError(o)) { + s.onError(nl.getError(o)); + break; + } else { + @SuppressWarnings("unchecked") + T t = (T)o; + s.onNext(t); + } + } + } + + void error(Throwable e, int id) { + List localQueue; + synchronized (guard) { + if (id != index) { return; } - for (Object o : localQueue) { - if (nl.isCompleted(o)) { - s.onCompleted(); - break; - } else - if (nl.isError(o)) { - s.onError(nl.getError(o)); - break; - } else { - @SuppressWarnings("unchecked") - T t = (T)o; - s.onNext(t); + if (emitting) { + if (queue == null) { + queue = new ArrayList(); } + queue.add(nl.error(e)); + return; } + + localQueue = queue; + queue = null; + emitting = true; } - - void error(Throwable e, int id) { - List localQueue; - synchronized (guard) { - if (id != index) { - return; - } - if (emitting) { - if (queue == null) { - queue = new ArrayList(); - } - queue.add(nl.error(e)); - return; - } - - localQueue = queue; - queue = null; - emitting = true; + + drain(localQueue); + s.onError(e); + unsubscribe(); + } + void complete(int id) { + List localQueue; + synchronized (guard) { + if (id != index) { + return; } - - drain(localQueue); - s.onError(e); - unsubscribe(); - } - void complete(int id) { - List localQueue; - synchronized (guard) { - if (id != index) { - return; - } - active = false; - if (!mainDone) { - return; - } - if (emitting) { - if (queue == null) { - queue = new ArrayList(); - } - queue.add(nl.completed()); - return; + active = false; + if (!mainDone) { + return; + } + if (emitting) { + if (queue == null) { + queue = new ArrayList(); } - - localQueue = queue; - queue = null; - emitting = true; + queue.add(nl.completed()); + return; } - - drain(localQueue); - s.onCompleted(); - unsubscribe(); + + localQueue = queue; + queue = null; + emitting = true; + } + + drain(localQueue); + s.onCompleted(); + unsubscribe(); + } + + final class InnerSubscriber extends Subscriber { + + /** + * The number of request that is not acknowledged. + * + * Guarded by guard. + */ + private long requested = 0; + + private final int id; + + private final long initialRequested; + + public InnerSubscriber(int id, long initialRequested) { + this.id = id; + this.initialRequested = initialRequested; + } + + @Override + public void onStart() { + requestMore(initialRequested); + } + + public void requestMore(long n) { + request(n); + } + + @Override + public void onNext(T t) { + emit(t, id, this); + } + + @Override + public void onError(Throwable e) { + error(e, id); + } + + @Override + public void onCompleted() { + complete(id); } - }; + } } - } diff --git a/rxjava-core/src/test/java/rx/internal/operators/OperatorSwitchTest.java b/rxjava-core/src/test/java/rx/internal/operators/OperatorSwitchTest.java index ddb12c78f2..88c443aa04 100644 --- a/rxjava-core/src/test/java/rx/internal/operators/OperatorSwitchTest.java +++ b/rxjava-core/src/test/java/rx/internal/operators/OperatorSwitchTest.java @@ -23,18 +23,17 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import java.util.Arrays; import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Test; import org.mockito.InOrder; -import rx.Observable; -import rx.Observer; -import rx.Scheduler; -import rx.Subscriber; +import rx.*; import rx.exceptions.TestException; import rx.functions.Action0; +import rx.observers.TestSubscriber; import rx.schedulers.TestScheduler; public class OperatorSwitchTest { @@ -404,4 +403,114 @@ public void call(Subscriber observer) { inOrder.verify(observer, times(1)).onCompleted(); inOrder.verifyNoMoreInteractions(); } + + @Test + public void testBackpressure() { + final Observable o1 = Observable.create(new Observable.OnSubscribe() { + @Override + public void call(final Subscriber observer) { + observer.setProducer(new Producer() { + + private int emitted = 0; + + @Override + public void request(long n) { + for(int i = 0; i < n && emitted < 10 && !observer.isUnsubscribed(); i++) { + scheduler.advanceTimeBy(5, TimeUnit.MILLISECONDS); + emitted++; + observer.onNext("a" + emitted); + } + if(emitted == 10) { + observer.onCompleted(); + } + } + }); + } + }); + final Observable o2 = Observable.create(new Observable.OnSubscribe() { + @Override + public void call(final Subscriber observer) { + observer.setProducer(new Producer() { + + private int emitted = 0; + + @Override + public void request(long n) { + for(int i = 0; i < n && emitted < 10 && !observer.isUnsubscribed(); i++) { + scheduler.advanceTimeBy(5, TimeUnit.MILLISECONDS); + emitted++; + observer.onNext("b" + emitted); + } + if(emitted == 10) { + observer.onCompleted(); + } + } + }); + } + }); + final Observable o3 = Observable.create(new Observable.OnSubscribe() { + @Override + public void call(final Subscriber observer) { + observer.setProducer(new Producer() { + + private int emitted = 0; + + @Override + public void request(long n) { + for(int i = 0; i < n && emitted < 10 && !observer.isUnsubscribed(); i++) { + emitted++; + observer.onNext("c" + emitted); + } + if(emitted == 10) { + observer.onCompleted(); + } + } + }); + } + }); + Observable> o = Observable.create(new Observable.OnSubscribe>() { + @Override + public void call(Subscriber> observer) { + publishNext(observer, 10, o1); + publishNext(observer, 20, o2); + publishNext(observer, 30, o3); + publishCompleted(observer, 30); + } + }); + final TestSubscriber testSubscriber = new TestSubscriber(); + Observable.switchOnNext(o).subscribe(new Subscriber() { + + private int requested = 0; + + @Override + public void onStart() { + requested = 3; + request(3); + } + + @Override + public void onCompleted() { + testSubscriber.onCompleted(); + } + + @Override + public void onError(Throwable e) { + testSubscriber.onError(e); + } + + @Override + public void onNext(String s) { + testSubscriber.onNext(s); + requested--; + if(requested == 0) { + requested = 3; + request(3); + } + } + }); + scheduler.advanceTimeBy(10, TimeUnit.MILLISECONDS); + testSubscriber.assertReceivedOnNext(Arrays.asList("a1", "b1", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10")); + testSubscriber.assertNoErrors(); + testSubscriber.assertTerminalEvent(); + } }