diff --git a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala index 9a37e0221b27..89e927e5784d 100644 --- a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala @@ -25,7 +25,12 @@ import java.io.{File, FileOutputStream, OutputStream} abstract class BenchmarkBase { var output: Option[OutputStream] = None - def benchmark(): Unit + /** + * Main process of the whole benchmark. + * Implementations of this method are supposed to use the wrapper method `runBenchmark` + * for each benchmark scenario. + */ + def runBenchmarkSuite(): Unit final def runBenchmark(benchmarkName: String)(func: => Any): Unit = { val separator = "=" * 96 @@ -46,7 +51,7 @@ abstract class BenchmarkBase { output = Some(new FileOutputStream(file)) } - benchmark() + runBenchmarkSuite() output.foreach { o => if (o != null) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala index 1a2216ea070c..6c1d58089867 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder */ object UDTSerializationBenchmark extends BenchmarkBase { - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("VectorUDT de/serialization") { val iters = 1e2.toInt diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index cbe723fd11c6..e7a99485cdf0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -41,7 +41,7 @@ object UnsafeProjectionBenchmark extends BenchmarkBase { (1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray } - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("unsafe projection") { val iters = 1024 * 16 val numRows = 1024 * 16 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 296ae104a94a..86e0df2fea35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -44,7 +44,7 @@ import org.apache.spark.unsafe.map.BytesToBytesMap */ object AggregateBenchmark extends SqlBasedBenchmark { - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("aggregate without grouping") { val N = 500L << 22 codegenBenchmark("agg w/o group", N) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 7cdf653e3869..cf05ca336171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -198,7 +198,7 @@ object FilterPushdownBenchmark extends BenchmarkBase with SQLHelper { } } - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("Pushdown for many distinct value case") { withTempPath { dir => withTempTable("orcTable", "parquetTable") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala index 8b275188f06d..83edf73abfae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -36,7 +36,7 @@ object PrimitiveArrayBenchmark extends BenchmarkBase { .config("spark.sql.autoBroadcastJoinThreshold", 1) .getOrCreate() - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("Write primitive arrays in dataset") { writeDatasetArray(4) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 958a06440214..9a54e2320b80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -119,7 +119,7 @@ object SortBenchmark extends BenchmarkBase { benchmark.run() } - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("radix sort") { sortBenchmark() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index ff0e4acd3127..0f9079744a22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -233,7 +233,7 @@ object CompressionSchemeBenchmark extends BenchmarkBase with AllCompressionSchem runDecodeBenchmark("STRING Decode", iters, count, STRING, testData) } - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("Compression Scheme Benchmark") { bitEncodingBenchmark(1024) shortEncodingBenchmark(1024) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index df6ab14e661c..f311465e582a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -443,7 +443,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { benchmark.run } - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("Int Read/Write") { intAccess(1024 * 40) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index 0bb5e8c14159..870ad4818eb2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -336,7 +336,7 @@ object OrcReadBenchmark extends BenchmarkBase with SQLHelper { } } - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("SQL Single Numeric Column Scan") { Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => numericScanBenchmark(1024 * 1024 * 15, dataType)