Skip to content

Commit

Permalink
Enhance index monitor to terminate streaming job on consecutive errors (
Browse files Browse the repository at this point in the history
#346)

* Add error counter and terminate logic in index monitor

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add new Spark conf for max error count and interval

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add new Spark conf for initial delay too

Signed-off-by: Chen Dai <daichen@amazon.com>

* Update user manual

Signed-off-by: Chen Dai <daichen@amazon.com>

---------

Signed-off-by: Chen Dai <daichen@amazon.com>
  • Loading branch information
dai-chen authored May 17, 2024
1 parent 300fedf commit 9de4f28
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 29 deletions.
3 changes: 3 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,9 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i
- `spark.flint.index.hybridscan.enabled`: default is false.
- `spark.flint.index.checkpoint.mandatory`: default is true.
- `spark.datasource.flint.socket_timeout_millis`: default value is 60000.
- `spark.flint.monitor.initialDelaySeconds`: Initial delay in seconds before starting the monitoring task. Default value is 15.
- `spark.flint.monitor.intervalSeconds`: Interval in seconds for scheduling the monitoring task. Default value is 60.
- `spark.flint.monitor.maxErrorCount`: Maximum number of consecutive errors allowed before stopping the monitoring task. Default value is 5.

#### Data Type Mapping

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ object FlintSparkConf {
.doc("Checkpoint location for incremental refresh index will be mandatory if enabled")
.createWithDefault("true")

val MONITOR_INITIAL_DELAY_SECONDS = FlintConfig("spark.flint.monitor.initialDelaySeconds")
.doc("Initial delay in seconds before starting the monitoring task")
.createWithDefault("15")

val MONITOR_INTERVAL_SECONDS = FlintConfig("spark.flint.monitor.intervalSeconds")
.doc("Interval in seconds for scheduling the monitoring task")
.createWithDefault("60")

val MONITOR_MAX_ERROR_COUNT = FlintConfig("spark.flint.monitor.maxErrorCount")
.doc("Maximum number of consecutive errors allowed in index monitor")
.createWithDefault("5")

val SOCKET_TIMEOUT_MILLIS =
FlintConfig(s"spark.datasource.flint.${FlintOptions.SOCKET_TIMEOUT_MILLIS}")
.datasourceOption()
Expand Down Expand Up @@ -223,6 +235,12 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable

def isCheckpointMandatory: Boolean = CHECKPOINT_MANDATORY.readFrom(reader).toBoolean

def monitorInitialDelaySeconds(): Int = MONITOR_INITIAL_DELAY_SECONDS.readFrom(reader).toInt

def monitorIntervalSeconds(): Int = MONITOR_INTERVAL_SECONDS.readFrom(reader).toInt

def monitorMaxErrorCount(): Int = MONITOR_MAX_ERROR_COUNT.readFrom(reader).toInt

/**
* spark.sql.session.timeZone
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import org.opensearch.flint.core.metrics.{MetricConstants, MetricsUtil}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.newDaemonThreadPoolScheduledExecutor

/**
Expand All @@ -34,43 +35,32 @@ class FlintSparkIndexMonitor(
dataSourceName: String)
extends Logging {

/** Task execution initial delay in seconds */
private val INITIAL_DELAY_SECONDS = FlintSparkConf().monitorInitialDelaySeconds()

/** Task execution interval in seconds */
private val INTERVAL_SECONDS = FlintSparkConf().monitorIntervalSeconds()

/** Max error count allowed */
private val MAX_ERROR_COUNT = FlintSparkConf().monitorMaxErrorCount()

/**
* Start monitoring task on the given Flint index.
*
* @param indexName
* Flint index name
*/
def startMonitor(indexName: String): Unit = {
val task = FlintSparkIndexMonitor.executor.scheduleWithFixedDelay(
() => {
logInfo(s"Scheduler trigger index monitor task for $indexName")
try {
if (isStreamingJobActive(indexName)) {
logInfo("Streaming job is still active")
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(latest => latest.state == REFRESHING)
.finalLog(latest => latest) // timestamp will update automatically
.commit(_ => {})
} else {
logError("Streaming job is not active. Cancelling monitor task")
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(_ => true)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})
logInfo(s"""Starting index monitor for $indexName with configuration:
| - Initial delay: $INITIAL_DELAY_SECONDS seconds
| - Interval: $INTERVAL_SECONDS seconds
| - Max error count: $MAX_ERROR_COUNT
|""".stripMargin)

stopMonitor(indexName)
logInfo("Index monitor task is cancelled")
}
} catch {
case e: Throwable =>
logError("Failed to update index log entry", e)
MetricsUtil.incrementCounter(MetricConstants.STREAMING_HEARTBEAT_FAILED_METRIC)
}
},
15, // Delay to ensure final logging is complete first, otherwise version conflicts
60, // TODO: make interval configurable
val task = FlintSparkIndexMonitor.executor.scheduleWithFixedDelay(
new FlintSparkIndexMonitorTask(indexName),
INITIAL_DELAY_SECONDS, // Delay to ensure final logging is complete first, otherwise version conflicts
INTERVAL_SECONDS,
TimeUnit.SECONDS)

FlintSparkIndexMonitor.indexMonitorTracker.put(indexName, task)
Expand All @@ -92,8 +82,68 @@ class FlintSparkIndexMonitor(
}
}

/**
* Index monitor task that encapsulates the execution logic with number of consecutive error
* tracked.
*
* @param indexName
* Flint index name
*/
private class FlintSparkIndexMonitorTask(indexName: String) extends Runnable {

/** The number of consecutive error */
private var errorCnt = 0

override def run(): Unit = {
logInfo(s"Scheduler trigger index monitor task for $indexName")
try {
if (isStreamingJobActive(indexName)) {
logInfo("Streaming job is still active")
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(latest => latest.state == REFRESHING)
.finalLog(latest => latest) // timestamp will update automatically
.commit(_ => {})
} else {
logError("Streaming job is not active. Cancelling monitor task")
flintClient
.startTransaction(indexName, dataSourceName)
.initialLog(_ => true)
.finalLog(latest => latest.copy(state = FAILED))
.commit(_ => {})

stopMonitor(indexName)
logInfo("Index monitor task is cancelled")
}
errorCnt = 0 // Reset counter if no error
} catch {
case e: Throwable =>
errorCnt += 1
logError(s"Failed to update index log entry, consecutive errors: $errorCnt", e)
MetricsUtil.incrementCounter(MetricConstants.STREAMING_HEARTBEAT_FAILED_METRIC)

// Stop streaming job and its monitor if max retry limit reached
if (errorCnt >= MAX_ERROR_COUNT) {
logInfo(s"Terminating streaming job and index monitor for $indexName")
stopStreamingJob(indexName)
stopMonitor(indexName)
logInfo(s"Streaming job and index monitor terminated")
}
}
}
}

private def isStreamingJobActive(indexName: String): Boolean =
spark.streams.active.exists(_.name == indexName)

private def stopStreamingJob(indexName: String): Unit = {
val job = spark.streams.active.find(_.name == indexName)
if (job.isDefined) {
job.get.stop()
} else {
logWarning("Refreshing job not found")
}
}
}

object FlintSparkIndexMonitor extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.opensearch.flint.core.http.FlintRetryOptions._
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite
import org.apache.spark.sql.flint.config.FlintSparkConf.{MONITOR_INITIAL_DELAY_SECONDS, MONITOR_INTERVAL_SECONDS, MONITOR_MAX_ERROR_COUNT}

class FlintSparkConfSuite extends FlintSuite {
test("test spark conf") {
Expand Down Expand Up @@ -84,6 +85,24 @@ class FlintSparkConfSuite extends FlintSuite {
overrideConf.flintOptions().getBatchBytes shouldBe 4 * 1024 * 1024
}

test("test index monitor options") {
val defaultConf = FlintSparkConf()
defaultConf.monitorInitialDelaySeconds() shouldBe 15
defaultConf.monitorIntervalSeconds() shouldBe 60
defaultConf.monitorMaxErrorCount() shouldBe 5

withSparkConf(MONITOR_MAX_ERROR_COUNT.key, MONITOR_INTERVAL_SECONDS.key) {
setFlintSparkConf(MONITOR_INITIAL_DELAY_SECONDS, 5)
setFlintSparkConf(MONITOR_INTERVAL_SECONDS, 30)
setFlintSparkConf(MONITOR_MAX_ERROR_COUNT, 10)

val overrideConf = FlintSparkConf()
defaultConf.monitorInitialDelaySeconds() shouldBe 5
overrideConf.monitorIntervalSeconds() shouldBe 30
overrideConf.monitorMaxErrorCount() shouldBe 10
}
}

/**
* Delete index `indexNames` after calling `f`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.opensearch.flint.OpenSearchTransactionSuite
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.scalatest.matchers.should.Matchers

import org.apache.spark.sql.flint.config.FlintSparkConf.MONITOR_MAX_ERROR_COUNT
import org.apache.spark.sql.flint.newDaemonThreadPoolScheduledExecutor

class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matchers {
Expand All @@ -40,6 +41,9 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
realExecutor.scheduleWithFixedDelay(invocation.getArgument(0), 5, 1, TimeUnit.SECONDS)
}).when(FlintSparkIndexMonitor.executor)
.scheduleWithFixedDelay(any[Runnable], any[Long], any[Long], any[TimeUnit])

// Set max error count higher to avoid impact on transient error test case
setFlintSparkConf(MONITOR_MAX_ERROR_COUNT, 10)
}

override def beforeEach(): Unit = {
Expand Down Expand Up @@ -128,6 +132,24 @@ class FlintSparkIndexMonitorITSuite extends OpenSearchTransactionSuite with Matc
}
}

test("monitor task and streaming job should terminate if exception occurred consistently") {
val task = FlintSparkIndexMonitor.indexMonitorTracker(testFlintIndex)

// Block write on metadata log index
setWriteBlockOnMetadataLogIndex(true)
waitForMonitorTaskRun()

// Both monitor task and streaming job should stop after 10 times
10 times { (_, _) =>
{
// assert nothing. just wait enough times of task execution
}
}

task.isCancelled shouldBe true
spark.streams.active.exists(_.name == testFlintIndex) shouldBe false
}

private def getLatestTimestamp: (Long, Long) = {
val latest = latestLogEntry(testLatestId)
(latest("jobStartTime").asInstanceOf[Long], latest("lastUpdateTime").asInstanceOf[Long])
Expand Down

0 comments on commit 9de4f28

Please sign in to comment.