Skip to content

Commit fe9cc4e

Browse files
committed
Add test.
1 parent 454acb4 commit fe9cc4e

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow}
3232
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils}
3333
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
3434
import org.apache.spark.sql.connector.expressions._
35+
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
3536
import org.apache.spark.sql.connector.read._
3637
import org.apache.spark.sql.connector.write._
3738
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
@@ -344,6 +345,10 @@ class InMemoryTable(
344345
case exc: StreamingNotSupportedOperation => exc.throwsException()
345346
case s => s
346347
}
348+
349+
override def supportedCustomMetrics(): Array[CustomMetric] = {
350+
Array(new InMemorySimpleCustomMetric)
351+
}
347352
}
348353
}
349354
}
@@ -604,4 +609,21 @@ private class BufferWriter extends DataWriter[InternalRow] {
604609
override def abort(): Unit = {}
605610

606611
override def close(): Unit = {}
612+
613+
override def currentMetricsValues(): Array[CustomTaskMetric] = {
614+
val metric = new CustomTaskMetric {
615+
override def name(): String = "in_memory_buffer_rows"
616+
617+
override def value(): Long = buffer.rows.size
618+
}
619+
Array(metric)
620+
}
621+
}
622+
623+
class InMemorySimpleCustomMetric extends CustomMetric {
624+
override def name(): String = "in_memory_buffer_rows"
625+
override def description(): String = "number of rows in buffer"
626+
override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
627+
s"in-memory rows: ${taskMetrics.sum}"
628+
}
607629
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution.datasources
18+
19+
import java.util.Collections
20+
21+
import org.scalatest.BeforeAndAfter
22+
import org.scalatest.time.SpanSugar._
23+
24+
import org.apache.spark.sql.QueryTest
25+
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog}
26+
import org.apache.spark.sql.functions.lit
27+
import org.apache.spark.sql.test.SharedSparkSession
28+
import org.apache.spark.sql.types.StructType
29+
30+
class FileFormatDataWriterMetricSuite
31+
extends QueryTest with SharedSparkSession with BeforeAndAfter {
32+
import testImplicits._
33+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
34+
35+
before {
36+
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
37+
}
38+
39+
after {
40+
spark.sessionState.catalogManager.reset()
41+
spark.sessionState.conf.clear()
42+
}
43+
44+
private def testMetricOnDSv2(func: String => Unit, checker: Map[Long, String] => Unit) {
45+
withTable("testcat.table_name") {
46+
val statusStore = spark.sharedState.statusStore
47+
val oldCount = statusStore.executionsList().size
48+
49+
val testCatalog = spark.sessionState.catalogManager.catalog("testcat").asTableCatalog
50+
51+
testCatalog.createTable(
52+
Identifier.of(Array(), "table_name"),
53+
new StructType().add("i", "int"),
54+
Array.empty, Collections.emptyMap[String, String])
55+
56+
func("testcat.table_name")
57+
58+
// Wait until the new execution is started and being tracked.
59+
eventually(timeout(10.seconds), interval(10.milliseconds)) {
60+
assert(statusStore.executionsCount() >= oldCount)
61+
}
62+
63+
// Wait for listener to finish computing the metrics for the execution.
64+
eventually(timeout(10.seconds), interval(10.milliseconds)) {
65+
assert(statusStore.executionsList().nonEmpty &&
66+
statusStore.executionsList().last.metricValues != null)
67+
}
68+
69+
val execId = statusStore.executionsList().last.executionId
70+
val metrics = statusStore.executionMetrics(execId)
71+
checker(metrics)
72+
}
73+
}
74+
75+
test("Report metrics from Datasource v2 write: append") {
76+
testMetricOnDSv2(table => {
77+
val df = sql("select 1 as i")
78+
val v2Writer = df.writeTo(table)
79+
v2Writer.append()
80+
}, metrics => {
81+
val customMetric = metrics.find(_._2 == "in-memory rows: 1")
82+
assert(customMetric.isDefined)
83+
})
84+
}
85+
86+
test("Report metrics from Datasource v2 write: overwrite") {
87+
testMetricOnDSv2(table => {
88+
val df = Seq(1, 2, 3).toDF("i")
89+
val v2Writer = df.writeTo(table)
90+
v2Writer.overwrite(lit(true))
91+
}, metrics => {
92+
val customMetric = metrics.find(_._2 == "in-memory rows: 3")
93+
assert(customMetric.isDefined)
94+
})
95+
}
96+
}

0 commit comments

Comments
 (0)