Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableCapability;
import org.apache.spark.sql.connector.metric.CustomMetric;
import org.apache.spark.sql.connector.metric.CustomTaskMetric;
import org.apache.spark.sql.connector.write.streaming.StreamingWrite;

/**
Expand Down Expand Up @@ -76,4 +77,14 @@ default StreamingWrite toStreaming() {
default CustomMetric[] supportedCustomMetrics() {
return new CustomMetric[]{};
}

/**
* Returns an array of custom metrics which are collected with values at the driver side only.
* Note that these metrics must be included in the supported custom metrics reported by
* `supportedCustomMetrics`.
*/
default CustomTaskMetric[] reportDriverMetrics() {
return new CustomTaskMetric[]{};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow,
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils}
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions._
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric}
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram, HistogramBin}
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning}
Expand Down Expand Up @@ -512,7 +512,11 @@ abstract class InMemoryBaseTable(
}

override def supportedCustomMetrics(): Array[CustomMetric] = {
Array(new InMemorySimpleCustomMetric)
Array(new InMemorySimpleCustomMetric, new InMemoryCustomDriverMetric)
}

override def reportDriverMetrics(): Array[CustomTaskMetric] = {
Array(new InMemoryCustomDriverTaskMetric(rows.size))
}
}
}
Expand Down Expand Up @@ -754,3 +758,13 @@ class InMemorySimpleCustomMetric extends CustomMetric {
s"in-memory rows: ${taskMetrics.sum}"
}
}

class InMemoryCustomDriverMetric extends CustomSumMetric {
override def name(): String = "number_of_rows_from_driver"
override def description(): String = "number of rows from driver"
}

class InMemoryCustomDriverTaskMetric(value: Long) extends CustomTaskMetric {
override def name(): String = "number_of_rows_from_driver"
override def value(): Long = value
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryExecNode}
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{LongAccumulator, Utils}
Expand Down Expand Up @@ -341,9 +341,22 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec {

override protected def run(): Seq[InternalRow] = {
val writtenRows = writeWithV2(write.toBatch)
postDriverMetrics()
refreshCache()
writtenRows
}

protected def postDriverMetrics(): Unit = {
val driveSQLMetrics = write.reportDriverMetrics().map(customTaskMetric => {
val metric = metrics(customTaskMetric.name())
metric.set(customTaskMetric.value())
metric
})

val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId,
driveSQLMetrics.toImmutableArraySeq)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class InMemoryTableMetricSuite
spark.sessionState.conf.clear()
}

private def testMetricOnDSv2(func: String => Unit, checker: Map[Long, String] => Unit): Unit = {
private def testMetricOnDSv2(func: String => Unit, checker: Map[String, String] => Unit): Unit = {
withTable("testcat.table_name") {
val statusStore = spark.sharedState.statusStore
val oldCount = statusStore.executionsList().size
Expand All @@ -67,8 +67,14 @@ class InMemoryTableMetricSuite
statusStore.executionsList().last.metricValues != null)
}

val execId = statusStore.executionsList().last.executionId
val metrics = statusStore.executionMetrics(execId)
val exec = statusStore.executionsList().last
val execId = exec.executionId
val sqlMetrics = exec.metrics.map { metric =>
metric.accumulatorId -> metric.name
}.toMap
val metrics = statusStore.executionMetrics(execId).map { case (k, v) =>
sqlMetrics(k) -> v
}
checker(metrics)
}
}
Expand All @@ -79,8 +85,8 @@ class InMemoryTableMetricSuite
val v2Writer = df.writeTo(table)
v2Writer.append()
}, metrics => {
val customMetric = metrics.find(_._2 == "in-memory rows: 1")
assert(customMetric.isDefined)
assert(metrics.get("number of rows in buffer").contains("in-memory rows: 1"))
assert(metrics.get("number of rows from driver").contains("1"))
})
}

Expand All @@ -90,8 +96,8 @@ class InMemoryTableMetricSuite
val v2Writer = df.writeTo(table)
v2Writer.overwrite(lit(true))
}, metrics => {
val customMetric = metrics.find(_._2 == "in-memory rows: 3")
assert(customMetric.isDefined)
assert(metrics.get("number of rows in buffer").contains("in-memory rows: 3"))
assert(metrics.get("number of rows from driver").contains("3"))
})
}
}