diff --git a/build.gradle b/build.gradle index 1d911966..8b1c2dbc 100644 --- a/build.gradle +++ b/build.gradle @@ -67,6 +67,7 @@ dependencies { api 'software.amazon.awssdk:athena' api 'software.amazon.awssdk:lambda' api 'software.amazon.awssdk:eventbridge' + api 'software.amazon.awssdk:kinesis' } diff --git a/src/main/java/io/kestra/plugin/aws/kinesis/PutRecords.java b/src/main/java/io/kestra/plugin/aws/kinesis/PutRecords.java new file mode 100644 index 00000000..0f5d92ec --- /dev/null +++ b/src/main/java/io/kestra/plugin/aws/kinesis/PutRecords.java @@ -0,0 +1,267 @@ +package io.kestra.plugin.aws.kinesis; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Strings; +import io.kestra.core.exceptions.IllegalVariableEvaluationException; +import io.kestra.core.models.annotations.Example; +import io.kestra.core.models.annotations.Plugin; +import io.kestra.core.models.annotations.PluginProperty; +import io.kestra.core.models.executions.metrics.Counter; +import io.kestra.core.models.executions.metrics.Timer; +import io.kestra.core.models.flows.State; +import io.kestra.core.models.tasks.RunnableTask; +import io.kestra.core.runners.RunContext; +import io.kestra.core.serializers.FileSerde; +import io.kestra.core.serializers.JacksonMapper; +import io.kestra.plugin.aws.AbstractConnection; +import io.kestra.plugin.aws.kinesis.model.Record; +import io.reactivex.BackpressureStrategy; +import io.reactivex.Flowable; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.*; +import lombok.experimental.SuperBuilder; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.KinesisClientBuilder; +import software.amazon.awssdk.services.kinesis.model.PutRecordsRequest; +import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry; +import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse; + +import javax.validation.constraints.NotNull; +import java.io.*; +import java.net.URI; +import java.net.URISyntaxException; +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import static io.kestra.core.utils.Rethrow.throwFunction; + +@SuperBuilder +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +@Plugin( + examples = { + @Example( + title = "Send multiple records as maps to Amazon Kinesis Data Streams. Check the following AWS API reference for the structure of the [PutRecordsRequestEntry](https://docs.aws.amazon.com/kinesis/latest/APIReference/API_PutRecordsRequestEntry.html) request payload.", + code = { + "streamName: \"mystream\"", + "records:", + " - data: \"user sign-in event\"", + " explicitHashKey: \"optional hash value overriding the partition key\"", + " - data: \"user sign-out event\"", + " partitionKey: \"user1\"" + } + ), + @Example( + title = "Send multiple records from an internal storage ion file to Amazon Kinesis Data Streams.", + code = { + "streamName: \"mystream\"", + "records: kestra://myfile.ion" + } + ) + } +) +@Schema( + title = "Send multiple records to Amazon Kinesis Data Streams." +) +public class PutRecords extends AbstractConnection implements RunnableTask { + private static final ObjectMapper MAPPER = JacksonMapper.ofIon() + .setSerializationInclusion(JsonInclude.Include.ALWAYS); + + @PluginProperty(dynamic = false) + @NotNull + @Schema( + title = "Mark the task as failed when sending a record is unsuccessful.", + description = "If true, the task will fail when any record fails to be sent." + ) + @Builder.Default + private boolean failOnUnsuccessfulRecords = true; + + @PluginProperty(dynamic = true) + @Schema( + title = "The name of the stream to add the records.", + description = "Make sure to set either `streamName` or `streamArn`. One of those must be provided." + ) + private String streamName; + + @PluginProperty(dynamic = true) + @Schema( + title = "The ARN of the stream to add the records.", + description = "Make sure to set either `streamName` or `streamArn`. One of those must be provided." + ) + private String streamArn; + + @PluginProperty(dynamic = true) + @Schema( + title = "List of records or internal storage URI of the file that defines the records to be sent to AWS Kinesis Data Streams.", + description = "A list of at least one record.", + anyOf = {String.class, Record[].class} + ) + @NotNull + private Object records; + + @Override + public Output run(RunContext runContext) throws Exception { + final long start = System.nanoTime(); + + List records = getRecordList(this.records, runContext); + + PutRecordsResponse putRecordsResponse = putRecords(runContext, records); + + // Fail if failOnUnsuccessfulRecords + if (failOnUnsuccessfulRecords && putRecordsResponse.failedRecordCount() > 0) { + var logger = runContext.logger(); + logger.error("Response show {} record failed: {}", putRecordsResponse.failedRecordCount(), putRecordsResponse); + throw new RuntimeException(String.format("Response show %d record failed: %s", putRecordsResponse.failedRecordCount(), putRecordsResponse)); + } + + // Set metrics + runContext.metric(Timer.of("duration", Duration.ofNanos(System.nanoTime() - start))); + runContext.metric(Counter.of("failedRecordCount", putRecordsResponse.failedRecordCount())); + runContext.metric(Counter.of("successfulRecordCount", records.size() - putRecordsResponse.failedRecordCount())); + runContext.metric(Counter.of("recordCount", records.size())); + + File tempFile = writeOutputFile(runContext, putRecordsResponse, records); + return Output.builder() + .uri(runContext.putTempFile(tempFile)) + .failedRecordsCount(putRecordsResponse.failedRecordCount()) + .recordCount(records.size()) + .build(); + } + + private PutRecordsResponse putRecords(RunContext runContext, List records) throws IllegalVariableEvaluationException { + try (KinesisClient client = client(runContext)) { + PutRecordsRequest.Builder builder = PutRecordsRequest.builder(); + + if (!Strings.isNullOrEmpty(streamArn)) { + builder.streamARN(streamArn); + } else if (!Strings.isNullOrEmpty(streamName)) { + builder.streamName(streamName); + } else { + throw new IllegalArgumentException("Either streamName or streamArn has to be set."); + } + + + List putRecordsRequestEntryList = records.stream() + .map(throwFunction(record -> record.toPutRecordsRequestEntry(runContext))) + .collect(Collectors.toList()); + + + builder.records(putRecordsRequestEntryList); + return client.putRecords(builder.build()); + } + } + + private List getRecordList(Object records, RunContext runContext) throws IllegalVariableEvaluationException, URISyntaxException, IOException { + if (records instanceof String) { + + URI from = new URI(runContext.render((String) records)); + if (!from.getScheme().equals("kestra")) { + throw new IllegalArgumentException("Invalid records parameter, must be a Kestra internal storage URI, or a list of record."); + } + try (BufferedReader inputStream = new BufferedReader(new InputStreamReader(runContext.uriToInputStream(from)))) { + return Flowable.create(FileSerde.reader(inputStream, Record.class), BackpressureStrategy.BUFFER) + .toList().blockingGet(); + } + } else if (records instanceof List) { + return MAPPER.convertValue(records, new TypeReference<>() { + }); + } + + throw new IllegalVariableEvaluationException("Invalid records type '" + records.getClass() + "'"); + } + + private File writeOutputFile(RunContext runContext, PutRecordsResponse putRecordsResponse, List records) throws IOException { + // Create Output + File tempFile = runContext.tempFile(".ion").toFile(); + try (var stream = new FileOutputStream(tempFile)) { + Flowable.fromIterable(records) + .zipWith(putRecordsResponse.records(), (record, response) -> OutputEntry.builder() + .record(record) + .sequenceNumber(response.sequenceNumber()) + .shardId(response.shardId()) + .errorCode(response.errorCode()) + .errorMessage(response.errorMessage()) + .build()) + .blockingForEach(outputEntry -> FileSerde.write(stream, outputEntry)); + } + return tempFile; + } + + protected KinesisClient client(RunContext runContext) throws IllegalVariableEvaluationException { + KinesisClientBuilder builder = KinesisClient.builder() + .credentialsProvider(this.credentials(runContext)); + + if (this.region != null) { + builder.region(Region.of(runContext.render(this.region))); + } + if (this.endpointOverride != null) { + builder.endpointOverride(URI.create(runContext.render(this.endpointOverride))); + } + + return builder.build(); + } + + @Builder + @Getter + public static class Output implements io.kestra.core.models.tasks.Output { + + @Schema( + title = "The URI of stored data", + description = "The successfully and unsuccessfully ingested records." + + "If the ingestion was successful, the output includes the record sequence number." + + "Otherwise, the output provides the error code and error message for troubleshooting." + ) + private URI uri; + + @Schema( + title = "The number of failed records." + ) + private int failedRecordsCount; + + @Schema( + title = "The total number of records sent to AWS Kinesis Data Streams." + ) + private int recordCount; + + @Override + public Optional finalState() { + return this.failedRecordsCount > 0 ? Optional.of(State.Type.WARNING) : io.kestra.core.models.tasks.Output.super.finalState(); + } + } + + @Builder + @Getter + public static class OutputEntry { + @Schema( + title = "The sequence number for an individual record result." + ) + private final String sequenceNumber; + + @Schema( + title = "The shard ID for an individual record result." + ) + private final String shardId; + + @Schema( + title = "The error code that indicates the failure." + ) + private final String errorCode; + + @Schema( + title = "The error message that explains the failure." + ) + private final String errorMessage; + + @Schema( + title = "The original record." + ) + private final Record record; + } +} diff --git a/src/main/java/io/kestra/plugin/aws/kinesis/model/Record.java b/src/main/java/io/kestra/plugin/aws/kinesis/model/Record.java new file mode 100644 index 00000000..f4304357 --- /dev/null +++ b/src/main/java/io/kestra/plugin/aws/kinesis/model/Record.java @@ -0,0 +1,59 @@ +package io.kestra.plugin.aws.kinesis.model; + +import com.fasterxml.jackson.annotation.JsonAlias; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Strings; +import io.kestra.core.exceptions.IllegalVariableEvaluationException; +import io.kestra.core.models.annotations.PluginProperty; +import io.kestra.core.runners.RunContext; +import io.kestra.core.serializers.JacksonMapper; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.extern.jackson.Jacksonized; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry; + +import javax.validation.constraints.NotNull; + +@Getter +@Builder +@EqualsAndHashCode +@Jacksonized +public class Record { + private static final ObjectMapper OBJECT_MAPPER = JacksonMapper.ofJson(); + + @Schema(title = "Determines which shard in the stream the data record is assigned to.") + @PluginProperty(dynamic = true) + @NotNull + @JsonAlias("PartitionKey") + private String partitionKey; + + @Schema(title = "The optional hash value used to determine explicitly the shard that the data record is assigned to by overriding the partition key hash.") + @PluginProperty(dynamic = true) + @JsonAlias("ExplicitHashKey") + private String explicitHashKey; + + @Schema(title = "Free-form data blob to put into the record.") + @PluginProperty(dynamic = true) + @NotNull + @JsonAlias("Data") + private String data; + + public PutRecordsRequestEntry toPutRecordsRequestEntry(RunContext runContext) throws IllegalVariableEvaluationException { + var partitionKey = runContext.render(this.partitionKey); + var explicitHashKey = runContext.render(this.explicitHashKey); + var data = runContext.render(this.data); + PutRecordsRequestEntry.Builder builder = PutRecordsRequestEntry.builder() + .data(SdkBytes.fromUtf8String(data)) + .partitionKey(partitionKey); + + if (!Strings.isNullOrEmpty(explicitHashKey)) { + builder.explicitHashKey(explicitHashKey); + } + + return builder + .build(); + } +} diff --git a/src/main/java/io/kestra/plugin/aws/kinesis/package-info.java b/src/main/java/io/kestra/plugin/aws/kinesis/package-info.java new file mode 100644 index 00000000..3b195a2a --- /dev/null +++ b/src/main/java/io/kestra/plugin/aws/kinesis/package-info.java @@ -0,0 +1,8 @@ +@PluginSubGroup( + description = "This sub-group of plugins contains tasks for using Amazon Kinesis.\n" + + "Amazon Kinesis is a family of services provided by Amazon Web Services (AWS) for processing and analyzing real-time streaming data at a large scale.", + categories = {PluginSubGroup.PluginCategory.MESSAGING, PluginSubGroup.PluginCategory.CLOUD} +) +package io.kestra.plugin.aws.kinesis; + +import io.kestra.core.models.annotations.PluginSubGroup; diff --git a/src/test/java/io/kestra/plugin/aws/kinesis/PutRecordsTest.java b/src/test/java/io/kestra/plugin/aws/kinesis/PutRecordsTest.java new file mode 100644 index 00000000..182e742d --- /dev/null +++ b/src/test/java/io/kestra/plugin/aws/kinesis/PutRecordsTest.java @@ -0,0 +1,269 @@ +package io.kestra.plugin.aws.kinesis; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.kestra.core.exceptions.IllegalVariableEvaluationException; +import io.kestra.core.runners.RunContext; +import io.kestra.core.runners.RunContextFactory; +import io.kestra.core.serializers.FileSerde; +import io.kestra.core.serializers.JacksonMapper; +import io.kestra.core.storages.StorageInterface; +import io.kestra.plugin.aws.kinesis.model.Record; +import io.micronaut.test.extensions.junit5.annotation.MicronautTest; +import io.reactivex.BackpressureStrategy; +import io.reactivex.Flowable; +import jakarta.inject.Inject; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.extern.jackson.Jacksonized; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.localstack.LocalStackContainer; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.KinesisClientBuilder; +import software.amazon.awssdk.services.kinesis.model.*; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileOutputStream; +import java.io.InputStreamReader; +import java.net.URI; +import java.util.List; + +import static io.kestra.core.utils.Rethrow.throwConsumer; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; + +@MicronautTest +@Testcontainers +class PutRecordsTest { + private static final ObjectMapper MAPPER = JacksonMapper.ofIon() + .setSerializationInclusion(JsonInclude.Include.ALWAYS); + + protected static LocalStackContainer localstack; + @Inject + protected RunContextFactory runContextFactory; + @Inject + protected StorageInterface storageInterface; + + @BeforeAll + static void startLocalstack() throws IllegalVariableEvaluationException, InterruptedException { + localstack = new LocalStackContainer(DockerImageName.parse("localstack/localstack:1.3.1")); + localstack.start(); + + KinesisClient client = client(localstack); + client.createStream(CreateStreamRequest.builder() + .streamName("streamName") + .streamModeDetails(StreamModeDetails.builder().streamMode(StreamMode.PROVISIONED).build()) + .build()); + DescribeStreamResponse stream = client.describeStream(DescribeStreamRequest.builder().streamName("streamName").build()); + while (stream.streamDescription().streamStatus() != StreamStatus.ACTIVE) { + stream = client.describeStream(DescribeStreamRequest.builder().streamName("streamName").build()); + } + } + + @AfterAll + static void stopLocalstack() { + if (localstack != null) { + localstack.stop(); + } + } + + private static List getOutputEntries(PutRecords put, RunContext runContext) throws Exception { + var output = put.run(runContext); + List outputEntries; + URI from = output.getUri(); + if (!from.getScheme().equals("kestra")) { + throw new IllegalArgumentException("Invalid entries parameter, must be a Kestra internal storage URI, or a list of entry."); + } + try (BufferedReader inputStream = new BufferedReader(new InputStreamReader(runContext.uriToInputStream(from)))) { + outputEntries = Flowable.create(FileSerde.reader(inputStream, PutRecords.OutputEntry.class), BackpressureStrategy.BUFFER).toList().blockingGet(); + } + return outputEntries; + } + + private static KinesisClient client(LocalStackContainer runContext) throws IllegalVariableEvaluationException { + KinesisClientBuilder builder = KinesisClient.builder() + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create( + localstack.getAccessKey(), + localstack.getSecretKey() + ))) + .region(Region.of(localstack.getRegion())) + .endpointOverride(localstack.getEndpoint()); + + return builder.build(); + } + + @Test + void runMap() throws Exception { + var runContext = runContextFactory.of(); + + Record record = Record.builder() + .explicitHashKey("5") + .partitionKey("partitionKey") + .data("record") + .build(); + Record record2 = Record.builder() + .partitionKey("partitionKey") + .data("record 2") + .build(); + Record record3 = Record.builder() + .explicitHashKey("5") + .partitionKey("partitionKey") + .data("record 3") + .build(); + var put = PutRecords.builder() + .endpointOverride(localstack.getEndpoint().toString()) + .region(localstack.getRegion()) + .accessKeyId(localstack.getAccessKey()) + .secretKeyId(localstack.getSecretKey()) + .streamName("streamName") + .records(List.of(record, record2, record3)) + .build(); + + + List outputEntries = getOutputEntries(put, runContext); + assertThat(outputEntries, hasSize(3)); + assertThat(outputEntries.get(0).getSequenceNumber(), notNullValue()); + assertThat(outputEntries.get(0).getErrorCode(), nullValue()); + assertThat(outputEntries.get(0).getErrorMessage(), nullValue()); + assertThat(outputEntries.get(0).getRecord(), equalTo(record)); + + assertThat(outputEntries.get(1).getSequenceNumber(), notNullValue()); + assertThat(outputEntries.get(1).getErrorCode(), nullValue()); + assertThat(outputEntries.get(1).getErrorMessage(), nullValue()); + assertThat(outputEntries.get(1).getRecord(), equalTo(record2)); + + assertThat(outputEntries.get(2).getSequenceNumber(), notNullValue()); + assertThat(outputEntries.get(2).getErrorCode(), nullValue()); + assertThat(outputEntries.get(2).getErrorMessage(), nullValue()); + assertThat(outputEntries.get(2).getRecord(), equalTo(record3)); + } + + @Test + void runStorage() throws Exception { + var runContext = runContextFactory.of(); + + Record record = Record.builder() + .explicitHashKey("5") + .partitionKey("partitionKey") + .data("record") + .build(); + Record record2 = Record.builder() + .partitionKey("partitionKey") + .data("record 2") + .build(); + Record record3 = Record.builder() + .explicitHashKey("5") + .partitionKey("partitionKey") + .data("record 3") + .build(); + + File tempFile = runContext.tempFile(".ion").toFile(); + try (var stream = new FileOutputStream(tempFile)) { + List.of(record, record2, record3).forEach(throwConsumer(r -> FileSerde.write(stream, r))); + } + + var put = PutRecords.builder() + .endpointOverride(localstack.getEndpoint().toString()) + .region(localstack.getRegion()) + .accessKeyId(localstack.getAccessKey()) + .secretKeyId(localstack.getSecretKey()) + .records(runContext.putTempFile(tempFile).toString()) + .streamName("streamName") + .build(); + + + List outputEntries = getOutputEntries(put, runContext); + + assertThat(outputEntries, hasSize(3)); + assertThat(outputEntries.get(0).getSequenceNumber(), notNullValue()); + assertThat(outputEntries.get(0).getErrorCode(), nullValue()); + assertThat(outputEntries.get(0).getErrorMessage(), nullValue()); + assertThat(outputEntries.get(0).getRecord(), equalTo(record)); + + assertThat(outputEntries.get(1).getSequenceNumber(), notNullValue()); + assertThat(outputEntries.get(1).getErrorCode(), nullValue()); + assertThat(outputEntries.get(1).getErrorMessage(), nullValue()); + assertThat(outputEntries.get(1).getRecord(), equalTo(record2)); + + assertThat(outputEntries.get(2).getSequenceNumber(), notNullValue()); + assertThat(outputEntries.get(2).getErrorCode(), nullValue()); + assertThat(outputEntries.get(2).getErrorMessage(), nullValue()); + assertThat(outputEntries.get(2).getRecord(), equalTo(record3)); + } + + @Test + void runStorageUpperCase() throws Exception { + var runContext = runContextFactory.of(); + + UpperCaseRecord record = UpperCaseRecord.builder() + .ExplicitHashKey("5") + .PartitionKey("partitionKey") + .Data("record") + .build(); + UpperCaseRecord record2 = UpperCaseRecord.builder() + .PartitionKey("partitionKey") + .Data("record 2") + .build(); + UpperCaseRecord record3 = UpperCaseRecord.builder() + .ExplicitHashKey("5") + .PartitionKey("partitionKey") + .Data("record 3") + .build(); + + File tempFile = runContext.tempFile(".ion").toFile(); + try (var stream = new FileOutputStream(tempFile)) { + List.of(record, record2, record3).forEach(throwConsumer(r -> FileSerde.write(stream, r))); + } + + var put = PutRecords.builder() + .endpointOverride(localstack.getEndpoint().toString()) + .region(localstack.getRegion()) + .accessKeyId(localstack.getAccessKey()) + .secretKeyId(localstack.getSecretKey()) + .records(runContext.putTempFile(tempFile).toString()) + .streamName("streamName") + .build(); + + + List outputEntries = getOutputEntries(put, runContext); + + assertThat(outputEntries, hasSize(3)); + assertThat(outputEntries.get(0).getSequenceNumber(), notNullValue()); + assertThat(outputEntries.get(0).getErrorCode(), nullValue()); + assertThat(outputEntries.get(0).getErrorMessage(), nullValue()); + assertThat(outputEntries.get(0).getRecord().getData(), equalTo(record.getData())); + + assertThat(outputEntries.get(1).getSequenceNumber(), notNullValue()); + assertThat(outputEntries.get(1).getErrorCode(), nullValue()); + assertThat(outputEntries.get(1).getErrorMessage(), nullValue()); + assertThat(outputEntries.get(1).getRecord().getData(), equalTo(record2.getData())); + + assertThat(outputEntries.get(2).getSequenceNumber(), notNullValue()); + assertThat(outputEntries.get(2).getErrorCode(), nullValue()); + assertThat(outputEntries.get(2).getErrorMessage(), nullValue()); + assertThat(outputEntries.get(2).getRecord().getData(), equalTo(record3.getData())); + } + + /** + * Test that user can use AWS notation in json + */ + @Getter + @Builder + @EqualsAndHashCode + @Jacksonized + private static class UpperCaseRecord { + private String PartitionKey; + private String ExplicitHashKey; + private String Data; + + } +} \ No newline at end of file