Skip to content

Commit

Permalink
Publisher#flatMapConcatIterable may skip emitting items
Browse files Browse the repository at this point in the history
Motivation:
Publisher#flatMapConcatIterable may not emit some items due to race
conditions and visibility issues. The iterator state is written to
outside the scope of holding the lock. After a drain loop completes
we may request 1 more iterator. However it is possible the thread
emitting holds the lock while another thread invokes onNext(t).
The emitting thread may not see the iterator, and instead see
`EmptyIterator.instance()` and cause it to request 1 more item, but
then the not-visible iterator contents won't be emitted.

Modifications:
- Make FlatMapIterableSubscriber iterator state volatile and atomically
update it. There is only ever 1 valid iterator because only 1 outstanding
demand is issued only after the current iterator `!hasNext()`. The iterator
state is re-read on each drain loop, and the terminal condition must atomically
set to EmptyIterator.
  • Loading branch information
Scottmitch committed Nov 15, 2024
1 parent af01383 commit d1ddd48
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Function;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -57,6 +58,9 @@ private static final class FlatMapIterableSubscriber<T, U> implements Subscriber
@SuppressWarnings("rawtypes")
private static final AtomicIntegerFieldUpdater<FlatMapIterableSubscriber> emittingUpdater =
AtomicIntegerFieldUpdater.newUpdater(FlatMapIterableSubscriber.class, "emitting");
@SuppressWarnings("rawtypes")
private static final AtomicReferenceFieldUpdater<FlatMapIterableSubscriber, Iterator> iterUpdater =
AtomicReferenceFieldUpdater.newUpdater(FlatMapIterableSubscriber.class, Iterator.class, "iterator");
private final Function<? super T, ? extends Iterable<? extends U>> mapper;
private final Subscriber<? super U> target;
@Nullable
Expand All @@ -74,7 +78,7 @@ private static final class FlatMapIterableSubscriber<T, U> implements Subscriber
* <p>
* Visibility and thread safety provided by {@link #emitting}.
*/
private Iterator<? extends U> currentIterator = emptyIterator();
private volatile Iterator<? extends U> iterator = emptyIterator();
@SuppressWarnings("unused")
private volatile long requestN;
@SuppressWarnings("unused")
Expand All @@ -98,8 +102,9 @@ public void onSubscribe(Subscription s) {
public void onNext(T u) {
// If Function.apply(...) throws we just propagate it to the caller which is responsible to terminate
// its subscriber and cancel the subscription.
currentIterator = requireNonNull(mapper.apply(u).iterator(),
() -> "Iterator from mapper " + mapper + " is null");
// Safe to assign because we only ever have demand outstanding of 1, so we never
// should concurrently access nextIterator or have multiple iterators being valid at any given time.
iterator = requireNonNull(mapper.apply(u).iterator(), () -> "Iterator from mapper " + mapper + " is null");
tryDrainIterator(ErrorHandlingStrategyInDrain.Throw);
}

Expand Down Expand Up @@ -151,8 +156,9 @@ public void cancel() {

private void doCancel() {
assert sourceSubscription != null;
final Iterator<? extends U> currentIterator = this.currentIterator;
this.currentIterator = EmptyIterator.instance();
@SuppressWarnings("unchecked")
final Iterator<? extends U> currentIterator =
(Iterator<? extends U>) iterUpdater.getAndSet(this, EmptyIterator.instance());
try {
tryClose(currentIterator);
} finally {
Expand Down Expand Up @@ -181,13 +187,14 @@ private void tryDrainIterator(ErrorHandlingStrategyInDrain errorHandlingStrategy
if (!tryAcquireLock(emittingUpdater, this)) {
break;
}
Iterator<? extends U> currIter = iterator;
long currRequestN = this.requestN;
final long initialRequestN = currRequestN;
try {
try {
while ((hasNext = currentIterator.hasNext()) && currRequestN > 0) {
while ((hasNext = currIter.hasNext()) && currRequestN > 0) {
--currRequestN;
target.onNext(currentIterator.next());
target.onNext(currIter.next());
}
} catch (Throwable cause) {
switch (errorHandlingStrategyInDrain) {
Expand All @@ -199,16 +206,16 @@ private void tryDrainIterator(ErrorHandlingStrategyInDrain errorHandlingStrategy
case Propagate:
terminated = true;
safeOnError(target, cause);
tryClose(currentIterator);
tryClose(currIter);
return; // hard return to avoid potential for duplicate terminal events
case Throw:
// since we only request 1 at a time we maybe holding requestN demand, in this case we
// discard the current iterator and request 1 more from upstream (if there is demand).
hasNext = false;
thrown = true;
final Iterator<? extends U> currentIterator = this.currentIterator;
this.currentIterator = EmptyIterator.instance();
tryClose(currentIterator);
currIter = EmptyIterator.instance();
iterator = currIter;
tryClose(currIter);
// let the exception propagate so the upstream source can do the cleanup.
throw cause;
default:
Expand All @@ -234,15 +241,26 @@ private void tryDrainIterator(ErrorHandlingStrategyInDrain errorHandlingStrategy
// We have been cancelled while we held the lock, do the cancel operation.
doCancel();
}
} else if (terminalNotification == null && !hasNext && currRequestN > 0 &&
(currentIterator != EmptyIterator.instance() || thrown)) {
// We only request 1 at a time, and therefore we don't have any outstanding demand, so
// we will not be getting an onNext call, so we write to the currentIterator variable
// here before we unlock emitting so visibility to other threads should be taken care of
// by the write to emitting below (and later read).
currentIterator = EmptyIterator.instance();
if (sourceSubscription != null) {
sourceSubscription.request(1);
} else if (terminalNotification == null && !hasNext) {
for (;;) {
final Iterator<? extends U> nextIter = iterator;
if (nextIter == currIter && currRequestN > 0 &&
(currIter != EmptyIterator.instance() || thrown)) {
// We only request 1 at a time, and therefore we don't have outstanding demand.
// We will not be getting an onNext call, so we write the currIter variable
// before we unlock emitting so visibility to other threads is taken care of
// by the write to emitting below (and later read).
if (iterUpdater.compareAndSet(this, currIter, EmptyIterator.instance())) {
if (sourceSubscription != null) {
sourceSubscription.request(1);
}
break;
}
} else {
// if nextIter != currIter -> outer loop will re-read "iterator" state and
// attempt to drain from it.
break;
}
}
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.servicetalk.concurrent.api;

import io.servicetalk.concurrent.BlockingIterable;
import io.servicetalk.concurrent.BlockingIterator;
import io.servicetalk.concurrent.PublisherSource.Processor;
import io.servicetalk.concurrent.PublisherSource.Subscriber;
import io.servicetalk.concurrent.PublisherSource.Subscription;
Expand All @@ -26,6 +27,7 @@
import org.junit.jupiter.api.extension.RegisterExtension;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
Expand All @@ -38,6 +40,7 @@
import static io.servicetalk.concurrent.api.Processors.newPublisherProcessor;
import static io.servicetalk.concurrent.api.Publisher.failed;
import static io.servicetalk.concurrent.api.Publisher.from;
import static io.servicetalk.concurrent.api.Publisher.fromIterable;
import static io.servicetalk.concurrent.api.SourceAdapters.fromSource;
import static io.servicetalk.concurrent.api.SourceAdapters.toSource;
import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION;
Expand All @@ -46,8 +49,10 @@
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.function.Function.identity;
import static java.util.stream.IntStream.range;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.sameInstance;
Expand Down Expand Up @@ -422,6 +427,20 @@ void exceptionFromSubscriptionRequestNIsPropagated() {
assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION));
}

@Test
void testFlatMapConcatIterable() throws Exception {
try (BlockingIterator<Integer> iterable = fromIterable(() -> range(0, 10000).iterator())
.publishOn(Executors.global())
.flatMapConcatIterable(Collections::singletonList)
.toIterable()
.iterator()) {
int expected = 0;
while (iterable.hasNext()) {
assertThat(iterable.next(), equalTo(expected++));
}
}
}

private void verifyTermination(boolean success) {
if (success) {
publisher.onComplete();
Expand Down

0 comments on commit d1ddd48

Please sign in to comment.