diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 67740c316647..3081ff935f04 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -22,7 +22,6 @@ import org.scalatest.Suite import org.scalatest.Tag import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode @@ -57,7 +56,7 @@ trait CodegenInterpretedPlanTest extends PlanTest { * Provides helper methods for comparing plans, but without the overhead of * mandating a FunSuite. */ -trait PlanTestBase extends PredicateHelper { self: Suite => +trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules protected def conf = SQLConf.get @@ -174,32 +173,4 @@ trait PlanTestBase extends PredicateHelper { self: Suite => plan1 == plan2 } } - - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL - * configurations. - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val conf = SQLConf.get - val (keys, values) = pairs.unzip - val currentValues = keys.map { key => - if (conf.contains(key)) { - Some(conf.getConfString(key)) - } else { - None - } - } - (keys, values).zipped.foreach { (k, v) => - if (SQLConf.staticConfKeys.contains(k)) { - throw new AnalysisException(s"Cannot modify the value of a static config: $k") - } - conf.setConfString(k, v) - } - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => conf.setConfString(key, value) - case (key, None) => conf.unsetConf(key) - } - } - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala new file mode 100644 index 000000000000..4d869d79ad59 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans + +import java.io.File + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +trait SQLHelper { + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL + * configurations. + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => + if (SQLConf.staticConfKeys.contains(k)) { + throw new AnalysisException(s"Cannot modify the value of a static config: $k") + } + conf.setConfString(k, v) + } + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } + + /** + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + */ + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index cf9bda2fb1ff..51a7f9f1ef09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.execution.benchmark import java.io.File import scala.collection.JavaConverters._ -import scala.util.{Random, Try} +import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnVector -import org.apache.spark.util.Utils /** @@ -37,7 +37,7 @@ import org.apache.spark.util.Utils * To run this: * spark-submit --class */ -object DataSourceReadBenchmark { +object DataSourceReadBenchmark extends SQLHelper { val conf = new SparkConf() .setAppName("DataSourceReadBenchmark") // Since `spark.master` always exists, overrides this value @@ -54,27 +54,10 @@ object DataSourceReadBenchmark { spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - def withTempTable(tableNames: String*)(f: => Unit): Unit = { try f finally tableNames.foreach(spark.catalog.dropTempView) } - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { val testDf = if (partition.isDefined) { df.write.partitionBy(partition.get) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 3b7f10783b64..7cdf653e3869 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.execution.benchmark import java.io.File -import scala.util.{Random, Try} +import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType} -import org.apache.spark.util.Utils /** * Benchmark to measure read performance with Filter pushdown. @@ -40,7 +40,7 @@ import org.apache.spark.util.Utils * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". * }}} */ -object FilterPushdownBenchmark extends BenchmarkBase { +object FilterPushdownBenchmark extends BenchmarkBase with SQLHelper { private val conf = new SparkConf() .setAppName(this.getClass.getSimpleName) @@ -60,28 +60,10 @@ object FilterPushdownBenchmark extends BenchmarkBase { private val spark = SparkSession.builder().config(conf).getOrCreate() - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - def withTempTable(tableNames: String*)(f: => Unit): Unit = { try f finally tableNames.foreach(spark.catalog.dropTempView) } - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - private def prepareTable( dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = { import spark.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index 6d319eb723d9..5d1a874999c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -16,21 +16,19 @@ */ package org.apache.spark.sql.execution.datasources.csv -import java.io.File - import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{Column, Row, SparkSession} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils /** * Benchmark to measure CSV read/write performance. * To run this: * spark-submit --class --jars */ -object CSVBenchmarks { +object CSVBenchmarks extends SQLHelper { val conf = new SparkConf() val spark = SparkSession.builder @@ -40,12 +38,6 @@ object CSVBenchmarks { .getOrCreate() import spark.implicits._ - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = { val benchmark = new Benchmark(s"Parsing quoted values", rowsNum) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala index e40cb9b50148..368318ab38cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala @@ -21,16 +21,16 @@ import java.io.File import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils /** * The benchmarks aims to measure performance of JSON parsing when encoding is set and isn't. * To run this: * spark-submit --class --jars */ -object JSONBenchmarks { +object JSONBenchmarks extends SQLHelper { val conf = new SparkConf() val spark = SparkSession.builder @@ -40,13 +40,6 @@ object JSONBenchmarks { .getOrCreate() import spark.implicits._ - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def schemaInferring(rowsNum: Int): Unit = { val benchmark = new Benchmark("JSON schema inferring", rowsNum) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala index fe59cb25d500..cbac1c13cdd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala @@ -25,12 +25,12 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.util.Utils -abstract class CheckpointFileManagerTests extends SparkFunSuite { +abstract class CheckpointFileManagerTests extends SparkFunSuite with SQLHelper { def createManager(path: Path): CheckpointFileManager @@ -88,12 +88,6 @@ abstract class CheckpointFileManagerTests extends SparkFunSuite { fm.delete(path) // should not throw exception } } - - protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } } class CheckpointFileManagerSuite extends SparkFunSuite with SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 2fb8f70a2079..6b03d1e5b766 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -40,7 +40,6 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.UninterruptibleThread import org.apache.spark.util.Utils @@ -167,18 +166,6 @@ private[sql] trait SQLTestUtilsBase super.withSQLConf(pairs: _*)(f) } - /** - * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If - * a file/directory is created there by `f`, it will be delete after `f` returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - /** * Copy file in jar's resource to a temp file, then pass it to `f`. * This function is used to make `f` can use the path of temp file(e.g. file:/), instead of diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index 0eab7d1ea8e8..49de007df382 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.hive.orc import java.io.File -import scala.util.{Random, Try} +import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils + /** * Benchmark to measure ORC read performance. @@ -34,7 +35,7 @@ import org.apache.spark.util.Utils * This is in `sql/hive` module in order to compare `sql/core` and `sql/hive` ORC data sources. */ // scalastyle:off line.size.limit -object OrcReadBenchmark { +object OrcReadBenchmark extends SQLHelper { val conf = new SparkConf() conf.set("orc.compression", "snappy") @@ -47,28 +48,10 @@ object OrcReadBenchmark { // Set default configs. Individual cases will change them if necessary. spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - def withTempTable(tableNames: String*)(f: => Unit): Unit = { try f finally tableNames.foreach(spark.catalog.dropTempView) } - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - private val NATIVE_ORC_FORMAT = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName private val HIVE_ORC_FORMAT = classOf[org.apache.spark.sql.hive.orc.OrcFileFormat].getCanonicalName