Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#461 Add the ability to set hard timeouts for Pramen operations #462

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1944,6 +1944,11 @@ Here is an example configuration for a JDBC source:
# specified, a warning will be added to notifications.
warn.maximum.execution.time.seconds = 3600

# [Optional] You can specify the maximum about the job should take.
# This is the hard timeout. The job will be killed if the timeout is breached
# The timeouut restriction applies to the full wall time of the task: validation and running.
kill.maximum.execution.time.seconds = 7200

# You can override any of source settings here
source {
minimum.records = 1000
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.core.exceptions

class TimeoutException(val msg: String, val cause: Throwable = null) extends RuntimeException(msg, cause)
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ case class OperationDef(
initialSourcingDateExpression: String,
processingTimestampColumn: Option[String],
warnMaxExecutionTimeSeconds: Option[Int],
killMaxExecutionTimeSeconds: Option[Int],
schemaTransformations: Seq[TransformExpression],
filters: Seq[String],
notificationTargets: Seq[String],
Expand All @@ -65,6 +66,7 @@ object OperationDef {
val INITIAL_SOURCING_DATE_EXPR = "initial.sourcing.date.expr"
val PROCESSING_TIMESTAMP_COLUMN_KEY = "processing.timestamp.column"
val WARN_MAXIMUM_EXECUTION_TIME_SECONDS_KEY = "warn.maximum.execution.time.seconds"
val KILL_MAXIMUM_EXECUTION_TIME_SECONDS_KEY = "kill.maximum.execution.time.seconds"
val SCHEMA_TRANSFORMATIONS_KEY = "transformations"
val FILTERS_KEY = "filters"
val NOTIFICATION_TARGETS_KEY = "notification.targets"
Expand Down Expand Up @@ -95,7 +97,8 @@ object OperationDef {
val outputInfoDateExpressionOpt = ConfigUtils.getOptionString(conf, OUTPUT_INFO_DATE_EXPRESSION_KEY)
val initialSourcingDateExpressionOpt = ConfigUtils.getOptionString(conf, INITIAL_SOURCING_DATE_EXPR)
val processingTimestampColumn = ConfigUtils.getOptionString(conf, PROCESSING_TIMESTAMP_COLUMN_KEY)
val maximumExecutionTimeSeconds = ConfigUtils.getOptionInt(conf, WARN_MAXIMUM_EXECUTION_TIME_SECONDS_KEY)
val warnMaximumExecutionTimeSeconds = ConfigUtils.getOptionInt(conf, WARN_MAXIMUM_EXECUTION_TIME_SECONDS_KEY)
val killMaximumExecutionTimeSeconds = ConfigUtils.getOptionInt(conf, KILL_MAXIMUM_EXECUTION_TIME_SECONDS_KEY)
val schemaTransformations = TransformExpression.fromConfig(conf, SCHEMA_TRANSFORMATIONS_KEY, parent)
val filters = ConfigUtils.getOptListStrings(conf, FILTERS_KEY)
val notificationTargets = ConfigUtils.getOptListStrings(conf, NOTIFICATION_TARGETS_KEY)
Expand Down Expand Up @@ -138,7 +141,8 @@ object OperationDef {
outputInfoDateExpression,
initialSourcingDateExpression,
processingTimestampColumn,
maximumExecutionTimeSeconds,
warnMaximumExecutionTimeSeconds,
killMaximumExecutionTimeSeconds,
schemaTransformations,
filters,
notificationTargets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,15 @@ import za.co.absa.pramen.core.pipeline._
import za.co.absa.pramen.core.state.PipelineState
import za.co.absa.pramen.core.utils.Emoji._
import za.co.absa.pramen.core.utils.SparkUtils._
import za.co.absa.pramen.core.utils.TimeUtils
import za.co.absa.pramen.core.utils.{ThreadUtils, TimeUtils}
import za.co.absa.pramen.core.utils.hive.HiveHelper

import java.sql.Date
import java.time.{Instant, LocalDate}
import java.util.concurrent.TimeUnit
import scala.concurrent.duration.Duration
import scala.concurrent.{ExecutionContext, Future}
import scala.util.control.NonFatal
import scala.util.{Failure, Success, Try}

abstract class TaskRunnerBase(conf: Config,
Expand Down Expand Up @@ -115,7 +118,27 @@ abstract class TaskRunnerBase(conf: Config,
/** Runs a task in the single thread. Performs all task logging and notification sending activities. */
protected def runTask(task: Task): RunStatus = {
val started = Instant.now()
task.job.operation.killMaxExecutionTimeSeconds match {
case Some(timeout) =>
@volatile var runStatus: RunStatus = null

try {
ThreadUtils.runWithTimeout(Duration(timeout, TimeUnit.SECONDS)) {
log.info(s"Running ${task.job.name} with the hard timeout = $timeout seconds.")
runStatus = doValidateAndRunTask(task)
}
runStatus
} catch {
case NonFatal(ex) =>
failTask(task, started, ex)
}
case None =>
doValidateAndRunTask(task)
}
}

protected def doValidateAndRunTask(task: Task): RunStatus = {
val started = Instant.now()
task.reason match {
case TaskRunReason.Skip(reason) =>
// This skips tasks that were skipped based on strong date constraints (e.g. attempt to run before the minimum date)
Expand Down Expand Up @@ -151,6 +174,29 @@ abstract class TaskRunnerBase(conf: Config,
onTaskCompletion(task, taskResult, isLazy = false)
}

/** Fails a task. Performs all task logging and notification sending activities. */
protected def failTask(task: Task, started: Instant, ex: Throwable): RunStatus = {
val now = Instant.now()
val runStatus = RunStatus.Failed(ex)
val runInfo = RunInfo(task.infoDate, started, now)
val isTransient = task.job.outputTable.format.isTransient
val isLazy = task.job.outputTable.format.isLazy
val taskResult = TaskResult(
task.job.name,
MetaTable.getMetaTableDef(task.job.outputTable),
runStatus,
Option(runInfo),
applicationId,
isTransient,
isRawFilesJob = task.job.outputTable.format.isInstanceOf[DataFormat.Raw],
Nil,
Nil,
Nil,
task.job.operation.extraOptions)

onTaskCompletion(task, taskResult, isLazy = isLazy)
}

/**
* Performs a pre-run check. If the check is successful, the job is validated, and then allowed to run.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.core.utils

import za.co.absa.pramen.core.exceptions.TimeoutException
import za.co.absa.pramen.core.utils.impl.ThreadWithException

import java.lang.Thread.UncaughtExceptionHandler
import scala.concurrent.duration.Duration

object ThreadUtils {
/**
* Executes an action with a timeout. If the timeout is breached the task is killed (using Thread.interrupt())
*
* If the task times out, an exception is thrown.
*
* Any exception is passed to the caller.
*
* @param timeout The task timeout.
* @param action An action to execute.
*/
@throws[TimeoutException]
def runWithTimeout(timeout: Duration)(action: => Unit): Unit = {
val thread = new ThreadWithException {
override def run(): Unit = {
action
}
}

val handler = new UncaughtExceptionHandler {
override def uncaughtException(t: Thread, ex: Throwable): Unit = {
thread.asInstanceOf[ThreadWithException].setException(ex)
}
}

thread.setUncaughtExceptionHandler(handler)

thread.start()
thread.join(timeout.toMillis)

if (thread.isAlive) {
val stackTrace = thread.getStackTrace
thread.interrupt()

val prettyTimeout = TimeUtils.prettyPrintElapsedTimeShort(timeout.toMillis)
val cause = new RuntimeException("The task has been interrupted by Pramen.")
cause.setStackTrace(stackTrace)

throw new TimeoutException(s"Timeout expired ($prettyTimeout).", cause)
}

thread.getException.foreach(ex => throw ex)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.core.utils.impl

class ThreadWithException extends Thread {
private var threadException: Option[Throwable] = None

def setException(ex: Throwable): Unit = synchronized { threadException = Option(ex) }

def getException: Option[Throwable] = synchronized {
val ex = threadException
ex
}
}
1 change: 1 addition & 0 deletions pramen/core/src/test/resources/log4j.properties
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ log4j.logger.za.co.absa.pramen.core.state.PipelineStateImpl=OFF
log4j.logger.za.co.absa.pramen.core.utils.ConfigUtils$=OFF
log4j.logger.za.co.absa.pramen.core.utils.JdbcNativeUtils$=OFF
log4j.logger.za.co.absa.pramen.core.utils.SparkUtils$=OFF
log4j.logger.za.co.absa.pramen.core.utils.ThreadUtils$=OFF
3 changes: 3 additions & 0 deletions pramen/core/src/test/resources/log4j2.properties
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,6 @@ logger.notificationbuilderimpl.level = OFF

logger.sparkutils.name = za.co.absa.pramen.core.utils.SparkUtils$
logger.sparkutils.level = OFF

logger.threadutils.name = za.co.absa.pramen.core.utils.ThreadUtils$
logger.threadutils.level = OFF
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ object OperationDefFactory {
initialSourcingDateExpression: String = "@runDate",
processingTimestampColumn: Option[String] = None,
warnMaxExecutionTimeSeconds: Option[Int] = None,
killMaxExecutionTimeSeconds: Option[Int] = None,
schemaTransformations: Seq[TransformExpression] = Nil,
filters: Seq[String] = Nil,
notificationTargets: Seq[String] = Nil,
Expand All @@ -55,6 +56,7 @@ object OperationDefFactory {
initialSourcingDateExpression,
processingTimestampColumn,
warnMaxExecutionTimeSeconds,
killMaxExecutionTimeSeconds,
schemaTransformations,
filters,
notificationTargets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class OperationDefSuite extends AnyWordSpec with TempDirFixture {
assert(op.operationType.asInstanceOf[Ingestion].sourceTables.head.metaTableName == "table1_sync")
assert(op.allowParallel)
assert(!op.alwaysAttempt)
assert(op.warnMaxExecutionTimeSeconds.isEmpty)
assert(op.killMaxExecutionTimeSeconds.isEmpty)
assert(op.notificationTargets.size == 2)
assert(op.notificationTargets.head == "hyperdrive1")
assert(op.notificationTargets(1) == "custom2")
Expand All @@ -109,6 +111,7 @@ class OperationDefSuite extends AnyWordSpec with TempDirFixture {
|output.table = "dummy_table"
|always.attempt = "true"
|warn.maximum.execution.time.seconds = 50
|kill.maximum.execution.time.seconds = 100
|
|dependencies = [
| {
Expand Down Expand Up @@ -150,6 +153,7 @@ class OperationDefSuite extends AnyWordSpec with TempDirFixture {
assert(op.dependencies(1).tables.contains("table2"))
assert(!op.dependencies(1).triggerUpdates)
assert(op.warnMaxExecutionTimeSeconds.contains(50))
assert(op.killMaxExecutionTimeSeconds.contains(100))
assert(op.schemaTransformations.length == 4)
assert(op.schemaTransformations.head.column == "A")
assert(op.schemaTransformations.head.expression.contains("cast(A as decimal(15,5))"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.core.tests.utils

import org.scalatest.wordspec.AnyWordSpec
import za.co.absa.pramen.core.utils.ThreadUtils

import java.util.concurrent.TimeUnit
import scala.concurrent.duration.Duration

class ThreadUtilsSuite extends AnyWordSpec {
"runWithTimeout" should {
"run the action normally when the timeout is not breached" in {
ThreadUtils.runWithTimeout(Duration(10, TimeUnit.SECONDS)) {
Thread.sleep(1)
}
}

"throw an exception when timeout is breached" in {
val ex = intercept[RuntimeException] {
ThreadUtils.runWithTimeout(Duration(1, TimeUnit.MILLISECONDS)) {
Thread.sleep(1000)
}
}

assert(ex.getMessage.contains("Timeout expired (instantly)."))
assert(ex.getCause != null)
assert(ex.getCause.isInstanceOf[RuntimeException])
assert(ex.getCause.getStackTrace.nonEmpty)
}

"pass the thrown exception to the caller" in {
val ex = intercept[IllegalStateException] {
ThreadUtils.runWithTimeout(Duration(10, TimeUnit.SECONDS)) {
throw new IllegalStateException("test")
}
}

assert(ex.getMessage == "test")
}
}

}
Loading