@@ -21,13 +21,22 @@ package org.apache.comet.fuzz
2121
2222import java .io .{BufferedWriter , FileWriter , PrintWriter , StringWriter }
2323
24- import scala .collection .mutable . WrappedArray
24+ import scala .collection .mutable
2525import scala .io .Source
2626
2727import org .apache .spark .sql .{Row , SparkSession }
2828
2929object QueryRunner {
3030
31+ def createOutputMdFile (): BufferedWriter = {
32+ val outputFilename = s " results- ${System .currentTimeMillis()}.md "
33+ // scalastyle:off println
34+ println(s " Writing results to $outputFilename" )
35+ // scalastyle:on println
36+
37+ new BufferedWriter (new FileWriter (outputFilename))
38+ }
39+
3140 def runQueries (
3241 spark : SparkSession ,
3342 numFiles : Int ,
@@ -39,12 +48,7 @@ object QueryRunner {
3948 var cometFailureCount = 0
4049 var cometSuccessCount = 0
4150
42- val outputFilename = s " results- ${System .currentTimeMillis()}.md "
43- // scalastyle:off println
44- println(s " Writing results to $outputFilename" )
45- // scalastyle:on println
46-
47- val w = new BufferedWriter (new FileWriter (outputFilename))
51+ val w = createOutputMdFile()
4852
4953 // register input files
5054 for (i <- 0 until numFiles) {
@@ -76,42 +80,13 @@ object QueryRunner {
7680 val cometRows = df.collect()
7781 val cometPlan = df.queryExecution.executedPlan.toString
7882
79- var success = true
80- if (sparkRows.length == cometRows.length) {
81- var i = 0
82- while (i < sparkRows.length) {
83- val l = sparkRows(i)
84- val r = cometRows(i)
85- assert(l.length == r.length)
86- for (j <- 0 until l.length) {
87- if (! same(l(j), r(j))) {
88- success = false
89- showSQL(w, sql)
90- showPlans(w, sparkPlan, cometPlan)
91- w.write(s " First difference at row $i: \n " )
92- w.write(" Spark: `" + formatRow(l) + " `\n " )
93- w.write(" Comet: `" + formatRow(r) + " `\n " )
94- i = sparkRows.length
95- }
96- }
97- i += 1
98- }
99- } else {
100- success = false
101- showSQL(w, sql)
102- showPlans(w, sparkPlan, cometPlan)
103- w.write(
104- s " [ERROR] Spark produced ${sparkRows.length} rows and " +
105- s " Comet produced ${cometRows.length} rows. \n " )
106- }
107-
108- // check that the plan contains Comet operators
109- if (! cometPlan.contains(" Comet" )) {
110- success = false
111- showSQL(w, sql)
112- showPlans(w, sparkPlan, cometPlan)
113- w.write(" [ERROR] Comet did not accelerate any part of the plan\n " )
114- }
83+ val success = QueryComparison .assertSameRows(
84+ sparkRows,
85+ cometRows,
86+ sqlText = sql,
87+ sparkPlan,
88+ cometPlan,
89+ output = w)
11590
11691 if (success) {
11792 cometSuccessCount += 1
@@ -123,7 +98,7 @@ object QueryRunner {
12398 case e : Exception =>
12499 // the query worked in Spark but failed in Comet, so this is likely a bug in Comet
125100 cometFailureCount += 1
126- showSQL(w, sql)
101+ QueryComparison . showSQL(w, sql)
127102 w.write(" ### Spark Plan\n " )
128103 w.write(s " ``` \n $sparkPlan\n ``` \n " )
129104
@@ -145,7 +120,7 @@ object QueryRunner {
145120 // we expect many generated queries to be invalid
146121 invalidQueryCount += 1
147122 if (showFailedSparkQueries) {
148- showSQL(w, sql)
123+ QueryComparison . showSQL(w, sql)
149124 w.write(s " Query failed in Spark: ${e.getMessage}\n " )
150125 }
151126 }
@@ -161,6 +136,56 @@ object QueryRunner {
161136 querySource.close()
162137 }
163138 }
139+ }
140+
141+ object QueryComparison {
142+ def assertSameRows (
143+ sparkRows : Array [Row ],
144+ cometRows : Array [Row ],
145+ sqlText : String ,
146+ sparkPlan : String ,
147+ cometPlan : String ,
148+ output : BufferedWriter ): Boolean = {
149+ var success = true
150+ if (sparkRows.length == cometRows.length) {
151+ var i = 0
152+ while (i < sparkRows.length) {
153+ val l = sparkRows(i)
154+ val r = cometRows(i)
155+ assert(l.length == r.length)
156+ for (j <- 0 until l.length) {
157+ if (! same(l(j), r(j))) {
158+ success = false
159+ showSQL(output, sqlText)
160+ showPlans(output, sparkPlan, cometPlan)
161+ output.write(s " First difference at row $i: \n " )
162+ output.write(" Spark: `" + formatRow(l) + " `\n " )
163+ output.write(" Comet: `" + formatRow(r) + " `\n " )
164+ i = sparkRows.length
165+ }
166+ }
167+ i += 1
168+ }
169+ } else {
170+ success = false
171+ showSQL(output, sqlText)
172+ showPlans(output, sparkPlan, cometPlan)
173+ output.write(
174+ s " [ERROR] Spark produced ${sparkRows.length} rows and " +
175+ s " Comet produced ${cometRows.length} rows. \n " )
176+ }
177+
178+ // check that the plan contains Comet operators
179+ if (! cometPlan.contains(" Comet" )) {
180+ success = false
181+ showSQL(output, sqlText)
182+ showPlans(output, sparkPlan, cometPlan)
183+ output.write(" [ERROR] Comet did not accelerate any part of the plan\n " )
184+ }
185+
186+ success
187+
188+ }
164189
165190 private def same (l : Any , r : Any ): Boolean = {
166191 if (l == null || r == null ) {
@@ -179,7 +204,7 @@ object QueryRunner {
179204 case (a : Double , b : Double ) => (a - b).abs <= 0.000001
180205 case (a : Array [_], b : Array [_]) =>
181206 a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
182- case (a : WrappedArray [_], b : WrappedArray [_]) =>
207+ case (a : mutable. WrappedArray [_], b : mutable. WrappedArray [_]) =>
183208 a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
184209 case (a : Row , b : Row ) =>
185210 val aa = a.toSeq
@@ -192,7 +217,7 @@ object QueryRunner {
192217 private def format (value : Any ): String = {
193218 value match {
194219 case null => " NULL"
195- case v : WrappedArray [_] => s " [ ${v.map(format).mkString(" ," )}] "
220+ case v : mutable. WrappedArray [_] => s " [ ${v.map(format).mkString(" ," )}] "
196221 case v : Array [Byte ] => s " [ ${v.mkString(" ," )}] "
197222 case r : Row => formatRow(r)
198223 case other => other.toString
@@ -203,7 +228,7 @@ object QueryRunner {
203228 row.toSeq.map(format).mkString(" ," )
204229 }
205230
206- private def showSQL (w : BufferedWriter , sql : String , maxLength : Int = 120 ): Unit = {
231+ def showSQL (w : BufferedWriter , sql : String , maxLength : Int = 120 ): Unit = {
207232 w.write(" ## SQL\n " )
208233 w.write(" ```\n " )
209234 val words = sql.split(" " )
@@ -229,5 +254,4 @@ object QueryRunner {
229254 w.write(" ### Comet Plan\n " )
230255 w.write(s " ``` \n $cometPlan\n ``` \n " )
231256 }
232-
233257}
0 commit comments