diff --git a/src/main/java/rx/internal/operators/OperatorGroupBy.java b/src/main/java/rx/internal/operators/OperatorGroupBy.java index 39cfae4457..722a047884 100644 --- a/src/main/java/rx/internal/operators/OperatorGroupBy.java +++ b/src/main/java/rx/internal/operators/OperatorGroupBy.java @@ -15,10 +15,12 @@ */ package rx.internal.operators; +import java.util.Map; import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLongFieldUpdater; @@ -34,6 +36,7 @@ import rx.functions.Func1; import rx.observables.GroupedObservable; import rx.subjects.Subject; +import rx.subscriptions.Subscriptions; /** * Groups the items emitted by an Observable according to a specified criterion, and emits these @@ -76,6 +79,13 @@ static final class GroupBySubscriber extends Subscriber { final Func1 elementSelector; final Subscriber> child; + // We should not call `unsubscribe()` until `groups.isEmpty() && child.isUnsubscribed()` is true. + // Use `WIP_FOR_UNSUBSCRIBE_UPDATER` to monitor these statuses and call `unsubscribe()` properly. + // Should check both when `child.unsubscribe` is called and any group is removed. + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP_FOR_UNSUBSCRIBE_UPDATER = AtomicIntegerFieldUpdater.newUpdater(GroupBySubscriber.class, "wipForUnsubscribe"); + volatile int wipForUnsubscribe = 1; + public GroupBySubscriber( Func1 keySelector, Func1 elementSelector, @@ -84,6 +94,16 @@ public GroupBySubscriber( this.keySelector = keySelector; this.elementSelector = elementSelector; this.child = child; + child.add(Subscriptions.create(new Action0() { + + @Override + public void call() { + if (WIP_FOR_UNSUBSCRIBE_UPDATER.decrementAndGet(self) == 0) { + self.unsubscribe(); + } + } + + })); } private static class GroupState { @@ -107,7 +127,13 @@ public Observer getObserver() { private static final NotificationLite nl = NotificationLite.instance(); volatile int completionEmitted; - volatile int terminated; + + private static final int UNTERMINATED = 0; + private static final int TERMINATED_WITH_COMPLETED = 1; + private static final int TERMINATED_WITH_ERROR = 2; + + // Must be one of `UNTERMINATED`, `TERMINATED_WITH_COMPLETED`, `TERMINATED_WITH_ERROR` + volatile int terminated = UNTERMINATED; @SuppressWarnings("rawtypes") static final AtomicIntegerFieldUpdater COMPLETION_EMITTED_UPDATER = AtomicIntegerFieldUpdater.newUpdater(GroupBySubscriber.class, "completionEmitted"); @@ -130,7 +156,7 @@ public void onStart() { @Override public void onCompleted() { - if (TERMINATED_UPDATER.compareAndSet(this, 0, 1)) { + if (TERMINATED_UPDATER.compareAndSet(this, UNTERMINATED, TERMINATED_WITH_COMPLETED)) { // if we receive onCompleted from our parent we onComplete children // for each group check if it is ready to accept more events if so pass the oncomplete through else buffer it. for (GroupState group : groups.values()) { @@ -138,7 +164,7 @@ public void onCompleted() { } // special case (no groups emitted ... or all unsubscribed) - if (groups.size() == 0) { + if (groups.isEmpty()) { // we must track 'completionEmitted' seperately from 'completed' since `completeInner` can result in childObserver.onCompleted() being emitted if (COMPLETION_EMITTED_UPDATER.compareAndSet(this, 0, 1)) { child.onCompleted(); @@ -149,9 +175,19 @@ public void onCompleted() { @Override public void onError(Throwable e) { - if (TERMINATED_UPDATER.compareAndSet(this, 0, 1)) { - // we immediately tear everything down if we receive an error - child.onError(e); + if (TERMINATED_UPDATER.compareAndSet(this, UNTERMINATED, TERMINATED_WITH_ERROR)) { + // It's safe to access all groups and emit the error. + // onNext and onError are in sequence so no group will be created in the loop. + for (GroupState group : groups.values()) { + emitItem(group, nl.error(e)); + } + try { + // we immediately tear everything down if we receive an error + child.onError(e); + } finally { + // We have not chained the subscribers, so need to call it explicitly. + unsubscribe(); + } } } @@ -187,7 +223,9 @@ public void onNext(T t) { } group = createNewGroup(key); } - emitItem(group, nl.next(t)); + if (group != null) { + emitItem(group, nl.next(t)); + } } catch (Throwable e) { onError(OnErrorThrowable.addValueAsLastCause(e, t)); } @@ -236,6 +274,11 @@ public void onCompleted() { @Override public void onError(Throwable e) { o.onError(e); + // eagerly cleanup instead of waiting for unsubscribe + if (once.compareAndSet(false, true)) { + // done once per instance, either onComplete or onUnSubscribe + cleanupGroup(key); + } } @Override @@ -250,7 +293,17 @@ public void onNext(T t) { } }); - GroupState putIfAbsent = groups.putIfAbsent(key, groupState); + GroupState putIfAbsent; + for (;;) { + int wip = wipForUnsubscribe; + if (wip <= 0) { + return null; + } + if (WIP_FOR_UNSUBSCRIBE_UPDATER.compareAndSet(this, wip, wip + 1)) { + putIfAbsent = groups.putIfAbsent(key, groupState); + break; + } + } if (putIfAbsent != null) { // this shouldn't happen (because we receive onNext sequentially) and would mean we have a bug throw new IllegalStateException("Group already existed while creating a new one"); @@ -264,7 +317,7 @@ private void cleanupGroup(Object key) { GroupState removed; removed = groups.remove(key); if (removed != null) { - if (removed.buffer.size() > 0) { + if (!removed.buffer.isEmpty()) { BUFFERED_COUNT.addAndGet(self, -removed.buffer.size()); } completeInner(); @@ -342,15 +395,14 @@ private void drainIfPossible(GroupState groupState) { } private void completeInner() { - // if we have no outstanding groups (all completed or unsubscribe) and terminated/unsubscribed on outer - if (groups.size() == 0 && (terminated == 1 || child.isUnsubscribed())) { + // A group is removed, so check if we need to call `unsubscribe` + if (WIP_FOR_UNSUBSCRIBE_UPDATER.decrementAndGet(this) == 0) { + // It means `groups.isEmpty() && child.isUnsubscribed()` is true + unsubscribe(); + } else if (groups.isEmpty() && terminated == TERMINATED_WITH_COMPLETED) { + // if we have no outstanding groups (all completed or unsubscribe) and terminated on outer // completionEmitted ensures we only emit onCompleted once if (COMPLETION_EMITTED_UPDATER.compareAndSet(this, 0, 1)) { - - if (child.isUnsubscribed()) { - // if the entire groupBy has been unsubscribed and children are completed we will propagate the unsubscribe up. - unsubscribe(); - } child.onCompleted(); } } diff --git a/src/test/java/rx/internal/operators/OperatorGroupByTest.java b/src/test/java/rx/internal/operators/OperatorGroupByTest.java index 9bbed5d04d..42023508d3 100644 --- a/src/test/java/rx/internal/operators/OperatorGroupByTest.java +++ b/src/test/java/rx/internal/operators/OperatorGroupByTest.java @@ -47,6 +47,7 @@ import rx.Observable.OnSubscribe; import rx.Observer; import rx.Subscriber; +import rx.Subscription; import rx.exceptions.TestException; import rx.functions.Action0; import rx.functions.Action1; @@ -1385,4 +1386,72 @@ public void call(String s) { assertEquals(null, key[0]); assertEquals(Arrays.asList("a", "b", "c"), values); } + + @Test + public void testGroupByUnsubscribe() { + final Subscription s = mock(Subscription.class); + Observable o = Observable.create( + new OnSubscribe() { + @Override + public void call(Subscriber subscriber) { + subscriber.add(s); + } + } + ); + o.groupBy(new Func1() { + + @Override + public Integer call(Integer integer) { + return null; + } + }).subscribe().unsubscribe(); + verify(s).unsubscribe(); + } + + @Test + public void testGroupByShouldPropagateError() { + final Throwable e = new RuntimeException("Oops"); + final TestSubscriber inner1 = new TestSubscriber(); + final TestSubscriber inner2 = new TestSubscriber(); + + final TestSubscriber> outer + = new TestSubscriber>(new Subscriber>() { + + @Override + public void onCompleted() { + } + + @Override + public void onError(Throwable e) { + } + + @Override + public void onNext(GroupedObservable o) { + if (o.getKey() == 0) { + o.subscribe(inner1); + } else { + o.subscribe(inner2); + } + } + }); + Observable.create( + new OnSubscribe() { + @Override + public void call(Subscriber subscriber) { + subscriber.onNext(0); + subscriber.onNext(1); + subscriber.onError(e); + } + } + ).groupBy(new Func1() { + + @Override + public Integer call(Integer i) { + return i % 2; + } + }).subscribe(outer); + assertEquals(Arrays.asList(e), outer.getOnErrorEvents()); + assertEquals(Arrays.asList(e), inner1.getOnErrorEvents()); + assertEquals(Arrays.asList(e), inner2.getOnErrorEvents()); + } }