diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordProcessorCheckpointer.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordProcessorCheckpointer.java index 72e18d73d..8e3dfd735 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordProcessorCheckpointer.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordProcessorCheckpointer.java @@ -14,6 +14,9 @@ */ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; +import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper; +import com.amazonaws.services.kinesis.metrics.impl.ThreadSafeMetricsDelegatingScope; +import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -50,6 +53,8 @@ class RecordProcessorCheckpointer implements IRecordProcessorCheckpointer { private SequenceNumberValidator sequenceNumberValidator; private ExtendedSequenceNumber sequenceNumberAtShardEnd; + + private IMetricsFactory metricsFactory; /** * Only has package level access, since only the Amazon Kinesis Client Library should be creating these. @@ -59,10 +64,12 @@ class RecordProcessorCheckpointer implements IRecordProcessorCheckpointer { */ RecordProcessorCheckpointer(ShardInfo shardInfo, ICheckpoint checkpoint, - SequenceNumberValidator validator) { + SequenceNumberValidator validator, + IMetricsFactory metricsFactory) { this.shardInfo = shardInfo; this.checkpoint = checkpoint; this.sequenceNumberValidator = validator; + this.metricsFactory = metricsFactory; } /** @@ -283,21 +290,33 @@ void advancePosition(ExtendedSequenceNumber extendedSequenceNumber) // just checkpoint at SHARD_END checkpointToRecord = ExtendedSequenceNumber.SHARD_END; } + + boolean unsetMetrics = false; // Don't checkpoint a value we already successfully checkpointed - if (extendedSequenceNumber != null && !extendedSequenceNumber.equals(lastCheckpointValue)) { - try { - if (LOG.isDebugEnabled()) { - LOG.debug("Setting " + shardInfo.getShardId() + ", token " + shardInfo.getConcurrencyToken() - + " checkpoint to " + checkpointToRecord); + try { + if (!MetricsHelper.isMetricsScopePresent()) { + MetricsHelper.setMetricsScope(new ThreadSafeMetricsDelegatingScope(metricsFactory.createMetrics())); + unsetMetrics = true; + } + if (extendedSequenceNumber != null && !extendedSequenceNumber.equals(lastCheckpointValue)) { + try { + if (LOG.isDebugEnabled()) { + LOG.debug("Setting " + shardInfo.getShardId() + ", token " + shardInfo.getConcurrencyToken() + + " checkpoint to " + checkpointToRecord); + } + checkpoint.setCheckpoint(shardInfo.getShardId(), checkpointToRecord, shardInfo.getConcurrencyToken()); + lastCheckpointValue = checkpointToRecord; + } catch (ThrottlingException | ShutdownException | InvalidStateException + | KinesisClientLibDependencyException e) { + throw e; + } catch (KinesisClientLibException e) { + LOG.warn("Caught exception setting checkpoint.", e); + throw new KinesisClientLibDependencyException("Caught exception while checkpointing", e); } - checkpoint.setCheckpoint(shardInfo.getShardId(), checkpointToRecord, shardInfo.getConcurrencyToken()); - lastCheckpointValue = checkpointToRecord; - } catch (ThrottlingException | ShutdownException | InvalidStateException - | KinesisClientLibDependencyException e) { - throw e; - } catch (KinesisClientLibException e) { - LOG.warn("Caught exception setting checkpoint.", e); - throw new KinesisClientLibDependencyException("Caught exception while checkpointing", e); + } + } finally { + if (unsetMetrics) { + MetricsHelper.unsetMetricsScope(); } } } diff --git a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java index 95cc663e6..4a001b9be 100644 --- a/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java +++ b/src/main/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumer.java @@ -170,7 +170,8 @@ private static final GetRecordsRetrievalStrategy makeStrategy(KinesisDataFetcher new SequenceNumberValidator( streamConfig.getStreamProxy(), shardInfo.getShardId(), - streamConfig.shouldValidateSequenceNumberBeforeCheckpointing())), + streamConfig.shouldValidateSequenceNumberBeforeCheckpointing()), + metricsFactory), leaseManager, parentShardPollIntervalMillis, cleanupLeasesOfCompletedShards, diff --git a/src/main/java/com/amazonaws/services/kinesis/metrics/impl/MetricsHelper.java b/src/main/java/com/amazonaws/services/kinesis/metrics/impl/MetricsHelper.java index 4599fbaa7..bf104cff3 100644 --- a/src/main/java/com/amazonaws/services/kinesis/metrics/impl/MetricsHelper.java +++ b/src/main/java/com/amazonaws/services/kinesis/metrics/impl/MetricsHelper.java @@ -72,13 +72,22 @@ public static IMetricsScope startScope(IMetricsFactory factory, String operation * @param scope */ public static void setMetricsScope(IMetricsScope scope) { - if (currentScope.get() != null) { + if (isMetricsScopePresent()) { throw new RuntimeException(String.format( "Metrics scope is already set for the current thread %s", Thread.currentThread().getName())); } currentScope.set(scope); } + /** + * Checks if current metricsscope is present or not. + * + * @return true if metrics scope is present, else returns false + */ + public static boolean isMetricsScopePresent() { + return currentScope.get() != null; + } + /** * Unsets the metrics scope for the current thread. */ diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordProcessorCheckpointerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordProcessorCheckpointerTest.java index 31a1e1844..7e637457e 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordProcessorCheckpointerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/RecordProcessorCheckpointerTest.java @@ -14,6 +14,13 @@ */ package com.amazonaws.services.kinesis.clientlibrary.lib.worker; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; @@ -23,7 +30,10 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; import org.mockito.Mockito; +import org.mockito.runners.MockitoJUnitRunner; import com.amazonaws.services.kinesis.clientlibrary.interfaces.ICheckpoint; import com.amazonaws.services.kinesis.clientlibrary.interfaces.IPreparedCheckpointer; @@ -31,15 +41,15 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.checkpoint.SentinelCheckpoint; import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord; +import com.amazonaws.services.kinesis.metrics.impl.MetricsHelper; +import com.amazonaws.services.kinesis.metrics.impl.NullMetricsScope; +import com.amazonaws.services.kinesis.metrics.interfaces.IMetricsFactory; import com.amazonaws.services.kinesis.model.Record; -import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; -import static org.mockito.Matchers.anyString; - /** * */ +@RunWith(MockitoJUnitRunner.class) public class RecordProcessorCheckpointerTest { private String startingSequenceNumber = "13"; private ExtendedSequenceNumber startingExtendedSequenceNumber = new ExtendedSequenceNumber(startingSequenceNumber); @@ -48,6 +58,9 @@ public class RecordProcessorCheckpointerTest { private ShardInfo shardInfo; private SequenceNumberValidator sequenceNumberValidator; private String shardId = "shardId-123"; + + @Mock + IMetricsFactory metricsFactory; /** * @throws java.lang.Exception @@ -78,7 +91,7 @@ public void tearDown() throws Exception { public final void testCheckpoint() throws Exception { // First call to checkpoint RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, null); + new RecordProcessorCheckpointer(shardInfo, checkpoint, null, metricsFactory); processingCheckpointer.setLargestPermittedCheckpointValue(startingExtendedSequenceNumber); processingCheckpointer.checkpoint(); Assert.assertEquals(startingExtendedSequenceNumber, checkpoint.getCheckpoint(shardId)); @@ -98,7 +111,7 @@ public final void testCheckpoint() throws Exception { @Test public final void testCheckpointRecord() throws Exception { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator, metricsFactory); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5025"); Record record = new Record().withSequenceNumber("5025"); @@ -114,7 +127,7 @@ public final void testCheckpointRecord() throws Exception { @Test public final void testCheckpointSubRecord() throws Exception { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator, metricsFactory); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5030"); Record record = new Record().withSequenceNumber("5030"); @@ -131,7 +144,7 @@ public final void testCheckpointSubRecord() throws Exception { @Test public final void testCheckpointSequenceNumber() throws Exception { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator, metricsFactory); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5035"); processingCheckpointer.setLargestPermittedCheckpointValue(extendedSequenceNumber); @@ -146,7 +159,7 @@ public final void testCheckpointSequenceNumber() throws Exception { @Test public final void testCheckpointExtendedSequenceNumber() throws Exception { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator, metricsFactory); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5040"); processingCheckpointer.setLargestPermittedCheckpointValue(extendedSequenceNumber); @@ -162,7 +175,7 @@ public final void testCheckpointExtendedSequenceNumber() throws Exception { public final void testPrepareCheckpoint() throws Exception { // First call to checkpoint RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator, metricsFactory); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber sequenceNumber1 = new ExtendedSequenceNumber("5001"); @@ -193,7 +206,7 @@ public final void testPrepareCheckpoint() throws Exception { @Test public final void testPrepareCheckpointRecord() throws Exception { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator, metricsFactory); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5025"); Record record = new Record().withSequenceNumber("5025"); @@ -218,7 +231,7 @@ public final void testPrepareCheckpointRecord() throws Exception { @Test public final void testPrepareCheckpointSubRecord() throws Exception { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator, metricsFactory); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5030"); Record record = new Record().withSequenceNumber("5030"); @@ -244,7 +257,7 @@ public final void testPrepareCheckpointSubRecord() throws Exception { @Test public final void testPrepareCheckpointSequenceNumber() throws Exception { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator, metricsFactory); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5035"); processingCheckpointer.setLargestPermittedCheckpointValue(extendedSequenceNumber); @@ -268,7 +281,7 @@ public final void testPrepareCheckpointSequenceNumber() throws Exception { @Test public final void testPrepareCheckpointExtendedSequenceNumber() throws Exception { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator, metricsFactory); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5040"); processingCheckpointer.setLargestPermittedCheckpointValue(extendedSequenceNumber); @@ -291,7 +304,7 @@ public final void testPrepareCheckpointExtendedSequenceNumber() throws Exception @Test public final void testMultipleOutstandingCheckpointersHappyCase() throws Exception { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator, metricsFactory); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); processingCheckpointer.setLargestPermittedCheckpointValue(new ExtendedSequenceNumber("6040")); @@ -323,7 +336,7 @@ public final void testMultipleOutstandingCheckpointersHappyCase() throws Excepti @Test public final void testMultipleOutstandingCheckpointersOutOfOrder() throws Exception { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, sequenceNumberValidator, metricsFactory); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); processingCheckpointer.setLargestPermittedCheckpointValue(new ExtendedSequenceNumber("7040")); @@ -358,7 +371,7 @@ public final void testMultipleOutstandingCheckpointersOutOfOrder() throws Except */ @Test public final void testUpdate() throws Exception { - RecordProcessorCheckpointer checkpointer = new RecordProcessorCheckpointer(shardInfo, checkpoint, null); + RecordProcessorCheckpointer checkpointer = new RecordProcessorCheckpointer(shardInfo, checkpoint, null, metricsFactory); ExtendedSequenceNumber sequenceNumber = new ExtendedSequenceNumber("10"); checkpointer.setLargestPermittedCheckpointValue(sequenceNumber); @@ -379,7 +392,7 @@ public final void testClientSpecifiedCheckpoint() throws Exception { SequenceNumberValidator validator = mock(SequenceNumberValidator.class); Mockito.doNothing().when(validator).validateSequenceNumber(anyString()); RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, validator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, validator, metricsFactory); // Several checkpoints we're gonna hit ExtendedSequenceNumber tooSmall = new ExtendedSequenceNumber("2"); @@ -467,7 +480,7 @@ public final void testClientSpecifiedTwoPhaseCheckpoint() throws Exception { SequenceNumberValidator validator = mock(SequenceNumberValidator.class); Mockito.doNothing().when(validator).validateSequenceNumber(anyString()); RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, validator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, validator, metricsFactory); // Several checkpoints we're gonna hit ExtendedSequenceNumber tooSmall = new ExtendedSequenceNumber("2"); @@ -595,7 +608,7 @@ public final void testMixedCheckpointCalls() throws Exception { for (LinkedHashMap testPlan : getMixedCallsTestPlan()) { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, validator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, validator, metricsFactory); testMixedCheckpointCalls(processingCheckpointer, testPlan, CheckpointerType.CHECKPOINTER); } } @@ -615,7 +628,7 @@ public final void testMixedTwoPhaseCheckpointCalls() throws Exception { for (LinkedHashMap testPlan : getMixedCallsTestPlan()) { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, validator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, validator, metricsFactory); testMixedCheckpointCalls(processingCheckpointer, testPlan, CheckpointerType.PREPARED_CHECKPOINTER); } } @@ -636,7 +649,7 @@ public final void testMixedTwoPhaseCheckpointCalls2() throws Exception { for (LinkedHashMap testPlan : getMixedCallsTestPlan()) { RecordProcessorCheckpointer processingCheckpointer = - new RecordProcessorCheckpointer(shardInfo, checkpoint, validator); + new RecordProcessorCheckpointer(shardInfo, checkpoint, validator, metricsFactory); testMixedCheckpointCalls(processingCheckpointer, testPlan, CheckpointerType.PREPARE_THEN_CHECKPOINTER); } } @@ -785,4 +798,34 @@ private void testMixedCheckpointCalls(RecordProcessorCheckpointer processingChec Assert.assertEquals(null, checkpoint.getCheckpointObject(shardId).getPendingCheckpoint()); } } + + @Test + public final void testUnsetMetricsScopeDuringCheckpointing() throws Exception { + // First call to checkpoint + RecordProcessorCheckpointer processingCheckpointer = + new RecordProcessorCheckpointer(shardInfo, checkpoint, null, metricsFactory); + ExtendedSequenceNumber sequenceNumber = new ExtendedSequenceNumber("5019"); + processingCheckpointer.setLargestPermittedCheckpointValue(sequenceNumber); + processingCheckpointer.checkpoint(); + Assert.assertEquals(sequenceNumber, checkpoint.getCheckpoint(shardId)); + verify(metricsFactory).createMetrics(); + Assert.assertFalse(MetricsHelper.isMetricsScopePresent()); + } + + @Test + public final void testSetMetricsScopeDuringCheckpointing() throws Exception { + // First call to checkpoint + RecordProcessorCheckpointer processingCheckpointer = + new RecordProcessorCheckpointer(shardInfo, checkpoint, null, metricsFactory); + NullMetricsScope scope = new NullMetricsScope(); + MetricsHelper.setMetricsScope(scope); + ExtendedSequenceNumber sequenceNumber = new ExtendedSequenceNumber("5019"); + processingCheckpointer.setLargestPermittedCheckpointValue(sequenceNumber); + processingCheckpointer.checkpoint(); + Assert.assertEquals(sequenceNumber, checkpoint.getCheckpoint(shardId)); + verify(metricsFactory, never()).createMetrics(); + Assert.assertTrue(MetricsHelper.isMetricsScopePresent()); + assertEquals(scope, MetricsHelper.getMetricsScope()); + MetricsHelper.unsetMetricsScope(); + } } diff --git a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java index 8a91c6e6e..216d59cdb 100644 --- a/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java +++ b/src/test/java/com/amazonaws/services/kinesis/clientlibrary/lib/worker/ShardConsumerTest.java @@ -342,7 +342,8 @@ public final void testConsumeShard() throws Exception { streamConfig.getStreamProxy(), shardInfo.getShardId(), streamConfig.shouldValidateSequenceNumberBeforeCheckpointing() - ) + ), + metricsFactory ); dataFetcher = new KinesisDataFetcher(streamConfig.getStreamProxy(), shardInfo); @@ -493,7 +494,8 @@ public final void testConsumeShardWithTransientTerminateError() throws Exception streamConfig.getStreamProxy(), shardInfo.getShardId(), streamConfig.shouldValidateSequenceNumberBeforeCheckpointing() - ) + ), + metricsFactory ); ShardConsumer consumer = @@ -621,7 +623,8 @@ public final void testConsumeShardWithInitialPositionAtTimestamp() throws Except streamConfig.getStreamProxy(), shardInfo.getShardId(), streamConfig.shouldValidateSequenceNumberBeforeCheckpointing() - ) + ), + metricsFactory ); dataFetcher = new KinesisDataFetcher(streamConfig.getStreamProxy(), shardInfo);