Skip to content

Commit

Permalink
Added connection & read timeout configuration for pushgateway sink
Browse files Browse the repository at this point in the history
  • Loading branch information
Arnovsky committed Jul 1, 2024
1 parent e2c9c9d commit 642fe9e
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 82 deletions.
4 changes: 4 additions & 0 deletions docs/Flight_recorder_mode_PrometheusPushgatewaySink.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ Configuration - PushGatewaySink parameters:
Example: --conf spark.sparkmeasure.pushgateway=localhost:9091
--conf spark.sparkmeasure.pushgateway.jobname=JOBNAME // defaut value is pushgateway
Example: --conf spark.sparkmeasure.pushgateway.jobname=myjob1
--conf spark.sparkmeasure.pushgateway.http.connection.timeout=TIME_IN_MS // defaut value is 5000
Example: --conf spark.sparkmeasure.pushgateway.http.connection.timeout=150
--conf spark.sparkmeasure.pushgateway.http.read.timeout=TIME_IN_MS // defaut value is 5000
Example: --conf spark.sparkmeasure.pushgateway.http.read.timeout=150
```

## Use case
Expand Down
42 changes: 20 additions & 22 deletions src/main/scala/ch/cern/sparkmeasure/PushGateway.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package ch.cern.sparkmeasure

import java.net.{URL, URLEncoder, HttpURLConnection}
import org.slf4j.LoggerFactory
import org.slf4j.{Logger, LoggerFactory}

import java.net.{HttpURLConnection, URL, URLEncoder}

/**
* PushGateway
Expand Down Expand Up @@ -30,25 +31,23 @@ import org.slf4j.LoggerFactory


/**
* serverIPnPort: String with prometheus pushgateway hostIP:Port,
* metricsJob: job name
* config: case class with all required configuration for pushgateway,
*/
case class PushGateway(serverIPnPort: String, metricsJob: String) {

lazy val logger = LoggerFactory.getLogger(this.getClass.getName)
case class PushGateway(config: PushgatewayConfig) {
lazy val logger: Logger = LoggerFactory.getLogger(this.getClass.getName)

var urlJob = s"DefaultJob"
private var urlJob = s"DefaultJob"
try {
urlJob = URLEncoder.encode(metricsJob, s"UTF-8")
urlJob = URLEncoder.encode(config.jobName, s"UTF-8")
} catch {
case uee: java.io.UnsupportedEncodingException =>
logger.error(s"metricsJob '$metricsJob' cannot be url encoded")
case _: java.io.UnsupportedEncodingException =>
logger.error(s"metricsJob '${config.jobName}' cannot be url encoded")
}
val urlBase = s"http://" + serverIPnPort + s"/metrics/job/" + urlJob + s"/instance/sparkMeasure"
val urlBase = s"http://${config.serverIPnPort}/metrics/job/$urlJob/instance/sparkMeasure"

val requestMethod = s"POST"
val connectTimeout = 5000 // milliseconds
val readTimeout = 5000 // milliseconds
private val requestMethod = s"POST"
private val connectTimeout = config.connectionTimeoutMs
private val readTimeout = config.readTimeoutMs


/**
Expand Down Expand Up @@ -127,11 +126,11 @@ case class PushGateway(serverIPnPort: String, metricsJob: String) {
val urlFull = urlBase + s"/type/" + urlType + s"/" + urlLabelName + s"/" + urlLabelValue

try {
val connection = (new URL(urlFull)).openConnection.asInstanceOf[HttpURLConnection]
val connection = new URL(urlFull).openConnection.asInstanceOf[HttpURLConnection]
connection.setConnectTimeout(connectTimeout)
connection.setReadTimeout(readTimeout)
connection.setRequestMethod(requestMethod)
connection.setRequestProperty("Content-Type","text/plain; version=0.0.4")
connection.setRequestProperty("Content-Type", "text/plain; version=0.0.4")
connection.setDoOutput(true)

val outputStream = connection.getOutputStream
Expand All @@ -141,19 +140,18 @@ case class PushGateway(serverIPnPort: String, metricsJob: String) {
outputStream.close();
}

val responseCode = connection.getResponseCode()
val responseMessage = connection.getResponseMessage()
val responseCode = connection.getResponseCode
val responseMessage = connection.getResponseMessage
connection.disconnect();
if (responseCode != 200 && responseCode != 202) // 200 and 202 Accepted, 400 Bad Request
logger.error(s"Data sent error, url: '$urlFull', response: $responseCode '$responseMessage'")
} catch {
case ste: java.net.SocketTimeoutException =>
println("java.net.SocketTimeoutException")
logger.error(s"Data sent error, url: '$urlFull', " + ste.getMessage())
logger.error(s"Data sent error, url: '$urlFull', " + ste.getMessage)
case ioe: java.io.IOException =>
println("java.io.IOException")
logger.error(s"Data sent error, url: '$urlFull', " + ioe.getMessage())
logger.error(s"Data sent error, url: '$urlFull', " + ioe.getMessage)
}

}
}
5 changes: 3 additions & 2 deletions src/main/scala/ch/cern/sparkmeasure/PushGatewaySink.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ class PushGatewaySink(conf: SparkConf) extends SparkListener {
logger.warn("Custom monitoring listener with Prometheus Push Gateway sink initializing. Now attempting to connect to the Push Gateway")

// Initialize PushGateway connection
val (url, job) = Utils.parsePushGatewayConfig(conf, logger)
val gateway = PushGateway(url, job)
private val gateway: PushGateway = PushGateway(
Utils.parsePushGatewayConfig(conf, logger)
)

var appId: String = SparkSession.getActiveSession match {
case Some(sparkSession) => sparkSession.sparkContext.applicationId
Expand Down
53 changes: 30 additions & 23 deletions src/main/scala/ch/cern/sparkmeasure/StageMetrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters.mapAsJavaMap
import scala.collection.mutable.{ListBuffer, LinkedHashMap}
import scala.math.{min, max}
import scala.collection.mutable.{LinkedHashMap, ListBuffer}
import scala.math.{max, min}

/**
* Stage Metrics: collects stage-level metrics with Stage granularity
* and provides aggregation and reporting functions for the end-user
* Stage Metrics: collects stage-level metrics with Stage granularity
* and provides aggregation and reporting functions for the end-user
*
* Example:
* val stageMetrics = ch.cern.sparkmeasure.StageMetrics(spark)
Expand Down Expand Up @@ -39,7 +39,7 @@ case class StageMetrics(sparkSession: SparkSession) {

// Marks the beginning of data collection
def begin(): Long = {
listenerStage.stageMetricsData.clear() // clear previous data to reduce memory footprint
listenerStage.stageMetricsData.clear() // clear previous data to reduce memory footprint
beginSnapshot = System.currentTimeMillis()
endSnapshot = beginSnapshot
beginSnapshot
Expand All @@ -63,7 +63,7 @@ case class StageMetrics(sparkSession: SparkSession) {

// Compute basic aggregations of the Stage metrics for the metrics report
// also filter op the time boundaries for the report
def aggregateStageMetrics() : LinkedHashMap[String, Long] = {
def aggregateStageMetrics(): LinkedHashMap[String, Long] = {

val agg = Utils.zeroMetricsStage()
var submissionTime = Long.MaxValue
Expand Down Expand Up @@ -108,14 +108,14 @@ case class StageMetrics(sparkSession: SparkSession) {
}

// Transforms aggregateStageMetrics output in a Java Map, needed by the Python API
def aggregateStageMetricsJavaMap() : java.util.Map[String, Long] = {
def aggregateStageMetricsJavaMap(): java.util.Map[String, Long] = {
mapAsJavaMap(aggregateStageMetrics())
}

// Extracts stages and their duration
def stagesDuration() : LinkedHashMap[Int, Long] = {
def stagesDuration(): LinkedHashMap[Int, Long] = {

val stages : LinkedHashMap[Int, Long] = LinkedHashMap.empty[Int,Long]
val stages: LinkedHashMap[Int, Long] = LinkedHashMap.empty[Int, Long]
for (metrics <- listenerStage.stageMetricsData.sortBy(_.stageId)
if (metrics.submissionTime >= beginSnapshot && metrics.completionTime <= endSnapshot)) {
stages += (metrics.stageId -> metrics.stageDuration)
Expand Down Expand Up @@ -166,10 +166,12 @@ case class StageMetrics(sparkSession: SparkSession) {
// between the end of the job and the time the last metrics value is received
// if you receive the error message java.util.NoSuchElementException: key not found:
// retry to run the report after a few seconds
def reportMemory(): String = {
def reportMemory(): String = {

var result = ListBuffer[String]()
val stages = {for (metrics <- listenerStage.stageMetricsData) yield metrics.stageId}.sorted
val stages = {
for (metrics <- listenerStage.stageMetricsData) yield metrics.stageId
}.sorted

// Additional details on executor (memory) metrics
result = result :+ "\nAdditional stage-level executor metrics (memory usage info):\n"
Expand All @@ -194,7 +196,7 @@ case class StageMetrics(sparkSession: SparkSession) {
if (executorMaxVal != "driver") {
s" on executor $executorMaxVal"
} else {
""
""
}
result = result :+ (messageHead + messageTail)
}
Expand Down Expand Up @@ -235,14 +237,14 @@ case class StageMetrics(sparkSession: SparkSession) {
s"max(completionTime) - min(submissionTime) as elapsedTime, sum(stageDuration) as stageDuration , " +
s"sum(executorRunTime) as executorRunTime, sum(executorCpuTime) as executorCpuTime, " +
s"sum(executorDeserializeTime) as executorDeserializeTime, sum(executorDeserializeCpuTime) as executorDeserializeCpuTime, " +
s"sum(resultSerializationTime) as resultSerializationTime, sum(jvmGCTime) as jvmGCTime, "+
s"sum(resultSerializationTime) as resultSerializationTime, sum(jvmGCTime) as jvmGCTime, " +
s"sum(shuffleFetchWaitTime) as shuffleFetchWaitTime, sum(shuffleWriteTime) as shuffleWriteTime, " +
s"max(resultSize) as resultSize, " +
s"sum(diskBytesSpilled) as diskBytesSpilled, sum(memoryBytesSpilled) as memoryBytesSpilled, " +
s"max(peakExecutionMemory) as peakExecutionMemory, sum(recordsRead) as recordsRead, sum(bytesRead) as bytesRead, " +
s"sum(recordsWritten) as recordsWritten, sum(bytesWritten) as bytesWritten, " +
s"sum(shuffleRecordsRead) as shuffleRecordsRead, sum(shuffleTotalBlocksFetched) as shuffleTotalBlocksFetched, "+
s"sum(shuffleLocalBlocksFetched) as shuffleLocalBlocksFetched, sum(shuffleRemoteBlocksFetched) as shuffleRemoteBlocksFetched, "+
s"sum(shuffleRecordsRead) as shuffleRecordsRead, sum(shuffleTotalBlocksFetched) as shuffleTotalBlocksFetched, " +
s"sum(shuffleLocalBlocksFetched) as shuffleLocalBlocksFetched, sum(shuffleRemoteBlocksFetched) as shuffleRemoteBlocksFetched, " +
s"sum(shuffleTotalBytesRead) as shuffleTotalBytesRead, sum(shuffleLocalBytesRead) as shuffleLocalBytesRead, " +
s"sum(shuffleRemoteBytesRead) as shuffleRemoteBytesRead, sum(shuffleRemoteBytesReadToDisk) as shuffleRemoteBytesReadToDisk, " +
s"sum(shuffleBytesWritten) as shuffleBytesWritten, sum(shuffleRecordsWritten) as shuffleRecordsWritten " +
Expand All @@ -269,10 +271,10 @@ case class StageMetrics(sparkSession: SparkSession) {
result = result :+ "Aggregated Spark stage metrics:"
val cols = aggregateDF.columns
result = result :+ (cols zip aggregateValues)
.map{
case(n:String, v:Long) => Utils.prettyPrintValues(n, v)
case(n: String, null) => n + " => null"
case(_,_) => ""
.map {
case (n: String, v: Long) => Utils.prettyPrintValues(n, v)
case (n: String, null) => n + " => null"
case (_, _) => ""
}.mkString("\n")
} else {
result = result :+ " no data to report "
Expand Down Expand Up @@ -314,13 +316,18 @@ case class StageMetrics(sparkSession: SparkSession) {
val aggregatedMetrics = aggregateStageMetrics()

/** Prepare a summary of the stage metrics for Prometheus. */
val pushGateway = PushGateway(serverIPnPort, metricsJob)
var str_metrics = s""
val pushGateway = PushGateway(
PushgatewayConfig(
serverIPnPort = serverIPnPort,
jobName = metricsJob
)
)

var str_metrics = s""
aggregatedMetrics.foreach {
case (metric: String, value: Long) =>
str_metrics += pushGateway.validateMetric(metric.toLowerCase()) + s" " + value.toString + s"\n"
}
str_metrics += pushGateway.validateMetric(metric.toLowerCase()) + s" " + value.toString + s"\n"
}

/** Send stage metrics to Prometheus. */
val metricsType = s"stage"
Expand Down
9 changes: 7 additions & 2 deletions src/main/scala/ch/cern/sparkmeasure/TaskMetrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,14 @@ case class TaskMetrics(sparkSession: SparkSession) {
val aggregatedMetrics = aggregateTaskMetrics()

/** Prepare a summary of the task metrics for Prometheus. */
val pushGateway = PushGateway(serverIPnPort, metricsJob)
var str_metrics = s""
val pushGateway = PushGateway(
PushgatewayConfig(
serverIPnPort = serverIPnPort,
jobName = metricsJob
)
)

var str_metrics = s""
aggregatedMetrics.foreach {
case (metric: String, value: Long) =>
str_metrics += pushGateway.validateMetric(metric.toLowerCase()) + s" " + value.toString + s"\n"
Expand Down
Loading

0 comments on commit 642fe9e

Please sign in to comment.