Skip to content

Commit f381c3d

Browse files
committed
chore: extract comparison tool from fuzzer
1 parent 3631b54 commit f381c3d

File tree

1 file changed

+73
-49
lines changed

1 file changed

+73
-49
lines changed

fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala

Lines changed: 73 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,22 @@ package org.apache.comet.fuzz
2121

2222
import java.io.{BufferedWriter, FileWriter, PrintWriter, StringWriter}
2323

24-
import scala.collection.mutable.WrappedArray
24+
import scala.collection.mutable
2525
import scala.io.Source
2626

2727
import org.apache.spark.sql.{Row, SparkSession}
2828

2929
object QueryRunner {
3030

31+
def createOutputMdFile(): BufferedWriter = {
32+
val outputFilename = s"results-${System.currentTimeMillis()}.md"
33+
// scalastyle:off println
34+
println(s"Writing results to $outputFilename")
35+
// scalastyle:on println
36+
37+
new BufferedWriter(new FileWriter(outputFilename))
38+
}
39+
3140
def runQueries(
3241
spark: SparkSession,
3342
numFiles: Int,
@@ -39,12 +48,7 @@ object QueryRunner {
3948
var cometFailureCount = 0
4049
var cometSuccessCount = 0
4150

42-
val outputFilename = s"results-${System.currentTimeMillis()}.md"
43-
// scalastyle:off println
44-
println(s"Writing results to $outputFilename")
45-
// scalastyle:on println
46-
47-
val w = new BufferedWriter(new FileWriter(outputFilename))
51+
val w = createOutputMdFile()
4852

4953
// register input files
5054
for (i <- 0 until numFiles) {
@@ -76,42 +80,13 @@ object QueryRunner {
7680
val cometRows = df.collect()
7781
val cometPlan = df.queryExecution.executedPlan.toString
7882

79-
var success = true
80-
if (sparkRows.length == cometRows.length) {
81-
var i = 0
82-
while (i < sparkRows.length) {
83-
val l = sparkRows(i)
84-
val r = cometRows(i)
85-
assert(l.length == r.length)
86-
for (j <- 0 until l.length) {
87-
if (!same(l(j), r(j))) {
88-
success = false
89-
showSQL(w, sql)
90-
showPlans(w, sparkPlan, cometPlan)
91-
w.write(s"First difference at row $i:\n")
92-
w.write("Spark: `" + formatRow(l) + "`\n")
93-
w.write("Comet: `" + formatRow(r) + "`\n")
94-
i = sparkRows.length
95-
}
96-
}
97-
i += 1
98-
}
99-
} else {
100-
success = false
101-
showSQL(w, sql)
102-
showPlans(w, sparkPlan, cometPlan)
103-
w.write(
104-
s"[ERROR] Spark produced ${sparkRows.length} rows and " +
105-
s"Comet produced ${cometRows.length} rows.\n")
106-
}
107-
108-
// check that the plan contains Comet operators
109-
if (!cometPlan.contains("Comet")) {
110-
success = false
111-
showSQL(w, sql)
112-
showPlans(w, sparkPlan, cometPlan)
113-
w.write("[ERROR] Comet did not accelerate any part of the plan\n")
114-
}
83+
val success = QueryComparison.assertSameRows(
84+
sparkRows,
85+
cometRows,
86+
sqlText = sql,
87+
sparkPlan,
88+
cometPlan,
89+
output = w)
11590

11691
if (success) {
11792
cometSuccessCount += 1
@@ -123,7 +98,7 @@ object QueryRunner {
12398
case e: Exception =>
12499
// the query worked in Spark but failed in Comet, so this is likely a bug in Comet
125100
cometFailureCount += 1
126-
showSQL(w, sql)
101+
QueryComparison.showSQL(w, sql)
127102
w.write("### Spark Plan\n")
128103
w.write(s"```\n$sparkPlan\n```\n")
129104

@@ -145,7 +120,7 @@ object QueryRunner {
145120
// we expect many generated queries to be invalid
146121
invalidQueryCount += 1
147122
if (showFailedSparkQueries) {
148-
showSQL(w, sql)
123+
QueryComparison.showSQL(w, sql)
149124
w.write(s"Query failed in Spark: ${e.getMessage}\n")
150125
}
151126
}
@@ -161,6 +136,56 @@ object QueryRunner {
161136
querySource.close()
162137
}
163138
}
139+
}
140+
141+
object QueryComparison {
142+
def assertSameRows(
143+
sparkRows: Array[Row],
144+
cometRows: Array[Row],
145+
sqlText: String,
146+
sparkPlan: String,
147+
cometPlan: String,
148+
output: BufferedWriter): Boolean = {
149+
var success = true
150+
if (sparkRows.length == cometRows.length) {
151+
var i = 0
152+
while (i < sparkRows.length) {
153+
val l = sparkRows(i)
154+
val r = cometRows(i)
155+
assert(l.length == r.length)
156+
for (j <- 0 until l.length) {
157+
if (!same(l(j), r(j))) {
158+
success = false
159+
showSQL(output, sqlText)
160+
showPlans(output, sparkPlan, cometPlan)
161+
output.write(s"First difference at row $i:\n")
162+
output.write("Spark: `" + formatRow(l) + "`\n")
163+
output.write("Comet: `" + formatRow(r) + "`\n")
164+
i = sparkRows.length
165+
}
166+
}
167+
i += 1
168+
}
169+
} else {
170+
success = false
171+
showSQL(output, sqlText)
172+
showPlans(output, sparkPlan, cometPlan)
173+
output.write(
174+
s"[ERROR] Spark produced ${sparkRows.length} rows and " +
175+
s"Comet produced ${cometRows.length} rows.\n")
176+
}
177+
178+
// check that the plan contains Comet operators
179+
if (!cometPlan.contains("Comet")) {
180+
success = false
181+
showSQL(output, sqlText)
182+
showPlans(output, sparkPlan, cometPlan)
183+
output.write("[ERROR] Comet did not accelerate any part of the plan\n")
184+
}
185+
186+
success
187+
188+
}
164189

165190
private def same(l: Any, r: Any): Boolean = {
166191
if (l == null || r == null) {
@@ -179,7 +204,7 @@ object QueryRunner {
179204
case (a: Double, b: Double) => (a - b).abs <= 0.000001
180205
case (a: Array[_], b: Array[_]) =>
181206
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
182-
case (a: WrappedArray[_], b: WrappedArray[_]) =>
207+
case (a: mutable.WrappedArray[_], b: mutable.WrappedArray[_]) =>
183208
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
184209
case (a: Row, b: Row) =>
185210
val aa = a.toSeq
@@ -192,7 +217,7 @@ object QueryRunner {
192217
private def format(value: Any): String = {
193218
value match {
194219
case null => "NULL"
195-
case v: WrappedArray[_] => s"[${v.map(format).mkString(",")}]"
220+
case v: mutable.WrappedArray[_] => s"[${v.map(format).mkString(",")}]"
196221
case v: Array[Byte] => s"[${v.mkString(",")}]"
197222
case r: Row => formatRow(r)
198223
case other => other.toString
@@ -203,7 +228,7 @@ object QueryRunner {
203228
row.toSeq.map(format).mkString(",")
204229
}
205230

206-
private def showSQL(w: BufferedWriter, sql: String, maxLength: Int = 120): Unit = {
231+
def showSQL(w: BufferedWriter, sql: String, maxLength: Int = 120): Unit = {
207232
w.write("## SQL\n")
208233
w.write("```\n")
209234
val words = sql.split(" ")
@@ -229,5 +254,4 @@ object QueryRunner {
229254
w.write("### Comet Plan\n")
230255
w.write(s"```\n$cometPlan\n```\n")
231256
}
232-
233257
}

0 commit comments

Comments
 (0)