From 8d3d6028a84cffe8662c65627ad2c1f2e8f4fd77 Mon Sep 17 00:00:00 2001 From: Scott Mitchell Date: Fri, 25 Aug 2023 15:56:29 -0700 Subject: [PATCH] Add Publisher.replay Motivation: Publisher.replay provides the ability to keep state that is preserved for multiple subscribers and across resubscribes. --- .../concurrent/api/MulticastPublisher.java | 253 ++++++++++---- .../servicetalk/concurrent/api/Publisher.java | 80 ++++- .../concurrent/api/ReplayAccumulator.java | 48 +++ .../concurrent/api/ReplayPublisher.java | 210 ++++++++++++ .../concurrent/api/ReplayStrategies.java | 244 ++++++++++++++ .../concurrent/api/ReplayStrategy.java | 67 ++++ .../concurrent/api/ReplayStrategyBuilder.java | 155 +++++++++ .../api/MulticastPublisherTest.java | 98 +++--- .../concurrent/api/ReplayPublisherTest.java | 309 ++++++++++++++++++ .../tck/PublisherReplayTckTest.java | 28 ++ 10 files changed, 1380 insertions(+), 112 deletions(-) create mode 100644 servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayAccumulator.java create mode 100644 servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayPublisher.java create mode 100644 servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategies.java create mode 100644 servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategy.java create mode 100644 servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategyBuilder.java create mode 100644 servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ReplayPublisherTest.java create mode 100644 servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherReplayTckTest.java diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/MulticastPublisher.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/MulticastPublisher.java index bde8cbd065..65476d16f5 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/MulticastPublisher.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/MulticastPublisher.java @@ -19,6 +19,7 @@ import io.servicetalk.concurrent.internal.ArrayUtils; import io.servicetalk.concurrent.internal.DelayedSubscription; import io.servicetalk.concurrent.internal.RejectedSubscribeException; +import io.servicetalk.concurrent.internal.TerminalNotification; import io.servicetalk.context.api.ContextMap; import org.slf4j.Logger; @@ -32,6 +33,7 @@ import java.util.function.Function; import javax.annotation.Nullable; +import static io.servicetalk.concurrent.api.Completable.completed; import static io.servicetalk.concurrent.api.PublishAndSubscribeOnPublishers.deliverOnSubscribeAndOnError; import static io.servicetalk.concurrent.internal.ConcurrentUtils.releaseLock; import static io.servicetalk.concurrent.internal.ConcurrentUtils.tryAcquireLock; @@ -47,8 +49,10 @@ import static java.util.Objects.requireNonNull; import static java.util.concurrent.atomic.AtomicReferenceFieldUpdater.newUpdater; -final class MulticastPublisher extends AbstractNoHandleSubscribePublisher { +class MulticastPublisher extends AbstractNoHandleSubscribePublisher { private static final Logger LOGGER = LoggerFactory.getLogger(MulticastPublisher.class); + static final int DEFAULT_MULTICAST_QUEUE_LIMIT = 64; + static final Function DEFAULT_MULTICAST_TERM_RESUB = t -> completed(); private static final Subscriber[] EMPTY_SUBSCRIBERS = new Subscriber[0]; @SuppressWarnings("rawtypes") private static final AtomicReferenceFieldUpdater @@ -61,7 +65,8 @@ final class MulticastPublisher extends AbstractNoHandleSubscribePublisher private final int minSubscribers; private final boolean exactlyMinSubscribers; private final boolean cancelUpstream; - private volatile State state; + @Nullable + volatile State state; MulticastPublisher(Publisher original, int minSubscribers, boolean exactlyMinSubscribers, boolean cancelUpstream, int maxQueueSize, Function terminalResubscribe) { @@ -76,16 +81,30 @@ final class MulticastPublisher extends AbstractNoHandleSubscribePublisher this.exactlyMinSubscribers = exactlyMinSubscribers; this.cancelUpstream = cancelUpstream; this.terminalResubscribe = requireNonNull(terminalResubscribe); - state = new State(maxQueueSize, minSubscribers); + } + + static MulticastPublisher newMulticastPublisher( + Publisher original, int minSubscribers, boolean exactlyMinSubscribers, boolean cancelUpstream, + int maxQueueSize, Function terminalResubscribe) { + MulticastPublisher publisher = new MulticastPublisher<>(original, minSubscribers, exactlyMinSubscribers, + cancelUpstream, minSubscribers, terminalResubscribe); + publisher.resetState(maxQueueSize, minSubscribers); + return publisher; } @Override - void handleSubscribe(Subscriber subscriber, ContextMap contextMap, + final void handleSubscribe(Subscriber subscriber, ContextMap contextMap, AsyncContextProvider contextProvider) { - state.addSubscriber(subscriber, contextMap, contextProvider); + final State cState = state; + assert cState != null; + cState.addNewSubscriber(subscriber, contextMap, contextProvider); } - private final class State extends MulticastRootSubscriber> implements Subscriber { + void resetState(int maxQueueSize, int minSubscribers) { + state = new State(maxQueueSize, minSubscribers); + } + + class State extends MulticastRootSubscriber> implements Subscriber { private final DefaultPriorityQueue> demandQueue; volatile int subscribeCount; @SuppressWarnings("unchecked") @@ -96,8 +115,35 @@ private final class State extends MulticastRootSubscriber(comparingLong(sub -> sub.priorityQueueValue), minSubscribers); } + @Nullable + @Override + final TerminalSubscriber addSubscriber(final MulticastFixedSubscriber subscriber, + @Nullable ContextMap contextMap, + AsyncContextProvider contextProvider) { + for (;;) { + final Subscriber[] currSubs = subscribers; + if (currSubs.length == 1 && currSubs[0] instanceof TerminalSubscriber) { + return (TerminalSubscriber) currSubs[0]; + } else { + @SuppressWarnings("unchecked") + Subscriber[] newSubs = (Subscriber[]) + Array.newInstance(Subscriber.class, currSubs.length + 1); + System.arraycopy(currSubs, 0, newSubs, 0, currSubs.length); + newSubs[currSubs.length] = subscriber; + if (newSubscribersUpdater.compareAndSet(this, currSubs, newSubs)) { + if (contextMap != null) { + // This operator has special behavior where it chooses to use the AsyncContext and signal + // offloader from the last subscribe operation. + original.delegateSubscribe(this, contextMap, contextProvider); + } + return null; + } + } + } + } + @Override - boolean removeSubscriber(final MulticastFixedSubscriber subscriber) { + final boolean removeSubscriber(final MulticastFixedSubscriber subscriber) { for (;;) { final Subscriber[] currSubs = subscribers; @SuppressWarnings("deprecation") @@ -118,8 +164,13 @@ boolean removeSubscriber(final MulticastFixedSubscriber subscriber) { if (newSubscribersUpdater.compareAndSet(this, currSubs, newSubs)) { if (cancelUpstream && newSubs.length == 0) { // Reset the state when all subscribers have cancelled to allow for re-subscribe. - state = new State(maxQueueSize, minSubscribers); - return true; + try { + resetState(maxQueueSize, minSubscribers); + return true; + } catch (Throwable cause) { + LOGGER.warn("unexpected exception creating new state {}", MulticastPublisher.this, + cause); + } } return false; } @@ -127,7 +178,7 @@ boolean removeSubscriber(final MulticastFixedSubscriber subscriber) { } @Override - long processRequestEvent(MulticastFixedSubscriber subscriber, final long n) { + final long processRequestEvent(MulticastFixedSubscriber subscriber, final long n) { assert n > 0; final MulticastFixedSubscriber oldMin = demandQueue.peek(); final long oldValue = subscriber.priorityQueueValue; @@ -141,7 +192,7 @@ long processRequestEvent(MulticastFixedSubscriber subscriber, final long n) { } @Override - long processCancelEvent(MulticastFixedSubscriber subscriber) { + final long processCancelEvent(MulticastFixedSubscriber subscriber) { MulticastFixedSubscriber min = demandQueue.peek(); if (!demandQueue.removeTyped(subscriber)) { return -1; @@ -159,7 +210,13 @@ long processCancelEvent(MulticastFixedSubscriber subscriber) { } @Override - void processSubscribeEvent(MulticastFixedSubscriber subscriber) { + boolean processSubscribeEvent(MulticastFixedSubscriber subscriber, + @Nullable TerminalSubscriber terminalSubscriber) { + if (terminalSubscriber != null) { + // Directly terminate the underlying subscriber to avoid queuing and MulticastFixedSubscriber rules. + terminalSubscriber.terminate(subscriber.subscriber); + return false; + } // Initialize the new subscriber's priorityQueueValue to the current minimum demand value to keep // outstanding demand bounded to maxQueueSize. final MulticastFixedSubscriber currMin = demandQueue.peek(); @@ -167,10 +224,21 @@ void processSubscribeEvent(MulticastFixedSubscriber subscriber) { subscriber.priorityQueueValue = subscriber.initPriorityQueueValue = currMin.priorityQueueValue; } demandQueue.add(subscriber); + return true; + } + + @Override + void processTerminal(final TerminalNotification terminalNotification) { + throw new UnsupportedOperationException("terminal queuing not supported. terminal=" + terminalNotification); } - void addSubscriber(Subscriber subscriber, ContextMap contextMap, - AsyncContextProvider contextProvider) { + @Override + void processOnNextEvent(final Object wrapped) { + throw new UnsupportedOperationException("onNext queuing not supported. wrapped=" + wrapped); + } + + final void addNewSubscriber(Subscriber subscriber, ContextMap contextMap, + AsyncContextProvider contextProvider) { final int sCount = subscribeCountUpdater.incrementAndGet(this); if (exactlyMinSubscribers && sCount > minSubscribers) { deliverOnSubscribeAndOnError(subscriber, contextMap, contextProvider, @@ -179,32 +247,26 @@ void addSubscriber(Subscriber subscriber, ContextMap contextMap, } MulticastFixedSubscriber multiSubscriber = new MulticastFixedSubscriber<>(this, subscriber, contextMap, contextProvider, sCount); - for (;;) { - final Subscriber[] currSubs = subscribers; - if (currSubs.length == 1 && currSubs[0] instanceof TerminalSubscriber) { - ((TerminalSubscriber) currSubs[0]).terminate(subscriber); - break; - } else { - @SuppressWarnings("unchecked") - Subscriber[] newSubs = (Subscriber[]) - Array.newInstance(Subscriber.class, currSubs.length + 1); - System.arraycopy(currSubs, 0, newSubs, 0, currSubs.length); - newSubs[currSubs.length] = multiSubscriber; - if (newSubscribersUpdater.compareAndSet(this, currSubs, newSubs)) { - addSubscriber(multiSubscriber); - if (sCount == minSubscribers) { - // This operator has special behavior where it chooses to use the AsyncContext and signal - // offloader from the last subscribe operation. - original.delegateSubscribe(this, contextMap, contextProvider); - } - break; + if (tryAcquireLock(subscriptionLockUpdater, this)) { + try { + // This operator has special behavior where it chooses to use the AsyncContext and signal + // offloader from the last subscribe operation. + processSubscribeEventInternal(multiSubscriber, sCount == minSubscribers ? contextMap : null, + contextProvider); + } finally { + if (!releaseLock(subscriptionLockUpdater, this)) { + processSubscriptionEvents(); } } + } else { + subscriptionEvents.add(new SubscribeEvent<>(multiSubscriber, + sCount == minSubscribers ? contextMap : null, contextProvider)); + processSubscriptionEvents(); } } @Override - public void onSubscribe(final Subscription subscription) { + public final void onSubscribe(final Subscription subscription) { onSubscribe0(subscription); } @@ -232,8 +294,14 @@ public void onComplete() { } private void onTerminal(@Nullable Throwable t, BiConsumer, Throwable> terminator) { - safeTerminalStateReset(t).whenFinally(() -> - state = new State(maxQueueSize, minSubscribers)).subscribe(); + safeTerminalStateReset(t).whenFinally(() -> { + try { + resetState(maxQueueSize, minSubscribers); + } catch (Throwable cause) { + LOGGER.warn("unexpected exception from terminal resubscribe Completable {}", + MulticastPublisher.this, cause); + } + }).subscribe(); @SuppressWarnings("unchecked") final Subscriber[] newSubs = (Subscriber[]) Array.newInstance(Subscriber.class, 1); @@ -269,10 +337,10 @@ private Completable safeTerminalStateReset(@Nullable Throwable t) { private abstract static class MulticastRootSubscriber> { @SuppressWarnings("rawtypes") - private static final AtomicIntegerFieldUpdater subscriptionLockUpdater = + static final AtomicIntegerFieldUpdater subscriptionLockUpdater = AtomicIntegerFieldUpdater.newUpdater(MulticastRootSubscriber.class, "subscriptionLock"); private final DelayedSubscription delayedSubscription = new DelayedSubscription(); - private final Queue subscriptionEvents = newUnboundedMpscQueue(8); + final Queue subscriptionEvents = newUnboundedMpscQueue(8); final int maxQueueSize; @SuppressWarnings("unused") @@ -282,6 +350,20 @@ private abstract static class MulticastRootSubscriber + * Invocation while {@link #subscriptionLock} is held. + * @param subscriber The {@link Subscriber} to remove. + * @param contextMap The context map to used when subscribing upstream, or {@code null} if should not subscribe. + * @param contextProvider The context provider to used when subscribing upstream. + * @return {@code null} if {@code subscriber} was added to the list, or non-{@code null} if not added to the + * because there was previously a terminal event. + */ + @Nullable + abstract TerminalSubscriber addSubscriber(T subscriber, @Nullable ContextMap contextMap, + AsyncContextProvider contextProvider); + /** * Remove a {@link Subscriber} from the underlying collection, and stop delivering signals to it. *

@@ -318,30 +400,46 @@ private abstract static class MulticastRootSubscriber * Invocation while {@link #subscriptionLock} is held. - * @param subscriber The subscriber which was passed to {@link #addSubscriber(MulticastLeafSubscriber)}. + * @param subscriber The subscriber which was passed to + * @param terminalSubscriber {@code null} if the {@code subscriber} was added to the list or non-{@code null} + * if a terminal event has occurred, and this method MUST eventually deliver the terminal signal to + * {@code subscriber}. + * {@link #addSubscriber(MulticastLeafSubscriber, ContextMap, AsyncContextProvider)}. + * @return {@code false} to stop handling this processor and break out early (e.g. can happen if + * {@code terminalSubscriber} is non-{@code null} no signals are queued and the terminal is delivered). + * {@code true} to unblock the {@code subscriber}'s signal queue and keep processing events. */ - abstract void processSubscribeEvent(T subscriber); + abstract boolean processSubscribeEvent(T subscriber, @Nullable TerminalSubscriber terminalSubscriber); - final void onSubscribe0(final Subscription subscription) { - delayedSubscription.delayedSubscription(subscription); + /** + * Invoked if a terminal signal for {@link State} couldn't be delivered inline. + *

+ * Invocation while {@link #subscriptionLock} is held. + * @param terminalNotification The terminal signal that is queued. + */ + abstract void processTerminal(TerminalNotification terminalNotification); + + /** + * Invoked if an {@link State#onNext(Object)} couldn't be delivered inline. + *

+ * Invocation while {@link #subscriptionLock} is held. + * @param wrapped The signal that is queued. + */ + abstract void processOnNextEvent(Object wrapped); + + /** + * Callback indicating upstream {@link Subscription} has been cancelled. + *

+ * Invocation while {@link #subscriptionLock} is held. + */ + void upstreamCancelled() { } - final void addSubscriber(T subscriber) { - if (tryAcquireLock(subscriptionLockUpdater, this)) { - try { - processSubscribeEventInternal(subscriber); - } finally { - if (!releaseLock(subscriptionLockUpdater, this)) { - processSubscriptionEvents(); - } - } - } else { - subscriptionEvents.add(new SubscribeEvent<>(subscriber)); - processSubscriptionEvents(); - } + final void onSubscribe0(final Subscription subscription) { + delayedSubscription.delayedSubscription(subscription); } final void request(T subscriber, long n) { @@ -387,7 +485,7 @@ private void requestUpstream(long n) { delayedSubscription.request(n); } - private void processSubscriptionEvents() { + final void processSubscriptionEvents() { boolean tryAcquire = true; Throwable delayedCause = null; while (tryAcquire && tryAcquireLock(subscriptionLockUpdater, this)) { @@ -406,11 +504,15 @@ private void processSubscriptionEvents() { } else if (event instanceof SubscribeEvent) { @SuppressWarnings("unchecked") final SubscribeEvent sEvent = (SubscribeEvent) event; - processSubscribeEventInternal(sEvent.subscriber); - } else { + processSubscribeEventInternal(sEvent.subscriber, sEvent.contextMap, sEvent.contextProvider); + } else if (event instanceof CancelEvent) { @SuppressWarnings("unchecked") final CancelEvent cEvent = (CancelEvent) event; processCancelEventInternal(cEvent.subscriber, cEvent.cancelUpstream); + } else if (event instanceof TerminalNotification) { + processTerminal((TerminalNotification) event); + } else { + processOnNextEvent(event); } } if (toRequest != 0) { @@ -431,16 +533,23 @@ private void processCancelEventInternal(T subscriber, boolean cancelUpstream) { final long result = processCancelEvent(subscriber); if (result >= 0) { if (cancelUpstream) { - delayedSubscription.cancel(); + try { + delayedSubscription.cancel(); + } finally { + upstreamCancelled(); + } } else if (result > 0) { requestUpstream(result); } } } - private void processSubscribeEventInternal(T subscriber) { - processSubscribeEvent(subscriber); - + void processSubscribeEventInternal(T subscriber, @Nullable ContextMap contextMap, + AsyncContextProvider contextProvider) { + TerminalSubscriber terminalSubscriber = addSubscriber(subscriber, contextMap, contextProvider); + if (!processSubscribeEvent(subscriber, terminalSubscriber)) { + return; + } try { // Note we invoke onSubscribe AFTER the subscribers array and demandQueue state is set // because the subscription methods depend upon this state. This may result in onNext(), @@ -456,11 +565,17 @@ private void processSubscribeEventInternal(T subscriber) { } } - private static final class SubscribeEvent> { + static final class SubscribeEvent> { private final T subscriber; + @Nullable + private final ContextMap contextMap; + private final AsyncContextProvider contextProvider; - private SubscribeEvent(final T subscriber) { + SubscribeEvent(final T subscriber, @Nullable final ContextMap contextMap, + final AsyncContextProvider contextProvider) { this.subscriber = subscriber; + this.contextMap = contextMap; + this.contextProvider = contextProvider; } } @@ -486,7 +601,7 @@ private CancelEvent(final T subscriber, final boolean cancelUpstream) { } } - private static final class MulticastFixedSubscriber extends MulticastLeafSubscriber implements Node { + static final class MulticastFixedSubscriber extends MulticastLeafSubscriber implements Node { private final int index; private final MulticastPublisher.State root; private final Subscriber subscriber; @@ -554,15 +669,15 @@ public String toString() { } } - private static final class TerminalSubscriber implements Subscriber { + static final class TerminalSubscriber implements Subscriber { @Nullable - private final Throwable terminalError; + final Throwable terminalError; private TerminalSubscriber(@Nullable final Throwable terminalError) { this.terminalError = terminalError; } - private void terminate(Subscriber sub) { + void terminate(Subscriber sub) { if (terminalError == null) { deliverCompleteFromSource(sub); } else { diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Publisher.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Publisher.java index 129b7ac380..6c2b511720 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Publisher.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Publisher.java @@ -56,6 +56,9 @@ import static io.servicetalk.concurrent.api.EmptyPublisher.emptyPublisher; import static io.servicetalk.concurrent.api.Executors.global; import static io.servicetalk.concurrent.api.FilterPublisher.newDistinctSupplier; +import static io.servicetalk.concurrent.api.MulticastPublisher.DEFAULT_MULTICAST_QUEUE_LIMIT; +import static io.servicetalk.concurrent.api.MulticastPublisher.DEFAULT_MULTICAST_TERM_RESUB; +import static io.servicetalk.concurrent.api.MulticastPublisher.newMulticastPublisher; import static io.servicetalk.concurrent.api.NeverPublisher.neverPublisher; import static io.servicetalk.concurrent.api.PublisherDoOnUtils.doOnCancelSupplier; import static io.servicetalk.concurrent.api.PublisherDoOnUtils.doOnCompleteSupplier; @@ -63,6 +66,7 @@ import static io.servicetalk.concurrent.api.PublisherDoOnUtils.doOnNextSupplier; import static io.servicetalk.concurrent.api.PublisherDoOnUtils.doOnRequestSupplier; import static io.servicetalk.concurrent.api.PublisherDoOnUtils.doOnSubscribeSupplier; +import static io.servicetalk.concurrent.api.ReplayPublisher.newReplayPublisher; import static io.servicetalk.concurrent.internal.SubscriberUtils.deliverErrorFromSource; import static io.servicetalk.utils.internal.DurationUtils.toNanos; import static java.util.Objects.requireNonNull; @@ -2991,7 +2995,7 @@ public final Publisher> groupToMany( */ @Deprecated public final Publisher multicastToExactly(int expectedSubscribers) { - return multicastToExactly(expectedSubscribers, 64); + return multicastToExactly(expectedSubscribers, DEFAULT_MULTICAST_QUEUE_LIMIT); } /** @@ -3023,7 +3027,7 @@ public final Publisher multicastToExactly(int expectedSubscribers) { */ @Deprecated public final Publisher multicastToExactly(int expectedSubscribers, int queueLimit) { - return new MulticastPublisher<>(this, expectedSubscribers, true, true, queueLimit, t -> completed()); + return newMulticastPublisher(this, expectedSubscribers, true, true, queueLimit, t -> completed()); } /** @@ -3082,7 +3086,7 @@ public final Publisher multicast(int minSubscribers) { * @see ReactiveX multicast operator */ public final Publisher multicast(int minSubscribers, boolean cancelUpstream) { - return multicast(minSubscribers, 64, cancelUpstream); + return multicast(minSubscribers, DEFAULT_MULTICAST_QUEUE_LIMIT, cancelUpstream); } /** @@ -3145,7 +3149,7 @@ public final Publisher multicast(int minSubscribers, int queueLimit) { * @see ReactiveX multicast operator */ public final Publisher multicast(int minSubscribers, int queueLimit, boolean cancelUpstream) { - return multicast(minSubscribers, queueLimit, cancelUpstream, t -> completed()); + return multicast(minSubscribers, queueLimit, cancelUpstream, DEFAULT_MULTICAST_TERM_RESUB); } /** @@ -3224,7 +3228,73 @@ public final Publisher multicast(int minSubscribers, int queueLimit, */ public final Publisher multicast(int minSubscribers, int queueLimit, boolean cancelUpstream, Function terminalResubscribe) { - return new MulticastPublisher<>(this, minSubscribers, false, cancelUpstream, queueLimit, terminalResubscribe); + return newMulticastPublisher(this, minSubscribers, false, cancelUpstream, queueLimit, terminalResubscribe); + } + + /** + * Similar to {@link #multicast(int)} in that multiple downstream {@link Subscriber}s are enabled on the returned + * {@link Publisher} but also retains {@code history} of the most recently emitted signals from + * {@link Subscriber#onNext(Object)} which are emitted to new downstream {@link Subscriber}s before emitting new + * signals. + * @param history max number of items to retain which can be delivered to new subscribers. + * @return A {@link Publisher} that allows for multiple downstream subscribers and emits the previous + * {@code history} {@link Subscriber#onNext(Object)} signals to each new subscriber. + * @see ReactiveX replay operator + * @see ReplayStrategies#historyBuilder(int) + * @see #replay(ReplayStrategy) + */ + public final Publisher replay(int history) { + return replay(ReplayStrategies.historyBuilder(history).build()); + } + + /** + * Similar to {@link #multicast(int)} in that multiple downstream {@link Subscriber}s are enabled on the returned + * {@link Publisher} but also retains {@code history} of the most recently emitted signals + * from {@link Subscriber#onNext(Object)} which are emitted to new downstream {@link Subscriber}s before emitting + * new signals. Each item is only retained for {@code ttl} duration of time. + * @param history max number of items to retain which can be delivered to new subscribers. + * @param ttl duration each element will be retained before being removed. + * @param executor used to enforce the {@code ttl} argument. + * @return A {@link Publisher} that allows for multiple downstream subscribers and emits the previous + * {@code history} {@link Subscriber#onNext(Object)} signals to each new subscriber. + * @see ReactiveX replay operator + * @see ReplayStrategies#historyTtlBuilder(int, Duration, io.servicetalk.concurrent.Executor) + * @see #replay(ReplayStrategy) + */ + public final Publisher replay(int history, Duration ttl, io.servicetalk.concurrent.Executor executor) { + return replay(ReplayStrategies.historyTtlBuilder(history, ttl, executor).build()); + } + + /** + * Similar to {@link #multicast(int)} in that multiple downstream {@link Subscriber}s are enabled on the returned + * {@link Publisher} but will also retain some history of {@link Subscriber#onNext(Object)} signals + * according to the {@link ReplayAccumulator} {@code accumulatorSupplier}. + * @param accumulatorSupplier supplies a {@link ReplayAccumulator} on each subscribe to upstream that can retain + * history of {@link Subscriber#onNext(Object)} signals to deliver to new downstream subscribers. + * @return A {@link Publisher} that allows for multiple downstream subscribers that can retain + * history of {@link Subscriber#onNext(Object)} signals to deliver to new downstream subscribers. + * @see ReactiveX replay operator + * @see #replay(ReplayStrategy) + */ + public final Publisher replay(Supplier> accumulatorSupplier) { + return replay(new ReplayStrategyBuilder<>(accumulatorSupplier).build()); + } + + /** + * Similar to {@link #multicast(int)} in that multiple downstream {@link Subscriber}s are enabled on the returned + * {@link Publisher} but will also retain some history of {@link Subscriber#onNext(Object)} signals + * according to the {@link ReplayStrategy} {@code replayStrategy}. + * @param replayStrategy a {@link ReplayStrategy} that determines the replay behavior and history retention logic. + * @return A {@link Publisher} that allows for multiple downstream subscribers that can retain + * history of {@link Subscriber#onNext(Object)} signals to deliver to new downstream subscribers. + * @see ReactiveX replay operator + * @see ReplayStrategyBuilder + * @see ReplayStrategies + */ + public final Publisher replay(ReplayStrategy replayStrategy) { + return newReplayPublisher(this, replayStrategy.accumulatorSupplier(), replayStrategy.minSubscribers(), + replayStrategy.cancelUpstream(), replayStrategy.queueLimitHint(), + replayStrategy.terminalResubscribe()); } /** diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayAccumulator.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayAccumulator.java new file mode 100644 index 0000000000..0beb05a792 --- /dev/null +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayAccumulator.java @@ -0,0 +1,48 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.api; + +import io.servicetalk.concurrent.PublisherSource.Subscriber; + +import java.util.function.Consumer; +import javax.annotation.Nullable; + +/** + * Accumulates signals for the {@link Publisher} replay operator. + * @param The type of data to accumulate. + */ +public interface ReplayAccumulator { + /** + * Called on each {@link Subscriber#onNext(Object)} and intended to accumulate the signal so that new + * {@link Subscriber}s will see this value via {@link #deliverAccumulation(Consumer)}. + *

+ * This method won't be called concurrently, but should return quickly to minimize performance impacts. + * @param t An {@link Subscriber#onNext(Object)} to accumulate. + */ + void accumulate(@Nullable T t); + + /** + * Called to deliver the signals from {@link #accumulate(Object)} to new {@code consumer}. + * @param consumer The consumer of the signals previously aggregated via {@link #accumulate(Object)}. + */ + void deliverAccumulation(Consumer consumer); + + /** + * Called if the accumulation can be cancelled and any asynchronous resources can be cleaned up (e.g. timers). + */ + default void cancelAccumulation() { + } +} diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayPublisher.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayPublisher.java new file mode 100644 index 0000000000..9d92930fd3 --- /dev/null +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayPublisher.java @@ -0,0 +1,210 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.api; + +import io.servicetalk.concurrent.internal.TerminalNotification; + +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.function.Function; +import java.util.function.Supplier; +import javax.annotation.Nullable; + +import static io.servicetalk.concurrent.api.SubscriberApiUtils.unwrapNullUnchecked; +import static io.servicetalk.concurrent.api.SubscriberApiUtils.wrapNull; +import static io.servicetalk.concurrent.internal.ConcurrentUtils.releaseLock; +import static io.servicetalk.concurrent.internal.ConcurrentUtils.tryAcquireLock; +import static io.servicetalk.concurrent.internal.SubscriberUtils.safeOnComplete; +import static io.servicetalk.concurrent.internal.SubscriberUtils.safeOnError; +import static io.servicetalk.concurrent.internal.TerminalNotification.complete; +import static io.servicetalk.concurrent.internal.TerminalNotification.error; +import static io.servicetalk.utils.internal.ThrowableUtils.addSuppressed; +import static java.util.Objects.requireNonNull; + +final class ReplayPublisher extends MulticastPublisher { + @SuppressWarnings("rawtypes") + private static final AtomicLongFieldUpdater signalQueuedUpdater = + AtomicLongFieldUpdater.newUpdater(ReplayPublisher.ReplayState.class, "signalsQueued"); + private final Supplier> accumulatorSupplier; + + private ReplayPublisher( + Publisher original, Supplier> accumulatorSupplier, int minSubscribers, + boolean cancelUpstream, int maxQueueSize, Function terminalResubscribe) { + super(original, minSubscribers, false, cancelUpstream, maxQueueSize, terminalResubscribe); + this.accumulatorSupplier = requireNonNull(accumulatorSupplier); + } + + static MulticastPublisher newReplayPublisher( + Publisher original, Supplier> accumulatorSupplier, int minSubscribers, + boolean cancelUpstream, int maxQueueSize, Function terminalResubscribe) { + ReplayPublisher publisher = new ReplayPublisher<>(original, accumulatorSupplier, minSubscribers, + cancelUpstream, minSubscribers, terminalResubscribe); + publisher.resetState(maxQueueSize, minSubscribers); + return publisher; + } + + @Override + void resetState(int maxQueueSize, int minSubscribers) { + state = new ReplayState(maxQueueSize, minSubscribers, accumulatorSupplier.get()); + } + + private final class ReplayState extends MulticastPublisher.State { + private final ReplayAccumulator accumulator; + /** + * We could check {@link #subscriptionEvents} is empty, but there are events outside of {@link Subscriber} + * signals in this queue that we don't care about in terms of preserving order, so we keep this count instead + * to only queue when necessary. + */ + volatile long signalsQueued; + + ReplayState(final int maxQueueSize, final int minSubscribers, + ReplayAccumulator accumulator) { + super(maxQueueSize, minSubscribers); + this.accumulator = requireNonNull(accumulator); + } + + @Override + public void onNext(@Nullable final T t) { + // signalsQueued must be 0 or else items maybe delivered out of order. The value will only be increased + // on the Subscriber thread (no concurrency) and decreased on the draining thread. Optimistically check + // the value here and worst case if the queue has been drained of signals and this thread hasn't yet + // observed the value we will queue but still see correct ordering. + if (signalsQueued == 0 && tryAcquireLock(subscriptionLockUpdater, this)) { + try { + // All subscribers must either see this direct onNext signal, or see it through the accumulator. + // Therefore, we accumulate and deliver onNext while locked to avoid either delivering the signal + // twice (accumulator, addSubscriber, and onNext) or not at all (missed due to concurrency). + accumulator.accumulate(t); + super.onNext(t); + } finally { + if (!releaseLock(subscriptionLockUpdater, this)) { + processSubscriptionEvents(); + } + } + } else { + queueOnNext(t); + } + } + + @Override + public void onError(final Throwable t) { + if (signalsQueued == 0 && tryAcquireLock(subscriptionLockUpdater, this)) { + try { + super.onError(t); + } finally { + if (!releaseLock(subscriptionLockUpdater, this)) { + processSubscriptionEvents(); + } + } + } else { + queueTerminal(error(t)); + } + } + + @Override + public void onComplete() { + if (signalsQueued == 0 && tryAcquireLock(subscriptionLockUpdater, this)) { + try { + super.onComplete(); + } finally { + if (!releaseLock(subscriptionLockUpdater, this)) { + processSubscriptionEvents(); + } + } + } else { + queueTerminal(complete()); + } + } + + @Override + void processOnNextEvent(Object wrapped) { + // subscriptionLockUpdater is held + signalQueuedUpdater.decrementAndGet(this); + final T unwrapped = unwrapNullUnchecked(wrapped); + accumulator.accumulate(unwrapped); + super.onNext(unwrapped); + } + + @Override + void processTerminal(TerminalNotification terminalNotification) { + // subscriptionLockUpdater is held + signalQueuedUpdater.decrementAndGet(this); + if (terminalNotification.cause() != null) { + super.onError(terminalNotification.cause()); + } else { + super.onComplete(); + } + } + + @Override + boolean processSubscribeEvent(MulticastFixedSubscriber subscriber, + @Nullable TerminalSubscriber terminalSubscriber) { + // subscriptionLockUpdater is held + if (terminalSubscriber == null) { + // Only call the super class if no terminal event. We don't want the super class to terminate + // the subscriber because we need to deliver any accumulated signals, and we also don't want to + // track state in demandQueue because it isn't necessary to manage upstream demand, and we don't want + // to hold a reference to the subscriber unnecessarily. + super.processSubscribeEvent(subscriber, null); + } + Throwable caughtCause = null; + try { + // It's safe to call onNext before onSubscribe bcz the base class expects onSubscribe to be async and + // queues/reorders events to preserve ReactiveStreams semantics. + accumulator.deliverAccumulation(subscriber::onNext); + } catch (Throwable cause) { + caughtCause = cause; + } finally { + if (terminalSubscriber != null) { + if (caughtCause != null) { + if (terminalSubscriber.terminalError != null) { + // Use caughtCause as original otherwise we keep appending to the cached Throwable. + safeOnError(subscriber, addSuppressed(caughtCause, terminalSubscriber.terminalError)); + } else { + safeOnError(subscriber, caughtCause); + } + } else if (terminalSubscriber.terminalError != null) { + safeOnError(subscriber, terminalSubscriber.terminalError); + } else { + safeOnComplete(subscriber); + } + } else if (caughtCause != null) { + safeOnError(subscriber, caughtCause); + } + } + // Even if we terminated we always want to continue processing to trigger onSubscriber and allow queued + // signals from above to be processed when demand arrives. + return true; + } + + @Override + void upstreamCancelled() { + // subscriptionLockUpdater is held + accumulator.cancelAccumulation(); + } + + private void queueOnNext(@Nullable T t) { + signalQueuedUpdater.incrementAndGet(this); + subscriptionEvents.add(wrapNull(t)); + processSubscriptionEvents(); + } + + private void queueTerminal(TerminalNotification terminalNotification) { + signalQueuedUpdater.incrementAndGet(this); + subscriptionEvents.add(terminalNotification); + processSubscriptionEvents(); + } + } +} diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategies.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategies.java new file mode 100644 index 0000000000..e642750305 --- /dev/null +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategies.java @@ -0,0 +1,244 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.api; + +import io.servicetalk.concurrent.Cancellable; +import io.servicetalk.concurrent.Executor; + +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; +import javax.annotation.Nullable; + +import static io.servicetalk.concurrent.api.SubscriberApiUtils.unwrapNullUnchecked; +import static io.servicetalk.concurrent.api.SubscriberApiUtils.wrapNull; +import static io.servicetalk.concurrent.internal.EmptySubscriptions.EMPTY_SUBSCRIPTION_NO_THROW; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static java.util.concurrent.atomic.AtomicReferenceFieldUpdater.newUpdater; + +/** + * Utilities to customize {@link ReplayStrategy}. + */ +public final class ReplayStrategies { + private ReplayStrategies() { + } + + /** + * Create a {@link ReplayStrategyBuilder} using the history strategy. + * @param history max number of items to retain which can be delivered to new subscribers. + * @param The type of {@link ReplayStrategyBuilder}. + * @return a {@link ReplayStrategyBuilder} using the history strategy. + */ + public static ReplayStrategyBuilder historyBuilder(int history) { + return new ReplayStrategyBuilder<>(() -> new MostRecentReplayAccumulator<>(history)); + } + + /** + * Create a {@link ReplayStrategyBuilder} using the history and TTL strategy. + * @param history max number of items to retain which can be delivered to new subscribers. + * @param ttl duration each element will be retained before being removed. + * @param executor used to enforce the {@code ttl} argument. + * @param The type of {@link ReplayStrategyBuilder}. + * @return a {@link ReplayStrategyBuilder} using the history and TTL strategy. + */ + public static ReplayStrategyBuilder historyTtlBuilder(int history, Duration ttl, Executor executor) { + return new ReplayStrategyBuilder<>(() -> new MostRecentTimeLimitedReplayAccumulator<>(history, ttl, executor)); + } + + private static final class MostRecentReplayAccumulator implements ReplayAccumulator { + private final int maxItems; + private final Deque list = new ArrayDeque<>(); + + MostRecentReplayAccumulator(final int maxItems) { + if (maxItems <= 0) { + throw new IllegalArgumentException("maxItems: " + maxItems + "(expected >0)"); + } + this.maxItems = maxItems; + } + + @Override + public void accumulate(@Nullable final T t) { + if (list.size() >= maxItems) { + list.pop(); + } + list.add(wrapNull(t)); + } + + @Override + public void deliverAccumulation(final Consumer consumer) { + for (Object item : list) { + consumer.accept(unwrapNullUnchecked(item)); + } + } + } + + private static final class MostRecentTimeLimitedReplayAccumulator implements ReplayAccumulator { + @SuppressWarnings("rawtypes") + private static final AtomicLongFieldUpdater stateSizeUpdater = + AtomicLongFieldUpdater.newUpdater(MostRecentTimeLimitedReplayAccumulator.class, "stateSize"); + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater + timerCancellableUpdater = newUpdater(MostRecentTimeLimitedReplayAccumulator.class, Cancellable.class, + "timerCancellable"); + private final Executor executor; + private final Queue> items; + private final long ttlNanos; + private final int maxItems; + /** + * Provide atomic state for size of {@link #items} and also for visibility between the threads consuming and + * producing. The atomically incrementing "state" ensures that any modifications from the producer thread + * are visible from the consumer thread and we never "miss" a timer schedule event if the queue becomes empty. + */ + private volatile long stateSize; + @Nullable + private volatile Cancellable timerCancellable; + + MostRecentTimeLimitedReplayAccumulator(final int maxItems, final Duration ttl, final Executor executor) { + if (ttl.isNegative()) { + throw new IllegalArgumentException("ttl: " + ttl + "(expected non-negative)"); + } + if (maxItems <= 0) { + throw new IllegalArgumentException("maxItems: " + maxItems + "(expected >0)"); + } + this.executor = requireNonNull(executor); + this.ttlNanos = ttl.toNanos(); + this.maxItems = maxItems; + items = new ConcurrentLinkedQueue<>(); // SpMc + } + + @Override + public void accumulate(@Nullable final T t) { + // We may exceed max items in the queue but this method isn't invoked concurrently, so we only go over by + // at most 1 item. + items.add(new TimeStampSignal<>(executor.currentTime(NANOSECONDS), t)); + for (;;) { + final long currentStateSize = stateSize; + final int currentSize = getSize(currentStateSize); + final int nextState = getState(currentStateSize) + 1; + if (currentSize >= maxItems) { + if (stateSizeUpdater.compareAndSet(this, currentStateSize, + buildStateSize(nextState, currentSize))) { + items.poll(); + break; + } + } else if (stateSizeUpdater.compareAndSet(this, currentStateSize, + buildStateSize(nextState, currentSize + 1))) { + if (currentSize == 0) { + schedulerTimer(ttlNanos); + } + break; + } + } + } + + @Override + public void deliverAccumulation(final Consumer consumer) { + for (TimeStampSignal timeStampSignal : items) { + consumer.accept(timeStampSignal.signal); + } + } + + @Override + public void cancelAccumulation() { + final Cancellable cancellable = timerCancellableUpdater.getAndSet(this, EMPTY_SUBSCRIPTION_NO_THROW); + if (cancellable != null) { + cancellable.cancel(); + } + } + + private static int getSize(long stateSize) { + return (int) stateSize; + } + + private static int getState(long stateSize) { + return (int) (stateSize >>> 32); + } + + private static long buildStateSize(int state, int size) { + return (((long) state) << 32) | size; + } + + private void schedulerTimer(long nanos) { + for (;;) { + final Cancellable currentCancellable = timerCancellable; + if (currentCancellable == EMPTY_SUBSCRIPTION_NO_THROW) { + break; + } else { + final Cancellable nextCancellable = executor.schedule(this::expireSignals, nanos, NANOSECONDS); + if (timerCancellableUpdater.compareAndSet(this, currentCancellable, nextCancellable)) { + // Current logic only has 1 timer outstanding at any give time so cancellation of + // the current cancellable shouldn't be necessary but do it for completeness. + if (currentCancellable != null) { + currentCancellable.cancel(); + } + break; + } else { + nextCancellable.cancel(); + } + } + } + } + + private void expireSignals() { + final long nanoTime = executor.currentTime(NANOSECONDS); + TimeStampSignal item; + for (;;) { + // read stateSize before peek, so if we poll from the queue we are sure to see the correct + // state relative to items in the queue. + final long currentStateSize = stateSize; + item = items.peek(); + if (item == null) { + break; + } else if (nanoTime - item.timeStamp >= ttlNanos) { + final int currentSize = getSize(currentStateSize); + if (stateSizeUpdater.compareAndSet(this, currentStateSize, + buildStateSize(getState(currentStateSize) + 1, currentSize - 1))) { + // When we add: we add to the queue we add first, then CAS sizeState. + // When we remove: we CAS the atomic state first, then poll. + // This avoids removing a non-expired item because if the "add" thread is running faster and + // already polled "item" the CAS will fail, and we will try again on the next loop iteration. + items.poll(); + if (currentSize == 1) { + // a new timer task will be scheduled after addition if this is the case. break to avoid + // multiple timer tasks running concurrently. + break; + } + } + } else { + schedulerTimer(ttlNanos - (nanoTime - item.timeStamp)); + break; // elements sorted in increasing time, break when first non-expired entry found. + } + } + } + } + + private static final class TimeStampSignal { + final long timeStamp; + @Nullable + final T signal; + + private TimeStampSignal(final long timeStamp, @Nullable final T signal) { + this.timeStamp = timeStamp; + this.signal = signal; + } + } +} diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategy.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategy.java new file mode 100644 index 0000000000..3e7a2e25ef --- /dev/null +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategy.java @@ -0,0 +1,67 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.api; + +import io.servicetalk.concurrent.PublisherSource.Subscriber; + +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * Used to customize the strategy for the {@link Publisher} replay operator. + * @param The type of data. + */ +public interface ReplayStrategy { + /** + * Get the minimum number of downstream subscribers before subscribing upstream. + * @return the minimum number of downstream subscribers before subscribing upstream. + */ + int minSubscribers(); + + /** + * Get a {@link Supplier} that provides the {@link ReplayAccumulator} on each upstream subscribe. + * @return a {@link Supplier} that provides the {@link ReplayAccumulator} on each upstream subscribe. + */ + Supplier> accumulatorSupplier(); + + /** + * Determine if all the downstream subscribers cancel, should upstream be cancelled. + * @return {@code true} if all the downstream subscribers cancel, should upstream be cancelled. {@code false} + * will not cancel upstream if all downstream subscribers cancel. + */ + boolean cancelUpstream(); + + /** + * Get a hint to limit the number of elements which will be queued for each {@link Subscriber} in order to + * compensate for unequal demand and late subscribers. + * @return a hint to limit the number of elements which will be queued for each {@link Subscriber} in order to + * compensate for unequal demand and late subscribers. + */ + int queueLimitHint(); + + /** + * Get a {@link Function} that is invoked when a terminal signal arrives from upstream and determines when state + * is reset to allow for upstream resubscribe. + * @return A {@link Function} that is invoked when a terminal signal arrives from upstream, and + * returns a {@link Completable} whose termination resets the state of the returned {@link Publisher} and allows + * for downstream resubscribing. The argument to this function is as follows: + *
    + *
  • {@code null} if upstream terminates with {@link Subscriber#onComplete()}
  • + *
  • otherwise the {@link Throwable} from {@link Subscriber#onError(Throwable)}
  • + *
+ */ + Function terminalResubscribe(); +} diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategyBuilder.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategyBuilder.java new file mode 100644 index 0000000000..4a030a5ae7 --- /dev/null +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategyBuilder.java @@ -0,0 +1,155 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.api; + +import io.servicetalk.concurrent.PublisherSource.Subscriber; + +import java.util.function.Function; +import java.util.function.Supplier; + +import static io.servicetalk.concurrent.api.Completable.never; +import static io.servicetalk.concurrent.api.MulticastPublisher.DEFAULT_MULTICAST_QUEUE_LIMIT; +import static java.util.Objects.requireNonNull; + +/** + * A builder of {@link ReplayStrategy}. + * @param The type of data for {@link ReplayStrategy}. + */ +public final class ReplayStrategyBuilder { + private int minSubscribers = 1; + private final Supplier> accumulatorSupplier; + private boolean cancelUpstream; + private int queueLimitHint = DEFAULT_MULTICAST_QUEUE_LIMIT; + private Function terminalResubscribe = t -> never(); + + /** + * Create a new instance. + * @param accumulatorSupplier provides the {@link ReplayAccumulator} to use on each subscribe to upstream. + */ + public ReplayStrategyBuilder(Supplier> accumulatorSupplier) { + this.accumulatorSupplier = requireNonNull(accumulatorSupplier); + } + + /** + * Set the minimum number of downstream subscribers before subscribing upstream. + * @param minSubscribers the minimum number of downstream subscribers before subscribing upstream. + * @return {@code this}. + */ + public ReplayStrategyBuilder minSubscribers(int minSubscribers) { + if (minSubscribers <= 0) { + throw new IllegalArgumentException("minSubscribers: " + minSubscribers + " (expected >0)"); + } + this.minSubscribers = minSubscribers; + return this; + } + + /** + * Determine if all the downstream subscribers cancel, should upstream be cancelled. + * @param cancelUpstream {@code true} if all the downstream subscribers cancel, should upstream be cancelled. + * {@code false} will not cancel upstream if all downstream subscribers cancel. + * @return {@code this}. + */ + public ReplayStrategyBuilder cancelUpstream(boolean cancelUpstream) { + this.cancelUpstream = cancelUpstream; + return this; + } + + /** + * Set a hint to limit the number of elements which will be queued for each {@link Subscriber} in order to + * compensate for unequal demand and late subscribers. + * @param queueLimitHint a hint to limit the number of elements which will be queued for each {@link Subscriber} in + * order to compensate for unequal demand and late subscribers. + * @return {@code this}. + */ + public ReplayStrategyBuilder queueLimitHint(int queueLimitHint) { + if (queueLimitHint < 1) { + throw new IllegalArgumentException("maxQueueSize: " + queueLimitHint + " (expected >1)"); + } + this.queueLimitHint = queueLimitHint; + return this; + } + + /** + * Set a {@link Function} that is invoked when a terminal signal arrives from upstream and determines when state + * is reset to allow for upstream resubscribe. + * @param terminalResubscribe A {@link Function} that is invoked when a terminal signal arrives from upstream, and + * returns a {@link Completable} whose termination resets the state of the returned {@link Publisher} and allows + * for downstream resubscribing. The argument to this function is as follows: + *
    + *
  • {@code null} if upstream terminates with {@link Subscriber#onComplete()}
  • + *
  • otherwise the {@link Throwable} from {@link Subscriber#onError(Throwable)}
  • + *
+ * @return {@code this}. + */ + public ReplayStrategyBuilder terminalResubscribe( + Function terminalResubscribe) { + this.terminalResubscribe = requireNonNull(terminalResubscribe); + return this; + } + + /** + * Build the {@link ReplayStrategy}. + * @return the {@link ReplayStrategy}. + */ + public ReplayStrategy build() { + return new DefaultReplayStrategy<>(minSubscribers, accumulatorSupplier, cancelUpstream, queueLimitHint, + terminalResubscribe); + } + + private static final class DefaultReplayStrategy implements ReplayStrategy { + private final int minSubscribers; + private final Supplier> accumulatorSupplier; + private final boolean cancelUpstream; + private final int queueLimitHint; + private final Function terminalResubscribe; + + private DefaultReplayStrategy( + final int minSubscribers, final Supplier> accumulatorSupplier, + final boolean cancelUpstream, final int queueLimitHint, + final Function terminalResubscribe) { + this.minSubscribers = minSubscribers; + this.accumulatorSupplier = accumulatorSupplier; + this.cancelUpstream = cancelUpstream; + this.queueLimitHint = queueLimitHint; + this.terminalResubscribe = terminalResubscribe; + } + + @Override + public int minSubscribers() { + return minSubscribers; + } + + @Override + public Supplier> accumulatorSupplier() { + return accumulatorSupplier; + } + + @Override + public boolean cancelUpstream() { + return cancelUpstream; + } + + @Override + public int queueLimitHint() { + return queueLimitHint; + } + + @Override + public Function terminalResubscribe() { + return terminalResubscribe; + } + } +} diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/MulticastPublisherTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/MulticastPublisherTest.java index 6250b6345d..2572f28cbf 100644 --- a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/MulticastPublisherTest.java +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/MulticastPublisherTest.java @@ -40,6 +40,7 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Stream; @@ -65,11 +66,11 @@ import static org.mockito.Mockito.verify; class MulticastPublisherTest { - private TestPublisher source; - private TestPublisherSubscriber subscriber1; - private TestPublisherSubscriber subscriber2; - private TestPublisherSubscriber subscriber3; - private TestSubscription subscription; + TestPublisher source; + TestPublisherSubscriber subscriber1; + TestPublisherSubscriber subscriber2; + TestPublisherSubscriber subscriber3; + TestSubscription subscription; @BeforeEach void setUp() { @@ -83,10 +84,31 @@ void setUp() { subscriber3 = new TestPublisherSubscriber<>(); } + Publisher applyOperator(Publisher source, int minSubscribers) { + return source.multicast(minSubscribers); + } + + Publisher applyOperator(Publisher source, int minSubscribers, boolean cancelUpstream) { + return source.multicast(minSubscribers, cancelUpstream); + } + + Publisher applyOperator(Publisher source, int minSubscribers, int queueLimit, + Function terminalResubscribe) { + return source.multicast(minSubscribers, queueLimit, terminalResubscribe); + } + + Publisher applyOperator(Publisher source, int minSubscribers, int queueLimit) { + return source.multicast(minSubscribers, queueLimit); + } + + Publisher applyOperator(Publisher source, int minSubscribers, int queueLimit, boolean cancelUpstream) { + return source.multicast(minSubscribers, queueLimit, cancelUpstream); + } + @ParameterizedTest @ValueSource(booleans = {true, false}) void singleSubscriber(boolean onError) { - toSource(source.multicast(1)).subscribe(subscriber1); + toSource(applyOperator(source, 1)).subscribe(subscriber1); subscriber1.awaitSubscription(); singleSourceTerminate(onError); } @@ -94,7 +116,7 @@ void singleSubscriber(boolean onError) { @ParameterizedTest @ValueSource(booleans = {true, false}) void singleSubscriberData(boolean onError) throws InterruptedException { - toSource(source.multicast(1)).subscribe(subscriber1); + toSource(applyOperator(source, 1)).subscribe(subscriber1); subscriber1.awaitSubscription().request(1); subscription.awaitRequestN(1); source.onNext(1); @@ -105,7 +127,7 @@ void singleSubscriberData(boolean onError) throws InterruptedException { @ParameterizedTest @ValueSource(booleans = {true, false}) void singleSubscriberMultipleData(boolean onError) throws InterruptedException { - toSource(source.multicast(1)).subscribe(subscriber1); + toSource(applyOperator(source, 1)).subscribe(subscriber1); Subscription localSubscription = subscriber1.awaitSubscription(); localSubscription.request(1); subscription.awaitRequestN(1); @@ -131,7 +153,7 @@ private void singleSourceTerminate(boolean onError) { @ParameterizedTest @ValueSource(booleans = {true, false}) void singleSubscriberCancel(boolean cancelUpstream) throws InterruptedException { - toSource(source.multicast(1, cancelUpstream)).subscribe(subscriber1); + toSource(applyOperator(source, 1, cancelUpstream)).subscribe(subscriber1); Subscription subscription1 = subscriber1.awaitSubscription(); subscription1.cancel(); subscription1.cancel(); // multiple cancels should be safe. @@ -145,7 +167,7 @@ void singleSubscriberCancel(boolean cancelUpstream) throws InterruptedException @ParameterizedTest @ValueSource(booleans = {true, false}) void twoSubscribersOneCancelsMultipleTimes(boolean cancelUpstream) throws InterruptedException { - Publisher publisher = source.multicast(2, cancelUpstream); + Publisher publisher = applyOperator(source, 2, cancelUpstream); toSource(publisher).subscribe(subscriber1); toSource(publisher).subscribe(subscriber2); Cancellable subscription1 = subscriber1.awaitSubscription(); @@ -163,7 +185,7 @@ void twoSubscribersOneCancelsMultipleTimes(boolean cancelUpstream) throws Interr @Test void singleSubscriberCancelStillDeliversData() throws InterruptedException { - Publisher publisher = source.multicast(1, false); + Publisher publisher = applyOperator(source, 1, false); toSource(publisher).subscribe(subscriber1); Subscription subscription1 = subscriber1.awaitSubscription(); subscription1.request(1); @@ -188,7 +210,7 @@ void singleSubscriberCancelStillDeliversData() throws InterruptedException { @ParameterizedTest @ValueSource(booleans = {true, false}) void subscriberCancelThenRequestIsNoop(boolean cancelUpstream) throws InterruptedException { - Publisher publisher = source.multicast(2, cancelUpstream); + Publisher publisher = applyOperator(source, 2, cancelUpstream); toSource(publisher).subscribe(subscriber1); Subscription subscription1 = subscriber1.awaitSubscription(); assertThat(subscription.requested(), is(0L)); @@ -224,7 +246,7 @@ void subscriberCancelThenRequestIsNoop(boolean cancelUpstream) throws Interrupte @ParameterizedTest @ValueSource(booleans = {true, false}) void twoSubscribersNoData(boolean onError) { - Publisher publisher = source.multicast(2); + Publisher publisher = applyOperator(source, 2); toSource(publisher).subscribe(subscriber1); subscriber1.awaitSubscription(); assertThat(subscription.requested(), is(0L)); @@ -236,7 +258,7 @@ void twoSubscribersNoData(boolean onError) { @ParameterizedTest @ValueSource(booleans = {true, false}) void twoSubscribersData(boolean onError) throws InterruptedException { - Publisher publisher = source.multicast(2); + Publisher publisher = applyOperator(source, 2); toSource(publisher).subscribe(subscriber1); subscriber1.awaitSubscription().request(1); assertThat(subscription.requested(), is(0L)); @@ -252,7 +274,7 @@ void twoSubscribersData(boolean onError) throws InterruptedException { @ParameterizedTest @ValueSource(booleans = {true, false}) void twoSubscribersMultipleData(boolean onError) throws InterruptedException { - Publisher publisher = source.multicast(2); + Publisher publisher = applyOperator(source, 2); toSource(publisher).subscribe(subscriber1); Subscription localSubscription1 = subscriber1.awaitSubscription(); localSubscription1.request(1); @@ -273,7 +295,7 @@ void twoSubscribersMultipleData(boolean onError) throws InterruptedException { twoSubscribersTerminate(onError); } - private void twoSubscribersTerminate(boolean onError) { + void twoSubscribersTerminate(boolean onError) { if (onError) { source.onError(DELIBERATE_EXCEPTION); assertThat(subscriber1.awaitOnError(), is(DELIBERATE_EXCEPTION)); @@ -288,7 +310,7 @@ private void twoSubscribersTerminate(boolean onError) { @ParameterizedTest @ValueSource(booleans = {true, false}) void twoSubscribersAfterTerminalData(boolean onError) throws InterruptedException { - Publisher publisher = source.multicast(1, 10, t -> never()); + Publisher publisher = applyOperator(source, 1, 10, t -> never()); toSource(publisher).subscribe(subscriber1); subscriber1.awaitSubscription().request(1); subscription.awaitRequestN(1); @@ -313,7 +335,7 @@ void twoSubscribersAfterTerminalData(boolean onError) throws InterruptedExceptio @ParameterizedTest @MethodSource("twoSubscribersInvalidRequestNParams") void twoSubscribersInvalidRequestN(long invalidN, boolean firstSubscription) { - Publisher publisher = Publisher.from(1).multicast(2); + Publisher publisher = applyOperator(Publisher.from(1), 2); toSource(publisher).subscribe(subscriber1); Subscription subscription1 = subscriber1.awaitSubscription(); toSource(publisher).subscribe(subscriber2); @@ -340,7 +362,7 @@ private static Stream twoSubscribersInvalidRequestNParams() { @ParameterizedTest @MethodSource("trueFalseStream") void twoSubscribersCancel(boolean firstSubscription, boolean cancelUpstream) throws InterruptedException { - Publisher publisher = source.multicast(2, cancelUpstream); + Publisher publisher = applyOperator(source, 2, cancelUpstream); toSource(publisher).subscribe(subscriber1); Subscription subscription1 = subscriber1.awaitSubscription(); toSource(publisher).subscribe(subscriber2); @@ -392,7 +414,7 @@ private static Stream trueFalseStream() { @ParameterizedTest @ValueSource(booleans = {true, false}) void threeSubscribersOneLateNoQueueData(boolean onError) throws InterruptedException { - Publisher publisher = source.multicast(2); + Publisher publisher = applyOperator(source, 2); toSource(publisher).subscribe(subscriber1); Subscription localSubscription1 = subscriber1.awaitSubscription(); localSubscription1.request(1); @@ -420,7 +442,7 @@ void threeSubscribersOneLateNoQueueData(boolean onError) throws InterruptedExcep @ParameterizedTest @ValueSource(booleans = {true, false}) void threeSubscribersOneLateQueueData(boolean onError) throws InterruptedException { - Publisher publisher = source.multicast(2); + Publisher publisher = applyOperator(source, 2); toSource(publisher).subscribe(subscriber1); toSource(publisher).subscribe(subscriber2); Subscription localSubscription1 = subscriber1.awaitSubscription(); @@ -444,7 +466,7 @@ void threeSubscribersOneLateQueueData(boolean onError) throws InterruptedExcepti @Test void cancelMinSubscriberRequestsMore() throws InterruptedException { - Publisher publisher = source.multicast(1); + Publisher publisher = applyOperator(source, 1); toSource(publisher).subscribe(subscriber1); Subscription localSubscription1 = subscriber1.awaitSubscription(); toSource(publisher).subscribe(subscriber2); @@ -458,7 +480,7 @@ void cancelMinSubscriberRequestsMore() throws InterruptedException { @Test void cancelMinSubscriberRespectsQueueLimit() throws InterruptedException { final int queueLimit = 64; - Publisher publisher = source.multicast(2, queueLimit); + Publisher publisher = applyOperator(source, 2, queueLimit); toSource(publisher).subscribe(subscriber1); Subscription localSubscription1 = subscriber1.awaitSubscription(); localSubscription1.request(10); @@ -521,7 +543,7 @@ void threeSubscribersOneCancelRequestsUpstream(boolean cancel2First) throws Inte @ParameterizedTest @MethodSource("trueFalseStream") void threeSubscribersOneLateAfterCancel(boolean cancelMax, boolean cancelUpstream) throws InterruptedException { - Publisher publisher = source.multicast(2, cancelUpstream); + Publisher publisher = applyOperator(source, 2, cancelUpstream); toSource(publisher).subscribe(subscriber1); Subscription localSubscription1 = subscriber1.awaitSubscription(); localSubscription1.request(5); @@ -560,7 +582,7 @@ void threeSubscribersOneLateAfterCancel(boolean cancelMax, boolean cancelUpstrea subscriber3.awaitOnComplete(); } - private void threeSubscribersTerminate(boolean onError) { + void threeSubscribersTerminate(boolean onError) { if (onError) { source.onError(DELIBERATE_EXCEPTION); assertThat(subscriber1.awaitOnError(), is(DELIBERATE_EXCEPTION)); @@ -576,7 +598,7 @@ private void threeSubscribersTerminate(boolean onError) { @Test void inlineRequestFromOnSubscribeToMultipleSubscribers() { - Publisher publisher = Publisher.from(1).multicast(2); + Publisher publisher = applyOperator(Publisher.from(1), 2); @SuppressWarnings("unchecked") Subscriber sub1 = mock(Subscriber.class); @SuppressWarnings("unchecked") @@ -603,8 +625,8 @@ void inlineRequestFromOnSubscribeToMultipleSubscribers() { @ParameterizedTest @ValueSource(booleans = {true, false}) void onErrorFromSubscriptionRequestToMultipleSubscribers(boolean onError) { - Publisher multicast = new TerminateFromOnSubscribePublisher(onError ? - error(DELIBERATE_EXCEPTION) : complete()).multicast(2); + Publisher multicast = applyOperator(new TerminateFromOnSubscribePublisher(onError ? + error(DELIBERATE_EXCEPTION) : complete()), 2); toSource(multicast).subscribe(subscriber1); toSource(multicast).subscribe(subscriber2); subscriber1.awaitSubscription().request(1); @@ -621,7 +643,7 @@ void onErrorFromSubscriptionRequestToMultipleSubscribers(boolean onError) { @ParameterizedTest @MethodSource("reentrySubscriberRequestCountIsCorrectParams") void reentrySubscriberOrderingCorrect(boolean firstReentry, boolean secondReentry) { - Publisher multicast = fromSource(new ReentryPublisher(0, 4)).multicast(2); + Publisher multicast = applyOperator(fromSource(new ReentryPublisher(0, 4)), 2); toSource(multicast.beforeOnNext(n -> { if (firstReentry) { subscriber1.awaitSubscription().request(1); @@ -654,7 +676,7 @@ void reentrySubscriberOrderingCorrect(boolean firstReentry, boolean secondReentr @ParameterizedTest @MethodSource("reentrySubscriberRequestCountIsCorrectParams") void reentrySubscriberRequestCountIsCorrect(boolean firstReentry, boolean secondReentry) { - Publisher multicast = source.multicast(2); + Publisher multicast = applyOperator(source, 2); toSource(multicast.whenOnNext(n -> { if (firstReentry) { subscriber1.awaitSubscription().request(1); @@ -697,7 +719,7 @@ private static Stream reentrySubscriberRequestCountIsCorrectParams() @ParameterizedTest @ValueSource(booleans = {true, false}) void reentryAndMultiQueueSupportsNull(boolean requestReentry) throws InterruptedException { - Publisher multicast = source.multicast(1); + Publisher multicast = applyOperator(source, 1); AtomicBoolean onNextCalled = new AtomicBoolean(); toSource(multicast).subscribe(subscriber1); subscriber1.awaitSubscription().request(3); @@ -726,7 +748,7 @@ void reentryAndMultiQueueSupportsNull(boolean requestReentry) throws Interrupted void reentryAsyncData() throws Exception { Executor executor = Executors.newCachedThreadExecutor(); try { - Publisher multicast = Publisher.from(1, 2, 3).publishOn(executor).multicast(2); + Publisher multicast = applyOperator(Publisher.from(1, 2, 3).publishOn(executor), 2); AtomicBoolean onNextCalled = new AtomicBoolean(); toSource(multicast.afterOnNext(n -> { if (onNextCalled.compareAndSet(false, true)) { @@ -749,7 +771,7 @@ void reentryAsyncData() throws Exception { @Test void replenishRequestNInMaxQueueIncrementsLongMax() { - Publisher multicast = source.multicast(2, 3); + Publisher multicast = applyOperator(source, 2, 3); toSource(multicast).subscribe(subscriber1); toSource(multicast).subscribe(subscriber2); @@ -770,7 +792,7 @@ void replenishRequestNInMaxQueueIncrementsRange() throws Exception { return list; }).toFuture().get(); - Publisher multi = original.multicast(2, 5); + Publisher multi = applyOperator(original, 2, 5); List first = new ArrayList<>(); List second = new ArrayList<>(); multi.forEach(first::add); @@ -783,7 +805,7 @@ void replenishRequestNInMaxQueueIncrementsRange() throws Exception { @Test void concurrentRequestN() throws InterruptedException { final int expectedSubscribers = 50; - Publisher multicast = source.multicast(expectedSubscribers); + Publisher multicast = applyOperator(source, expectedSubscribers); @SuppressWarnings("unchecked") TestPublisherSubscriber[] subscribers = (TestPublisherSubscriber[]) new TestPublisherSubscriber[expectedSubscribers]; @@ -821,7 +843,7 @@ void concurrentRequestN() throws InterruptedException { @Test void concurrentRequestNAndOnNext() throws BrokenBarrierException, InterruptedException { final int expectedSubscribers = 400; - Publisher multicast = source.multicast(expectedSubscribers); + Publisher multicast = applyOperator(source, expectedSubscribers); @SuppressWarnings("unchecked") TestPublisherSubscriber[] subscribers = (TestPublisherSubscriber[]) new TestPublisherSubscriber[expectedSubscribers]; @@ -881,7 +903,7 @@ void concurrentRequestNAndOnNext() throws BrokenBarrierException, InterruptedExc @MethodSource("trueFalseStream") void threeConcurrentLateSubscriber(boolean cancelEarlySub, boolean cancelUpstream) throws Exception { final int expectedSignals = 1000; - Publisher publisher = source.multicast(2, expectedSignals, cancelUpstream); + Publisher publisher = applyOperator(source, 2, expectedSignals, cancelUpstream); toSource(publisher).subscribe(subscriber1); Subscription subscription1 = subscriber1.awaitSubscription(); subscription1.request(expectedSignals); @@ -943,7 +965,7 @@ void threeConcurrentLateSubscriber(boolean cancelEarlySub, boolean cancelUpstrea @Test void twoConcurrentSubscriptions() throws Exception { final int expectedSignals = 1000; - Publisher publisher = source.multicast(2, expectedSignals); + Publisher publisher = applyOperator(source, 2, expectedSignals); toSource(publisher).subscribe(subscriber1); Subscription subscription1 = subscriber1.awaitSubscription(); toSource(publisher).subscribe(subscriber2); diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ReplayPublisherTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ReplayPublisherTest.java new file mode 100644 index 0000000000..235b6bac3d --- /dev/null +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ReplayPublisherTest.java @@ -0,0 +1,309 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.api; + +import io.servicetalk.concurrent.test.internal.TestPublisherSubscriber; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.time.Duration; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.Future; +import java.util.function.Consumer; +import java.util.function.Function; +import javax.annotation.Nullable; + +import static io.servicetalk.concurrent.api.SourceAdapters.toSource; +import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; +import static java.time.Duration.ofMillis; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + +final class ReplayPublisherTest extends MulticastPublisherTest { + private final TestPublisherSubscriber subscriber4 = new TestPublisherSubscriber<>(); + private final TestExecutor executor = new TestExecutor(); + + @AfterEach + void tearDown() throws Exception { + executor.closeAsync().toFuture().get(); + } + + @Override + Publisher applyOperator(Publisher source, int minSubscribers) { + return source.replay(new ReplayStrategyBuilder(EmptyReplayAccumulator::emptyAccumulator) + .minSubscribers(minSubscribers).build()); + } + + @Override + Publisher applyOperator(Publisher source, int minSubscribers, boolean cancelUpstream) { + return source.replay(new ReplayStrategyBuilder(EmptyReplayAccumulator::emptyAccumulator) + .cancelUpstream(cancelUpstream) + .minSubscribers(minSubscribers).build()); + } + + @Override + Publisher applyOperator(Publisher source, int minSubscribers, int queueLimit, + Function terminalResubscribe) { + return source.replay(new ReplayStrategyBuilder(EmptyReplayAccumulator::emptyAccumulator) + .queueLimitHint(queueLimit) + .terminalResubscribe(terminalResubscribe) + .minSubscribers(minSubscribers).build()); + } + + @Override + Publisher applyOperator(Publisher source, int minSubscribers, int queueLimit) { + return source.replay(new ReplayStrategyBuilder(EmptyReplayAccumulator::emptyAccumulator) + .queueLimitHint(queueLimit) + .minSubscribers(minSubscribers).build()); + } + + @Override + Publisher applyOperator(Publisher source, int minSubscribers, int queueLimit, boolean cancelUpstream) { + return source.replay(new ReplayStrategyBuilder(EmptyReplayAccumulator::emptyAccumulator) + .queueLimitHint(queueLimit) + .cancelUpstream(cancelUpstream) + .minSubscribers(minSubscribers).build()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void twoSubscribersHistory(boolean onError) { + Publisher publisher = source.replay(2); + toSource(publisher).subscribe(subscriber1); + subscriber1.awaitSubscription().request(4); + assertThat(subscription.requested(), is(4L)); + source.onNext(1, 2, null); + assertThat(subscriber1.takeOnNext(3), contains(1, 2, null)); + + toSource(publisher).subscribe(subscriber2); + subscriber2.awaitSubscription().request(4); + assertThat(subscription.requested(), is(4L)); + + assertThat(subscriber2.takeOnNext(2), contains(2, null)); + + source.onNext(4); + assertThat(subscriber1.takeOnNext(), is(4)); + assertThat(subscriber2.takeOnNext(), is(4)); + + twoSubscribersTerminate(onError); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void subscribeAfterTerminalDeliversHistory(boolean onError) { + Publisher publisher = source.replay(2); + toSource(publisher).subscribe(subscriber1); + subscriber1.awaitSubscription().request(4); + assertThat(subscription.requested(), is(4L)); + source.onNext(1, 2, 3); + assertThat(subscriber1.takeOnNext(3), contains(1, 2, 3)); + if (onError) { + source.onError(DELIBERATE_EXCEPTION); + assertThat(subscriber1.awaitOnError(), is(DELIBERATE_EXCEPTION)); + } else { + source.onComplete(); + subscriber1.awaitOnComplete(); + } + + toSource(publisher).subscribe(subscriber2); + subscriber2.awaitSubscription().request(4); + assertThat(subscriber2.takeOnNext(2), contains(2, 3)); + if (onError) { + assertThat(subscriber2.awaitOnError(), is(DELIBERATE_EXCEPTION)); + } else { + subscriber2.awaitOnComplete(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void threeSubscribersSum(boolean onError) { + Publisher publisher = source.replay(SumReplayAccumulator::new); + toSource(publisher).subscribe(subscriber1); + subscriber1.awaitSubscription().request(4); + assertThat(subscription.requested(), is(4L)); + source.onNext(1, 2, 3); + assertThat(subscriber1.takeOnNext(3), contains(1, 2, 3)); + + toSource(publisher).subscribe(subscriber2); + subscriber2.awaitSubscription().request(4); + assertThat(subscription.requested(), is(4L)); + + assertThat(subscriber2.takeOnNext(), equalTo(6)); + + source.onNext(4); + assertThat(subscriber1.takeOnNext(), is(4)); + assertThat(subscriber2.takeOnNext(), is(4)); + + toSource(publisher).subscribe(subscriber3); + subscriber3.awaitSubscription().request(4); + assertThat(subscription.requested(), is(4L)); + assertThat(subscriber3.takeOnNext(), equalTo(10)); + + subscriber1.awaitSubscription().request(1); + assertThat(subscription.requested(), is(5L)); + source.onNext(5); + + assertThat(subscriber1.takeOnNext(), is(5)); + assertThat(subscriber2.takeOnNext(), is(5)); + assertThat(subscriber3.takeOnNext(), is(5)); + + threeSubscribersTerminate(onError); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void threeSubscribersTTL(boolean onError) { + final Duration ttl = ofMillis(2); + Publisher publisher = source.replay(2, ttl, executor); + toSource(publisher).subscribe(subscriber1); + subscriber1.awaitSubscription().request(4); + assertThat(subscription.requested(), is(4L)); + source.onNext(1, 2); + executor.advanceTimeBy(1, MILLISECONDS); + source.onNext((Integer) null); + assertThat(subscriber1.takeOnNext(3), contains(1, 2, null)); + + toSource(publisher).subscribe(subscriber2); + subscriber2.awaitSubscription().request(4); + assertThat(subscriber2.takeOnNext(2), contains(2, null)); + + executor.advanceTimeBy(1, MILLISECONDS); + toSource(publisher).subscribe(subscriber3); + subscriber3.awaitSubscription().request(4); + assertThat(subscriber3.takeOnNext(), equalTo(null)); + + source.onNext(4); + assertThat(subscriber1.takeOnNext(), equalTo(4)); + assertThat(subscriber2.takeOnNext(), equalTo(4)); + assertThat(subscriber3.takeOnNext(), equalTo(4)); + + subscriber1.awaitSubscription().request(10); + subscriber2.awaitSubscription().request(10); + subscriber3.awaitSubscription().request(10); + executor.advanceTimeBy(ttl.toMillis(), MILLISECONDS); + toSource(publisher).subscribe(subscriber4); + subscriber4.awaitSubscription().request(4); + assertThat(subscriber4.pollOnNext(10, MILLISECONDS), nullValue()); + + threeSubscribersTerminate(onError); + } + + @ParameterizedTest(name = "{displayName} [{index}] expectedSubscribers={0} expectedSum={1}") + @CsvSource(value = {"500,500", "50,50", "50,500", "500,50"}) + void concurrentSubscribes(final int expectedSubscribers, final long expectedSum) throws Exception { + Publisher replay = source.replay(SumReplayAccumulator::new); + CyclicBarrier startBarrier = new CyclicBarrier(expectedSubscribers + 1); + Completable[] completables = new Completable[expectedSubscribers]; + @SuppressWarnings("unchecked") + TestPublisherSubscriber[] subscribers = (TestPublisherSubscriber[]) + new TestPublisherSubscriber[expectedSubscribers]; + Executor executor = Executors.newCachedThreadExecutor(); + try { + for (int i = 0; i < subscribers.length; ++i) { + final TestPublisherSubscriber currSubscriber = new TestPublisherSubscriber<>(); + subscribers[i] = currSubscriber; + completables[i] = executor.submit(() -> { + try { + startBarrier.await(); + toSource(replay).subscribe(currSubscriber); + currSubscriber.awaitSubscription().request(expectedSum); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + Future future = Completable.mergeAll(completables.length, completables).toFuture(); + startBarrier.await(); + for (int i = 0; i < expectedSum; ++i) { + subscription.awaitRequestN(i + 1); + source.onNext(1); + } + + future.get(); + source.onComplete(); // deliver terminal after all requests have been delivered. + + for (final TestPublisherSubscriber currSubscriber : subscribers) { + int numOnNext = 0; + long currSum = 0; + while (currSum < expectedSum) { + Integer next = currSubscriber.takeOnNext(); + ++numOnNext; + if (next != null) { + currSum += next; + } + } + try { + assertThat(currSum, equalTo(expectedSum)); + currSubscriber.awaitOnComplete(); + } catch (Throwable cause) { + throw new AssertionError("failure numOnNext=" + numOnNext, cause); + } + } + + subscription.awaitRequestN(expectedSum); + assertThat(subscription.isCancelled(), is(false)); + } finally { + executor.closeAsync().toFuture().get(); + } + } + + private static final class EmptyReplayAccumulator implements ReplayAccumulator { + static final ReplayAccumulator INSTANCE = new EmptyReplayAccumulator<>(); + + private EmptyReplayAccumulator() { + } + + @SuppressWarnings("unchecked") + static ReplayAccumulator emptyAccumulator() { + return (ReplayAccumulator) INSTANCE; + } + + @Override + public void accumulate(@Nullable final T t) { + } + + @Override + public void deliverAccumulation(final Consumer consumer) { + } + } + + private static final class SumReplayAccumulator implements ReplayAccumulator { + private int sum; + + @Override + public void accumulate(@Nullable final Integer integer) { + if (integer != null) { + sum += integer; + } + } + + @Override + public void deliverAccumulation(final Consumer consumer) { + if (sum != 0) { + consumer.accept(sum); + } + } + } +} diff --git a/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherReplayTckTest.java b/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherReplayTckTest.java new file mode 100644 index 0000000000..29a71f807b --- /dev/null +++ b/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherReplayTckTest.java @@ -0,0 +1,28 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.reactivestreams.tck; + +import io.servicetalk.concurrent.api.Publisher; + +import org.testng.annotations.Test; + +@Test +public class PublisherReplayTckTest extends AbstractPublisherOperatorTckTest { + @Override + protected Publisher composePublisher(Publisher publisher, int elements) { + return publisher.replay(1); + } +}