diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherConcatMapIterable.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherConcatMapIterable.java index 61a7a79a63..3ab6612f7e 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherConcatMapIterable.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherConcatMapIterable.java @@ -23,6 +23,7 @@ import java.util.NoSuchElementException; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.Function; import javax.annotation.Nullable; @@ -57,6 +58,9 @@ private static final class FlatMapIterableSubscriber implements Subscriber @SuppressWarnings("rawtypes") private static final AtomicIntegerFieldUpdater emittingUpdater = AtomicIntegerFieldUpdater.newUpdater(FlatMapIterableSubscriber.class, "emitting"); + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater iterUpdater = + AtomicReferenceFieldUpdater.newUpdater(FlatMapIterableSubscriber.class, Iterator.class, "iterator"); private final Function> mapper; private final Subscriber target; @Nullable @@ -74,7 +78,7 @@ private static final class FlatMapIterableSubscriber implements Subscriber *

* Visibility and thread safety provided by {@link #emitting}. */ - private Iterator currentIterator = emptyIterator(); + private volatile Iterator iterator = emptyIterator(); @SuppressWarnings("unused") private volatile long requestN; @SuppressWarnings("unused") @@ -98,8 +102,9 @@ public void onSubscribe(Subscription s) { public void onNext(T u) { // If Function.apply(...) throws we just propagate it to the caller which is responsible to terminate // its subscriber and cancel the subscription. - currentIterator = requireNonNull(mapper.apply(u).iterator(), - () -> "Iterator from mapper " + mapper + " is null"); + // Safe to assign because we only ever have demand outstanding of 1, so we never + // should concurrently access nextIterator or have multiple iterators being valid at any given time. + iterator = requireNonNull(mapper.apply(u).iterator(), () -> "Iterator from mapper " + mapper + " is null"); tryDrainIterator(ErrorHandlingStrategyInDrain.Throw); } @@ -151,8 +156,9 @@ public void cancel() { private void doCancel() { assert sourceSubscription != null; - final Iterator currentIterator = this.currentIterator; - this.currentIterator = EmptyIterator.instance(); + @SuppressWarnings("unchecked") + final Iterator currentIterator = + (Iterator) iterUpdater.getAndSet(this, EmptyIterator.instance()); try { tryClose(currentIterator); } finally { @@ -181,13 +187,14 @@ private void tryDrainIterator(ErrorHandlingStrategyInDrain errorHandlingStrategy if (!tryAcquireLock(emittingUpdater, this)) { break; } + Iterator currIter = iterator; long currRequestN = this.requestN; final long initialRequestN = currRequestN; try { try { - while ((hasNext = currentIterator.hasNext()) && currRequestN > 0) { + while ((hasNext = currIter.hasNext()) && currRequestN > 0) { --currRequestN; - target.onNext(currentIterator.next()); + target.onNext(currIter.next()); } } catch (Throwable cause) { switch (errorHandlingStrategyInDrain) { @@ -199,16 +206,15 @@ private void tryDrainIterator(ErrorHandlingStrategyInDrain errorHandlingStrategy case Propagate: terminated = true; safeOnError(target, cause); - tryClose(currentIterator); + tryClose(currIter); return; // hard return to avoid potential for duplicate terminal events case Throw: // since we only request 1 at a time we maybe holding requestN demand, in this case we // discard the current iterator and request 1 more from upstream (if there is demand). hasNext = false; thrown = true; - final Iterator currentIterator = this.currentIterator; - this.currentIterator = EmptyIterator.instance(); - tryClose(currentIterator); + tryClose(currIter); + iterator = currIter = EmptyIterator.instance(); // let the exception propagate so the upstream source can do the cleanup. throw cause; default: @@ -235,15 +241,16 @@ private void tryDrainIterator(ErrorHandlingStrategyInDrain errorHandlingStrategy doCancel(); } } else if (terminalNotification == null && !hasNext && currRequestN > 0 && - (currentIterator != EmptyIterator.instance() || thrown)) { - // We only request 1 at a time, and therefore we don't have any outstanding demand, so - // we will not be getting an onNext call, so we write to the currentIterator variable - // here before we unlock emitting so visibility to other threads should be taken care of - // by the write to emitting below (and later read). - currentIterator = EmptyIterator.instance(); - if (sourceSubscription != null) { - sourceSubscription.request(1); - } + (currIter != EmptyIterator.instance() || thrown) && + // We only request 1 at a time, and therefore we don't have outstanding demand. + // We will not be getting an onNext call concurrently, but the onNext(..) call may + // be on a different thread outside the emitting lock. For this reason we do a CAS + // to ensure the currIter read at the beginning of the outer loop is still the + // current iterator. If the CAS fails the outer loop will re-read iterator and try + // to emit if items are present and demand allows it. + iterUpdater.compareAndSet(this, currIter, EmptyIterator.instance()) && + sourceSubscription != null) { + sourceSubscription.request(1); } } finally { // The lock must be released after we interact with the subscription for thread safety diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherConcatMapIterableTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherConcatMapIterableTest.java index f5f5386a39..2d3186e149 100644 --- a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherConcatMapIterableTest.java +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherConcatMapIterableTest.java @@ -16,6 +16,7 @@ package io.servicetalk.concurrent.api; import io.servicetalk.concurrent.BlockingIterable; +import io.servicetalk.concurrent.BlockingIterator; import io.servicetalk.concurrent.PublisherSource.Processor; import io.servicetalk.concurrent.PublisherSource.Subscriber; import io.servicetalk.concurrent.PublisherSource.Subscription; @@ -26,6 +27,7 @@ import org.junit.jupiter.api.extension.RegisterExtension; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; @@ -38,6 +40,7 @@ import static io.servicetalk.concurrent.api.Processors.newPublisherProcessor; import static io.servicetalk.concurrent.api.Publisher.failed; import static io.servicetalk.concurrent.api.Publisher.from; +import static io.servicetalk.concurrent.api.Publisher.fromIterable; import static io.servicetalk.concurrent.api.SourceAdapters.fromSource; import static io.servicetalk.concurrent.api.SourceAdapters.toSource; import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; @@ -46,8 +49,10 @@ import static java.util.Collections.singletonList; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.function.Function.identity; +import static java.util.stream.IntStream.range; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.sameInstance; @@ -422,6 +427,20 @@ void exceptionFromSubscriptionRequestNIsPropagated() { assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION)); } + @Test + void concurrencyEmitsInOrder() throws Exception { + try (BlockingIterator iterable = fromIterable(() -> range(0, 10_000).iterator()) + .publishOn(Executors.global()) + .flatMapConcatIterable(Collections::singletonList) + .toIterable() + .iterator()) { + int expected = 0; + while (iterable.hasNext()) { + assertThat(iterable.next(), equalTo(expected++)); + } + } + } + private void verifyTermination(boolean success) { if (success) { publisher.onComplete();