diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java index f386d48c8..3f76af36e 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java @@ -20,17 +20,12 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import com.google.common.annotations.VisibleForTesting; -import io.reactivex.Flowable; -import io.reactivex.Scheduler; -import io.reactivex.schedulers.Schedulers; import lombok.AccessLevel; import lombok.Getter; import lombok.NonNull; @@ -44,7 +39,6 @@ import software.amazon.kinesis.metrics.MetricsCollectingTaskDecorator; import software.amazon.kinesis.metrics.MetricsFactory; import software.amazon.kinesis.retrieval.RecordsPublisher; -import software.amazon.kinesis.retrieval.RetryableRetrievalException; /** * Responsible for consuming data records of a (specified) shard. @@ -60,7 +54,6 @@ public class ShardConsumer { public static final int MAX_TIME_BETWEEN_REQUEST_RESPONSE = 35000; private final RecordsPublisher recordsPublisher; private final ExecutorService executorService; - private final Scheduler scheduler; private final ShardInfo shardInfo; private final ShardConsumerArgument shardConsumerArgument; @NonNull @@ -72,9 +65,6 @@ public class ShardConsumer { private ConsumerTask currentTask; private TaskOutcome taskOutcome; - private final AtomicReference processFailure = new AtomicReference<>(null); - private final AtomicReference dispatchFailure = new AtomicReference<>(null); - private CompletableFuture stateChangeFuture; private boolean needsInitialization = true; @@ -94,7 +84,7 @@ public class ShardConsumer { private volatile ShutdownReason shutdownReason; private volatile ShutdownNotification shutdownNotification; - private final InternalSubscriber subscriber; + private final ShardConsumerSubscriber subscriber; public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executorService, ShardInfo shardInfo, Optional logWarningForTaskAfterMillis, ShardConsumerArgument shardConsumerArgument, @@ -119,8 +109,7 @@ public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executor this.taskExecutionListener = taskExecutionListener; this.currentState = initialState; this.taskMetricsDecorator = taskMetricsDecorator; - scheduler = Schedulers.from(executorService); - subscriber = new InternalSubscriber(); + subscriber = new ShardConsumerSubscriber(recordsPublisher, executorService, bufferSize, this); this.bufferSize = bufferSize; if (this.shardInfo.isCompleted()) { @@ -128,64 +117,8 @@ public ShardConsumer(RecordsPublisher recordsPublisher, ExecutorService executor } } - private void startSubscriptions() { - Flowable.fromPublisher(recordsPublisher).subscribeOn(scheduler).observeOn(scheduler, true, bufferSize) - .subscribe(subscriber); - } - - private final Object lockObject = new Object(); - private Instant lastRequestTime = null; - - private class InternalSubscriber implements Subscriber { - - private Subscription subscription; - private volatile Instant lastDataArrival; - - @Override - public void onSubscribe(Subscription s) { - subscription = s; - subscription.request(1); - } - - @Override - public void onNext(ProcessRecordsInput input) { - try { - synchronized (lockObject) { - lastRequestTime = null; - } - lastDataArrival = Instant.now(); - handleInput(input.toBuilder().cacheExitTime(Instant.now()).build(), subscription); - } catch (Throwable t) { - log.warn("{}: Caught exception from handleInput", shardInfo.shardId(), t); - dispatchFailure.set(t); - } finally { - subscription.request(1); - synchronized (lockObject) { - lastRequestTime = Instant.now(); - } - } - } - - @Override - public void onError(Throwable t) { - log.warn("{}: onError(). Cancelling subscription, and marking self as failed.", shardInfo.shardId(), t); - subscription.cancel(); - processFailure.set(t); - } - - @Override - public void onComplete() { - log.debug("{}: onComplete(): Received onComplete. Activity should be triggered externally", shardInfo.shardId()); - } - public void cancel() { - if (subscription != null) { - subscription.cancel(); - } - } - } - - private synchronized void handleInput(ProcessRecordsInput input, Subscription subscription) { + synchronized void handleInput(ProcessRecordsInput input, Subscription subscription) { if (isShutdownRequested()) { subscription.cancel(); return; @@ -240,50 +173,15 @@ public void executeLifecycle() { Throwable healthCheck() { logNoDataRetrievedAfterTime(); logLongRunningTask(); - Throwable failure = processFailure.get(); - if (!processFailure.compareAndSet(failure, null) && failure != null) { - log.error("{}: processFailure was updated while resetting, this shouldn't happen. " + - "Will retry on next health check", shardInfo.shardId()); - return null; - } + Throwable failure = subscriber.healthCheck(MAX_TIME_BETWEEN_REQUEST_RESPONSE); + if (failure != null) { - String logMessage = String.format("%s: Failure occurred in retrieval. Restarting data requests", shardInfo.shardId()); - if (failure instanceof RetryableRetrievalException) { - log.debug(logMessage, failure.getCause()); - } else { - log.warn(logMessage, failure); - } - startSubscriptions(); return failure; } - Throwable expectedDispatchFailure = dispatchFailure.get(); - if (expectedDispatchFailure != null) { - if (!dispatchFailure.compareAndSet(expectedDispatchFailure, null)) { - log.info("{}: Unable to reset the dispatch failure, this can happen if the record processor is failing aggressively.", shardInfo.shardId()); - return null; - } - log.warn("Exception occurred while dispatching incoming data. The incoming data has been skipped", expectedDispatchFailure); - return expectedDispatchFailure; - } - synchronized (lockObject) { - if (lastRequestTime != null) { - Instant now = Instant.now(); - Duration timeSinceLastResponse = Duration.between(lastRequestTime, now); - if (timeSinceLastResponse.toMillis() > MAX_TIME_BETWEEN_REQUEST_RESPONSE) { - log.error( - "{}: Last request was dispatched at {}, but no response as of {} ({}). Cancelling subscription, and restarting.", - shardInfo.shardId(), lastRequestTime, now, timeSinceLastResponse); - if (subscriber != null) { - subscriber.cancel(); - } - // - // Set the last request time to now, we specifically don't null it out since we want it to trigger a - // restart if the subscription still doesn't start producing. - // - lastRequestTime = Instant.now(); - startSubscriptions(); - } - } + Throwable dispatchFailure = subscriber.getAndResetDispatchFailure(); + if (dispatchFailure != null) { + log.warn("Exception occurred while dispatching incoming data. The incoming data has been skipped", dispatchFailure); + return dispatchFailure; } return null; @@ -306,10 +204,10 @@ String longRunningTaskMessage(Duration taken) { private void logNoDataRetrievedAfterTime() { logWarningForTaskAfterMillis.ifPresent(value -> { - Instant lastDataArrival = subscriber.lastDataArrival; + Instant lastDataArrival = subscriber.lastDataArrival(); if (lastDataArrival != null) { Instant now = Instant.now(); - Duration timeSince = Duration.between(subscriber.lastDataArrival, now); + Duration timeSince = Duration.between(subscriber.lastDataArrival(), now); if (timeSince.toMillis() > value) { log.warn("Last time data arrived: {} ({})", lastDataArrival, timeSince); } @@ -335,7 +233,7 @@ private void logLongRunningTask() { @VisibleForTesting void subscribe() { - startSubscriptions(); + subscriber.startSubscriptions(); } @VisibleForTesting diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriber.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriber.java new file mode 100644 index 000000000..bd89d2564 --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriber.java @@ -0,0 +1,183 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle; + +import com.google.common.annotations.VisibleForTesting; +import io.reactivex.Flowable; +import io.reactivex.Scheduler; +import io.reactivex.schedulers.Schedulers; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.experimental.Accessors; +import lombok.extern.slf4j.Slf4j; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.kinesis.retrieval.RecordsPublisher; +import software.amazon.kinesis.retrieval.RecordsRetrieved; +import software.amazon.kinesis.retrieval.RetryableRetrievalException; + +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.ExecutorService; + +@Slf4j +@Accessors(fluent = true) +class ShardConsumerSubscriber implements Subscriber { + + private final RecordsPublisher recordsPublisher; + private final Scheduler scheduler; + private final int bufferSize; + private final ShardConsumer shardConsumer; + + @VisibleForTesting + final Object lockObject = new Object(); + private Instant lastRequestTime = null; + private RecordsRetrieved lastAccepted = null; + + private Subscription subscription; + @Getter + private volatile Instant lastDataArrival; + @Getter + private volatile Throwable dispatchFailure; + @Getter(AccessLevel.PACKAGE) + private volatile Throwable retrievalFailure; + + ShardConsumerSubscriber(RecordsPublisher recordsPublisher, ExecutorService executorService, int bufferSize, + ShardConsumer shardConsumer) { + this.recordsPublisher = recordsPublisher; + this.scheduler = Schedulers.from(executorService); + this.bufferSize = bufferSize; + this.shardConsumer = shardConsumer; + } + + void startSubscriptions() { + synchronized (lockObject) { + if (lastAccepted != null) { + recordsPublisher.restartFrom(lastAccepted); + } + Flowable.fromPublisher(recordsPublisher).subscribeOn(scheduler).observeOn(scheduler, true, bufferSize) + .subscribe(this); + } + } + + Throwable healthCheck(long maxTimeBetweenRequests) { + Throwable result = restartIfFailed(); + if (result == null) { + restartIfRequestTimerExpired(maxTimeBetweenRequests); + } + return result; + } + + Throwable getAndResetDispatchFailure() { + synchronized (lockObject) { + Throwable failure = dispatchFailure; + dispatchFailure = null; + return failure; + } + } + + private Throwable restartIfFailed() { + Throwable oldFailure = null; + if (retrievalFailure != null) { + synchronized (lockObject) { + String logMessage = String.format("%s: Failure occurred in retrieval. Restarting data requests", shardConsumer.shardInfo().shardId()); + if (retrievalFailure instanceof RetryableRetrievalException) { + log.debug(logMessage, retrievalFailure.getCause()); + } else { + log.warn(logMessage, retrievalFailure); + } + oldFailure = retrievalFailure; + retrievalFailure = null; + } + startSubscriptions(); + } + + return oldFailure; + } + + private void restartIfRequestTimerExpired(long maxTimeBetweenRequests) { + synchronized (lockObject) { + if (lastRequestTime != null) { + Instant now = Instant.now(); + Duration timeSinceLastResponse = Duration.between(lastRequestTime, now); + if (timeSinceLastResponse.toMillis() > maxTimeBetweenRequests) { + log.error( + "{}: Last request was dispatched at {}, but no response as of {} ({}). Cancelling subscription, and restarting.", + shardConsumer.shardInfo().shardId(), lastRequestTime, now, timeSinceLastResponse); + cancel(); + // + // Set the last request time to now, we specifically don't null it out since we want it to + // trigger a + // restart if the subscription still doesn't start producing. + // + lastRequestTime = Instant.now(); + startSubscriptions(); + } + } + } + } + + @Override + public void onSubscribe(Subscription s) { + subscription = s; + subscription.request(1); + } + + @Override + public void onNext(RecordsRetrieved input) { + try { + synchronized (lockObject) { + lastRequestTime = null; + } + lastDataArrival = Instant.now(); + shardConsumer.handleInput(input.processRecordsInput().toBuilder().cacheExitTime(Instant.now()).build(), + subscription); + + } catch (Throwable t) { + log.warn("{}: Caught exception from handleInput", shardConsumer.shardInfo().shardId(), t); + synchronized (lockObject) { + dispatchFailure = t; + } + } finally { + subscription.request(1); + synchronized (lockObject) { + lastAccepted = input; + lastRequestTime = Instant.now(); + } + } + } + + @Override + public void onError(Throwable t) { + synchronized (lockObject) { + log.warn("{}: onError(). Cancelling subscription, and marking self as failed.", + shardConsumer.shardInfo().shardId(), t); + subscription.cancel(); + retrievalFailure = t; + } + } + + @Override + public void onComplete() { + log.debug("{}: onComplete(): Received onComplete. Activity should be triggered externally", + shardConsumer.shardInfo().shardId()); + } + + public void cancel() { + if (subscription != null) { + subscription.cancel(); + } + } +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RecordsPublisher.java index 87e881a47..98c2e77c8 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RecordsPublisher.java @@ -24,7 +24,7 @@ /** * Provides a record publisher that will retrieve records from Kinesis for processing */ -public interface RecordsPublisher extends Publisher { +public interface RecordsPublisher extends Publisher { /** * Initializes the publisher with where to start processing. If there is a stored sequence number the publisher will * begin from that sequence number, otherwise it will use the initial position. @@ -35,6 +35,12 @@ public interface RecordsPublisher extends Publisher { * if there is no sequence number the initial position to use */ void start(ExtendedSequenceNumber extendedSequenceNumber, InitialPositionInStreamExtended initialPositionInStreamExtended); + + /** + * Restart from the last accepted and processed + * @param recordsRetrieved the processRecordsInput to restart from + */ + void restartFrom(RecordsRetrieved recordsRetrieved); /** diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RecordsRetrieved.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RecordsRetrieved.java new file mode 100644 index 000000000..d58336e94 --- /dev/null +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/RecordsRetrieved.java @@ -0,0 +1,27 @@ +/* + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.retrieval; + +import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; + +public interface RecordsRetrieved { + + /** + * Retrieves the records that have been received via one of the publishers + * + * @return the processRecordsInput received + */ + ProcessRecordsInput processRecordsInput(); +} diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java index c199eeca1..638becf1d 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisher.java @@ -24,8 +24,10 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import lombok.Data; import lombok.NonNull; import lombok.RequiredArgsConstructor; +import lombok.experimental.Accessors; import lombok.extern.slf4j.Slf4j; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; @@ -42,6 +44,7 @@ import software.amazon.kinesis.retrieval.IteratorBuilder; import software.amazon.kinesis.retrieval.KinesisClientRecord; import software.amazon.kinesis.retrieval.RecordsPublisher; +import software.amazon.kinesis.retrieval.RecordsRetrieved; import software.amazon.kinesis.retrieval.RetryableRetrievalException; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; @@ -67,7 +70,7 @@ public class FanOutRecordsPublisher implements RecordsPublisher { private InitialPositionInStreamExtended initialPositionInStreamExtended; private boolean isFirstConnection = true; - private Subscriber subscriber; + private Subscriber subscriber; private long availableQueueSpace = 0; @Override @@ -93,6 +96,24 @@ public void shutdown() { } } + @Override + public void restartFrom(RecordsRetrieved recordsRetrieved) { + synchronized (lockObject) { + if (flow != null) { + // + // The flow should not be running at this time + // + flow.cancel(); + } + flow = null; + if (!(recordsRetrieved instanceof FanoutRecordsRetrieved)) { + throw new IllegalArgumentException( + "Provided ProcessRecordsInput not created from the FanOutRecordsPublisher"); + } + currentSequenceNumber = ((FanoutRecordsRetrieved) recordsRetrieved).continuationSequenceNumber(); + } + } + private boolean hasValidSubscriber() { return subscriber != null; } @@ -174,8 +195,10 @@ private void handleFlowError(Throwable t) { log.debug( "{}: Could not call SubscribeToShard successfully because shard no longer exists. Marking shard for completion.", shardId); + FanoutRecordsRetrieved response = new FanoutRecordsRetrieved( + ProcessRecordsInput.builder().records(Collections.emptyList()).isAtShardEnd(true).build(), null); subscriber - .onNext(ProcessRecordsInput.builder().records(Collections.emptyList()).isAtShardEnd(true).build()); + .onNext(response); subscriber.onComplete(); } else { subscriber.onError(t); @@ -257,9 +280,10 @@ private void recordsReceived(RecordFlow triggeringFlow, SubscribeToShardEvent re ProcessRecordsInput input = ProcessRecordsInput.builder().cacheEntryTime(Instant.now()) .millisBehindLatest(recordBatchEvent.millisBehindLatest()) .isAtShardEnd(recordBatchEvent.continuationSequenceNumber() == null).records(records).build(); + FanoutRecordsRetrieved recordsRetrieved = new FanoutRecordsRetrieved(input, recordBatchEvent.continuationSequenceNumber()); try { - subscriber.onNext(input); + subscriber.onNext(recordsRetrieved); // // Only advance the currentSequenceNumber if we successfully dispatch the last received input // @@ -311,7 +335,7 @@ private void onComplete(RecordFlow triggeringFlow) { } @Override - public void subscribe(Subscriber s) { + public void subscribe(Subscriber s) { synchronized (lockObject) { if (subscriber != null) { log.error( @@ -444,6 +468,19 @@ public void onComplete() { }); } + @Accessors(fluent = true) + @Data + static class FanoutRecordsRetrieved implements RecordsRetrieved { + + private final ProcessRecordsInput processRecordsInput; + private final String continuationSequenceNumber; + + @Override + public ProcessRecordsInput processRecordsInput() { + return processRecordsInput; + } + } + @RequiredArgsConstructor @Slf4j static class RecordFlow implements SubscribeToShardResponseHandler { diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/BlockingRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/BlockingRecordsPublisher.java index 8fd68b804..834cf9c3a 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/BlockingRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/BlockingRecordsPublisher.java @@ -27,6 +27,7 @@ import software.amazon.kinesis.retrieval.GetRecordsRetrievalStrategy; import software.amazon.kinesis.retrieval.KinesisClientRecord; import software.amazon.kinesis.retrieval.RecordsPublisher; +import software.amazon.kinesis.retrieval.RecordsRetrieved; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; /** @@ -38,7 +39,7 @@ public class BlockingRecordsPublisher implements RecordsPublisher { private final int maxRecordsPerCall; private final GetRecordsRetrievalStrategy getRecordsRetrievalStrategy; - private Subscriber subscriber; + private Subscriber subscriber; public BlockingRecordsPublisher(final int maxRecordsPerCall, final GetRecordsRetrievalStrategy getRecordsRetrievalStrategy) { @@ -70,7 +71,12 @@ public void shutdown() { } @Override - public void subscribe(Subscriber s) { + public void subscribe(Subscriber s) { subscriber = s; } + + @Override + public void restartFrom(RecordsRetrieved recordsRetrieved) { + throw new UnsupportedOperationException(); + } } diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/KinesisDataFetcher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/KinesisDataFetcher.java index 06843cc67..0fbb06b4e 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/KinesisDataFetcher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/KinesisDataFetcher.java @@ -227,6 +227,12 @@ public void restartIterator() { advanceIteratorTo(lastKnownSequenceNumber, initialPositionInStream); } + public void resetIterator(String shardIterator, String sequenceNumber, InitialPositionInStreamExtended initialPositionInStream) { + this.nextIterator = shardIterator; + this.lastKnownSequenceNumber = sequenceNumber; + this.initialPositionInStream = initialPositionInStream; + } + private GetRecordsResponse getRecords(@NonNull final String nextIterator) { final AWSExceptionManager exceptionManager = createExceptionManager(); GetRecordsRequest request = KinesisRequestsBuilder.getRecordsRequestBuilder().shardIterator(nextIterator) diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java index 15a564dfc..8e9a56f2d 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java @@ -20,14 +20,19 @@ import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.stream.Collectors; +import com.google.common.annotations.VisibleForTesting; import org.apache.commons.lang3.Validate; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import lombok.Data; import lombok.NonNull; +import lombok.experimental.Accessors; import lombok.extern.slf4j.Slf4j; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.cloudwatch.model.StandardUnit; @@ -44,6 +49,7 @@ import software.amazon.kinesis.retrieval.GetRecordsRetrievalStrategy; import software.amazon.kinesis.retrieval.KinesisClientRecord; import software.amazon.kinesis.retrieval.RecordsPublisher; +import software.amazon.kinesis.retrieval.RecordsRetrieved; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; /** @@ -58,7 +64,8 @@ @KinesisClientInternalApi public class PrefetchRecordsPublisher implements RecordsPublisher { private static final String EXPIRED_ITERATOR_METRIC = "ExpiredIterator"; - LinkedBlockingQueue getRecordsResultQueue; + @VisibleForTesting + LinkedBlockingQueue getRecordsResultQueue; private int maxPendingProcessRecordsInput; private int maxByteSize; private int maxRecordsCount; @@ -75,9 +82,15 @@ public class PrefetchRecordsPublisher implements RecordsPublisher { private final KinesisDataFetcher dataFetcher; private final String shardId; - private Subscriber subscriber; + private Subscriber subscriber; private final AtomicLong requestedResponses = new AtomicLong(0); + private String highestSequenceNumber; + private InitialPositionInStreamExtended initialPositionInStreamExtended; + + private final ReentrantReadWriteLock resetLock = new ReentrantReadWriteLock(); + private boolean wasReset = false; + /** * Constructor for the PrefetchRecordsPublisher. This cache prefetches records from Kinesis and stores them in a * LinkedBlockingQueue. @@ -124,6 +137,8 @@ public void start(ExtendedSequenceNumber extendedSequenceNumber, InitialPosition throw new IllegalStateException("ExecutorService has been shutdown."); } + this.initialPositionInStreamExtended = initialPositionInStreamExtended; + highestSequenceNumber = extendedSequenceNumber.sequenceNumber(); dataFetcher.initialize(extendedSequenceNumber, initialPositionInStreamExtended); if (!started) { @@ -133,7 +148,7 @@ public void start(ExtendedSequenceNumber extendedSequenceNumber, InitialPosition started = true; } - ProcessRecordsInput getNextResult() { + RecordsRetrieved getNextResult() { if (executorService.isShutdown()) { throw new IllegalStateException("Shutdown has been called on the cache, can't accept new requests."); } @@ -141,14 +156,16 @@ ProcessRecordsInput getNextResult() { if (!started) { throw new IllegalStateException("Cache has not been initialized, make sure to call start."); } - ProcessRecordsInput result = null; + PrefetchRecordsRetrieved result = null; try { - result = getRecordsResultQueue.take().toBuilder().cacheExitTime(Instant.now()).build(); - prefetchCounters.removed(result); + result = getRecordsResultQueue.take().prepareForPublish(); + prefetchCounters.removed(result.processRecordsInput); requestedResponses.decrementAndGet(); + } catch (InterruptedException e) { log.error("Interrupted while getting records from the cache", e); } + return result; } @@ -160,7 +177,28 @@ public void shutdown() { } @Override - public void subscribe(Subscriber s) { + public void restartFrom(RecordsRetrieved recordsRetrieved) { + if (!(recordsRetrieved instanceof PrefetchRecordsRetrieved)) { + throw new IllegalArgumentException( + "Provided RecordsRetrieved was not produced by the PrefetchRecordsPublisher"); + } + PrefetchRecordsRetrieved prefetchRecordsRetrieved = (PrefetchRecordsRetrieved) recordsRetrieved; + resetLock.writeLock().lock(); + try { + getRecordsResultQueue.clear(); + prefetchCounters.reset(); + + highestSequenceNumber = prefetchRecordsRetrieved.lastBatchSequenceNumber(); + dataFetcher.resetIterator(prefetchRecordsRetrieved.shardIterator(), highestSequenceNumber, + initialPositionInStreamExtended); + wasReset = true; + } finally { + resetLock.writeLock().unlock(); + } + } + + @Override + public void subscribe(Subscriber s) { subscriber = s; subscriber.onSubscribe(new Subscription() { @Override @@ -176,9 +214,22 @@ public void cancel() { }); } - private void addArrivedRecordsInput(ProcessRecordsInput processRecordsInput) throws InterruptedException { - getRecordsResultQueue.put(processRecordsInput); - prefetchCounters.added(processRecordsInput); + private void addArrivedRecordsInput(PrefetchRecordsRetrieved recordsRetrieved) throws InterruptedException { + wasReset = false; + while (!getRecordsResultQueue.offer(recordsRetrieved, idleMillisBetweenCalls, TimeUnit.MILLISECONDS)) { + // + // Unlocking the read lock, and then reacquiring the read lock, should allow any waiters on the write lock a + // chance to run. If the write lock is acquired by restartFrom than the readLock will now block until + // restartFrom(...) has completed. This is to ensure that if a reset has occurred we know to discard the + // data we received and start a new fetch of data. + // + resetLock.readLock().unlock(); + resetLock.readLock().lock(); + if (wasReset) { + throw new PositionResetException(); + } + } + prefetchCounters.added(recordsRetrieved.processRecordsInput); } private synchronized void drainQueueForRequests() { @@ -187,6 +238,34 @@ private synchronized void drainQueueForRequests() { } } + @Accessors(fluent = true) + @Data + static class PrefetchRecordsRetrieved implements RecordsRetrieved { + + final ProcessRecordsInput processRecordsInput; + final String lastBatchSequenceNumber; + final String shardIterator; + + PrefetchRecordsRetrieved prepareForPublish() { + return new PrefetchRecordsRetrieved(processRecordsInput.toBuilder().cacheExitTime(Instant.now()).build(), + lastBatchSequenceNumber, shardIterator); + } + + } + + private String calculateHighestSequenceNumber(ProcessRecordsInput processRecordsInput) { + String result = this.highestSequenceNumber; + if (processRecordsInput.records() != null && !processRecordsInput.records().isEmpty()) { + result = processRecordsInput.records().get(processRecordsInput.records().size() - 1).sequenceNumber(); + } + return result; + } + + private static class PositionResetException extends RuntimeException { + + } + + private class DefaultGetRecordsCacheDaemon implements Runnable { volatile boolean isShutdown = false; @@ -197,57 +276,78 @@ public void run() { log.warn("Prefetch thread was interrupted."); break; } - MetricsScope scope = MetricsUtil.createMetricsWithOperation(metricsFactory, operation); - if (prefetchCounters.shouldGetNewRecords()) { - try { - sleepBeforeNextCall(); - GetRecordsResponse getRecordsResult = getRecordsRetrievalStrategy.getRecords(maxRecordsPerCall); - lastSuccessfulCall = Instant.now(); - - final List records = getRecordsResult.records().stream() - .map(KinesisClientRecord::fromRecord).collect(Collectors.toList()); - ProcessRecordsInput processRecordsInput = ProcessRecordsInput.builder() - .records(records) - .millisBehindLatest(getRecordsResult.millisBehindLatest()) - .cacheEntryTime(lastSuccessfulCall) - .isAtShardEnd(getRecordsRetrievalStrategy.getDataFetcher().isShardEndReached()) - .build(); - addArrivedRecordsInput(processRecordsInput); - drainQueueForRequests(); - } catch (InterruptedException e) { - log.info("Thread was interrupted, indicating shutdown was called on the cache."); - } catch (ExpiredIteratorException e) { - log.info("ShardId {}: records threw ExpiredIteratorException - restarting" - + " after greatest seqNum passed to customer", shardId, e); - - scope.addData(EXPIRED_ITERATOR_METRIC, 1, StandardUnit.COUNT, MetricsLevel.SUMMARY); - - dataFetcher.restartIterator(); - } catch (SdkClientException e) { - log.error("Exception thrown while fetching records from Kinesis", e); - } catch (Throwable e) { - log.error("Unexpected exception was thrown. This could probably be an issue or a bug." + - " Please search for the exception/error online to check what is going on. If the " + - "issue persists or is a recurring problem, feel free to open an issue on, " + - "https://github.com/awslabs/amazon-kinesis-client.", e); - } finally { - MetricsUtil.endScope(scope); - } - } else { - // - // Consumer isn't ready to receive new records will allow prefetch counters to pause - // - try { - prefetchCounters.waitForConsumer(); - } catch (InterruptedException ie) { - log.info("Thread was interrupted while waiting for the consumer. " + - "Shutdown has probably been started"); - } + + resetLock.readLock().lock(); + try { + makeRetrievalAttempt(); + } catch(PositionResetException pre) { + log.debug("Position was reset while attempting to add item to queue."); + } finally { + resetLock.readLock().unlock(); } + + } callShutdownOnStrategy(); } + private void makeRetrievalAttempt() { + MetricsScope scope = MetricsUtil.createMetricsWithOperation(metricsFactory, operation); + if (prefetchCounters.shouldGetNewRecords()) { + try { + sleepBeforeNextCall(); + GetRecordsResponse getRecordsResult = getRecordsRetrievalStrategy.getRecords(maxRecordsPerCall); + lastSuccessfulCall = Instant.now(); + + final List records = getRecordsResult.records().stream() + .map(KinesisClientRecord::fromRecord).collect(Collectors.toList()); + ProcessRecordsInput processRecordsInput = ProcessRecordsInput.builder() + .records(records) + .millisBehindLatest(getRecordsResult.millisBehindLatest()) + .cacheEntryTime(lastSuccessfulCall) + .isAtShardEnd(getRecordsRetrievalStrategy.getDataFetcher().isShardEndReached()) + .build(); + + highestSequenceNumber = calculateHighestSequenceNumber(processRecordsInput); + PrefetchRecordsRetrieved recordsRetrieved = new PrefetchRecordsRetrieved(processRecordsInput, + highestSequenceNumber, getRecordsResult.nextShardIterator()); + highestSequenceNumber = recordsRetrieved.lastBatchSequenceNumber; + addArrivedRecordsInput(recordsRetrieved); + drainQueueForRequests(); + } catch (PositionResetException pse) { + throw pse; + } catch (InterruptedException e) { + log.info("Thread was interrupted, indicating shutdown was called on the cache."); + } catch (ExpiredIteratorException e) { + log.info("ShardId {}: records threw ExpiredIteratorException - restarting" + + " after greatest seqNum passed to customer", shardId, e); + + scope.addData(EXPIRED_ITERATOR_METRIC, 1, StandardUnit.COUNT, MetricsLevel.SUMMARY); + + dataFetcher.restartIterator(); + } catch (SdkClientException e) { + log.error("Exception thrown while fetching records from Kinesis", e); + } catch (Throwable e) { + log.error("Unexpected exception was thrown. This could probably be an issue or a bug." + + " Please search for the exception/error online to check what is going on. If the " + + "issue persists or is a recurring problem, feel free to open an issue on, " + + "https://github.com/awslabs/amazon-kinesis-client.", e); + } finally { + MetricsUtil.endScope(scope); + } + } else { + // + // Consumer isn't ready to receive new records will allow prefetch counters to pause + // + try { + prefetchCounters.waitForConsumer(); + } catch (InterruptedException ie) { + log.info("Thread was interrupted while waiting for the consumer. " + + "Shutdown has probably been started"); + } + } + } + private void callShutdownOnStrategy() { if (!getRecordsRetrievalStrategy.isShutdown()) { getRecordsRetrievalStrategy.shutdown(); @@ -302,6 +402,11 @@ public synchronized boolean shouldGetNewRecords() { return size < maxRecordsCount && byteSize < maxByteSize; } + void reset() { + size = 0; + byteSize = 0; + } + @Override public String toString() { return String.format("{ Requests: %d, Records: %d, Bytes: %d }", getRecordsResultQueue.size(), size, diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriberTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriberTest.java new file mode 100644 index 000000000..21b004516 --- /dev/null +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerSubscriberTest.java @@ -0,0 +1,447 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package software.amazon.kinesis.lifecycle; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.junit.Assert.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.argThat; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static software.amazon.kinesis.utils.ProcessRecordsInputMatcher.eqProcessRecordsInput; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.Stream; + +import org.apache.commons.lang3.StringUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.runners.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import software.amazon.kinesis.common.InitialPositionInStreamExtended; +import software.amazon.kinesis.leases.ShardInfo; +import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; +import software.amazon.kinesis.retrieval.KinesisClientRecord; +import software.amazon.kinesis.retrieval.RecordsPublisher; +import software.amazon.kinesis.retrieval.RecordsRetrieved; +import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; + +@Slf4j +@RunWith(MockitoJUnitRunner.class) +public class ShardConsumerSubscriberTest { + + private final Object processedNotifier = new Object(); + + private static final String TERMINAL_MARKER = "Terminal"; + + @Mock + private ShardConsumer shardConsumer; + @Mock + private RecordsRetrieved recordsRetrieved; + + private ProcessRecordsInput processRecordsInput; + private TestPublisher recordsPublisher; + + private ExecutorService executorService; + private int bufferSize = 8; + + private ShardConsumerSubscriber subscriber; + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() { + executorService = Executors.newFixedThreadPool(8, new ThreadFactoryBuilder() + .setNameFormat("test-" + testName.getMethodName() + "-%04d").setDaemon(true).build()); + recordsPublisher = new TestPublisher(); + + ShardInfo shardInfo = new ShardInfo("shard-001", "", Collections.emptyList(), + ExtendedSequenceNumber.TRIM_HORIZON); + when(shardConsumer.shardInfo()).thenReturn(shardInfo); + + processRecordsInput = ProcessRecordsInput.builder().records(Collections.emptyList()) + .cacheEntryTime(Instant.now()).build(); + + subscriber = new ShardConsumerSubscriber(recordsPublisher, executorService, bufferSize, shardConsumer); + when(recordsRetrieved.processRecordsInput()).thenReturn(processRecordsInput); + } + + @After + public void after() { + executorService.shutdownNow(); + } + + @Test + public void singleItemTest() throws Exception { + addItemsToReturn(1); + + setupNotifierAnswer(1); + + synchronized (processedNotifier) { + subscriber.startSubscriptions(); + processedNotifier.wait(5000); + } + + verify(shardConsumer).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), any(Subscription.class)); + } + + @Test + public void multipleItemTest() throws Exception { + addItemsToReturn(100); + + setupNotifierAnswer(recordsPublisher.responses.size()); + + synchronized (processedNotifier) { + subscriber.startSubscriptions(); + processedNotifier.wait(5000); + } + + verify(shardConsumer, times(100)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), + any(Subscription.class)); + } + + @Test + public void consumerErrorSkipsEntryTest() throws Exception { + addItemsToReturn(20); + + Throwable testException = new Throwable("ShardConsumerError"); + + doAnswer(new Answer() { + int expectedInvocations = recordsPublisher.responses.size(); + + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + expectedInvocations--; + if (expectedInvocations == 10) { + throw testException; + } + if (expectedInvocations <= 0) { + synchronized (processedNotifier) { + processedNotifier.notifyAll(); + } + } + return null; + } + }).when(shardConsumer).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + + synchronized (processedNotifier) { + subscriber.startSubscriptions(); + processedNotifier.wait(5000); + } + + assertThat(subscriber.getAndResetDispatchFailure(), equalTo(testException)); + assertThat(subscriber.getAndResetDispatchFailure(), nullValue()); + + verify(shardConsumer, times(20)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), + any(Subscription.class)); + + } + + @Test + public void onErrorStopsProcessingTest() throws Exception { + Throwable expected = new Throwable("Wheee"); + addItemsToReturn(10); + recordsPublisher.add(new ResponseItem(expected)); + addItemsToReturn(10); + + setupNotifierAnswer(10); + + synchronized (processedNotifier) { + subscriber.startSubscriptions(); + processedNotifier.wait(5000); + } + + for (int attempts = 0; attempts < 10; attempts++) { + if (subscriber.retrievalFailure() != null) { + break; + } + Thread.sleep(10); + } + + verify(shardConsumer, times(10)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), + any(Subscription.class)); + assertThat(subscriber.retrievalFailure(), equalTo(expected)); + } + + @Test + public void restartAfterErrorTest() throws Exception { + Throwable expected = new Throwable("whee"); + addItemsToReturn(9); + RecordsRetrieved edgeRecord = mock(RecordsRetrieved.class); + when(edgeRecord.processRecordsInput()).thenReturn(processRecordsInput); + recordsPublisher.add(new ResponseItem(edgeRecord)); + recordsPublisher.add(new ResponseItem(expected)); + addItemsToReturn(10); + + setupNotifierAnswer(10); + + synchronized (processedNotifier) { + subscriber.startSubscriptions(); + processedNotifier.wait(5000); + } + + for (int attempts = 0; attempts < 10; attempts++) { + if (subscriber.retrievalFailure() != null) { + break; + } + Thread.sleep(100); + } + + setupNotifierAnswer(10); + + synchronized (processedNotifier) { + assertThat(subscriber.healthCheck(100000), equalTo(expected)); + processedNotifier.wait(5000); + } + + assertThat(recordsPublisher.restartedFrom, equalTo(edgeRecord)); + verify(shardConsumer, times(20)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), + any(Subscription.class)); + } + + @Test + public void restartAfterRequestTimerExpiresTest() throws Exception { + + executorService = Executors.newFixedThreadPool(1, new ThreadFactoryBuilder() + .setNameFormat("test-" + testName.getMethodName() + "-%04d").setDaemon(true).build()); + + subscriber = new ShardConsumerSubscriber(recordsPublisher, executorService, bufferSize, shardConsumer); + addUniqueItem(1); + addTerminalMarker(1); + + CyclicBarrier barrier = new CyclicBarrier(2); + + List received = new ArrayList<>(); + doAnswer(a -> { + ProcessRecordsInput input = a.getArgumentAt(0, ProcessRecordsInput.class); + received.add(input); + if (input.records().stream().anyMatch(r -> StringUtils.startsWith(r.partitionKey(), TERMINAL_MARKER))) { + synchronized (processedNotifier) { + processedNotifier.notifyAll(); + } + } + return null; + }).when(shardConsumer).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + + synchronized (processedNotifier) { + subscriber.startSubscriptions(); + processedNotifier.wait(5000); + } + + synchronized (processedNotifier) { + executorService.execute(() -> { + try { + // + // Notify the test as soon as we have started executing, then wait on the post add barrier. + // + synchronized (processedNotifier) { + processedNotifier.notifyAll(); + } + barrier.await(); + } catch (Exception e) { + log.error("Exception while blocking thread", e); + } + }); + // + // Wait for our blocking thread to control the thread in the executor. + // + processedNotifier.wait(5000); + } + + Stream.iterate(2, i -> i + 1).limit(97).forEach(this::addUniqueItem); + + addTerminalMarker(2); + + synchronized (processedNotifier) { + assertThat(subscriber.healthCheck(1), nullValue()); + barrier.await(500, TimeUnit.MILLISECONDS); + + processedNotifier.wait(5000); + } + + verify(shardConsumer, times(100)).handleInput(argThat(eqProcessRecordsInput(processRecordsInput)), + any(Subscription.class)); + + assertThat(received.size(), equalTo(recordsPublisher.responses.size())); + Stream.iterate(0, i -> i + 1).limit(received.size()).forEach(i -> assertThat(received.get(i), + eqProcessRecordsInput(recordsPublisher.responses.get(i).recordsRetrieved.processRecordsInput()))); + + } + + private void addUniqueItem(int id) { + RecordsRetrieved r = mock(RecordsRetrieved.class, "Record-" + id); + ProcessRecordsInput input = ProcessRecordsInput.builder().cacheEntryTime(Instant.now()) + .records(Collections.singletonList(KinesisClientRecord.builder().partitionKey("Record-" + id).build())) + .build(); + when(r.processRecordsInput()).thenReturn(input); + recordsPublisher.add(new ResponseItem(r)); + } + + private ProcessRecordsInput addTerminalMarker(int id) { + RecordsRetrieved terminalResponse = mock(RecordsRetrieved.class, TERMINAL_MARKER + "-" + id); + ProcessRecordsInput terminalInput = ProcessRecordsInput.builder() + .records(Collections + .singletonList(KinesisClientRecord.builder().partitionKey(TERMINAL_MARKER + "-" + id).build())) + .cacheEntryTime(Instant.now()).build(); + when(terminalResponse.processRecordsInput()).thenReturn(terminalInput); + recordsPublisher.add(new ResponseItem(terminalResponse)); + + return terminalInput; + } + + private void addItemsToReturn(int count) { + Stream.iterate(0, i -> i + 1).limit(count) + .forEach(i -> recordsPublisher.add(new ResponseItem(recordsRetrieved))); + } + + private void setupNotifierAnswer(int expected) { + doAnswer(new Answer() { + int seen = expected; + + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + seen--; + if (seen == 0) { + synchronized (processedNotifier) { + processedNotifier.notifyAll(); + } + } + return null; + } + }).when(shardConsumer).handleInput(any(ProcessRecordsInput.class), any(Subscription.class)); + } + + private class ResponseItem { + private final RecordsRetrieved recordsRetrieved; + private final Throwable throwable; + private int throwCount = 1; + + public ResponseItem(@NonNull RecordsRetrieved recordsRetrieved) { + this.recordsRetrieved = recordsRetrieved; + this.throwable = null; + } + + public ResponseItem(@NonNull Throwable throwable) { + this.throwable = throwable; + this.recordsRetrieved = null; + } + } + + private class TestPublisher implements RecordsPublisher { + + private final LinkedList responses = new LinkedList<>(); + private volatile long requested = 0; + private int currentIndex = 0; + private Subscriber subscriber; + private RecordsRetrieved restartedFrom; + + void add(ResponseItem... toAdd) { + responses.addAll(Arrays.asList(toAdd)); + send(); + } + + void send() { + send(0); + } + + synchronized void send(long toRequest) { + requested += toRequest; + while (requested > 0 && currentIndex < responses.size()) { + ResponseItem item = responses.get(currentIndex); + currentIndex++; + if (item.recordsRetrieved != null) { + subscriber.onNext(item.recordsRetrieved); + } else { + if (item.throwCount > 0) { + item.throwCount--; + subscriber.onError(item.throwable); + } else { + continue; + } + } + requested--; + } + } + + @Override + public void start(ExtendedSequenceNumber extendedSequenceNumber, + InitialPositionInStreamExtended initialPositionInStreamExtended) { + + } + + @Override + public void restartFrom(RecordsRetrieved recordsRetrieved) { + restartedFrom = recordsRetrieved; + currentIndex = -1; + for (int i = 0; i < responses.size(); i++) { + ResponseItem item = responses.get(i); + if (recordsRetrieved.equals(item.recordsRetrieved)) { + currentIndex = i + 1; + break; + } + } + + } + + @Override + public void shutdown() { + + } + + @Override + public void subscribe(Subscriber s) { + subscriber = s; + s.onSubscribe(new Subscription() { + @Override + public void request(long n) { + send(n); + } + + @Override + public void cancel() { + requested = 0; + } + }); + } + } + +} \ No newline at end of file diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java index 114d4d47c..39c867c61 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java @@ -70,6 +70,7 @@ import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput; import software.amazon.kinesis.retrieval.RecordsPublisher; +import software.amazon.kinesis.retrieval.RecordsRetrieved; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; /** @@ -161,7 +162,7 @@ private class TestPublisher implements RecordsPublisher { final CyclicBarrier barrier = new CyclicBarrier(2); final CyclicBarrier requestBarrier = new CyclicBarrier(2); - Subscriber subscriber; + Subscriber subscriber; final Subscription subscription = mock(Subscription.class); TestPublisher() { @@ -193,7 +194,7 @@ public void shutdown() { } @Override - public void subscribe(Subscriber s) { + public void subscribe(Subscriber s) { subscriber = s; subscriber.onSubscribe(subscription); try { @@ -203,6 +204,11 @@ public void subscribe(Subscriber s) { } } + @Override + public void restartFrom(RecordsRetrieved recordsRetrieved) { + + } + public void awaitSubscription() throws InterruptedException, BrokenBarrierException { barrier.await(); barrier.reset(); @@ -219,10 +225,10 @@ public void awaitInitialSetup() throws InterruptedException, BrokenBarrierExcept } public void publish() { - publish(processRecordsInput); + publish(() -> processRecordsInput); } - public void publish(ProcessRecordsInput input) { + public void publish(RecordsRetrieved input) { subscriber.onNext(input); } } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java index 50896f990..18ea93b50 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/fanout/FanOutRecordsPublisherTest.java @@ -45,6 +45,7 @@ import software.amazon.kinesis.common.InitialPositionInStreamExtended; import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; import software.amazon.kinesis.retrieval.KinesisClientRecord; +import software.amazon.kinesis.retrieval.RecordsRetrieved; import software.amazon.kinesis.retrieval.RetryableRetrievalException; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; @@ -62,7 +63,7 @@ public class FanOutRecordsPublisherTest { @Mock private Subscription subscription; @Mock - private Subscriber subscriber; + private Subscriber subscriber; private SubscribeToShardEvent batchEvent; @@ -80,7 +81,7 @@ public void simpleTest() throws Exception { List receivedInput = new ArrayList<>(); - source.subscribe(new Subscriber() { + source.subscribe(new Subscriber() { Subscription subscription; @Override @@ -90,8 +91,8 @@ public void onSubscribe(Subscription s) { } @Override - public void onNext(ProcessRecordsInput input) { - receivedInput.add(input); + public void onNext(RecordsRetrieved input) { + receivedInput.add(input.processRecordsInput()); subscription.request(1); } @@ -147,7 +148,7 @@ public void largeRequestTest() throws Exception { List receivedInput = new ArrayList<>(); - source.subscribe(new Subscriber() { + source.subscribe(new Subscriber() { Subscription subscription; @Override @@ -157,8 +158,8 @@ public void onSubscribe(Subscription s) { } @Override - public void onNext(ProcessRecordsInput input) { - receivedInput.add(input); + public void onNext(RecordsRetrieved input) { + receivedInput.add(input.processRecordsInput()); subscription.request(1); } @@ -206,7 +207,7 @@ public void testResourceNotFoundForShard() { ArgumentCaptor flowCaptor = ArgumentCaptor .forClass(FanOutRecordsPublisher.RecordFlow.class); - ArgumentCaptor inputCaptor = ArgumentCaptor.forClass(ProcessRecordsInput.class); + ArgumentCaptor inputCaptor = ArgumentCaptor.forClass(RecordsRetrieved.class); source.subscribe(subscriber); @@ -219,7 +220,7 @@ public void testResourceNotFoundForShard() { verify(subscriber).onNext(inputCaptor.capture()); verify(subscriber).onComplete(); - ProcessRecordsInput input = inputCaptor.getValue(); + ProcessRecordsInput input = inputCaptor.getValue().processRecordsInput(); assertThat(input.isAtShardEnd(), equalTo(true)); assertThat(input.records().isEmpty(), equalTo(true)); } @@ -325,7 +326,7 @@ private void verifyRecords(List clientRecordsList, List { + private static class NonFailingSubscriber implements Subscriber { final List received = new ArrayList<>(); Subscription subscription; @@ -336,8 +337,8 @@ public void onSubscribe(Subscription s) { } @Override - public void onNext(ProcessRecordsInput input) { - received.add(input); + public void onNext(RecordsRetrieved input) { + received.add(input.processRecordsInput()); subscription.request(1); } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherIntegrationTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherIntegrationTest.java index 96943e241..1526cb222 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherIntegrationTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherIntegrationTest.java @@ -121,13 +121,13 @@ public void testRollingCache() { getRecordsCache.start(extendedSequenceNumber, initialPosition); sleep(IDLE_MILLIS_BETWEEN_CALLS); - ProcessRecordsInput processRecordsInput1 = getRecordsCache.getNextResult(); + ProcessRecordsInput processRecordsInput1 = getRecordsCache.getNextResult().processRecordsInput(); assertTrue(processRecordsInput1.records().isEmpty()); assertEquals(processRecordsInput1.millisBehindLatest(), new Long(1000)); assertNotNull(processRecordsInput1.cacheEntryTime()); - ProcessRecordsInput processRecordsInput2 = getRecordsCache.getNextResult(); + ProcessRecordsInput processRecordsInput2 = getRecordsCache.getNextResult().processRecordsInput(); assertNotEquals(processRecordsInput1, processRecordsInput2); } @@ -139,8 +139,8 @@ public void testFullCache() { assertEquals(getRecordsCache.getRecordsResultQueue.size(), MAX_SIZE); - ProcessRecordsInput processRecordsInput1 = getRecordsCache.getNextResult(); - ProcessRecordsInput processRecordsInput2 = getRecordsCache.getNextResult(); + ProcessRecordsInput processRecordsInput1 = getRecordsCache.getNextResult().processRecordsInput(); + ProcessRecordsInput processRecordsInput2 = getRecordsCache.getNextResult().processRecordsInput(); assertNotEquals(processRecordsInput1, processRecordsInput2); } @@ -179,9 +179,9 @@ public void testDifferentShardCaches() { sleep(IDLE_MILLIS_BETWEEN_CALLS); - ProcessRecordsInput p1 = getRecordsCache.getNextResult(); + ProcessRecordsInput p1 = getRecordsCache.getNextResult().processRecordsInput(); - ProcessRecordsInput p2 = recordsPublisher2.getNextResult(); + ProcessRecordsInput p2 = recordsPublisher2.getNextResult().processRecordsInput(); assertNotEquals(p1, p2); assertTrue(p1.records().isEmpty()); @@ -207,7 +207,7 @@ public DataFetcherResult answer(final InvocationOnMock invocationOnMock) throws getRecordsCache.start(extendedSequenceNumber, initialPosition); sleep(IDLE_MILLIS_BETWEEN_CALLS); - ProcessRecordsInput processRecordsInput = getRecordsCache.getNextResult(); + ProcessRecordsInput processRecordsInput = getRecordsCache.getNextResult().processRecordsInput(); assertNotNull(processRecordsInput); assertTrue(processRecordsInput.records().isEmpty()); diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java index 7fb82ea6b..94373fe0c 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java @@ -24,17 +24,21 @@ import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.atLeast; -import static org.mockito.Mockito.atMost; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static software.amazon.kinesis.utils.ProcessRecordsInputMatcher.eqProcessRecordsInput; import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -42,15 +46,18 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.commons.lang3.StringUtils; import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; -import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -66,6 +73,7 @@ import software.amazon.kinesis.metrics.NullMetricsFactory; import software.amazon.kinesis.retrieval.GetRecordsRetrievalStrategy; import software.amazon.kinesis.retrieval.KinesisClientRecord; +import software.amazon.kinesis.retrieval.RecordsRetrieved; import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; /** @@ -92,7 +100,7 @@ public class PrefetchRecordsPublisherTest { private List records; private ExecutorService executorService; - private LinkedBlockingQueue spyQueue; + private LinkedBlockingQueue spyQueue; private PrefetchRecordsPublisher getRecordsCache; private String operation = "ProcessTask"; private GetRecordsResponse getRecordsResponse; @@ -131,7 +139,7 @@ record = Record.builder().data(createByteBufferWithSize(SIZE_512_KB)).build(); .map(KinesisClientRecord::fromRecord).collect(Collectors.toList()); getRecordsCache.start(sequenceNumber, initialPosition); - ProcessRecordsInput result = getRecordsCache.getNextResult(); + ProcessRecordsInput result = getRecordsCache.getNextResult().processRecordsInput(); assertEquals(expectedRecords, result.records()); @@ -200,7 +208,7 @@ record = Record.builder().data(createByteBufferWithSize(1024)).build(); .map(KinesisClientRecord::fromRecord).collect(Collectors.toList()); getRecordsCache.start(sequenceNumber, initialPosition); - ProcessRecordsInput processRecordsInput = getRecordsCache.getNextResult(); + ProcessRecordsInput processRecordsInput = getRecordsCache.getNextResult().processRecordsInput(); verify(executorService).execute(any()); assertEquals(expectedRecords, processRecordsInput.records()); @@ -209,7 +217,7 @@ record = Record.builder().data(createByteBufferWithSize(1024)).build(); sleep(2000); - ProcessRecordsInput processRecordsInput2 = getRecordsCache.getNextResult(); + ProcessRecordsInput processRecordsInput2 = getRecordsCache.getNextResult().processRecordsInput(); assertNotEquals(processRecordsInput, processRecordsInput2); assertEquals(expectedRecords, processRecordsInput2.records()); assertNotEquals(processRecordsInput2.timeSpentInCache(), Duration.ZERO); @@ -276,7 +284,7 @@ public void testNoDeadlockOnFullQueue() { Object lock = new Object(); - Subscriber subscriber = new Subscriber() { + Subscriber subscriber = new Subscriber() { Subscription sub; @Override @@ -286,7 +294,7 @@ public void onSubscribe(Subscription s) { } @Override - public void onNext(ProcessRecordsInput processRecordsInput) { + public void onNext(RecordsRetrieved recordsRetrieved) { receivedItems.incrementAndGet(); if (receivedItems.get() >= expectedItems) { synchronized (lock) { @@ -325,6 +333,87 @@ public void onComplete() { assertThat(receivedItems.get(), equalTo(expectedItems)); } + @Test + public void testResetClearsRemainingData() { + List responses = Stream.iterate(0, i -> i + 1).limit(10).map(i -> { + Record record = Record.builder().partitionKey("record-" + i).sequenceNumber("seq-" + i) + .data(SdkBytes.fromByteArray(new byte[] { 1, 2, 3 })).approximateArrivalTimestamp(Instant.now()) + .build(); + String nextIterator = "shard-iter-" + (i + 1); + return GetRecordsResponse.builder().records(record).nextShardIterator(nextIterator).build(); + }).collect(Collectors.toList()); + + RetrieverAnswer retrieverAnswer = new RetrieverAnswer(responses); + + when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenAnswer(retrieverAnswer); + doAnswer(a -> { + String resetTo = a.getArgumentAt(0, String.class); + retrieverAnswer.resetIteratorTo(resetTo); + return null; + }).when(dataFetcher).resetIterator(anyString(), anyString(), any()); + + getRecordsCache.start(sequenceNumber, initialPosition); + + RecordsRetrieved lastProcessed = getRecordsCache.getNextResult(); + RecordsRetrieved expected = getRecordsCache.getNextResult(); + + // + // Skip some of the records the cache + // + getRecordsCache.getNextResult(); + getRecordsCache.getNextResult(); + + verify(getRecordsRetrievalStrategy, atLeast(2)).getRecords(anyInt()); + + while(getRecordsCache.getRecordsResultQueue.remainingCapacity() > 0) { + Thread.yield(); + } + + getRecordsCache.restartFrom(lastProcessed); + RecordsRetrieved postRestart = getRecordsCache.getNextResult(); + + assertThat(postRestart.processRecordsInput(), eqProcessRecordsInput(expected.processRecordsInput())); + verify(dataFetcher).resetIterator(eq(responses.get(0).nextShardIterator()), + eq(responses.get(0).records().get(0).sequenceNumber()), any()); + + } + + private static class RetrieverAnswer implements Answer { + + private final List responses; + private Iterator iterator; + + public RetrieverAnswer(List responses) { + this.responses = responses; + this.iterator = responses.iterator(); + } + + public void resetIteratorTo(String nextIterator) { + Iterator newIterator = responses.iterator(); + while(newIterator.hasNext()) { + GetRecordsResponse current = newIterator.next(); + if (StringUtils.equals(nextIterator, current.nextShardIterator())) { + if (!newIterator.hasNext()) { + iterator = responses.iterator(); + } else { + newIterator.next(); + iterator = newIterator; + } + break; + } + } + } + + @Override + public GetRecordsResponse answer(InvocationOnMock invocation) throws Throwable { + GetRecordsResponse response = iterator.next(); + if (!iterator.hasNext()) { + iterator = responses.iterator(); + } + return response; + } + } + @After public void shutdown() { getRecordsCache.shutdown(); @@ -340,4 +429,5 @@ private void sleep(long millis) { private SdkBytes createByteBufferWithSize(int size) { return SdkBytes.fromByteArray(new byte[size]); } + } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/utils/ProcessRecordsInputMatcher.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/utils/ProcessRecordsInputMatcher.java new file mode 100644 index 000000000..76763ebdb --- /dev/null +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/utils/ProcessRecordsInputMatcher.java @@ -0,0 +1,79 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Amazon Software License (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/asl/ + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.kinesis.utils; + +import lombok.Data; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeDiagnosingMatcher; +import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; + +public class ProcessRecordsInputMatcher extends TypeSafeDiagnosingMatcher { + + private final ProcessRecordsInput template; + private final Map matchers = new HashMap<>(); + + public static ProcessRecordsInputMatcher eqProcessRecordsInput(ProcessRecordsInput expected) { + return new ProcessRecordsInputMatcher(expected); + } + + public ProcessRecordsInputMatcher(ProcessRecordsInput template) { + matchers.put("cacheEntryTime", + nullOrEquals(template.cacheEntryTime(), ProcessRecordsInput::cacheEntryTime)); + matchers.put("checkpointer", nullOrEquals(template.checkpointer(), ProcessRecordsInput::checkpointer)); + matchers.put("isAtShardEnd", nullOrEquals(template.isAtShardEnd(), ProcessRecordsInput::isAtShardEnd)); + matchers.put("millisBehindLatest", + nullOrEquals(template.millisBehindLatest(), ProcessRecordsInput::millisBehindLatest)); + matchers.put("records", nullOrEquals(template.records(), ProcessRecordsInput::records)); + + this.template = template; + } + + private static MatcherData nullOrEquals(Object item, Function accessor) { + if (item == null) { + return new MatcherData(nullValue(), accessor); + } + return new MatcherData(equalTo(item), accessor); + } + + @Override + protected boolean matchesSafely(ProcessRecordsInput item, Description mismatchDescription) { + return matchers.entrySet().stream() + .filter(e -> e.getValue().matcher.matches(e.getValue().accessor.apply(item))).anyMatch(e -> { + mismatchDescription.appendText(e.getKey()).appendText(" "); + e.getValue().matcher.describeMismatch(e.getValue().accessor.apply(item), mismatchDescription); + return true; + }); + } + + @Override + public void describeTo(Description description) { + matchers.forEach((k, v) -> description.appendText(k).appendText(" ").appendDescriptionOf(v.matcher)); + } + + @Data + private static class MatcherData { + private final Matcher matcher; + private final Function accessor; + } +}