diff --git a/data-prepper-plugins/aws-lambda/build.gradle b/data-prepper-plugins/aws-lambda/build.gradle index a0319fabd4..b1d278cb18 100644 --- a/data-prepper-plugins/aws-lambda/build.gradle +++ b/data-prepper-plugins/aws-lambda/build.gradle @@ -63,10 +63,7 @@ task integrationTest(type: Test) { classpath = sourceSets.integrationTest.runtimeClasspath systemProperty 'log4j.configurationFile', 'src/test/resources/log4j2.properties' - - //Enable Multi-thread in tests - systemProperty 'junit.jupiter.execution.parallel.enabled', 'true' - systemProperty 'junit.jupiter.execution.parallel.mode.default', 'concurrent' + systemProperty 'tests.lambda.sink.region', System.getProperty('tests.lambda.sink.region') systemProperty 'tests.lambda.sink.functionName', System.getProperty('tests.lambda.sink.functionName') diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java index f05ab16b2e..5ea7115bbf 100644 --- a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java +++ b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java @@ -3,65 +3,68 @@ * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.dataprepper.plugins.lambda.processor; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.spy; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.event.DefaultEventMetadata; import org.opensearch.dataprepper.model.event.Event; -import org.opensearch.dataprepper.model.event.JacksonEvent; import org.opensearch.dataprepper.model.event.EventMetadata; -import org.opensearch.dataprepper.model.event.DefaultEventMetadata; -import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.event.JacksonEvent; import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.types.ByteCount; -import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; -import org.opensearch.dataprepper.expression.ExpressionEvaluator; -import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodec; +import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodecConfig; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; +import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; -import org.opensearch.dataprepper.model.codec.InputCodec; -import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodec; -import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodecConfig; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.mockito.ArgumentMatchers.any; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.lenient; - -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.DistributionSummary; -import io.micrometer.core.instrument.Timer; import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import java.lang.reflect.Field; + import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.List; -import java.util.HashMap; -import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + @ExtendWith(MockitoExtension.class) @MockitoSettings(strictness = Strictness.LENIENT) public class LambdaProcessorIT { + @Mock + InvocationType invocationType; private AwsCredentialsProvider awsCredentialsProvider; private LambdaProcessor lambdaProcessor; private LambdaProcessorConfig lambdaProcessorConfig; @@ -73,78 +76,45 @@ public class LambdaProcessorIT { @Mock private PluginFactory pluginFactory; @Mock + private PluginMetrics pluginMetrics; + @Mock private PluginSetting pluginSetting; @Mock private ExpressionEvaluator expressionEvaluator; @Mock - private Counter numberOfRecordsSuccessCounter; + private Counter testCounter; @Mock - private Counter numberOfRecordsFailedCounter; - @Mock - private Counter numberOfRequestsSuccessCounter; - @Mock - private Counter numberOfRequestsFailedCounter; - @Mock - private Counter sinkSuccessCounter; - @Mock - private Timer lambdaLatencyMetric; - @Mock - private DistributionSummary requestPayloadMetric; - @Mock - private DistributionSummary responsePayloadMetric; - @Mock - InvocationType invocationType; + private Timer testTimer; + private LambdaProcessor createObjectUnderTest(LambdaProcessorConfig processorConfig) { return new LambdaProcessor(pluginFactory, pluginSetting, processorConfig, awsCredentialsSupplier, expressionEvaluator); } - private void setPrivateField(Object targetObject, String fieldName, Object value) throws Exception { - Field field = targetObject.getClass().getDeclaredField(fieldName); - field.setAccessible(true); - field.set(targetObject, value); - } - - private void setPrivateFields(final LambdaProcessor lambdaProcessor) throws Exception { - setPrivateField(lambdaProcessor, "numberOfRecordsSuccessCounter", numberOfRecordsSuccessCounter); - setPrivateField(lambdaProcessor, "numberOfRecordsFailedCounter", numberOfRecordsFailedCounter); - setPrivateField(lambdaProcessor, "numberOfRequestsSuccessCounter", numberOfRequestsSuccessCounter); - setPrivateField(lambdaProcessor, "numberOfRequestsFailedCounter", numberOfRequestsFailedCounter); - setPrivateField(lambdaProcessor, "lambdaLatencyMetric", lambdaLatencyMetric); - setPrivateField(lambdaProcessor, "requestPayloadMetric", requestPayloadMetric); - setPrivateField(lambdaProcessor, "responsePayloadMetric", responsePayloadMetric); - } - @BeforeEach public void setup() { lambdaRegion = System.getProperty("tests.lambda.processor.region"); functionName = System.getProperty("tests.lambda.processor.functionName"); role = System.getProperty("tests.lambda.processor.sts_role_arn"); + pluginMetrics = mock(PluginMetrics.class); pluginSetting = mock(PluginSetting.class); when(pluginSetting.getPipelineName()).thenReturn("pipeline"); when(pluginSetting.getName()).thenReturn("name"); - numberOfRecordsSuccessCounter = mock(Counter.class); - numberOfRecordsFailedCounter = mock(Counter.class); - numberOfRequestsSuccessCounter = mock(Counter.class); - numberOfRequestsFailedCounter = mock(Counter.class); - lambdaLatencyMetric = mock(Timer.class); - requestPayloadMetric = mock(DistributionSummary.class); - responsePayloadMetric = mock(DistributionSummary.class); - try { - lenient().doAnswer(args -> { - return null; - }).when(numberOfRecordsSuccessCounter).increment(any(Double.class)); - } catch (Exception e){} + testCounter = mock(Counter.class); try { lenient().doAnswer(args -> { return null; - }).when(numberOfRecordsFailedCounter).increment(); - } catch (Exception e){} + }).when(testCounter).increment(any(Double.class)); + } catch (Exception e) { + } try { lenient().doAnswer(args -> { return null; - }).when(lambdaLatencyMetric).record(any(Long.class), any(TimeUnit.class)); - } catch (Exception e){} - + }).when(testTimer).record(any(Long.class), any(TimeUnit.class)); + } catch (Exception e) { + } + when(pluginMetrics.counter(any())).thenReturn(testCounter); + testTimer = mock(Timer.class); + when(pluginMetrics.timer(any())).thenReturn(testTimer); lambdaProcessorConfig = mock(LambdaProcessorConfig.class); expressionEvaluator = mock(ExpressionEvaluator.class); awsCredentialsProvider = DefaultCredentialsProvider.create(); @@ -201,7 +171,7 @@ public void testRequestResponseWithMatchingEventsAggregateMode(int numRecords) { List> records = createRecords(numRecords); Collection> results = lambdaProcessor.doExecute(records); assertThat(results.size(), equalTo(numRecords)); - validateResultsForAggregateMode(results ); + validateResultsForAggregateMode(results); } @ParameterizedTest @@ -242,7 +212,7 @@ public void testDifferentInvocationTypes(String invocationType) throws Exception validateStrictModeResults(results); } else { // For "Event" invocation type - assertThat(results.size(), equalTo(0)); + assertThat(results.size(), equalTo(10)); } } @@ -287,12 +257,12 @@ private void validateStrictModeResults(Collection> results) { for (int i = 0; i < resultRecords.size(); i++) { Map eventData = resultRecords.get(i).getData().toMap(); Map attr = resultRecords.get(i).getData().getMetadata().getAttributes(); - int id = (Integer)eventData.get("id"); - assertThat(eventData.get("key"+id), equalTo(id)); - String stringValue = "value"+id; - assertThat(eventData.get("keys"+id), equalTo(stringValue.toUpperCase())); - assertThat(attr.get("attr"+id), equalTo(id)); - assertThat(attr.get("attrs"+id), equalTo("attrvalue"+id)); + int id = (Integer) eventData.get("id"); + assertThat(eventData.get("key" + id), equalTo(id)); + String stringValue = "value" + id; + assertThat(eventData.get("keys" + id), equalTo(stringValue.toUpperCase())); + assertThat(attr.get("attr" + id), equalTo(id)); + assertThat(attr.get("attrs" + id), equalTo("attrvalue" + id)); } } @@ -301,11 +271,11 @@ private List> createRecords(int numRecords) { for (int i = 0; i < numRecords; i++) { Map map = new HashMap<>(); map.put("id", i); - map.put("key"+i, i); - map.put("keys"+i, "value"+i); + map.put("key" + i, i); + map.put("keys" + i, "value" + i); Map attrs = new HashMap<>(); - attrs.put("attr"+i, i); - attrs.put("attrs"+i, "attrvalue"+i); + attrs.put("attr" + i, i); + attrs.put("attrs" + i, "attrvalue" + i); EventMetadata metadata = DefaultEventMetadata.builder() .withEventType("event") .withAttributes(attrs) diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java index 65ff6af530..3518508b64 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java @@ -5,13 +5,6 @@ package org.opensearch.dataprepper.plugins.lambda.common; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.sink.OutputCodecContext; @@ -27,76 +20,80 @@ import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + public class LambdaCommonHandler { - private static final Logger LOG = LoggerFactory.getLogger(LambdaCommonHandler.class); + private static final Logger LOG = LoggerFactory.getLogger(LambdaCommonHandler.class); - private LambdaCommonHandler() { - } + private LambdaCommonHandler() { + } - public static boolean isSuccess(InvokeResponse response) { - int statusCode = response.statusCode(); - return statusCode >= 200 && statusCode < 300; - } + public static boolean isSuccess(InvokeResponse response) { + int statusCode = response.statusCode(); + return statusCode >= 200 && statusCode < 300; + } - public static void waitForFutures(Collection> futureList) { + public static void waitForFutures(Collection> futureList) { - if (!futureList.isEmpty()) { - try { - CompletableFuture.allOf(futureList.toArray(new CompletableFuture[0])).join(); - } catch (Exception e) { - LOG.warn("Exception while waiting for Lambda invocations to complete", e); - } + if (!futureList.isEmpty()) { + CompletableFuture.allOf(futureList.toArray(new CompletableFuture[0])).join(); + } } - } - private static List createBufferBatches(Collection> records, - BatchOptions batchOptions, final OutputCodecContext outputCodecContext) { + private static List createBufferBatches(Collection> records, + BatchOptions batchOptions, final OutputCodecContext outputCodecContext) { - int maxEvents = batchOptions.getThresholdOptions().getEventCount(); - ByteCount maxBytes = batchOptions.getThresholdOptions().getMaximumSize(); - String keyName = batchOptions.getKeyName(); - Duration maxCollectionDuration = batchOptions.getThresholdOptions().getEventCollectTimeOut(); + int maxEvents = batchOptions.getThresholdOptions().getEventCount(); + ByteCount maxBytes = batchOptions.getThresholdOptions().getMaximumSize(); + String keyName = batchOptions.getKeyName(); + Duration maxCollectionDuration = batchOptions.getThresholdOptions().getEventCollectTimeOut(); - Buffer currentBufferPerBatch = new InMemoryBuffer(keyName, outputCodecContext); - List batchedBuffers = new ArrayList<>(); + Buffer currentBufferPerBatch = new InMemoryBuffer(keyName, outputCodecContext); + List batchedBuffers = new ArrayList<>(); - LOG.debug("Batch size received to lambda processor: {}", records.size()); - for (Record record : records) { + LOG.debug("Batch size received to lambda processor: {}", records.size()); + for (Record record : records) { - currentBufferPerBatch.addRecord(record); - if (ThresholdCheck.checkThresholdExceed(currentBufferPerBatch, maxEvents, maxBytes, - maxCollectionDuration)) { - batchedBuffers.add(currentBufferPerBatch); - currentBufferPerBatch = new InMemoryBuffer(keyName, outputCodecContext); - } - } + currentBufferPerBatch.addRecord(record); + if (ThresholdCheck.checkThresholdExceed(currentBufferPerBatch, maxEvents, maxBytes, + maxCollectionDuration)) { + batchedBuffers.add(currentBufferPerBatch); + currentBufferPerBatch = new InMemoryBuffer(keyName, outputCodecContext); + } + } - if (currentBufferPerBatch.getEventCount() > 0) { - batchedBuffers.add(currentBufferPerBatch); + if (currentBufferPerBatch.getEventCount() > 0) { + batchedBuffers.add(currentBufferPerBatch); + } + return batchedBuffers; } - return batchedBuffers; - } - - public static Map> sendRecords( - Collection> records, - LambdaCommonConfig config, - LambdaAsyncClient lambdaAsyncClient, - final OutputCodecContext outputCodecContext) { - - List batchedBuffers = createBufferBatches(records, config.getBatchOptions(), - outputCodecContext); - - Map> bufferToFutureMap = new HashMap<>(); - LOG.debug("Batch Chunks created after threshold check: {}", batchedBuffers.size()); - for (Buffer buffer : batchedBuffers) { - InvokeRequest requestPayload = buffer.getRequestPayload(config.getFunctionName(), - config.getInvocationType().getAwsLambdaValue()); - CompletableFuture future = lambdaAsyncClient.invoke(requestPayload); - bufferToFutureMap.put(buffer, future); + + public static Map> sendRecords( + Collection> records, + LambdaCommonConfig config, + LambdaAsyncClient lambdaAsyncClient, + final OutputCodecContext outputCodecContext) { + + List batchedBuffers = createBufferBatches(records, config.getBatchOptions(), + outputCodecContext); + + Map> bufferToFutureMap = new HashMap<>(); + LOG.debug("Batch Chunks created after threshold check: {}", batchedBuffers.size()); + for (Buffer buffer : batchedBuffers) { + InvokeRequest requestPayload = buffer.getRequestPayload(config.getFunctionName(), + config.getInvocationType().getAwsLambdaValue()); + CompletableFuture future = lambdaAsyncClient.invoke(requestPayload); + bufferToFutureMap.put(buffer, future); + } + waitForFutures(bufferToFutureMap.values()); + return bufferToFutureMap; } - waitForFutures(bufferToFutureMap.values()); - return bufferToFutureMap; - } } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/ResponseEventHandlingStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/ResponseEventHandlingStrategy.java index e27f0e1b89..b4fe7aa1a1 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/ResponseEventHandlingStrategy.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/ResponseEventHandlingStrategy.java @@ -7,10 +7,10 @@ import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import java.util.List; public interface ResponseEventHandlingStrategy { - void handleEvents(List parsedEvents, List> originalRecords, List> resultRecords, Buffer flushedBuffer); + + List> handleEvents(List parsedEvents, List> originalRecords); } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategy.java index b19dd8d156..0407cf7273 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategy.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategy.java @@ -9,21 +9,17 @@ import org.opensearch.dataprepper.model.event.DefaultEventHandle; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import org.opensearch.dataprepper.plugins.lambda.common.ResponseEventHandlingStrategy; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import java.util.ArrayList; import java.util.List; public class AggregateResponseEventHandlingStrategy implements ResponseEventHandlingStrategy { - private static final Logger LOG = LoggerFactory.getLogger(AggregateResponseEventHandlingStrategy.class); - @Override - public void handleEvents(List parsedEvents, List> originalRecords, - List> resultRecords, Buffer flushedBuffer) { + public List> handleEvents(List parsedEvents, List> originalRecords) { + List> resultRecords = new ArrayList<>(); Event originalEvent = originalRecords.get(0).getData(); DefaultEventHandle eventHandle = (DefaultEventHandle) originalEvent.getEventHandle(); AcknowledgementSet originalAcknowledgementSet = eventHandle.getAcknowledgementSet(); @@ -38,5 +34,6 @@ public void handleEvents(List parsedEvents, List> originalR originalAcknowledgementSet.add(responseEvent); } } + return resultRecords; } } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java index 85e33e192c..719a09eee6 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java @@ -5,23 +5,9 @@ package org.opensearch.dataprepper.plugins.lambda.processor; -import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; -import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess; - import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Timer; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; import org.opensearch.dataprepper.metrics.PluginMetrics; @@ -49,14 +35,30 @@ import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; +import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess; + @DataPrepperPlugin(name = "aws_lambda", pluginType = Processor.class, pluginConfigurationType = LambdaProcessorConfig.class) public class LambdaProcessor extends AbstractProcessor, Record> { - public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS = "lambdaProcessorObjectsEventsSucceeded"; - public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED = "lambdaProcessorObjectsEventsFailed"; - public static final String NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA = "lambdaProcessorNumberOfRequestsSucceeded"; - public static final String NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA = "lambdaProcessorNumberOfRequestsFailed"; - public static final String LAMBDA_LATENCY_METRIC = "lambdaProcessorLatency"; + public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS = "recordsSuccessfullySentToLambda"; + public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED = "recordsFailedToSentLambda"; + public static final String NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA = "numberOfRequestsSucceeded"; + public static final String NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA = "numberOfRequestsFailed"; + public static final String LAMBDA_LATENCY_METRIC = "lambdaFunctionLatency"; public static final String REQUEST_PAYLOAD_SIZE = "requestPayloadSize"; public static final String RESPONSE_PAYLOAD_SIZE = "responsePayloadSize"; @@ -80,21 +82,23 @@ public class LambdaProcessor extends AbstractProcessor, Record> doExecute(Collection> records) { for (Record record : records) { final Event event = record.getData(); // If the condition is false, add the event to resultRecords as-is - if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) { + if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, + event)) { resultRecords.add(record); continue; } recordsToLambda.add(record); } - Map> bufferToFutureMap = LambdaCommonHandler.sendRecords( - recordsToLambda, lambdaProcessorConfig, lambdaAsyncClient, - new OutputCodecContext()); + Map> bufferToFutureMap = new HashMap<>(); + try { + bufferToFutureMap = LambdaCommonHandler.sendRecords( + recordsToLambda, lambdaProcessorConfig, lambdaAsyncClient, + new OutputCodecContext()); + } catch (Exception e) { + LOG.error(NOISY, "Error while sending records to Lambda", e); + resultRecords.addAll(addFailureTags(recordsToLambda)); + } + for (Map.Entry> entry : bufferToFutureMap.entrySet()) { CompletableFuture future = entry.getValue(); Buffer inputBuffer = entry.getKey(); @@ -162,25 +174,26 @@ public Collection> doExecute(Collection> records) { Duration latency = inputBuffer.stopLatencyWatch(); lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); requestPayloadMetric.record(inputBuffer.getPayloadRequestSize()); - if (isSuccess(response)) { - resultRecords.addAll(convertLambdaResponseToEvent(inputBuffer, response)); - numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); - numberOfRequestsSuccessCounter.increment(); - if (response.payload() != null) { - responsePayloadMetric.record(response.payload().asByteArray().length); - } - continue; - } else { - LOG.error("Lambda invoke failed with error {} ", response.statusCode()); - /* fall through */ + if (!isSuccess(response)) { + String errorMessage = String.format("Lambda invoke failed with status code %s error %s ", + response.statusCode(), response.payload().asUtf8String()); + throw new RuntimeException(errorMessage); + } + + resultRecords.addAll(convertLambdaResponseToEvent(inputBuffer, response)); + numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsSuccessCounter.increment(); + if (response.payload() != null) { + responsePayloadMetric.record(response.payload().asByteArray().length); } + } catch (Exception e) { - LOG.error("Exception from Lambda invocation ", e); + LOG.error(NOISY, e.getMessage(), e); /* fall through */ + numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsFailedCounter.increment(); + resultRecords.addAll(addFailureTags(inputBuffer.getRecords())); } - numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); - numberOfRequestsFailedCounter.increment(); - resultRecords.addAll(addFailureTags(inputBuffer.getRecords())); } return resultRecords; } @@ -191,39 +204,36 @@ public Collection> doExecute(Collection> records) { * 2. If it is not an array, then create one event per response. */ List> convertLambdaResponseToEvent(Buffer flushedBuffer, - final InvokeResponse lambdaResponse) { + final InvokeResponse lambdaResponse) throws IOException { InputCodec responseCodec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSetting); List> originalRecords = flushedBuffer.getRecords(); List parsedEvents = new ArrayList<>(); - List> resultRecords = new ArrayList<>(); SdkBytes payload = lambdaResponse.payload(); // Handle null or empty payload - if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) { - LOG.warn(NOISY, "Lambda response payload is null or empty, dropping the original events"); - } else { - InputStream inputStream = new ByteArrayInputStream(payload.asByteArray()); - //Convert to response codec - try { - responseCodec.parse(inputStream, record -> { - Event event = record.getData(); - parsedEvents.add(event); - }); - } catch (IOException ex) { - LOG.error("Error while trying to parse response from Lambda", ex); - throw new RuntimeException(ex); - } - if (parsedEvents.size() == 0) { - throw new RuntimeException("Lambda Response could not be parsed, returning original events"); - } + if (payload == null || payload.asByteArray().length == 0) { + LOG.warn(NOISY, + "Lambda response payload is null or empty, dropping the original events"); + return responseStrategy.handleEvents(parsedEvents, originalRecords); + } - LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " + - "FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), - flushedBuffer.getSize()); - responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); + //Convert using response codec + InputStream inputStream = new ByteArrayInputStream(payload.asByteArray()); + responseCodec.parse(inputStream, record -> { + Event event = record.getData(); + parsedEvents.add(event); + }); + + if (parsedEvents.isEmpty()) { + throw new RuntimeException( + "Lambda Response could not be parsed, returning original events"); } - return resultRecords; + + LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " + + "FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), + flushedBuffer.getSize()); + return responseStrategy.handleEvents(parsedEvents, originalRecords); } /* diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategy.java index 4d6a8e9f28..2744534c8d 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategy.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategy.java @@ -7,20 +7,23 @@ import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import org.opensearch.dataprepper.plugins.lambda.common.ResponseEventHandlingStrategy; +import java.util.ArrayList; import java.util.List; import java.util.Map; public class StrictResponseEventHandlingStrategy implements ResponseEventHandlingStrategy { @Override - public void handleEvents(List parsedEvents, List> originalRecords, List> resultRecords, Buffer flushedBuffer) { - if (parsedEvents.size() != flushedBuffer.getEventCount()) { - throw new RuntimeException("Response Processing Mode is configured as Strict mode but behavior is aggregate mode. Event count mismatch."); + public List> handleEvents(List parsedEvents, + List> originalRecords) { + if (parsedEvents.size() != originalRecords.size()) { + throw new RuntimeException( + "Response Processing Mode is configured as Strict mode but behavior is aggregate mode. Event count mismatch."); } + List> resultRecords = new ArrayList<>(); for (int i = 0; i < parsedEvents.size(); i++) { Event responseEvent = parsedEvents.get(i); Event originalEvent = originalRecords.get(i).getData(); @@ -37,6 +40,7 @@ public void handleEvents(List parsedEvents, List> originalR // Add updated event to resultRecords resultRecords.add(originalRecords.get(i)); } + return resultRecords; } } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java index f4e59e967d..8240943374 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java @@ -5,17 +5,9 @@ package org.opensearch.dataprepper.plugins.lambda.sink; -import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess; - import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Timer; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; @@ -42,183 +34,197 @@ import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import java.time.Duration; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; +import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess; + @DataPrepperPlugin(name = "aws_lambda", pluginType = Sink.class, pluginConfigurationType = LambdaSinkConfig.class) public class LambdaSink extends AbstractSink> { - public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS = "lambdaSinkObjectsEventsSucceeded"; - public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED = "lambdaSinkObjectsEventsFailed"; - public static final String NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA = "lambdaSinkNumberOfRequestsSucceeded"; - public static final String NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA = "lambdaSinkNumberOfRequestsFailed"; - public static final String LAMBDA_LATENCY_METRIC = "lambdaSinkLatency"; - public static final String REQUEST_PAYLOAD_SIZE = "lambdaSinkRequestPayloadSize"; - public static final String RESPONSE_PAYLOAD_SIZE = "lambdaSinkResponsePayloadSize"; - - private static final Logger LOG = LoggerFactory.getLogger(LambdaSink.class); - private static final String BUCKET = "bucket"; - private static final String KEY_PATH = "key_path_prefix"; - private final Counter numberOfRecordsSuccessCounter; - private final Counter numberOfRecordsFailedCounter; - private final Counter numberOfRequestsSuccessCounter; - private final Counter numberOfRequestsFailedCounter; - private final LambdaSinkConfig lambdaSinkConfig; - private final ExpressionEvaluator expressionEvaluator; - private final LambdaAsyncClient lambdaAsyncClient; - private final DistributionSummary responsePayloadMetric; - private final Timer lambdaLatencyMetric; - private final DistributionSummary requestPayloadMetric; - private final PluginSetting pluginSetting; - private final OutputCodecContext outputCodecContext; - private volatile boolean sinkInitialized; - private DlqPushHandler dlqPushHandler = null; - - @DataPrepperPluginConstructor - public LambdaSink(final PluginSetting pluginSetting, - final LambdaSinkConfig lambdaSinkConfig, - final PluginFactory pluginFactory, - final SinkContext sinkContext, - final AwsCredentialsSupplier awsCredentialsSupplier, - final ExpressionEvaluator expressionEvaluator - ) { - super(pluginSetting); - this.pluginSetting = pluginSetting; - sinkInitialized = Boolean.FALSE; - this.lambdaSinkConfig = lambdaSinkConfig; - this.expressionEvaluator = expressionEvaluator; - this.outputCodecContext = OutputCodecContext.fromSinkContext(sinkContext); - - this.numberOfRecordsSuccessCounter = pluginMetrics.counter( - NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS); - this.numberOfRecordsFailedCounter = pluginMetrics.counter( - NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED); - this.numberOfRequestsSuccessCounter = pluginMetrics.counter( - NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA); - this.numberOfRequestsFailedCounter = pluginMetrics.counter( - NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA); - this.lambdaLatencyMetric = pluginMetrics.timer(LAMBDA_LATENCY_METRIC); - this.requestPayloadMetric = pluginMetrics.summary(REQUEST_PAYLOAD_SIZE); - this.responsePayloadMetric = pluginMetrics.summary(RESPONSE_PAYLOAD_SIZE); - ClientOptions clientOptions = lambdaSinkConfig.getClientOptions(); - if (clientOptions == null) { - clientOptions = new ClientOptions(); + public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS = "recordsSuccessfullySentToLambda"; + public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED = "recordsFailedToSentLambda"; + public static final String NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA = "numberOfRequestsSucceeded"; + public static final String NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA = "numberOfRequestsFailed"; + public static final String LAMBDA_LATENCY_METRIC = "lambdaFunctionLatency"; + public static final String REQUEST_PAYLOAD_SIZE = "requestPayloadSize"; + public static final String RESPONSE_PAYLOAD_SIZE = "responsePayloadSize"; + + private static final Logger LOG = LoggerFactory.getLogger(LambdaSink.class); + private static final String BUCKET = "bucket"; + private static final String KEY_PATH = "key_path_prefix"; + private final Counter numberOfRecordsSuccessCounter; + private final Counter numberOfRecordsFailedCounter; + private final Counter numberOfRequestsSuccessCounter; + private final Counter numberOfRequestsFailedCounter; + private final LambdaSinkConfig lambdaSinkConfig; + private final ExpressionEvaluator expressionEvaluator; + private final LambdaAsyncClient lambdaAsyncClient; + private final DistributionSummary responsePayloadMetric; + private final Timer lambdaLatencyMetric; + private final DistributionSummary requestPayloadMetric; + private final PluginSetting pluginSetting; + private final OutputCodecContext outputCodecContext; + private volatile boolean sinkInitialized; + private DlqPushHandler dlqPushHandler = null; + + @DataPrepperPluginConstructor + public LambdaSink(final PluginSetting pluginSetting, + final LambdaSinkConfig lambdaSinkConfig, + final PluginFactory pluginFactory, + final SinkContext sinkContext, + final AwsCredentialsSupplier awsCredentialsSupplier, + final ExpressionEvaluator expressionEvaluator + ) { + super(pluginSetting); + this.pluginSetting = pluginSetting; + sinkInitialized = Boolean.FALSE; + this.lambdaSinkConfig = lambdaSinkConfig; + this.expressionEvaluator = expressionEvaluator; + this.outputCodecContext = OutputCodecContext.fromSinkContext(sinkContext); + + this.numberOfRecordsSuccessCounter = pluginMetrics.counter( + NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS); + this.numberOfRecordsFailedCounter = pluginMetrics.counter( + NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED); + this.numberOfRequestsSuccessCounter = pluginMetrics.counter( + NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA); + this.numberOfRequestsFailedCounter = pluginMetrics.counter( + NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA); + this.lambdaLatencyMetric = pluginMetrics.timer(LAMBDA_LATENCY_METRIC); + this.requestPayloadMetric = pluginMetrics.summary(REQUEST_PAYLOAD_SIZE); + this.responsePayloadMetric = pluginMetrics.summary(RESPONSE_PAYLOAD_SIZE); + ClientOptions clientOptions = lambdaSinkConfig.getClientOptions(); + if (clientOptions == null) { + clientOptions = new ClientOptions(); + } + this.lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient( + lambdaSinkConfig.getAwsAuthenticationOptions(), + awsCredentialsSupplier, + clientOptions + ); + if (lambdaSinkConfig.getDlqPluginSetting() != null) { + this.dlqPushHandler = new DlqPushHandler(pluginFactory, + String.valueOf(lambdaSinkConfig.getDlqPluginSetting().get(BUCKET)), + lambdaSinkConfig.getDlqStsRoleARN() + , lambdaSinkConfig.getDlqStsRegion(), + String.valueOf(lambdaSinkConfig.getDlqPluginSetting().get(KEY_PATH))); + } + } - this.lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient( - lambdaSinkConfig.getAwsAuthenticationOptions(), - awsCredentialsSupplier, - clientOptions - ); - if (lambdaSinkConfig.getDlqPluginSetting() != null) { - this.dlqPushHandler = new DlqPushHandler(pluginFactory, - String.valueOf(lambdaSinkConfig.getDlqPluginSetting().get(BUCKET)), - lambdaSinkConfig.getDlqStsRoleARN() - , lambdaSinkConfig.getDlqStsRegion(), - String.valueOf(lambdaSinkConfig.getDlqPluginSetting().get(KEY_PATH))); + + @Override + public boolean isReady() { + return sinkInitialized; + } + + @Override + public void doInitialize() { + try { + doInitializeInternal(); + } catch (InvalidPluginConfigurationException e) { + LOG.error("Invalid plugin configuration, Hence failed to initialize s3-sink plugin."); + this.shutdown(); + throw e; + } catch (Exception e) { + LOG.error("Failed to initialize lambda plugin."); + this.shutdown(); + throw e; + } } - } - - @Override - public boolean isReady() { - return sinkInitialized; - } - - @Override - public void doInitialize() { - try { - doInitializeInternal(); - } catch (InvalidPluginConfigurationException e) { - LOG.error("Invalid plugin configuration, Hence failed to initialize s3-sink plugin."); - this.shutdown(); - throw e; - } catch (Exception e) { - LOG.error("Failed to initialize lambda plugin."); - this.shutdown(); - throw e; + private void doInitializeInternal() { + sinkInitialized = Boolean.TRUE; } - } - private void doInitializeInternal() { - sinkInitialized = Boolean.TRUE; - } + /** + * @param records Records to be output + */ + @Override + public void doOutput(final Collection> records) { + if (records.isEmpty()) { + return; + } - /** - * @param records Records to be output - */ - @Override - public void doOutput(final Collection> records) { + Map> bufferToFutureMap = new HashMap<>(); + try { + //Result from lambda is not currently processes. + bufferToFutureMap = LambdaCommonHandler.sendRecords( + records, + lambdaSinkConfig, + lambdaAsyncClient, + outputCodecContext); + } catch (Exception e) { + LOG.error("Exception while processing records ", e); + //TODO: introduce DLQ handler here before releasing the records + releaseEventHandlesPerBatch(false, records); + } - if (records.isEmpty()) { - return; + for (Map.Entry> entry : bufferToFutureMap.entrySet()) { + CompletableFuture future = entry.getValue(); + Buffer inputBuffer = entry.getKey(); + try { + InvokeResponse response = future.join(); + Duration latency = inputBuffer.stopLatencyWatch(); + lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); + requestPayloadMetric.record(inputBuffer.getPayloadRequestSize()); + if (!isSuccess(response)) { + String errorMessage = String.format("Lambda invoke failed with status code %s error %s ", + response.statusCode(), response.payload().asUtf8String()); + throw new RuntimeException(errorMessage); + } + + releaseEventHandlesPerBatch(true, inputBuffer.getRecords()); + numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsSuccessCounter.increment(); + if (response.payload() != null) { + responsePayloadMetric.record(response.payload().asByteArray().length); + } + + } catch (Exception e) { + LOG.error(NOISY, e.getMessage(), e); + numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsFailedCounter.increment(); + handleFailure(new RuntimeException("failed"), inputBuffer); + } + } } - //Result from lambda is not currently processes. - Map> bufferToFutureMap = LambdaCommonHandler.sendRecords( - records, - lambdaSinkConfig, - lambdaAsyncClient, - outputCodecContext); - - for (Map.Entry> entry : bufferToFutureMap.entrySet()) { - CompletableFuture future = entry.getValue(); - Buffer inputBuffer = entry.getKey(); - try { - InvokeResponse response = future.join(); - Duration latency = inputBuffer.stopLatencyWatch(); - lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); - requestPayloadMetric.record(inputBuffer.getPayloadRequestSize()); - if (isSuccess(response)) { - releaseEventHandlesPerBatch(true, inputBuffer); - numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); - numberOfRequestsSuccessCounter.increment(); - if (response.payload() != null) { - responsePayloadMetric.record(response.payload().asByteArray().length); - } - continue; - } else { - LOG.error("Lambda invoke failed with error {} ", response.statusCode()); - handleFailure(new RuntimeException("failed"), inputBuffer); + + void handleFailure(Throwable throwable, Buffer flushedBuffer) { + try { + numberOfRecordsFailedCounter.increment(flushedBuffer.getEventCount()); + SdkBytes payload = flushedBuffer.getPayload(); + if (dlqPushHandler != null) { + dlqPushHandler.perform(pluginSetting, + new LambdaSinkFailedDlqData(payload, throwable.getMessage(), 0)); + releaseEventHandlesPerBatch(true, flushedBuffer.getRecords()); + } else { + releaseEventHandlesPerBatch(false, flushedBuffer.getRecords()); + } + } catch (Exception ex) { + LOG.error("Exception occurred during error handling"); + releaseEventHandlesPerBatch(false, flushedBuffer.getRecords()); } - } catch (Exception e) { - LOG.error("Exception from Lambda invocation ", e); - numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); - numberOfRequestsFailedCounter.increment(); - handleFailure(new RuntimeException("failed"), inputBuffer); - } - } - } - - - void handleFailure(Throwable throwable, Buffer flushedBuffer) { - try { - if (flushedBuffer.getEventCount() > 0) { - numberOfRecordsFailedCounter.increment(flushedBuffer.getEventCount()); - } - - SdkBytes payload = flushedBuffer.getPayload(); - if (dlqPushHandler != null) { - dlqPushHandler.perform(pluginSetting, - new LambdaSinkFailedDlqData(payload, throwable.getMessage(), 0)); - releaseEventHandlesPerBatch(true, flushedBuffer); - } else { - releaseEventHandlesPerBatch(false, flushedBuffer); - } - } catch (Exception ex) { - LOG.error("Exception occured during error handling"); } - } - - /* - * Release events per batch - */ - private void releaseEventHandlesPerBatch(boolean success, Buffer flushedBuffer) { - List> records = flushedBuffer.getRecords(); - for (Record record : records) { - Event event = record.getData(); - if (event != null) { - EventHandle eventHandle = event.getEventHandle(); - if (eventHandle != null) { - eventHandle.release(success); + + /* + * Release events per batch + */ + private void releaseEventHandlesPerBatch(boolean success, Collection> records) { + for (Record record : records) { + Event event = record.getData(); + if (event != null) { + EventHandle eventHandle = event.getEventHandle(); + if (eventHandle != null) { + eventHandle.release(success); + } + } } - } } - } } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java index 44be8f5dd1..284f463868 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java @@ -1,22 +1,9 @@ package org.opensearch.dataprepper.plugins.lambda.common; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; @@ -25,114 +12,116 @@ import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.sink.OutputCodecContext; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; -import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; -import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; +import org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessorConfig; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.createLambdaConfigurationFromYaml; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleEventRecords; + @ExtendWith(MockitoExtension.class) @MockitoSettings(strictness = Strictness.LENIENT) class LambdaCommonHandlerTest { - @Mock - private LambdaAsyncClient lambdaAsyncClient; - - @Mock - private LambdaCommonConfig config; - - @Mock - private BatchOptions batchOptions; - - @Mock - private OutputCodecContext outputCodecContext; - - @Test - void testCheckStatusCode() { - InvokeResponse successResponse = InvokeResponse.builder().statusCode(200).build(); - InvokeResponse failureResponse = InvokeResponse.builder().statusCode(400).build(); - - assertTrue(LambdaCommonHandler.isSuccess(successResponse)); - assertFalse(LambdaCommonHandler.isSuccess(failureResponse)); - } - - @Test - void testWaitForFutures() { - List> futureList = new ArrayList<>(); - CompletableFuture future1 = new CompletableFuture<>(); - CompletableFuture future2 = new CompletableFuture<>(); - futureList.add(future1); - futureList.add(future2); - - // Simulate completion of futures - future1.complete(InvokeResponse.builder().build()); - future2.complete(InvokeResponse.builder().build()); - - LambdaCommonHandler.waitForFutures(futureList); - - assertFalse(futureList.isEmpty()); - } - - @Test - void testSendRecords() { - when(config.getBatchOptions()).thenReturn(batchOptions); - when(batchOptions.getThresholdOptions()).thenReturn(mock(ThresholdOptions.class)); - when(batchOptions.getKeyName()).thenReturn("testKey"); - when(config.getFunctionName()).thenReturn("testFunction"); - when(config.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); - when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) - .thenReturn( - CompletableFuture.completedFuture(InvokeResponse.builder().statusCode(200).build())); - - Event mockEvent = mock(Event.class); - when(mockEvent.toMap()).thenReturn(Collections.singletonMap("testKey", "testValue")); - List> records = Collections.singletonList(new Record<>(mockEvent)); - - Map> bufferCompletableFutureMap = LambdaCommonHandler.sendRecords( - records, config, lambdaAsyncClient, - outputCodecContext); - - assertNotNull(bufferCompletableFutureMap); - verify(lambdaAsyncClient, atLeastOnce()).invoke(any(InvokeRequest.class)); - } - - @Test - void testSendRecordsWithNullKeyName() { - when(config.getBatchOptions()).thenReturn(batchOptions); - when(batchOptions.getThresholdOptions()).thenReturn(mock(ThresholdOptions.class)); - when(batchOptions.getKeyName()).thenReturn(null); - when(config.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); - when(config.getFunctionName()).thenReturn("testFunction"); - - Event mockEvent = mock(Event.class); - when(mockEvent.toMap()).thenReturn(Collections.singletonMap("testKey", "testValue")); - List> records = Collections.singletonList(new Record<>(mockEvent)); - - assertThrows(NullPointerException.class, () -> - LambdaCommonHandler.sendRecords(records, config, lambdaAsyncClient, outputCodecContext) - ); - } - - @Test - void testSendRecordsWithFailure() { - when(config.getBatchOptions()).thenReturn(batchOptions); - when(batchOptions.getThresholdOptions()).thenReturn(mock(ThresholdOptions.class)); - when(batchOptions.getKeyName()).thenReturn("testKey"); - when(config.getFunctionName()).thenReturn("testFunction"); - when(config.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); - when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) - .thenReturn(CompletableFuture.failedFuture(new RuntimeException("Test exception"))); - - List> records = new ArrayList<>(); - records.add(new Record<>(mock(Event.class))); - - Map> bufferCompletableFutureMap = LambdaCommonHandler.sendRecords( - records, config, lambdaAsyncClient, - outputCodecContext); - - assertNotNull(bufferCompletableFutureMap); - verify(lambdaAsyncClient, atLeastOnce()).invoke(any(InvokeRequest.class)); - } + @Mock + private LambdaAsyncClient lambdaAsyncClient; + + @Mock + private OutputCodecContext outputCodecContext; + + @Test + void testCheckStatusCode() { + InvokeResponse successResponse = InvokeResponse.builder().statusCode(200).build(); + InvokeResponse failureResponse = InvokeResponse.builder().statusCode(400).build(); + + assertTrue(LambdaCommonHandler.isSuccess(successResponse)); + assertFalse(LambdaCommonHandler.isSuccess(failureResponse)); + } + + @Test + void testWaitForFutures() { + List> futureList = new ArrayList<>(); + CompletableFuture future1 = new CompletableFuture<>(); + CompletableFuture future2 = new CompletableFuture<>(); + futureList.add(future1); + futureList.add(future2); + + // Simulate completion of futures + future1.complete(InvokeResponse.builder().build()); + future2.complete(InvokeResponse.builder().build()); + + LambdaCommonHandler.waitForFutures(futureList); + + assertFalse(futureList.isEmpty()); + } + + @ParameterizedTest + @ValueSource(strings = {"lambda-processor-success-config.yaml"}) + void testSendRecords(String configFilePath) { + LambdaProcessorConfig lambdaConfiguration = createLambdaConfigurationFromYaml(configFilePath); + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) + .thenReturn( + CompletableFuture.completedFuture(InvokeResponse.builder().statusCode(200).build())); + + int oneRandomCount = (int) (Math.random() * 1000); + List> records = getSampleEventRecords(oneRandomCount); + + Map> bufferCompletableFutureMap = LambdaCommonHandler.sendRecords( + records, lambdaConfiguration, lambdaAsyncClient, + outputCodecContext); + + assertNotNull(bufferCompletableFutureMap); + int batchSize = lambdaConfiguration.getBatchOptions().getThresholdOptions().getEventCount(); + int bufferBatchCount = (int) Math.ceil((1.0 * oneRandomCount) / batchSize); + assertEquals(bufferBatchCount, + bufferCompletableFutureMap.size()); + verify(lambdaAsyncClient, atLeastOnce()).invoke(any(InvokeRequest.class)); + } + + @ParameterizedTest + @ValueSource(strings = {"lambda-processor-null-key-name.yaml"}) + void testSendRecordsWithNullKeyName(String configFilePath) { + LambdaProcessorConfig lambdaConfiguration = createLambdaConfigurationFromYaml(configFilePath); + + Event mockEvent = mock(Event.class); + when(mockEvent.toMap()).thenReturn(Collections.singletonMap("testKey", "testValue")); + List> records = Collections.singletonList(new Record<>(mockEvent)); + + assertThrows(NullPointerException.class, () -> + LambdaCommonHandler.sendRecords(records, lambdaConfiguration, lambdaAsyncClient, outputCodecContext) + ); + } + + @ParameterizedTest + @ValueSource(strings = {"lambda-processor-success-config.yaml"}) + void testSendRecordsWithFailure(String configFilePath) { + LambdaProcessorConfig lambdaConfiguration = createLambdaConfigurationFromYaml(configFilePath); + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException("Test exception"))); + + List> records = new ArrayList<>(); + records.add(new Record<>(mock(Event.class))); + + assertThrows(RuntimeException.class, () -> LambdaCommonHandler.sendRecords( + records, lambdaConfiguration, lambdaAsyncClient, + outputCodecContext)); + verify(lambdaAsyncClient, atLeastOnce()).invoke(any(InvokeRequest.class)); + } } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategyTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategyTest.java index 691e72ee5c..87e78d3551 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategyTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategyTest.java @@ -5,30 +5,28 @@ package org.opensearch.dataprepper.plugins.lambda.processor; -import static org.junit.jupiter.api.Assertions.assertEquals; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import org.mockito.MockitoAnnotations; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.event.DefaultEventHandle; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + public class AggregateResponseEventHandlingStrategyTest { - @Mock - private Buffer flushedBuffer; @Mock private AcknowledgementSet acknowledgementSet; @@ -46,7 +44,6 @@ public class AggregateResponseEventHandlingStrategyTest { private Event parsedEvent2; private List> originalRecords; - private List> resultRecords; private AggregateResponseEventHandlingStrategy aggregateResponseEventHandlingStrategy; @@ -57,7 +54,6 @@ public void setUp() { // Set up original records list with a mock original event originalRecords = new ArrayList<>(); - resultRecords = new ArrayList<>(); originalRecords.add(new Record<>(originalEvent)); // Mock event handle and acknowledgement set @@ -71,7 +67,7 @@ public void testHandleEvents_AddsParsedEventsToResultRecords() { List parsedEvents = Arrays.asList(parsedEvent1, parsedEvent2); // Act - aggregateResponseEventHandlingStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); + List> resultRecords = aggregateResponseEventHandlingStrategy.handleEvents(parsedEvents, originalRecords); // Assert assertEquals(2, resultRecords.size()); @@ -92,7 +88,7 @@ public void testHandleEvents_NoAcknowledgementSet_DoesNotThrowException() { when(eventHandle.getAcknowledgementSet()).thenReturn(null); // Act - aggregateResponseEventHandlingStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); + List> resultRecords = aggregateResponseEventHandlingStrategy.handleEvents(parsedEvents, originalRecords); // Assert assertEquals(2, resultRecords.size()); @@ -109,7 +105,7 @@ public void testHandleEvents_EmptyParsedEvents_DoesNotAddToResultRecords() { List parsedEvents = new ArrayList<>(); // Act - aggregateResponseEventHandlingStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); + List> resultRecords = aggregateResponseEventHandlingStrategy.handleEvents(parsedEvents, originalRecords); // Assert assertEquals(0, resultRecords.size()); diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java index baa8989d22..caed598787 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java @@ -4,42 +4,15 @@ */ package org.opensearch.dataprepper.plugins.lambda.processor; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyDouble; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.dataprepper.plugins.lambda.sink.LambdaSinkTest.getSampleRecord; -import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.createLambdaConfigurationFromYaml; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Timer; -import java.io.InputStream; -import java.lang.reflect.Field; -import java.time.Duration; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; - -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -47,28 +20,59 @@ import org.mockito.quality.Strictness; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; -import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.DefaultEventHandle; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.event.EventMetadata; import org.opensearch.dataprepper.model.plugin.PluginFactory; import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.model.types.ByteCount; -import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler; +import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodec; +import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodecConfig; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; -import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import java.io.InputStream; +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.createLambdaConfigurationFromYaml; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleEventRecords; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleRecord; + @MockitoSettings(strictness = Strictness.LENIENT) public class LambdaProcessorTest { @@ -85,18 +89,12 @@ public class LambdaProcessorTest { @Mock private PluginSetting pluginSetting; - @Mock - private LambdaProcessorConfig lambdaProcessorConfig; - @Mock private AwsCredentialsSupplier awsCredentialsSupplier; @Mock private ExpressionEvaluator expressionEvaluator; - @Mock - private LambdaCommonHandler lambdaCommonHandler; - @Mock private InputCodec responseCodec; @@ -128,8 +126,14 @@ public class LambdaProcessorTest { @Mock private LambdaAsyncClient lambdaAsyncClient; - // The class under test - private LambdaProcessor lambdaProcessor; + + private static Stream getLambdaResponseConversionSamples() { + return Stream.of( + arguments("lambda-processor-success-config.yaml", null), + arguments("lambda-processor-success-config.yaml", SdkBytes.fromByteArray("{}".getBytes())), + arguments("lambda-processor-success-config.yaml", SdkBytes.fromByteArray("[]".getBytes())) + ); + } @BeforeEach public void setUp() throws Exception { @@ -150,38 +154,21 @@ public void setUp() throws Exception { when(pluginMetrics.timer(anyString())).thenReturn(lambdaLatencyMetric); */ - ClientOptions clientOptions = new ClientOptions(); - when(lambdaProcessorConfig.getClientOptions()).thenReturn(clientOptions); - when(lambdaProcessorConfig.getFunctionName()).thenReturn("test-function"); - when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn( - awsAuthenticationOptions); - when(lambdaProcessorConfig.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); - BatchOptions batchOptions = mock(BatchOptions.class); - ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); - when(lambdaProcessorConfig.getBatchOptions()).thenReturn(batchOptions); - when(lambdaProcessorConfig.getWhenCondition()).thenReturn(null); - when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(false); // Mock AWS Authentication Options when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1); when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn("testRole"); // Mock BatchOptions and ThresholdOptions - when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); - when(thresholdOptions.getEventCount()).thenReturn(10); - when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("6mb")); - when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.ofSeconds(30)); - when(batchOptions.getKeyName()).thenReturn("key"); // Mock PluginFactory to return the mocked responseCodec when(pluginFactory.loadPlugin(eq(InputCodec.class), any(PluginSetting.class))).thenReturn( - responseCodec); + new JsonInputCodec(new JsonInputCodecConfig())); // Instantiate the LambdaProcessor manually - lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, lambdaProcessorConfig, - awsCredentialsSupplier, expressionEvaluator); - populatePrivateFields(); + +// populatePrivateFields(); //setPrivateField(lambdaProcessor, "pluginMetrics", pluginMetrics); // Mock InvokeResponse when(invokeResponse.payload()).thenReturn(SdkBytes.fromUtf8String("[{\"key\":\"value\"}]")); @@ -189,24 +176,24 @@ public void setUp() throws Exception { // Mock the invoke method to return a completed future CompletableFuture invokeFuture = CompletableFuture.completedFuture( - invokeResponse); + invokeResponse); when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(invokeFuture); // Mock Response Codec parse method - doNothing().when(responseCodec).parse(any(InputStream.class), any(Consumer.class)); +// doNothing().when(responseCodec).parse(any(InputStream.class), any(Consumer.class)); } - private void populatePrivateFields() throws Exception { + private void populatePrivateFields(LambdaProcessor lambdaProcessor) throws Exception { // Use reflection to set the private fields setPrivateField(lambdaProcessor, "numberOfRecordsSuccessCounter", - numberOfRecordsSuccessCounter); + numberOfRecordsSuccessCounter); setPrivateField(lambdaProcessor, "numberOfRequestsSuccessCounter", - numberOfRequestsSuccessCounter); + numberOfRequestsSuccessCounter); setPrivateField(lambdaProcessor, "numberOfRecordsFailedCounter", - numberOfRecordsFailedCounter); + numberOfRecordsFailedCounter); setPrivateField(lambdaProcessor, "numberOfRequestsFailedCounter", - numberOfRequestsFailedCounter); + numberOfRequestsFailedCounter); setPrivateField(lambdaProcessor, "lambdaLatencyMetric", lambdaLatencyMetric); setPrivateField(lambdaProcessor, "responsePayloadMetric", responsePayloadMetric); setPrivateField(lambdaProcessor, "requestPayloadMetric", requestPayloadMetric); @@ -215,7 +202,7 @@ private void populatePrivateFields() throws Exception { // Helper method to set private fields via reflection private void setPrivateField(Object targetObject, String fieldName, Object value) - throws Exception { + throws Exception { Field field = targetObject.getClass().getDeclaredField(fieldName); field.setAccessible(true); field.set(targetObject, value); @@ -239,10 +226,10 @@ public void testProcessorDefaults() { ClientOptions clientOptions = defaultConfig.getClientOptions(); assertNotNull(clientOptions); assertEquals(ClientOptions.DEFAULT_CONNECTION_RETRIES, - clientOptions.getMaxConnectionRetries()); + clientOptions.getMaxConnectionRetries()); assertEquals(ClientOptions.DEFAULT_API_TIMEOUT, clientOptions.getApiCallTimeout()); assertEquals(ClientOptions.DEFAULT_CONNECTION_TIMEOUT, - clientOptions.getConnectionTimeout()); + clientOptions.getConnectionTimeout()); assertEquals(ClientOptions.DEFAULT_MAXIMUM_CONCURRENCY, clientOptions.getMaxConcurrency()); assertEquals(ClientOptions.DEFAULT_BASE_DELAY, clientOptions.getBaseDelay()); assertEquals(ClientOptions.DEFAULT_MAX_BACKOFF, clientOptions.getMaxBackoff()); @@ -258,15 +245,18 @@ public void testDoExecute_WithExceptionInSendRecords(String configFileName) thro // Arrange List> records = Collections.singletonList(getSampleRecord()); LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml( - configFileName); - lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, - lambdaProcessorConfig, - awsCredentialsSupplier, expressionEvaluator); - populatePrivateFields(); + configFileName); + LambdaProcessor lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, lambdaProcessorConfig, + awsCredentialsSupplier, expressionEvaluator); + populatePrivateFields(lambdaProcessor); when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenThrow(new RuntimeException("test exception")); - Assertions.assertThrows( RuntimeException.class, () -> lambdaProcessor.doExecute(records)); + Collection> outputRecords = lambdaProcessor.doExecute(records); + assertNotNull(outputRecords); + assertEquals(1, outputRecords.size()); + Record record = outputRecords.iterator().next(); + assertEquals("[lambda_failure]", record.getData().getMetadata().getTags().toString()); } @@ -277,10 +267,9 @@ public void testDoExecute_WithExceptionDuringProcessing(String configFileName) t List> records = Collections.singletonList(getSampleRecord()); LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml( configFileName); - lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, - lambdaProcessorConfig, + LambdaProcessor lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, lambdaProcessorConfig, awsCredentialsSupplier, expressionEvaluator); - populatePrivateFields(); + populatePrivateFields(lambdaProcessor); CompletableFuture invokeFuture = CompletableFuture.completedFuture( @@ -288,54 +277,74 @@ public void testDoExecute_WithExceptionDuringProcessing(String configFileName) t when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(invokeFuture); when(invokeResponse.payload()).thenThrow(new RuntimeException("Test Exception")); + Collection> result = lambdaProcessor.doExecute(records); // Assert assertEquals(1, result.size()); verify(numberOfRecordsFailedCounter, times(1)).increment(1.0); } - @Test - public void testDoExecute_WithEmptyResponse() throws Exception { + @ParameterizedTest + @ValueSource(strings = {"lambda-processor-success-config.yaml"}) + public void testDoExecute_UnableParseResponse(String configFileName) throws Exception { // Arrange - Event event = mock(Event.class); - Record record = new Record<>(event); - List> records = Collections.singletonList(record); + int recordCount = (int) (Math.random() * 100); + List> records = getSampleEventRecords(recordCount); + InvokeResponse invokeResponse = mock(InvokeResponse.class); - // Mock Buffer to return empty payload - when(invokeResponse.payload()).thenReturn(SdkBytes.fromUtf8String("")); + // Mock Buffer to return empty payload + when(invokeResponse.payload()).thenReturn(SdkBytes.fromUtf8String("[{\"key\": \"value\"}]")); + LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml(configFileName); + LambdaProcessor lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, lambdaProcessorConfig, + awsCredentialsSupplier, expressionEvaluator); + populatePrivateFields(lambdaProcessor); // Act Collection> result = lambdaProcessor.doExecute(records); // Assert - assertEquals(0, result.size(), "Result should be empty due to empty Lambda response."); - verify(numberOfRecordsSuccessCounter, times(1)).increment(1.0); + assertEquals(recordCount, result.size(), "Result should be empty due to empty Lambda response."); + verify(numberOfRecordsSuccessCounter, times(0)).increment(1.0); + verify(numberOfRecordsFailedCounter, times(1)).increment(recordCount); } - @Test - public void testDoExecute_WithNullResponse() throws Exception { + @ParameterizedTest + @ValueSource(strings = {"lambda-processor-success-config.yaml"}) + public void testDoExecute_WithNullResponse_get_original_records_with_tags(String configFileName) throws Exception { // Arrange - Event event = mock(Event.class); - Record record = new Record<>(event); - List> records = Collections.singletonList(record); + + List> records = getSampleEventRecords(1); // Mock Buffer to return null payload when(invokeResponse.payload()).thenReturn(null); - + LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml(configFileName); + LambdaProcessor lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, lambdaProcessorConfig, + awsCredentialsSupplier, expressionEvaluator); + populatePrivateFields(lambdaProcessor); // Act Collection> result = lambdaProcessor.doExecute(records); // Assert - assertEquals(0, result.size(), "Result should be empty due to null Lambda response."); - verify(numberOfRecordsSuccessCounter, times(1)).increment(1.0); + assertEquals(1, result.size(), "Result should be empty due to null Lambda response."); + for (Record record : result) { + EventMetadata metadata = record.getData().getMetadata(); + assertEquals(1, metadata.getTags().size()); + assertEquals("[lambda_failure]", metadata.getTags().toString()); + } + verify(numberOfRecordsFailedCounter, times(1)).increment(1.0); + verify(numberOfRecordsSuccessCounter, times(0)).increment(0); } - @Test - public void testDoExecute_WithEmptyRecords() { + @ParameterizedTest + @ValueSource(strings = {"lambda-processor-success-config.yaml"}) + public void testDoExecute_WithEmptyRecords(String configFileName) { // Arrange Collection> records = Collections.emptyList(); // Act + LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml(configFileName); + LambdaProcessor lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, lambdaProcessorConfig, + awsCredentialsSupplier, expressionEvaluator); Collection> result = lambdaProcessor.doExecute(records); // Assert @@ -344,8 +353,9 @@ public void testDoExecute_WithEmptyRecords() { verify(numberOfRecordsFailedCounter, never()).increment(anyDouble()); } - @Test - public void testDoExecute_WhenConditionFalse() { + @ParameterizedTest + @ValueSource(strings = {"lambda-processor-when-condition-config.yaml"}) + public void testDoExecute_WhenConditionFalse(String configFileName) { // Arrange Event event = mock(Event.class); DefaultEventHandle eventHandle = mock(DefaultEventHandle.class); @@ -355,70 +365,58 @@ public void testDoExecute_WhenConditionFalse() { Record record = new Record<>(event); Collection> records = Collections.singletonList(record); + // Instantiate the LambdaProcessor manually + LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml(configFileName); + LambdaProcessor lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, lambdaProcessorConfig, + awsCredentialsSupplier, expressionEvaluator); // Mock condition evaluator to return false when(expressionEvaluator.evaluateConditional(anyString(), eq(event))).thenReturn(false); - when(lambdaProcessorConfig.getWhenCondition()).thenReturn("some_condition"); - - // Instantiate the LambdaProcessor manually - lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, lambdaProcessorConfig, - awsCredentialsSupplier, expressionEvaluator); // Act Collection> result = lambdaProcessor.doExecute(records); // Assert assertEquals(1, result.size(), - "Result should contain one record as the condition is false."); + "Result should contain one record as the condition is false."); verify(numberOfRecordsSuccessCounter, never()).increment(anyDouble()); verify(numberOfRecordsFailedCounter, never()).increment(anyDouble()); } - @Test - public void testDoExecute_SuccessfulProcessing() throws Exception { + @ParameterizedTest + @ValueSource(strings = {"lambda-processor-success-config.yaml"}) + public void testDoExecute_SuccessfulProcessing(String configFileName) throws Exception { // Arrange - Event event = mock(Event.class); - DefaultEventHandle eventHandle = mock(DefaultEventHandle.class); - AcknowledgementSet acknowledgementSet = mock(AcknowledgementSet.class); - when(event.getEventHandle()).thenReturn(eventHandle); - when(eventHandle.getAcknowledgementSet()).thenReturn(acknowledgementSet); - Record record = new Record<>(event); - Collection> records = Collections.singletonList(record); + int recordCount = 1; + Collection> records = getSampleEventRecords(recordCount); // Mock the invoke method to return a completed future - CompletableFuture invokeFuture = CompletableFuture.completedFuture( - invokeResponse); + InvokeResponse invokeResponse = InvokeResponse.builder() + .payload(SdkBytes.fromUtf8String("[{\"key1\": \"value1\", \"key2\": \"value2\"}]")) + .statusCode(200) + .build(); + CompletableFuture invokeFuture = CompletableFuture.completedFuture(invokeResponse); when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(invokeFuture); - // Mock Buffer behavior - when(bufferMock.getEventCount()).thenReturn(0).thenReturn(1).thenReturn(0); - when(bufferMock.getRecords()).thenReturn(Collections.singletonList(record)); - - doAnswer(invocation -> { - invocation.getArgument(0); - @SuppressWarnings("unchecked") - Consumer> consumer = invocation.getArgument(1); - - // Simulate parsing by providing a mocked event - Event parsedEvent = mock(Event.class); - Record parsedRecord = new Record<>(parsedEvent); - consumer.accept(parsedRecord); - - return null; - }).when(responseCodec).parse(any(InputStream.class), any(Consumer.class)); // Act + // Instantiate the LambdaProcessor manually + LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml(configFileName); + LambdaProcessor lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, lambdaProcessorConfig, + awsCredentialsSupplier, expressionEvaluator); + populatePrivateFields(lambdaProcessor); Collection> result = lambdaProcessor.doExecute(records); // Assert - assertEquals(1, result.size(), "Result should contain one record."); + assertEquals(recordCount, result.size(), "Result should contain one record."); verify(numberOfRecordsSuccessCounter, times(1)).increment(1.0); } - @Test - public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProcessing() - throws Exception { + @ParameterizedTest + @ValueSource(strings = {"lambda-processor-success-config.yaml"}) + public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProcessing(String configFileName) + throws Exception { // Arrange - when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true); + // Mock LambdaResponse with a valid payload containing two events String payloadString = "[{\"key\":\"value1\"}, {\"key\":\"value2\"}]"; @@ -454,8 +452,11 @@ public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProc when(bufferMock.getEventCount()).thenReturn(2); // Act + LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml(configFileName); + LambdaProcessor lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, lambdaProcessorConfig, + awsCredentialsSupplier, expressionEvaluator); List> resultRecords = lambdaProcessor.convertLambdaResponseToEvent(bufferMock, - invokeResponse); + invokeResponse); // Assert assertEquals(2, resultRecords.size(), "ResultRecords should contain two records."); @@ -464,12 +465,13 @@ public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProc verify(originalEvent2, never()).getMetadata(); } - @Test - public void testConvertLambdaResponseToEvent_WithUnequalEventCounts_SuccessfulProcessing() - throws Exception { + @ParameterizedTest + @ValueSource(strings = {"lambda-processor-unequal-success-config.yaml"}) + public void testConvertLambdaResponseToEvent_WithUnequalEventCounts_SuccessfulProcessing(String configFileName) + throws Exception { // Arrange // Set responseEventsMatch to false - when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(false); + // Mock LambdaResponse with a valid payload containing three events String payloadString = "[{\"key\":\"value1\"}, {\"key\":\"value2\"}, {\"key\":\"value3\"}]"; @@ -517,11 +519,38 @@ public void testConvertLambdaResponseToEvent_WithUnequalEventCounts_SuccessfulPr when(bufferMock.getEventCount()).thenReturn(2); // Act - List> resultRecords = lambdaProcessor.convertLambdaResponseToEvent(bufferMock, - invokeResponse); + LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml(configFileName); + LambdaProcessor lambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, lambdaProcessorConfig, + awsCredentialsSupplier, expressionEvaluator); + List> resultRecords = lambdaProcessor.convertLambdaResponseToEvent(bufferMock, invokeResponse); // Assert // Verify that three records are added to the result assertEquals(3, resultRecords.size(), "ResultRecords should contain three records."); } + @ParameterizedTest + @MethodSource("getLambdaResponseConversionSamples") + public void testConvertLambdaResponseToEvent_ExpectException_when_request_response_do_not_match(String configFile, SdkBytes lambdaReponse) { + // Arrange + // Set responseEventsMatch to false + LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml(configFile); + LambdaProcessor localLambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, + lambdaProcessorConfig, awsCredentialsSupplier, expressionEvaluator); + InvokeResponse invokeResponse = mock(InvokeResponse.class); + // Mock LambdaResponse with a valid payload containing three events + when(invokeResponse.payload()).thenReturn(lambdaReponse); + when(invokeResponse.statusCode()).thenReturn(200); // Success status code + + int randomCount = (int) (Math.random() * 10); + List> originalRecords = getSampleEventRecords(randomCount); + Buffer buffer = new InMemoryBuffer(lambdaProcessorConfig.getBatchOptions().getKeyName()); + for (Record originalRecord : originalRecords) { + buffer.addRecord(originalRecord); + } + // Act + assertThrows(RuntimeException.class, () -> localLambdaProcessor.convertLambdaResponseToEvent(buffer, invokeResponse), + "For Strict mode request and response size from lambda should match"); + + } + } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategyTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategyTest.java index 9b3fc4b35b..3962f264b9 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategyTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategyTest.java @@ -5,126 +5,65 @@ package org.opensearch.dataprepper.plugins.lambda.processor; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.anyString; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import org.mockito.MockitoAnnotations; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; import java.util.List; -import java.util.Map; - -public class StrictResponseEventHandlingStrategyTest { - - @Mock - private Buffer flushedBuffer; - - @Mock - private Event originalEvent; - @Mock - private Event parsedEvent1; - - @Mock - private Event parsedEvent2; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleEventRecords; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleParsedEvents; - private List> originalRecords; - private List> resultRecords; - private StrictResponseEventHandlingStrategy strictResponseEventHandlingStrategy; +public class StrictResponseEventHandlingStrategyTest { - @BeforeEach - public void setUp() { - MockitoAnnotations.openMocks(this); - strictResponseEventHandlingStrategy = new StrictResponseEventHandlingStrategy(); - // Set up original records list with mock original events - originalRecords = new ArrayList<>(); - resultRecords = new ArrayList<>(); - originalRecords.add(new Record<>(originalEvent)); - originalRecords.add(new Record<>(originalEvent)); - } + private final StrictResponseEventHandlingStrategy strictResponseEventHandlingStrategy = new StrictResponseEventHandlingStrategy(); @Test public void testHandleEvents_WithMatchingEventCount_ShouldUpdateOriginalEvents() { - // Arrange - List parsedEvents = Arrays.asList(parsedEvent1, parsedEvent2); - - // Mocking flushedBuffer to return an event count of 2 - when(flushedBuffer.getEventCount()).thenReturn(2); - // Mocking parsedEvent1 and parsedEvent2 to return sample data - Map responseData1 = new HashMap<>(); - responseData1.put("key1", "value1"); - when(parsedEvent1.toMap()).thenReturn(responseData1); + // Arrange + int oneRandomCount = (int) (Math.random() * 100); + List parsedEvents = getSampleParsedEvents(oneRandomCount); + List> originalRecords = getSampleEventRecords(oneRandomCount); - Map responseData2 = new HashMap<>(); - responseData2.put("key2", "value2"); - when(parsedEvent2.toMap()).thenReturn(responseData2); + // Before Test, make sure that they are not the same + for (int i = 0; i < oneRandomCount; i++) { + assertNotEquals(originalRecords.get(i).getData(), parsedEvents.get(i)); + } // Act - strictResponseEventHandlingStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); - - // Assert - // Verify original event is cleared and then updated with response data - verify(originalEvent, times(2)).clear(); - verify(originalEvent).put("key1", "value1"); - verify(originalEvent).put("key2", "value2"); - - // Ensure resultRecords contains the original records - assertEquals(2, resultRecords.size()); - assertEquals(originalRecords.get(0), resultRecords.get(0)); - assertEquals(originalRecords.get(1), resultRecords.get(1)); + List> resultRecords = strictResponseEventHandlingStrategy.handleEvents(parsedEvents, originalRecords); + + // Before Test, make sure that they are not the same + for (int i = 0; i < oneRandomCount; i++) { + assertNotEquals(resultRecords.get(i).getData(), parsedEvents.get(i)); + } } @Test public void testHandleEvents_WithMismatchingEventCount_ShouldThrowException() { // Arrange - List parsedEvents = Arrays.asList(parsedEvent1, parsedEvent2); - - // Mocking flushedBuffer to return an event count of 3 (mismatch) - when(flushedBuffer.getEventCount()).thenReturn(3); + int firstRandomCount = (int) (Math.random() * 10); + List parsedEvents = getSampleParsedEvents(firstRandomCount); + List> originalRecords = getSampleEventRecords(firstRandomCount + 10); // Act & Assert RuntimeException exception = assertThrows(RuntimeException.class, () -> - strictResponseEventHandlingStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer) + strictResponseEventHandlingStrategy.handleEvents(parsedEvents, originalRecords) ); assertEquals("Response Processing Mode is configured as Strict mode but behavior is aggregate mode. Event count mismatch.", exception.getMessage()); - - // Verify original events were not cleared or modified - verify(originalEvent, never()).clear(); - verify(originalEvent, never()).put(anyString(), any()); } @Test public void testHandleEvents_EmptyParsedEvents_ShouldNotThrowException() { - // Arrange - List parsedEvents = new ArrayList<>(); - - // Mocking flushedBuffer to return an event count of 0 - when(flushedBuffer.getEventCount()).thenReturn(0); - // Act - strictResponseEventHandlingStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); - - // Assert - // Verify no events were cleared or modified - verify(originalEvent, never()).clear(); - verify(originalEvent, never()).put(anyString(), any()); - + List> resultRecords = strictResponseEventHandlingStrategy.handleEvents(new ArrayList<>(), new ArrayList<>()); // Ensure resultRecords is empty assertEquals(0, resultRecords.size()); } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java index 185e781b0b..377a385fb8 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java @@ -5,22 +5,8 @@ package org.opensearch.dataprepper.plugins.lambda.sink; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.anyString; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Timer; -import java.lang.reflect.Field; -import java.time.Duration; -import java.util.Collections; -import java.util.UUID; -import java.util.concurrent.atomic.AtomicLong; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; @@ -30,14 +16,7 @@ import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.codec.OutputCodec; import org.opensearch.dataprepper.model.configuration.PluginSetting; -import org.opensearch.dataprepper.model.event.Event; -import org.opensearch.dataprepper.model.event.EventHandle; -import org.opensearch.dataprepper.model.event.EventMetadata; -import org.opensearch.dataprepper.model.event.JacksonEvent; import org.opensearch.dataprepper.model.plugin.PluginFactory; -import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.model.record.RecordMetadata; -import org.opensearch.dataprepper.model.sink.OutputCodecContext; import org.opensearch.dataprepper.model.sink.SinkContext; import org.opensearch.dataprepper.model.types.ByteCount; import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; @@ -51,106 +30,106 @@ import org.opensearch.dataprepper.plugins.lambda.sink.dlq.DlqPushHandler; import org.opensearch.dataprepper.plugins.lambda.sink.dlq.LambdaSinkFailedDlqData; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.lambda.LambdaAsyncClient; -import software.amazon.awssdk.services.lambda.model.InvokeResponse; + +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.Collections; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicLong; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleRecord; public class LambdaSinkTest { - @Mock - SinkContext sinkContext; - @Mock - private LambdaAsyncClient lambdaAsyncClient; - @Mock - private LambdaSinkConfig lambdaSinkConfig; - @Mock - private PluginMetrics pluginMetrics; - @Mock - private PluginFactory pluginFactory; - - private PluginSetting pluginSetting; - @Mock - private OutputCodecContext codecContext; - @Mock - private AwsCredentialsSupplier awsCredentialsSupplier; - @Mock - private DlqPushHandler dlqPushHandler; - @Mock - private ExpressionEvaluator expressionEvaluator; - @Mock - private Counter numberOfRecordsSuccessCounter; - @Mock - private Counter numberOfRecordsFailedCounter; - @Mock - private Timer lambdaLatencyMetric; - @Mock - private OutputCodec requestCodec; - @Mock - private Buffer currentBufferPerBatch; - @Mock - private Event event; - @Mock - private EventHandle eventHandle; - @Mock - private EventMetadata eventMetadata; - @Mock - private InvokeResponse invokeResponse; - - private LambdaSink lambdaSink; - - @Mock - private AwsAuthenticationOptions awsAuthenticationOptions; - - public static Record getSampleRecord() { - Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); - return new Record<>(event, RecordMetadata.defaultMetadata()); - } - - @BeforeEach - public void setUp() { - MockitoAnnotations.openMocks(this); - - // Mock PluginMetrics counters and timers - when(pluginMetrics.counter("lambdaSinkObjectsEventsSucceeded")).thenReturn( - numberOfRecordsSuccessCounter); - when(pluginMetrics.counter("lambdaSinkObjectsEventsFailed")).thenReturn( - numberOfRecordsFailedCounter); - when(pluginMetrics.timer(anyString())).thenReturn(lambdaLatencyMetric); - when(pluginMetrics.gauge(anyString(), any(AtomicLong.class))).thenReturn(new AtomicLong()); - - // Mock lambdaSinkConfig - when(lambdaSinkConfig.getFunctionName()).thenReturn("test-function"); - when(lambdaSinkConfig.getInvocationType()).thenReturn(InvocationType.EVENT); - - // Mock BatchOptions and ThresholdOptions - BatchOptions batchOptions = mock(BatchOptions.class); - ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); - when(batchOptions.getKeyName()).thenReturn("test"); - when(lambdaSinkConfig.getBatchOptions()).thenReturn(batchOptions); - when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); - when(thresholdOptions.getEventCount()).thenReturn(10); - when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("1mb")); - when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.ofSeconds(1)); - - // Mock JsonOutputCodec - requestCodec = mock(JsonOutputCodec.class); - when(pluginFactory.loadPlugin(eq(OutputCodec.class), any(PluginSetting.class))).thenReturn( - requestCodec); - - // Initialize bufferFactory and buffer - currentBufferPerBatch = mock(Buffer.class); - when(currentBufferPerBatch.getEventCount()).thenReturn(0); - when(lambdaSinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); - when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of("us-east-1")); - this.pluginSetting = new PluginSetting("aws_lambda", Collections.emptyMap()); - this.pluginSetting.setPipelineName(UUID.randomUUID().toString()); - this.awsAuthenticationOptions = new AwsAuthenticationOptions(); - - ClientOptions clientOptions = new ClientOptions(); - when(lambdaSinkConfig.getClientOptions()).thenReturn(clientOptions); - - this.lambdaSink = new LambdaSink(pluginSetting, lambdaSinkConfig, pluginFactory, sinkContext, - awsCredentialsSupplier, expressionEvaluator); - } + @Mock + SinkContext sinkContext; + + @Mock + private LambdaSinkConfig lambdaSinkConfig; + @Mock + private PluginMetrics pluginMetrics; + @Mock + private PluginFactory pluginFactory; + + private PluginSetting pluginSetting; + + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + @Mock + private DlqPushHandler dlqPushHandler; + @Mock + private ExpressionEvaluator expressionEvaluator; + @Mock + private Counter numberOfRecordsSuccessCounter; + @Mock + private Counter numberOfRecordsFailedCounter; + @Mock + private Timer lambdaLatencyMetric; + @Mock + private OutputCodec requestCodec; + @Mock + private Buffer currentBufferPerBatch; + + private LambdaSink lambdaSink; + + @Mock + private AwsAuthenticationOptions awsAuthenticationOptions; + + + @BeforeEach + public void setUp() { + MockitoAnnotations.openMocks(this); + + // Mock PluginMetrics counters and timers + when(pluginMetrics.counter("lambdaSinkObjectsEventsSucceeded")).thenReturn( + numberOfRecordsSuccessCounter); + when(pluginMetrics.counter("lambdaSinkObjectsEventsFailed")).thenReturn( + numberOfRecordsFailedCounter); + when(pluginMetrics.timer(anyString())).thenReturn(lambdaLatencyMetric); + when(pluginMetrics.gauge(anyString(), any(AtomicLong.class))).thenReturn(new AtomicLong()); + + // Mock lambdaSinkConfig + when(lambdaSinkConfig.getFunctionName()).thenReturn("test-function"); + when(lambdaSinkConfig.getInvocationType()).thenReturn(InvocationType.EVENT); + + // Mock BatchOptions and ThresholdOptions + BatchOptions batchOptions = mock(BatchOptions.class); + ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); + when(batchOptions.getKeyName()).thenReturn("test"); + when(lambdaSinkConfig.getBatchOptions()).thenReturn(batchOptions); + when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); + when(thresholdOptions.getEventCount()).thenReturn(10); + when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("1mb")); + when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.ofSeconds(1)); + + // Mock JsonOutputCodec + requestCodec = mock(JsonOutputCodec.class); + when(pluginFactory.loadPlugin(eq(OutputCodec.class), any(PluginSetting.class))).thenReturn( + requestCodec); + + // Initialize bufferFactory and buffer + currentBufferPerBatch = mock(Buffer.class); + when(currentBufferPerBatch.getEventCount()).thenReturn(0); + when(lambdaSinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of("us-east-1")); + this.pluginSetting = new PluginSetting("aws_lambda", Collections.emptyMap()); + this.pluginSetting.setPipelineName(UUID.randomUUID().toString()); + this.awsAuthenticationOptions = new AwsAuthenticationOptions(); + + ClientOptions clientOptions = new ClientOptions(); + when(lambdaSinkConfig.getClientOptions()).thenReturn(clientOptions); + + this.lambdaSink = new LambdaSink(pluginSetting, lambdaSinkConfig, pluginFactory, sinkContext, + awsCredentialsSupplier, expressionEvaluator); + } /* @Test @@ -183,40 +162,40 @@ public void testOutput_SuccessfulProcessing() throws Exception { */ - // Helper method to set private fields via reflection - private void setPrivateField(Object targetObject, String fieldName, Object value) { - try { - Field field = targetObject.getClass().getDeclaredField(fieldName); - field.setAccessible(true); - field.set(targetObject, value); - } catch (Exception e) { - throw new RuntimeException(e); + // Helper method to set private fields via reflection + private void setPrivateField(Object targetObject, String fieldName, Object value) { + try { + Field field = targetObject.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(targetObject, value); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + public void testHandleFailure_WithDlq() { + Throwable throwable = new RuntimeException("Test Exception"); + Buffer buffer = new InMemoryBuffer(UUID.randomUUID().toString()); + buffer.addRecord(getSampleRecord()); + setPrivateField(lambdaSink, "dlqPushHandler", dlqPushHandler); + setPrivateField(lambdaSink, "numberOfRecordsFailedCounter", numberOfRecordsFailedCounter); + lambdaSink.handleFailure(throwable, buffer); + verify(numberOfRecordsFailedCounter, times(1)).increment(1.0); + verify(dlqPushHandler, times(1)).perform(eq(pluginSetting), any(LambdaSinkFailedDlqData.class)); + } + + @Test + public void testHandleFailure_WithoutDlq() { + Throwable throwable = new RuntimeException("Test Exception"); + Buffer buffer = new InMemoryBuffer(UUID.randomUUID().toString()); + buffer.addRecord(getSampleRecord()); + when(lambdaSinkConfig.getDlqPluginSetting()).thenReturn(null); + setPrivateField(lambdaSink, "numberOfRecordsFailedCounter", numberOfRecordsFailedCounter); + lambdaSink.handleFailure(throwable, buffer); + verify(numberOfRecordsFailedCounter, times(1)).increment(1); + verify(dlqPushHandler, never()).perform(any(), any()); } - } - - @Test - public void testHandleFailure_WithDlq() { - Throwable throwable = new RuntimeException("Test Exception"); - Buffer buffer = new InMemoryBuffer(UUID.randomUUID().toString()); - buffer.addRecord(getSampleRecord()); - setPrivateField(lambdaSink, "dlqPushHandler", dlqPushHandler); - setPrivateField(lambdaSink, "numberOfRecordsFailedCounter", numberOfRecordsFailedCounter); - lambdaSink.handleFailure(throwable, buffer); - verify(numberOfRecordsFailedCounter, times(1)).increment(1.0); - verify(dlqPushHandler, times(1)).perform(eq(pluginSetting), any(LambdaSinkFailedDlqData.class)); - } - - @Test - public void testHandleFailure_WithoutDlq() { - Throwable throwable = new RuntimeException("Test Exception"); - Buffer buffer = new InMemoryBuffer(UUID.randomUUID().toString()); - buffer.addRecord(getSampleRecord()); - when(lambdaSinkConfig.getDlqPluginSetting()).thenReturn(null); - setPrivateField(lambdaSink, "numberOfRecordsFailedCounter", numberOfRecordsFailedCounter); - lambdaSink.handleFailure(throwable, buffer); - verify(numberOfRecordsFailedCounter, times(1)).increment(1); - verify(dlqPushHandler, never()).perform(any(), any()); - } /* @Test diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaTestSetupUtil.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaTestSetupUtil.java index 93551acda6..1453c30e46 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaTestSetupUtil.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaTestSetupUtil.java @@ -4,39 +4,71 @@ import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; -import java.io.IOException; -import java.io.InputStream; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.record.RecordMetadata; import org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessorConfig; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + public class LambdaTestSetupUtil { - private static final Logger log = LoggerFactory.getLogger(LambdaTestSetupUtil.class); + private static final Logger log = LoggerFactory.getLogger(LambdaTestSetupUtil.class); + + public static ObjectMapper getObjectMapper() { + return new ObjectMapper( + new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)).registerModule( + new JavaTimeModule()); + } - public static ObjectMapper getObjectMapper() { - return new ObjectMapper( - new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)).registerModule( - new JavaTimeModule()); - } + private static InputStream getResourceAsStream(String resourceName) { + InputStream inputStream = Thread.currentThread().getContextClassLoader() + .getResourceAsStream(resourceName); + if (inputStream == null) { + inputStream = LambdaTestSetupUtil.class.getResourceAsStream("/" + resourceName); + } + return inputStream; + } - private static InputStream getResourceAsStream(String resourceName) { - InputStream inputStream = Thread.currentThread().getContextClassLoader() - .getResourceAsStream(resourceName); - if (inputStream == null) { - inputStream = LambdaTestSetupUtil.class.getResourceAsStream("/" + resourceName); + public static LambdaProcessorConfig createLambdaConfigurationFromYaml(String fileName) { + ObjectMapper objectMapper = getObjectMapper(); + try (InputStream inputStream = getResourceAsStream(fileName)) { + return objectMapper.readValue(inputStream, LambdaProcessorConfig.class); + } catch (IOException ex) { + log.error("Failed to parse pipeline Yaml", ex); + throw new RuntimeException(ex); + } } - return inputStream; - } - - public static LambdaProcessorConfig createLambdaConfigurationFromYaml(String fileName) { - ObjectMapper objectMapper = getObjectMapper(); - try (InputStream inputStream = getResourceAsStream(fileName)) { - return objectMapper.readValue(inputStream, LambdaProcessorConfig.class); - } catch (IOException ex) { - log.error("Failed to parse pipeline Yaml", ex); - throw new RuntimeException(ex); + + public static Record getSampleRecord() { + return new Record<>(getSampleEvent(), RecordMetadata.defaultMetadata()); + } + + public static Event getSampleEvent() { + return JacksonEvent.fromMessage(UUID.randomUUID().toString()); + } + + public static List> getSampleEventRecords(int count) { + List> originalRecords = new ArrayList<>(); + for (int i = 0; i < count; i++) { + originalRecords.add(getSampleRecord()); + } + return originalRecords; + } + + public static List getSampleParsedEvents(int count) { + List originalRecords = new ArrayList<>(); + for (int i = 0; i < count; i++) { + originalRecords.add(getSampleEvent()); + } + return originalRecords; } - } } diff --git a/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-null-key-name.yaml b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-null-key-name.yaml new file mode 100644 index 0000000000..e04db57448 --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-null-key-name.yaml @@ -0,0 +1,13 @@ +function_name: "lambdaProcessorTest" +response_events_match: true +tags_on_failure: [ "lambda_failure" ] +batch: + key_name: + threshold: + event_count: 100 + maximum_size: 1mb + event_collect_timeout: 335 +aws: + region: "us-east-1" + sts_role_arn: "arn:aws:iam::1234567890:role/sample-pipeine-role" + diff --git a/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-unequal-success-config.yaml b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-unequal-success-config.yaml new file mode 100644 index 0000000000..c1f843ce78 --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-unequal-success-config.yaml @@ -0,0 +1,13 @@ +function_name: "lambdaProcessorTest" +response_events_match: false +tags_on_failure: [ "lambda_failure" ] +batch: + key_name: "osi_key" + threshold: + event_count: 100 + maximum_size: 1mb + event_collect_timeout: 335 +aws: + region: "us-east-1" + sts_role_arn: "arn:aws:iam::1234567890:role/sample-pipeine-role" + diff --git a/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-when-condition-config.yaml b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-when-condition-config.yaml new file mode 100644 index 0000000000..318ab7363d --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-when-condition-config.yaml @@ -0,0 +1,14 @@ +function_name: "lambdaProcessorTest" +response_events_match: true +tags_on_failure: [ "lambda_failure" ] +lambda_when: "some-condition" +batch: + key_name: "osi_key" + threshold: + event_count: 100 + maximum_size: 1mb + event_collect_timeout: 335 +aws: + region: "us-east-1" + sts_role_arn: "arn:aws:iam::1234567890:role/sample-pipeine-role" +