diff --git a/README.md b/README.md index 31a22bf..3a58c45 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,8 @@ Refering $SPARK_HOME to the Spark installation directory. | awsUseInstanceProfile | true | Use Instance Profile Credentials if none of credentials provided | | kinesis.executor.recordMaxBufferedTime | 1000 (millis) | Specify the maximum buffered time of a record | | kinesis.executor.maxConnections | 1 | Specify the maximum connections to Kinesis | -| kinesis.executor.aggregationEnabled | true | Specify if records should be aggregated before sending them to Kinesis | +| kinesis.executor.aggregationEnabled | true | Specify if records should be aggregated before sending them to Kinesis | +| kniesis.executor.flushwaittimemillis | 100 | Wait time while flushing records to Kinesis on Task End | ## Roadmap * We need to migrate to DataSource V2 APIs for MicroBatchExecution. diff --git a/src/main/scala/org/apache/spark/sql/kinesis/KinesisSourceProvider.scala b/src/main/scala/org/apache/spark/sql/kinesis/KinesisSourceProvider.scala index 6102562..cad76e9 100644 --- a/src/main/scala/org/apache/spark/sql/kinesis/KinesisSourceProvider.scala +++ b/src/main/scala/org/apache/spark/sql/kinesis/KinesisSourceProvider.scala @@ -238,6 +238,7 @@ private[kinesis] object KinesisSourceProvider extends Logging { private[kinesis] val SINK_RECORD_MAX_BUFFERED_TIME = "kinesis.executor.recordmaxbufferedtime" private[kinesis] val SINK_MAX_CONNECTIONS = "kinesis.executor.maxconnections" private[kinesis] val SINK_AGGREGATION_ENABLED = "kinesis.executor.aggregationenabled" + private[kinesis] val SINK_FLUSH_WAIT_TIME_MILLIS = "kniesis.executor.flushwaittimemillis" private[kinesis] def getKinesisPosition( @@ -266,6 +267,8 @@ private[kinesis] object KinesisSourceProvider extends Logging { private[kinesis] val DEFAULT_SINK_MAX_CONNECTIONS: String = "1" private[kinesis] val DEFAULT_SINK_AGGREGATION: String = "true" + + private[kinesis] val DEFAULT_FLUSH_WAIT_TIME_MILLIS: String = "100" } diff --git a/src/main/scala/org/apache/spark/sql/kinesis/KinesisWriteTask.scala b/src/main/scala/org/apache/spark/sql/kinesis/KinesisWriteTask.scala index f05b34e..c4dcbfe 100644 --- a/src/main/scala/org/apache/spark/sql/kinesis/KinesisWriteTask.scala +++ b/src/main/scala/org/apache/spark/sql/kinesis/KinesisWriteTask.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.kinesis import java.nio.ByteBuffer +import scala.util.Try + import com.amazonaws.services.kinesis.producer.{KinesisProducer, UserRecordResult} import com.google.common.util.concurrent.{FutureCallback, Futures} @@ -34,9 +36,19 @@ private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, Strin private val streamName = producerConfiguration.getOrElse( KinesisSourceProvider.SINK_STREAM_NAME_KEY, "") + private val flushWaitTimeMills = Try(producerConfiguration.getOrElse( + KinesisSourceProvider.SINK_FLUSH_WAIT_TIME_MILLIS, + KinesisSourceProvider.DEFAULT_FLUSH_WAIT_TIME_MILLIS).toLong).getOrElse { + throw new IllegalArgumentException( + s"${KinesisSourceProvider.SINK_FLUSH_WAIT_TIME_MILLIS} has to be a positive integer") + } + + private var failedWrite: Throwable = _ + + def execute(iterator: Iterator[InternalRow]): Unit = { producer = CachedKinesisProducer.getOrCreate(producerConfiguration) - while (iterator.hasNext) { + while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() val projectedRow = projection(currentRow) val partitionKey = projectedRow.getString(0) @@ -54,7 +66,10 @@ private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, Strin val kinesisCallBack = new FutureCallback[UserRecordResult]() { override def onFailure(t: Throwable): Unit = { - logError(s"Writing to $streamName failed due to ${t.getCause}") + if (failedWrite == null && t!= null) { + failedWrite = t + logError(s"Writing to $streamName failed due to ${t.getCause}") + } } override def onSuccess(result: UserRecordResult): Unit = { @@ -68,13 +83,34 @@ private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, Strin sentSeqNumbers } - def close(): Unit = { + private def flushRecordsIfNecessary(): Unit = { if (producer != null) { - producer.flush() - producer = null + while (producer.getOutstandingRecordsCount > 0) { + try { + producer.flush() + Thread.sleep(flushWaitTimeMills) + checkForErrors() + } catch { + case e: InterruptedException => + + } + } } } + def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } + + def close(): Unit = { + checkForErrors() + flushRecordsIfNecessary() + checkForErrors() + producer = null + } + private def createProjection: UnsafeProjection = { val partitionKeyExpression = inputSchema