From b9d41cd79f05f6c420d070ad07cdfa8f853fd461 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Sat, 14 Oct 2017 23:04:16 -0400 Subject: [PATCH 01/15] Separate out the portion of SharedSQLContext that requires a FunSuite from the part that works with just any old test suite. --- .../spark/sql/catalyst/plans/PlanTest.scala | 151 +-------- .../sql/catalyst/plans/PlanTestBase.scala | 177 +++++++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 269 +--------------- .../spark/sql/test/SQLTestUtilsBase.scala | 289 ++++++++++++++++++ .../spark/sql/test/ShareSparkSession.scala | 119 ++++++++ .../spark/sql/test/SharedSQLContext.scala | 80 +---- 6 files changed, 595 insertions(+), 490 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTestBase.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtilsBase.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/ShareSparkSession.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 10bdfafd6f933..2cb0671c57bee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -18,158 +18,9 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -trait PlanTest extends SparkFunSuite with PredicateHelper { - - // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules - protected def conf = SQLConf.get - - /** - * Since attribute references are given globally unique ids during analysis, - * we must normalize them to check if two different queries are identical. - */ - protected def normalizeExprIds(plan: LogicalPlan) = { - plan transformAllExpressions { - case s: ScalarSubquery => - s.copy(exprId = ExprId(0)) - case e: Exists => - e.copy(exprId = ExprId(0)) - case l: ListQuery => - l.copy(exprId = ExprId(0)) - case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) - case a: Alias => - Alias(a.child, a.name)(exprId = ExprId(0)) - case ae: AggregateExpression => - ae.copy(resultId = ExprId(0)) - } - } - - /** - * Normalizes plans: - * - Filter the filter conditions that appear in a plan. For instance, - * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) - * etc., will all now be equivalent. - * - Sample the seed will replaced by 0L. - * - Join conditions will be resorted by hashCode. - */ - protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { - plan transform { - case Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).map(rewriteEqual).sortBy(_.hashCode()) - .reduce(And), child) - case sample: Sample => - sample.copy(seed = 0L) - case Join(left, right, joinType, condition) if condition.isDefined => - val newCondition = - splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) - .reduce(And) - Join(left, right, joinType, Some(newCondition)) - } - } - - /** - * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be - * equivalent: - * 1. (a = b), (b = a); - * 2. (a <=> b), (b <=> a). - */ - private def rewriteEqual(condition: Expression): Expression = condition match { - case eq @ EqualTo(l: Expression, r: Expression) => - Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) - case eq @ EqualNullSafe(l: Expression, r: Expression) => - Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) - case _ => condition // Don't reorder. - } - - /** Fails the test if the two plans do not match */ - protected def comparePlans( - plan1: LogicalPlan, - plan2: LogicalPlan, - checkAnalysis: Boolean = true): Unit = { - if (checkAnalysis) { - // Make sure both plan pass checkAnalysis. - SimpleAnalyzer.checkAnalysis(plan1) - SimpleAnalyzer.checkAnalysis(plan2) - } - - val normalized1 = normalizePlan(normalizeExprIds(plan1)) - val normalized2 = normalizePlan(normalizeExprIds(plan2)) - if (normalized1 != normalized2) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } - } - - /** Fails the test if the two expressions do not match */ - protected def compareExpressions(e1: Expression, e2: Expression): Unit = { - comparePlans(Filter(e1, OneRowRelation()), Filter(e2, OneRowRelation()), checkAnalysis = false) - } - - /** Fails the test if the join order in the two plans do not match */ - protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizePlan(normalizeExprIds(plan1)) - val normalized2 = normalizePlan(normalizeExprIds(plan2)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } - } - - /** Consider symmetry for joins when comparing plans. */ - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) - case (p1: Project, p2: Project) => - p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) - case _ => - plan1 == plan2 - } - } - - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL - * configurations. - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val conf = SQLConf.get - val (keys, values) = pairs.unzip - val currentValues = keys.map { key => - if (conf.contains(key)) { - Some(conf.getConfString(key)) - } else { - None - } - } - (keys, values).zipped.foreach { (k, v) => - if (SQLConf.staticConfKeys.contains(k)) { - throw new AnalysisException(s"Cannot modify the value of a static config: $k") - } - conf.setConfString(k, v) - } - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => conf.setConfString(key, value) - case (key, None) => conf.unsetConf(key) - } - } - } +trait PlanTest extends SparkFunSuite with PlanTestBase { } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTestBase.scala new file mode 100644 index 0000000000000..264abda6f690a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTestBase.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.scalatest.Suite + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Provides helper methods for comparing plans, but without the overhead of + * mandating a FunSuite. + */ +trait PlanTestBase extends PredicateHelper { self: Suite => + + // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules + protected def conf = SQLConf.get + + /** + * Since attribute references are given globally unique ids during analysis, + * we must normalize them to check if two different queries are identical. + */ + protected def normalizeExprIds(plan: LogicalPlan) = { + plan transformAllExpressions { + case s: ScalarSubquery => + s.copy(exprId = ExprId(0)) + case e: Exists => + e.copy(exprId = ExprId(0)) + case l: ListQuery => + l.copy(exprId = ExprId(0)) + case a: AttributeReference => + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) + case a: Alias => + Alias(a.child, a.name)(exprId = ExprId(0)) + case ae: AggregateExpression => + ae.copy(resultId = ExprId(0)) + } + } + + /** + * Normalizes plans: + * - Filter the filter conditions that appear in a plan. For instance, + * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) + * etc., will all now be equivalent. + * - Sample the seed will replaced by 0L. + * - Join conditions will be resorted by hashCode. + */ + protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { + plan transform { + case Filter(condition: Expression, child: LogicalPlan) => + Filter(splitConjunctivePredicates(condition).map(rewriteEqual).sortBy(_.hashCode()) + .reduce(And), child) + case sample: Sample => + sample.copy(seed = 0L) + case Join(left, right, joinType, condition) if condition.isDefined => + val newCondition = + splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) + .reduce(And) + Join(left, right, joinType, Some(newCondition)) + } + } + + /** + * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be + * equivalent: + * 1. (a = b), (b = a); + * 2. (a <=> b), (b <=> a). + */ + private def rewriteEqual(condition: Expression): Expression = condition match { + case eq @ EqualTo(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) + case eq @ EqualNullSafe(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) + case _ => condition // Don't reorder. + } + + /** Fails the test if the two plans do not match */ + protected def comparePlans( + plan1: LogicalPlan, + plan2: LogicalPlan, + checkAnalysis: Boolean = true): Unit = { + if (checkAnalysis) { + // Make sure both plan pass checkAnalysis. + SimpleAnalyzer.checkAnalysis(plan1) + SimpleAnalyzer.checkAnalysis(plan2) + } + + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) + if (normalized1 != normalized2) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Fails the test if the two expressions do not match */ + protected def compareExpressions(e1: Expression, e2: Expression): Unit = { + comparePlans(Filter(e1, OneRowRelation()), Filter(e2, OneRowRelation()), checkAnalysis = false) + } + + /** Fails the test if the join order in the two plans do not match */ + protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) { + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) + if (!sameJoinPlan(normalized1, normalized2)) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Consider symmetry for joins when comparing plans. */ + private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + (plan1, plan2) match { + case (j1: Join, j2: Join) => + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + case (p1: Project, p2: Project) => + p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) + case _ => + plan1 == plan2 + } + } + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => + if (SQLConf.staticConfKeys.contains(k)) { + throw new AnalysisException(s"Cannot modify the value of a static config: $k") + } + conf.setConfString(k, v) + } + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index a14a1441a4313..7650f5c112d5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -17,33 +17,17 @@ package org.apache.spark.sql.test -import java.io.File -import java.net.URI -import java.nio.file.Files -import java.util.{Locale, UUID} - -import scala.concurrent.duration._ -import scala.language.implicitConversions import scala.util.control.NonFatal -import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterAll -import org.scalatest.concurrent.Eventually - import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException -import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.FilterExec -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{UninterruptibleThread, Utils} +import org.apache.spark.util.UninterruptibleThread /** - * Helper trait that should be extended by all SQL test suites. + * Helper trait that should be extended by all SQL test suites within the Spark + * code base. * * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. @@ -52,39 +36,10 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ -private[sql] trait SQLTestUtils - extends SparkFunSuite with Eventually - with BeforeAndAfterAll - with SQLTestData - with PlanTest { self => - - protected def sparkContext = spark.sparkContext - +private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with PlanTest { // Whether to materialize all test data before the first test is run private var loadTestDataBeforeTests = false - // Shorthand for running a query using our SQLContext - protected lazy val sql = spark.sql _ - - /** - * A helper object for importing SQL implicits. - * - * Note that the alternative of importing `spark.implicits._` is not possible here. - * This is because we create the `SQLContext` immediately before the first test is run, - * but the implicits import is needed in the constructor. - */ - protected object testImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.spark.sqlContext - } - - /** - * Materialize the test data immediately after the `SQLContext` is set up. - * This is necessary if the data is accessed by name but not through direct reference. - */ - protected def setupTestData(): Unit = { - loadTestDataBeforeTests = true - } - protected override def beforeAll(): Unit = { super.beforeAll() if (loadTestDataBeforeTests) { @@ -92,209 +47,12 @@ private[sql] trait SQLTestUtils } } - protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - SparkSession.setActiveSession(spark) - super.withSQLConf(pairs: _*)(f) - } - /** - * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If - * a file/directory is created there by `f`, it will be delete after `f` returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - /** - * Copy file in jar's resource to a temp file, then pass it to `f`. - * This function is used to make `f` can use the path of temp file(e.g. file:/), instead of - * path of jar's resource which starts with 'jar:file:/' - */ - protected def withResourceTempPath(resourcePath: String)(f: File => Unit): Unit = { - val inputStream = - Thread.currentThread().getContextClassLoader.getResourceAsStream(resourcePath) - withTempDir { dir => - val tmpFile = new File(dir, "tmp") - Files.copy(inputStream, tmpFile.toPath) - f(tmpFile) - } - } - - /** - * Waits for all tasks on all executors to be finished. - */ - protected def waitForTasksToFinish(): Unit = { - eventually(timeout(10.seconds)) { - assert(spark.sparkContext.statusTracker - .getExecutorInfos.map(_.numRunningTasks()).sum == 0) - } - } - - /** - * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` - * returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally { - // wait for all tasks to finish before deleting files - waitForTasksToFinish() - Utils.deleteRecursively(dir) - } - } - - /** - * Creates the specified number of temporary directories, which is then passed to `f` and will be - * deleted after `f` returns. - */ - protected def withTempPaths(numPaths: Int)(f: Seq[File] => Unit): Unit = { - val files = Array.fill[File](numPaths)(Utils.createTempDir().getCanonicalFile) - try f(files) finally { - // wait for all tasks to finish before deleting files - waitForTasksToFinish() - files.foreach(Utils.deleteRecursively) - } - } - - /** - * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). - */ - protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { - try { - f - } catch { - case cause: Throwable => throw cause - } finally { - // If the test failed part way, we don't want to mask the failure by failing to remove - // temp tables that never got created. - functions.foreach { case (functionName, isTemporary) => - val withTemporary = if (isTemporary) "TEMPORARY" else "" - spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") - assert( - !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), - s"Function $functionName should have been dropped. But, it still exists.") - } - } - } - - /** - * Drops temporary table `tableName` after calling `f`. - */ - protected def withTempView(tableNames: String*)(f: => Unit): Unit = { - try f finally { - // If the test failed part way, we don't want to mask the failure by failing to remove - // temp tables that never got created. - try tableNames.foreach(spark.catalog.dropTempView) catch { - case _: NoSuchTableException => - } - } - } - - /** - * Drops table `tableName` after calling `f`. - */ - protected def withTable(tableNames: String*)(f: => Unit): Unit = { - try f finally { - tableNames.foreach { name => - spark.sql(s"DROP TABLE IF EXISTS $name") - } - } - } - - /** - * Drops view `viewName` after calling `f`. - */ - protected def withView(viewNames: String*)(f: => Unit): Unit = { - try f finally { - viewNames.foreach { name => - spark.sql(s"DROP VIEW IF EXISTS $name") - } - } - } - - /** - * Creates a temporary database and switches current database to it before executing `f`. This - * database is dropped after `f` returns. - * - * Note that this method doesn't switch current database before executing `f`. - */ - protected def withTempDatabase(f: String => Unit): Unit = { - val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" - - try { - spark.sql(s"CREATE DATABASE $dbName") - } catch { case cause: Throwable => - fail("Failed to create temporary database", cause) - } - - try f(dbName) finally { - if (spark.catalog.currentDatabase == dbName) { - spark.sql(s"USE $DEFAULT_DATABASE") - } - spark.sql(s"DROP DATABASE $dbName CASCADE") - } - } - - /** - * Drops database `dbName` after calling `f`. - */ - protected def withDatabase(dbNames: String*)(f: => Unit): Unit = { - try f finally { - dbNames.foreach { name => - spark.sql(s"DROP DATABASE IF EXISTS $name CASCADE") - } - spark.sql(s"USE $DEFAULT_DATABASE") - } - } - - /** - * Enables Locale `language` before executing `f`, then switches back to the default locale of JVM - * after `f` returns. - */ - protected def withLocale(language: String)(f: => Unit): Unit = { - val originalLocale = Locale.getDefault - try { - // Add Locale setting - Locale.setDefault(new Locale(language)) - f - } finally { - Locale.setDefault(originalLocale) - } - } - - /** - * Activates database `db` before executing `f`, then switches back to `default` database after - * `f` returns. - */ - protected def activateDatabase(db: String)(f: => Unit): Unit = { - spark.sessionState.catalog.setCurrentDatabase(db) - try f finally spark.sessionState.catalog.setCurrentDatabase("default") - } - - /** - * Strip Spark-side filtering in order to check if a datasource filters rows correctly. - */ - protected def stripSparkFilter(df: DataFrame): DataFrame = { - val schema = df.schema - val withoutFilters = df.queryExecution.sparkPlan.transform { - case FilterExec(_, child) => child - } - - spark.internalCreateDataFrame(withoutFilters.execute(), schema) - } - - /** - * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier - * way to construct `DataFrame` directly out of local data without relying on implicits. + * Materialize the test data immediately after the `SQLContext` is set up. + * This is necessary if the data is accessed by name but not through direct reference. */ - protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - Dataset.ofRows(spark, plan) + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true } /** @@ -352,17 +110,6 @@ private[sql] trait SQLTestUtils test(name) { runOnThread() } } } - - /** - * This method is used to make the given path qualified, when a path - * does not contain a scheme, this path will not be changed after the default - * FileSystem is changed. - */ - def makeQualifiedPath(path: String): URI = { - val hadoopPath = new Path(path) - val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) - fs.makeQualified(hadoopPath).toUri - } } private[sql] object SQLTestUtils { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtilsBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtilsBase.scala new file mode 100644 index 0000000000000..63373283adf62 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtilsBase.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.io.File +import java.net.URI +import java.nio.file.Files +import java.util.{Locale, UUID} + +import scala.concurrent.duration._ +import scala.language.implicitConversions + +import org.apache.hadoop.fs.Path +import org.scalatest.{BeforeAndAfterAll, Suite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE +import org.apache.spark.sql.catalyst.plans.PlanTestBase +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.FilterExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * Helper trait that can be extended by all external SQL test suites. + * + * This allows subclasses to plugin a custom `SQLContext`. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. + * + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtilsBase + extends Eventually + with BeforeAndAfterAll + with SQLTestData + with PlanTestBase { self: Suite => + + protected def sparkContext = spark.sparkContext + + // Shorthand for running a query using our SQLContext + protected lazy val sql = spark.sql _ + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `spark.implicits._` is not possible here. + * This is because we create the `SQLContext` immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.spark.sqlContext + } + + protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + SparkSession.setActiveSession(spark) + super.withSQLConf(pairs: _*)(f) + } + + /** + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + /** + * Copy file in jar's resource to a temp file, then pass it to `f`. + * This function is used to make `f` can use the path of temp file(e.g. file:/), instead of + * path of jar's resource which starts with 'jar:file:/' + */ + protected def withResourceTempPath(resourcePath: String)(f: File => Unit): Unit = { + val inputStream = + Thread.currentThread().getContextClassLoader.getResourceAsStream(resourcePath) + withTempDir { dir => + val tmpFile = new File(dir, "tmp") + Files.copy(inputStream, tmpFile.toPath) + f(tmpFile) + } + } + + /** + * Waits for all tasks on all executors to be finished. + */ + protected def waitForTasksToFinish(): Unit = { + eventually(timeout(10.seconds)) { + assert(spark.sparkContext.statusTracker + .getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } + + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + Utils.deleteRecursively(dir) + } + } + + /** + * Creates the specified number of temporary directories, which is then passed to `f` and will be + * deleted after `f` returns. + */ + protected def withTempPaths(numPaths: Int)(f: Seq[File] => Unit): Unit = { + val files = Array.fill[File](numPaths)(Utils.createTempDir().getCanonicalFile) + try f(files) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + files.foreach(Utils.deleteRecursively) + } + } + + /** + * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). + */ + protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { + try { + f + } catch { + case cause: Throwable => throw cause + } finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + functions.foreach { case (functionName, isTemporary) => + val withTemporary = if (isTemporary) "TEMPORARY" else "" + spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + assert( + !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + s"Function $functionName should have been dropped. But, it still exists.") + } + } + } + + /** + * Drops temporary table `tableName` after calling `f`. + */ + protected def withTempView(tableNames: String*)(f: => Unit): Unit = { + try f finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + try tableNames.foreach(spark.catalog.dropTempView) catch { + case _: NoSuchTableException => + } + } + } + + /** + * Drops table `tableName` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + /** + * Drops view `viewName` after calling `f`. + */ + protected def withView(viewNames: String*)(f: => Unit): Unit = { + try f finally { + viewNames.foreach { name => + spark.sql(s"DROP VIEW IF EXISTS $name") + } + } + } + + /** + * Creates a temporary database and switches current database to it before executing `f`. This + * database is dropped after `f` returns. + * + * Note that this method doesn't switch current database before executing `f`. + */ + protected def withTempDatabase(f: String => Unit): Unit = { + val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" + + try { + spark.sql(s"CREATE DATABASE $dbName") + } catch { case cause: Throwable => + fail("Failed to create temporary database", cause) + } + + try f(dbName) finally { + if (spark.catalog.currentDatabase == dbName) { + spark.sql(s"USE $DEFAULT_DATABASE") + } + spark.sql(s"DROP DATABASE $dbName CASCADE") + } + } + + /** + * Drops database `dbName` after calling `f`. + */ + protected def withDatabase(dbNames: String*)(f: => Unit): Unit = { + try f finally { + dbNames.foreach { name => + spark.sql(s"DROP DATABASE IF EXISTS $name CASCADE") + } + spark.sql(s"USE $DEFAULT_DATABASE") + } + } + + /** + * Enables Locale `language` before executing `f`, then switches back to the default locale of JVM + * after `f` returns. + */ + protected def withLocale(language: String)(f: => Unit): Unit = { + val originalLocale = Locale.getDefault + try { + // Add Locale setting + Locale.setDefault(new Locale(language)) + f + } finally { + Locale.setDefault(originalLocale) + } + } + + /** + * Activates database `db` before executing `f`, then switches back to `default` database after + * `f` returns. + */ + protected def activateDatabase(db: String)(f: => Unit): Unit = { + spark.sessionState.catalog.setCurrentDatabase(db) + try f finally spark.sessionState.catalog.setCurrentDatabase("default") + } + + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val withoutFilters = df.queryExecution.sparkPlan.transform { + case FilterExec(_, child) => child + } + + spark.internalCreateDataFrame(withoutFilters.execute(), schema) + } + + /** + * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier + * way to construct `DataFrame` directly out of local data without relying on implicits. + */ + protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + Dataset.ofRows(spark, plan) + } + + + /** + * This method is used to make the given path qualified, when a path + * does not contain a scheme, this path will not be changed after the default + * FileSystem is changed. + */ + def makeQualifiedPath(path: String): URI = { + val hadoopPath = new Path(path) + val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) + fs.makeQualified(hadoopPath).toUri + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/ShareSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/ShareSparkSession.scala new file mode 100644 index 0000000000000..e0568a3c5c99f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/ShareSparkSession.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import scala.concurrent.duration._ + +import org.scalatest.{BeforeAndAfterEach, Suite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf} +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.internal.SQLConf + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. + */ +trait SharedSparkSession + extends SQLTestUtilsBase + with BeforeAndAfterEach + with Eventually { self: Suite => + + protected def sparkConf = { + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: TestSparkSession = { + new TestSparkSession(sparkConf) + } + + /** + * Initialize the [[TestSparkSession]]. Generally, this is just called from + * beforeAll; however, in test using styles other than FunSuite, there is + * often code that relies on the session between test group constructs and + * the actual tests, which may need this session. It is purely a semantic + * difference, but semantically, it makes more sense to call + * 'initializeSession' between a 'describe' and an 'it' call than it does to + * call 'beforeAll'. + */ + protected def initializeSession(): Unit = { + SparkSession.sqlListener.set(null) + if (_spark == null) { + _spark = createSparkSession + } + } + + /** + * Make sure the [[TestSparkSession]] is initialized before any tests are run. + */ + protected override def beforeAll(): Unit = { + initializeSession() + + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + super.afterAll() + if (_spark != null) { + _spark.sessionState.catalog.reset() + _spark.stop() + _spark = null + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index cd8d0708d8a32..4eab6b9c43db6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,86 +17,8 @@ package org.apache.spark.sql.test -import scala.concurrent.duration._ - -import org.scalatest.BeforeAndAfterEach -import org.scalatest.concurrent.Eventually - -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - -/** - * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. - */ -trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { - - protected def sparkConf = { - new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set("spark.unsafe.exceptionOnMemoryLeak", "true") - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected implicit def spark: SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - new TestSparkSession(sparkConf) - } - - /** - * Initialize the [[TestSparkSession]]. - */ +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession { protected override def beforeAll(): Unit = { - SparkSession.sqlListener.set(null) - if (_spark == null) { - _spark = createSparkSession - } - // Ensure we have initialized the context before calling parent code super.beforeAll() } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } - - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } } From 0d4bd97247a2d083c7de55663703b38a34298c9c Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Sun, 15 Oct 2017 11:57:09 -0400 Subject: [PATCH 02/15] Fix typo in trait name --- .../test/{ShareSparkSession.scala => SharedSparkSession.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/test/{ShareSparkSession.scala => SharedSparkSession.scala} (100%) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/ShareSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala similarity index 100% rename from sql/core/src/test/scala/org/apache/spark/sql/test/ShareSparkSession.scala rename to sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala From 83c44f1c24619e906af48180d0aace38587aa88d Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Sun, 15 Oct 2017 11:57:42 -0400 Subject: [PATCH 03/15] Add simple tests for each non-FunSuite test style --- .../spark/sql/test/GenericFlatSpecSuite.scala | 45 ++++++++++++++++ .../spark/sql/test/GenericFunSpecSuite.scala | 47 +++++++++++++++++ .../spark/sql/test/GenericWordSpecSuite.scala | 51 +++++++++++++++++++ 3 files changed, 143 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala new file mode 100644 index 0000000000000..6179585a0d39a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FlatSpec + +/** + * The purpose of this suite is to make sure that generic FlatSpec-based scala + * tests work with a shared spark session + */ +class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" should "have the specified number of elements" in { + assert(8 === ds.count) + } + it should "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + it should "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it should "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala new file mode 100644 index 0000000000000..15139ee8b3047 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FunSpec + +/** + * The purpose of this suite is to make sure that generic FunSpec-based scala + * tests work with a shared spark session + */ +class GenericFunSpecSuite extends FunSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + describe("Simple Dataset") { + it("should have the specified number of elements") { + assert(8 === ds.count) + } + it("should have the specified number of unique elements") { + assert(8 === ds.distinct.count) + } + it("should have the specified number of elements in each column") { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it("should have the correct number of distinct elements in each column") { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala new file mode 100644 index 0000000000000..b6548bf95fec8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.WordSpec + +/** + * The purpose of this suite is to make sure that generic WordSpec-based scala + * tests work with a shared spark session + */ +class GenericWordSpecSuite extends WordSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" when { + "looked at as complete rows" should { + "have the specified number of elements" in { + assert(8 === ds.count) + } + "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + } + "refined to specific columns" should { + "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } + } +} From e460612ec6f36e62d8d21d88c2344378ecba581a Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Sun, 15 Oct 2017 12:20:44 -0400 Subject: [PATCH 04/15] Document testing possibilities --- docs/testing.md | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 docs/testing.md diff --git a/docs/testing.md b/docs/testing.md new file mode 100644 index 0000000000000..e3c9ba1f68cac --- /dev/null +++ b/docs/testing.md @@ -0,0 +1,42 @@ +--- +layout: global +title: Testing with Spark +description: How to write unit tests for Spark and for Spark-based applications +--- + +* This will become a table of contents (this text will be scraped). +{:toc} + +# Overview + +*************************************************************************************************** + +# Primary testing classes + +1. org.apache.spark.SharedSparkContext + Includes a SparkContext for use by the test suite +1. org.apache.spark.SparkFunSuite + Standardizes reporting of test and test suite names +1. org.apache.spark.sql.test.SharedSQLContext + Includes a SQLContext for use by the test suite + Incorporates SparkFunSuite +1. org.apache.spark.sql.test.SharedSparkSession + Includes a SparkSession for use by the test suite + +*************************************************************************************************** + +# Testing Spark + +All internal tests of Spark code should derive, directly or indirectly, from SparkFunSuite, so as to standardize the reporting of suite and test name logging. + +Tests that require a SparkContext should derive from SharedSparkContext also. + +SharedSQLContext already derives from SparkFunSuite, so may be extended directly by tests requiring a SQLContext. + +SharedSparkSession does not derive from SparkFunSuite, so should not be extended directly for any internal spark tests. + +*************************************************************************************************** + +# Testing code that uses Spark + +External applications that use Spark may extend SharedSparkContext or SharedSparkSession. It should be noted that, in SharedSparkSession, the SparkSession isn't initialized by default until beforeAll - which has not been called before the code that is inside outer test grouping blocks (like 'describe'), but outside actual test cases. To use a SparkSession in these areas of code, one must call initializeSession first. \ No newline at end of file From 0ee2aadf29b681b23bed356b14038525574204a5 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Wed, 18 Oct 2017 19:46:44 -0400 Subject: [PATCH 05/15] Better documentation of testing procedures --- docs/index.md | 3 +- docs/testing.md | 104 +++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 91 insertions(+), 16 deletions(-) diff --git a/docs/index.md b/docs/index.md index b867c972b4b48..9360676619c86 100644 --- a/docs/index.md +++ b/docs/index.md @@ -92,7 +92,8 @@ options for deployment: * [Structured Streaming](structured-streaming-programming-guide.html): processing structured data streams with relation queries (using Datasets and DataFrames, newer API than DStreams) * [Spark Streaming](streaming-programming-guide.html): processing data streams using DStreams (old API) * [MLlib](ml-guide.html): applying machine learning algorithms -* [GraphX](graphx-programming-guide.html): processing graphs +* [GraphX](graphx-programming-guide.html): processing graphs +* [Unit Testing](testing.html): unit testing spark **API Docs:** diff --git a/docs/testing.md b/docs/testing.md index e3c9ba1f68cac..3ce34be40b0da 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -13,30 +13,104 @@ description: How to write unit tests for Spark and for Spark-based applications # Primary testing classes -1. org.apache.spark.SharedSparkContext - Includes a SparkContext for use by the test suite -1. org.apache.spark.SparkFunSuite +1. `org.apache.spark.SharedSparkContext` + Includes a `SparkContext` for use by the test suite +1. `org.apache.spark.SparkFunSuite` Standardizes reporting of test and test suite names -1. org.apache.spark.sql.test.SharedSQLContext - Includes a SQLContext for use by the test suite - Incorporates SparkFunSuite -1. org.apache.spark.sql.test.SharedSparkSession - Includes a SparkSession for use by the test suite +1. `org.apache.spark.sql.test.SharedSQLContext` + Includes a `SQLContext` for use by the test suite + Incorporates `SparkFunSuite` +1. `org.apache.spark.sql.test.SharedSparkSession` + Includes a `SparkSession` for use by the test suite *************************************************************************************************** -# Testing Spark +# Unit testing Spark -All internal tests of Spark code should derive, directly or indirectly, from SparkFunSuite, so as to standardize the reporting of suite and test name logging. +All internal tests of Spark code should derive, directly or indirectly, from `SparkFunSuite`, so as to standardize the reporting of suite and test name logging. -Tests that require a SparkContext should derive from SharedSparkContext also. +Tests that require a `SparkContext` should derive from `SharedSparkContext` also. -SharedSQLContext already derives from SparkFunSuite, so may be extended directly by tests requiring a SQLContext. +`SharedSQLContext` already derives from `SparkFunSuite`, so may be extended directly by tests requiring a `SQLContext`. -SharedSparkSession does not derive from SparkFunSuite, so should not be extended directly for any internal spark tests. +`SharedSparkSession` does not derive from `SparkFunSuite`, so should not be extended directly for any internal Spark tests. *************************************************************************************************** -# Testing code that uses Spark +# Unit testing code that uses Spark using ScalaTest -External applications that use Spark may extend SharedSparkContext or SharedSparkSession. It should be noted that, in SharedSparkSession, the SparkSession isn't initialized by default until beforeAll - which has not been called before the code that is inside outer test grouping blocks (like 'describe'), but outside actual test cases. To use a SparkSession in these areas of code, one must call initializeSession first. \ No newline at end of file +External applications that use Spark may extend `SharedSparkContext` or `SharedSparkSession`. These classes support various testing styles: + +## FunSuite style +(% highlight scala %} +class MySparkTest extends FunSuite with SharedSparkContext { + test("A parallelized RDD should be able to count its elements") { + assert(4 === sc.parallelize(Seq(1, 2, 3, 4)).count) + } +} +{% endhighlight %} + +## FunSpec style +(% highlight scala %} +class MySparkTest extends FunSpec with SharedSparkContext { + describe("A parallelized RDD") { + it("should be able to count its elements") { + assert(4 === sc.parallelize(Seq(1, 2, 3, 4)).count) + } + } +} +{% endhighlight %} + +## FlatSpec style +(% highlight scala %} +class MySparkTest extends FlatSpec with SharedSparkContext { + "A parallelized RDD" should "be able to count its elements" in { + assert(4 === sc.parallelize(Seq(1, 2, 3, 4)).count) + } +} +{% endhighlight %} + +## WordSpec style +(% highlight scala %} +class MySparkTest extends WordSpec with SharedSparkContext { + "A parallelized RDD" when { + "created" should { + "be able to count its elements" in { + assert(4 === sc.parallelize(Seq(1, 2, 3, 4)).count) + } + } + } +} +{% endhighlight %} + +# Context and Session initialization + +It should be noted that, in `SharedSparkContext`, the `SparkContext` (`sc`) isn't initialized until `beforeAll` is called. When using several testing styles, such as `FunSpec`, it is not uncommon to initialize shared resources inside a describe block (or equivalent), but outside an `it` block - i.e., in the registration phase. `beforeAll`, however, isn't called until after the registration phase, and just before the test phase. Therefore, an `initializeContext` call is exposed so that users can make sure the context is initialized in these blocks. For example: + +{% highlight scala %} +class MySparkTest extend FunSpec with SharedSparkContext { + describe("A parallelized RDD") { + initializeContext() + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + it("should be able to count its elements") { + assert(4 === rdd.count) + } + } +} +{% endhighlight scala %} + +Similarly, in `SharedSparkSession`, there is an `initializeSession` call for the same purpose: + +{% highlight scala %} +class MySparkTest extend FunSpec with SharedSparkSession { + describe("A simple Dataset") { + initializeContext() + import testImplicits._ + + val dataset = Seq(1, 2, 3, 4).toDS + it("should be able to count its elements") { + assert(4 === dataset.count) + } + } +} +{% endhighlight scala %} From 802a958b640067b99fda0b2c8587dea5b8000495 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Wed, 18 Oct 2017 19:46:58 -0400 Subject: [PATCH 06/15] Same initialization issue in SharedSparkContext as is in SharedSparkSession --- .../org/apache/spark/SharedSparkContext.scala | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 6aedcb1271ff6..8f20427cf9d67 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -29,10 +29,25 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel var conf = new SparkConf(false) + /** + * Initialize the [[SparkContext]]. Generally, this is just called from + * beforeAll; however, in test using styles other than FunSuite, there is + * often code that relies on the session between test group constructs and + * the actual tests, which may need this session. It is purely a semantic + * difference, but semantically, it makes more sense to call + * 'initializeContext' between a 'describe' and an 'it' call than it does to + * call 'beforeAll'. + */ + protected def initializeContext(): Unit = { + if (null == _sc) { + _sc = new SparkContext( + "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + } + } + override def beforeAll() { super.beforeAll() - _sc = new SparkContext( - "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + initializeContext() } override def afterAll() { From 4218b86d5a8ff2321232ff38ed3e1b217ff7db2a Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Sun, 22 Oct 2017 23:49:39 -0400 Subject: [PATCH 07/15] Remove documentation of testing --- docs/index.md | 3 +- docs/testing.md | 116 ------------------------------------------------ 2 files changed, 1 insertion(+), 118 deletions(-) delete mode 100644 docs/testing.md diff --git a/docs/index.md b/docs/index.md index 9360676619c86..b867c972b4b48 100644 --- a/docs/index.md +++ b/docs/index.md @@ -92,8 +92,7 @@ options for deployment: * [Structured Streaming](structured-streaming-programming-guide.html): processing structured data streams with relation queries (using Datasets and DataFrames, newer API than DStreams) * [Spark Streaming](streaming-programming-guide.html): processing data streams using DStreams (old API) * [MLlib](ml-guide.html): applying machine learning algorithms -* [GraphX](graphx-programming-guide.html): processing graphs -* [Unit Testing](testing.html): unit testing spark +* [GraphX](graphx-programming-guide.html): processing graphs **API Docs:** diff --git a/docs/testing.md b/docs/testing.md deleted file mode 100644 index 3ce34be40b0da..0000000000000 --- a/docs/testing.md +++ /dev/null @@ -1,116 +0,0 @@ ---- -layout: global -title: Testing with Spark -description: How to write unit tests for Spark and for Spark-based applications ---- - -* This will become a table of contents (this text will be scraped). -{:toc} - -# Overview - -*************************************************************************************************** - -# Primary testing classes - -1. `org.apache.spark.SharedSparkContext` - Includes a `SparkContext` for use by the test suite -1. `org.apache.spark.SparkFunSuite` - Standardizes reporting of test and test suite names -1. `org.apache.spark.sql.test.SharedSQLContext` - Includes a `SQLContext` for use by the test suite - Incorporates `SparkFunSuite` -1. `org.apache.spark.sql.test.SharedSparkSession` - Includes a `SparkSession` for use by the test suite - -*************************************************************************************************** - -# Unit testing Spark - -All internal tests of Spark code should derive, directly or indirectly, from `SparkFunSuite`, so as to standardize the reporting of suite and test name logging. - -Tests that require a `SparkContext` should derive from `SharedSparkContext` also. - -`SharedSQLContext` already derives from `SparkFunSuite`, so may be extended directly by tests requiring a `SQLContext`. - -`SharedSparkSession` does not derive from `SparkFunSuite`, so should not be extended directly for any internal Spark tests. - -*************************************************************************************************** - -# Unit testing code that uses Spark using ScalaTest - -External applications that use Spark may extend `SharedSparkContext` or `SharedSparkSession`. These classes support various testing styles: - -## FunSuite style -(% highlight scala %} -class MySparkTest extends FunSuite with SharedSparkContext { - test("A parallelized RDD should be able to count its elements") { - assert(4 === sc.parallelize(Seq(1, 2, 3, 4)).count) - } -} -{% endhighlight %} - -## FunSpec style -(% highlight scala %} -class MySparkTest extends FunSpec with SharedSparkContext { - describe("A parallelized RDD") { - it("should be able to count its elements") { - assert(4 === sc.parallelize(Seq(1, 2, 3, 4)).count) - } - } -} -{% endhighlight %} - -## FlatSpec style -(% highlight scala %} -class MySparkTest extends FlatSpec with SharedSparkContext { - "A parallelized RDD" should "be able to count its elements" in { - assert(4 === sc.parallelize(Seq(1, 2, 3, 4)).count) - } -} -{% endhighlight %} - -## WordSpec style -(% highlight scala %} -class MySparkTest extends WordSpec with SharedSparkContext { - "A parallelized RDD" when { - "created" should { - "be able to count its elements" in { - assert(4 === sc.parallelize(Seq(1, 2, 3, 4)).count) - } - } - } -} -{% endhighlight %} - -# Context and Session initialization - -It should be noted that, in `SharedSparkContext`, the `SparkContext` (`sc`) isn't initialized until `beforeAll` is called. When using several testing styles, such as `FunSpec`, it is not uncommon to initialize shared resources inside a describe block (or equivalent), but outside an `it` block - i.e., in the registration phase. `beforeAll`, however, isn't called until after the registration phase, and just before the test phase. Therefore, an `initializeContext` call is exposed so that users can make sure the context is initialized in these blocks. For example: - -{% highlight scala %} -class MySparkTest extend FunSpec with SharedSparkContext { - describe("A parallelized RDD") { - initializeContext() - val rdd = sc.parallelize(Seq(1, 2, 3, 4)) - it("should be able to count its elements") { - assert(4 === rdd.count) - } - } -} -{% endhighlight scala %} - -Similarly, in `SharedSparkSession`, there is an `initializeSession` call for the same purpose: - -{% highlight scala %} -class MySparkTest extend FunSpec with SharedSparkSession { - describe("A simple Dataset") { - initializeContext() - import testImplicits._ - - val dataset = Seq(1, 2, 3, 4).toDS - it("should be able to count its elements") { - assert(4 === dataset.count) - } - } -} -{% endhighlight scala %} From 2d927e94f627919ac1546b47072276b23d3e8da2 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Tue, 24 Oct 2017 00:37:48 -0400 Subject: [PATCH 08/15] Move base versions of PlanTest and SQLTestUtils into the same file as where they came from, in an attempt to make diffs simpler --- .../spark/sql/catalyst/plans/PlanTest.scala | 158 ++++++++++ .../sql/catalyst/plans/PlanTestBase.scala | 177 ----------- .../apache/spark/sql/test/SQLTestUtils.scala | 268 ++++++++++++++++ .../spark/sql/test/SQLTestUtilsBase.scala | 289 ------------------ 4 files changed, 426 insertions(+), 466 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTestBase.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtilsBase.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 2cb0671c57bee..f584fdba348d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,10 +17,168 @@ package org.apache.spark.sql.catalyst.plans +import org.scalatest.Suite + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ trait PlanTest extends SparkFunSuite with PlanTestBase { } + +/** + * Provides helper methods for comparing plans, but without the overhead of + * mandating a FunSuite. + */ +trait PlanTestBase extends PredicateHelper { self: Suite => + + // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules + protected def conf = SQLConf.get + + /** + * Since attribute references are given globally unique ids during analysis, + * we must normalize them to check if two different queries are identical. + */ + protected def normalizeExprIds(plan: LogicalPlan) = { + plan transformAllExpressions { + case s: ScalarSubquery => + s.copy(exprId = ExprId(0)) + case e: Exists => + e.copy(exprId = ExprId(0)) + case l: ListQuery => + l.copy(exprId = ExprId(0)) + case a: AttributeReference => + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) + case a: Alias => + Alias(a.child, a.name)(exprId = ExprId(0)) + case ae: AggregateExpression => + ae.copy(resultId = ExprId(0)) + } + } + + /** + * Normalizes plans: + * - Filter the filter conditions that appear in a plan. For instance, + * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) + * etc., will all now be equivalent. + * - Sample the seed will replaced by 0L. + * - Join conditions will be resorted by hashCode. + */ + protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { + plan transform { + case Filter(condition: Expression, child: LogicalPlan) => + Filter(splitConjunctivePredicates(condition).map(rewriteEqual).sortBy(_.hashCode()) + .reduce(And), child) + case sample: Sample => + sample.copy(seed = 0L) + case Join(left, right, joinType, condition) if condition.isDefined => + val newCondition = + splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) + .reduce(And) + Join(left, right, joinType, Some(newCondition)) + } + } + + /** + * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be + * equivalent: + * 1. (a = b), (b = a); + * 2. (a <=> b), (b <=> a). + */ + private def rewriteEqual(condition: Expression): Expression = condition match { + case eq @ EqualTo(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) + case eq @ EqualNullSafe(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) + case _ => condition // Don't reorder. + } + + /** Fails the test if the two plans do not match */ + protected def comparePlans( + plan1: LogicalPlan, + plan2: LogicalPlan, + checkAnalysis: Boolean = true): Unit = { + if (checkAnalysis) { + // Make sure both plan pass checkAnalysis. + SimpleAnalyzer.checkAnalysis(plan1) + SimpleAnalyzer.checkAnalysis(plan2) + } + + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) + if (normalized1 != normalized2) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Fails the test if the two expressions do not match */ + protected def compareExpressions(e1: Expression, e2: Expression): Unit = { + comparePlans(Filter(e1, OneRowRelation()), Filter(e2, OneRowRelation()), checkAnalysis = false) + } + + /** Fails the test if the join order in the two plans do not match */ + protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) { + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) + if (!sameJoinPlan(normalized1, normalized2)) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Consider symmetry for joins when comparing plans. */ + private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + (plan1, plan2) match { + case (j1: Join, j2: Join) => + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + case (p1: Project, p2: Project) => + p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) + case _ => + plan1 == plan2 + } + } + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => + if (SQLConf.staticConfKeys.contains(k)) { + throw new AnalysisException(s"Cannot modify the value of a static config: $k") + } + conf.setConfString(k, v) + } + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTestBase.scala deleted file mode 100644 index 264abda6f690a..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTestBase.scala +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans - -import org.scalatest.Suite - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf - -/** - * Provides helper methods for comparing plans, but without the overhead of - * mandating a FunSuite. - */ -trait PlanTestBase extends PredicateHelper { self: Suite => - - // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules - protected def conf = SQLConf.get - - /** - * Since attribute references are given globally unique ids during analysis, - * we must normalize them to check if two different queries are identical. - */ - protected def normalizeExprIds(plan: LogicalPlan) = { - plan transformAllExpressions { - case s: ScalarSubquery => - s.copy(exprId = ExprId(0)) - case e: Exists => - e.copy(exprId = ExprId(0)) - case l: ListQuery => - l.copy(exprId = ExprId(0)) - case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) - case a: Alias => - Alias(a.child, a.name)(exprId = ExprId(0)) - case ae: AggregateExpression => - ae.copy(resultId = ExprId(0)) - } - } - - /** - * Normalizes plans: - * - Filter the filter conditions that appear in a plan. For instance, - * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) - * etc., will all now be equivalent. - * - Sample the seed will replaced by 0L. - * - Join conditions will be resorted by hashCode. - */ - protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { - plan transform { - case Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).map(rewriteEqual).sortBy(_.hashCode()) - .reduce(And), child) - case sample: Sample => - sample.copy(seed = 0L) - case Join(left, right, joinType, condition) if condition.isDefined => - val newCondition = - splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) - .reduce(And) - Join(left, right, joinType, Some(newCondition)) - } - } - - /** - * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be - * equivalent: - * 1. (a = b), (b = a); - * 2. (a <=> b), (b <=> a). - */ - private def rewriteEqual(condition: Expression): Expression = condition match { - case eq @ EqualTo(l: Expression, r: Expression) => - Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) - case eq @ EqualNullSafe(l: Expression, r: Expression) => - Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) - case _ => condition // Don't reorder. - } - - /** Fails the test if the two plans do not match */ - protected def comparePlans( - plan1: LogicalPlan, - plan2: LogicalPlan, - checkAnalysis: Boolean = true): Unit = { - if (checkAnalysis) { - // Make sure both plan pass checkAnalysis. - SimpleAnalyzer.checkAnalysis(plan1) - SimpleAnalyzer.checkAnalysis(plan2) - } - - val normalized1 = normalizePlan(normalizeExprIds(plan1)) - val normalized2 = normalizePlan(normalizeExprIds(plan2)) - if (normalized1 != normalized2) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } - } - - /** Fails the test if the two expressions do not match */ - protected def compareExpressions(e1: Expression, e2: Expression): Unit = { - comparePlans(Filter(e1, OneRowRelation()), Filter(e2, OneRowRelation()), checkAnalysis = false) - } - - /** Fails the test if the join order in the two plans do not match */ - protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizePlan(normalizeExprIds(plan1)) - val normalized2 = normalizePlan(normalizeExprIds(plan2)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } - } - - /** Consider symmetry for joins when comparing plans. */ - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) - case (p1: Project, p2: Project) => - p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) - case _ => - plan1 == plan2 - } - } - - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL - * configurations. - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val conf = SQLConf.get - val (keys, values) = pairs.unzip - val currentValues = keys.map { key => - if (conf.contains(key)) { - Some(conf.getConfString(key)) - } else { - None - } - } - (keys, values).zipped.foreach { (k, v) => - if (SQLConf.staticConfKeys.contains(k)) { - throw new AnalysisException(s"Cannot modify the value of a static config: $k") - } - conf.setConfString(k, v) - } - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => conf.setConfString(key, value) - case (key, None) => conf.unsetConf(key) - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 7650f5c112d5c..e5695088a3462 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -17,13 +17,32 @@ package org.apache.spark.sql.test +import java.io.File +import java.net.URI +import java.nio.file.Files +import java.util.{Locale, UUID} + +import scala.concurrent.duration._ +import scala.language.implicitConversions import scala.util.control.NonFatal +import org.apache.hadoop.fs.Path +import org.scalatest.{BeforeAndAfterAll, Suite} +import org.scalatest.concurrent.Eventually + import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.PlanTestBase +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.FilterExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.UninterruptibleThread +import org.apache.spark.util.Utils /** * Helper trait that should be extended by all SQL test suites within the Spark @@ -154,3 +173,252 @@ private[sql] object SQLTestUtils { } } } + +/** + * Helper trait that can be extended by all external SQL test suites. + * + * This allows subclasses to plugin a custom `SQLContext`. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. + * + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtilsBase + extends Eventually + with BeforeAndAfterAll + with SQLTestData + with PlanTestBase { self: Suite => + + protected def sparkContext = spark.sparkContext + + // Shorthand for running a query using our SQLContext + protected lazy val sql = spark.sql _ + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `spark.implicits._` is not possible here. + * This is because we create the `SQLContext` immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.spark.sqlContext + } + + protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + SparkSession.setActiveSession(spark) + super.withSQLConf(pairs: _*)(f) + } + + /** + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + /** + * Copy file in jar's resource to a temp file, then pass it to `f`. + * This function is used to make `f` can use the path of temp file(e.g. file:/), instead of + * path of jar's resource which starts with 'jar:file:/' + */ + protected def withResourceTempPath(resourcePath: String)(f: File => Unit): Unit = { + val inputStream = + Thread.currentThread().getContextClassLoader.getResourceAsStream(resourcePath) + withTempDir { dir => + val tmpFile = new File(dir, "tmp") + Files.copy(inputStream, tmpFile.toPath) + f(tmpFile) + } + } + + /** + * Waits for all tasks on all executors to be finished. + */ + protected def waitForTasksToFinish(): Unit = { + eventually(timeout(10.seconds)) { + assert(spark.sparkContext.statusTracker + .getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } + + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + Utils.deleteRecursively(dir) + } + } + + /** + * Creates the specified number of temporary directories, which is then passed to `f` and will be + * deleted after `f` returns. + */ + protected def withTempPaths(numPaths: Int)(f: Seq[File] => Unit): Unit = { + val files = Array.fill[File](numPaths)(Utils.createTempDir().getCanonicalFile) + try f(files) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + files.foreach(Utils.deleteRecursively) + } + } + + /** + * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). + */ + protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { + try { + f + } catch { + case cause: Throwable => throw cause + } finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + functions.foreach { case (functionName, isTemporary) => + val withTemporary = if (isTemporary) "TEMPORARY" else "" + spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + assert( + !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + s"Function $functionName should have been dropped. But, it still exists.") + } + } + } + + /** + * Drops temporary table `tableName` after calling `f`. + */ + protected def withTempView(tableNames: String*)(f: => Unit): Unit = { + try f finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + try tableNames.foreach(spark.catalog.dropTempView) catch { + case _: NoSuchTableException => + } + } + } + + /** + * Drops table `tableName` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + /** + * Drops view `viewName` after calling `f`. + */ + protected def withView(viewNames: String*)(f: => Unit): Unit = { + try f finally { + viewNames.foreach { name => + spark.sql(s"DROP VIEW IF EXISTS $name") + } + } + } + + /** + * Creates a temporary database and switches current database to it before executing `f`. This + * database is dropped after `f` returns. + * + * Note that this method doesn't switch current database before executing `f`. + */ + protected def withTempDatabase(f: String => Unit): Unit = { + val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" + + try { + spark.sql(s"CREATE DATABASE $dbName") + } catch { case cause: Throwable => + fail("Failed to create temporary database", cause) + } + + try f(dbName) finally { + if (spark.catalog.currentDatabase == dbName) { + spark.sql(s"USE $DEFAULT_DATABASE") + } + spark.sql(s"DROP DATABASE $dbName CASCADE") + } + } + + /** + * Drops database `dbName` after calling `f`. + */ + protected def withDatabase(dbNames: String*)(f: => Unit): Unit = { + try f finally { + dbNames.foreach { name => + spark.sql(s"DROP DATABASE IF EXISTS $name CASCADE") + } + spark.sql(s"USE $DEFAULT_DATABASE") + } + } + + /** + * Enables Locale `language` before executing `f`, then switches back to the default locale of JVM + * after `f` returns. + */ + protected def withLocale(language: String)(f: => Unit): Unit = { + val originalLocale = Locale.getDefault + try { + // Add Locale setting + Locale.setDefault(new Locale(language)) + f + } finally { + Locale.setDefault(originalLocale) + } + } + + /** + * Activates database `db` before executing `f`, then switches back to `default` database after + * `f` returns. + */ + protected def activateDatabase(db: String)(f: => Unit): Unit = { + spark.sessionState.catalog.setCurrentDatabase(db) + try f finally spark.sessionState.catalog.setCurrentDatabase("default") + } + + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val withoutFilters = df.queryExecution.sparkPlan.transform { + case FilterExec(_, child) => child + } + + spark.internalCreateDataFrame(withoutFilters.execute(), schema) + } + + /** + * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier + * way to construct `DataFrame` directly out of local data without relying on implicits. + */ + protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + Dataset.ofRows(spark, plan) + } + + + /** + * This method is used to make the given path qualified, when a path + * does not contain a scheme, this path will not be changed after the default + * FileSystem is changed. + */ + def makeQualifiedPath(path: String): URI = { + val hadoopPath = new Path(path) + val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) + fs.makeQualified(hadoopPath).toUri + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtilsBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtilsBase.scala deleted file mode 100644 index 63373283adf62..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtilsBase.scala +++ /dev/null @@ -1,289 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import java.io.File -import java.net.URI -import java.nio.file.Files -import java.util.{Locale, UUID} - -import scala.concurrent.duration._ -import scala.language.implicitConversions - -import org.apache.hadoop.fs.Path -import org.scalatest.{BeforeAndAfterAll, Suite} -import org.scalatest.concurrent.Eventually - -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException -import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE -import org.apache.spark.sql.catalyst.plans.PlanTestBase -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.FilterExec -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.Utils - -/** - * Helper trait that can be extended by all external SQL test suites. - * - * This allows subclasses to plugin a custom `SQLContext`. - * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. - * - * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is - * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. - */ -private[sql] trait SQLTestUtilsBase - extends Eventually - with BeforeAndAfterAll - with SQLTestData - with PlanTestBase { self: Suite => - - protected def sparkContext = spark.sparkContext - - // Shorthand for running a query using our SQLContext - protected lazy val sql = spark.sql _ - - /** - * A helper object for importing SQL implicits. - * - * Note that the alternative of importing `spark.implicits._` is not possible here. - * This is because we create the `SQLContext` immediately before the first test is run, - * but the implicits import is needed in the constructor. - */ - protected object testImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.spark.sqlContext - } - - protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - SparkSession.setActiveSession(spark) - super.withSQLConf(pairs: _*)(f) - } - - /** - * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If - * a file/directory is created there by `f`, it will be delete after `f` returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - /** - * Copy file in jar's resource to a temp file, then pass it to `f`. - * This function is used to make `f` can use the path of temp file(e.g. file:/), instead of - * path of jar's resource which starts with 'jar:file:/' - */ - protected def withResourceTempPath(resourcePath: String)(f: File => Unit): Unit = { - val inputStream = - Thread.currentThread().getContextClassLoader.getResourceAsStream(resourcePath) - withTempDir { dir => - val tmpFile = new File(dir, "tmp") - Files.copy(inputStream, tmpFile.toPath) - f(tmpFile) - } - } - - /** - * Waits for all tasks on all executors to be finished. - */ - protected def waitForTasksToFinish(): Unit = { - eventually(timeout(10.seconds)) { - assert(spark.sparkContext.statusTracker - .getExecutorInfos.map(_.numRunningTasks()).sum == 0) - } - } - - /** - * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` - * returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally { - // wait for all tasks to finish before deleting files - waitForTasksToFinish() - Utils.deleteRecursively(dir) - } - } - - /** - * Creates the specified number of temporary directories, which is then passed to `f` and will be - * deleted after `f` returns. - */ - protected def withTempPaths(numPaths: Int)(f: Seq[File] => Unit): Unit = { - val files = Array.fill[File](numPaths)(Utils.createTempDir().getCanonicalFile) - try f(files) finally { - // wait for all tasks to finish before deleting files - waitForTasksToFinish() - files.foreach(Utils.deleteRecursively) - } - } - - /** - * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). - */ - protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { - try { - f - } catch { - case cause: Throwable => throw cause - } finally { - // If the test failed part way, we don't want to mask the failure by failing to remove - // temp tables that never got created. - functions.foreach { case (functionName, isTemporary) => - val withTemporary = if (isTemporary) "TEMPORARY" else "" - spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") - assert( - !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), - s"Function $functionName should have been dropped. But, it still exists.") - } - } - } - - /** - * Drops temporary table `tableName` after calling `f`. - */ - protected def withTempView(tableNames: String*)(f: => Unit): Unit = { - try f finally { - // If the test failed part way, we don't want to mask the failure by failing to remove - // temp tables that never got created. - try tableNames.foreach(spark.catalog.dropTempView) catch { - case _: NoSuchTableException => - } - } - } - - /** - * Drops table `tableName` after calling `f`. - */ - protected def withTable(tableNames: String*)(f: => Unit): Unit = { - try f finally { - tableNames.foreach { name => - spark.sql(s"DROP TABLE IF EXISTS $name") - } - } - } - - /** - * Drops view `viewName` after calling `f`. - */ - protected def withView(viewNames: String*)(f: => Unit): Unit = { - try f finally { - viewNames.foreach { name => - spark.sql(s"DROP VIEW IF EXISTS $name") - } - } - } - - /** - * Creates a temporary database and switches current database to it before executing `f`. This - * database is dropped after `f` returns. - * - * Note that this method doesn't switch current database before executing `f`. - */ - protected def withTempDatabase(f: String => Unit): Unit = { - val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" - - try { - spark.sql(s"CREATE DATABASE $dbName") - } catch { case cause: Throwable => - fail("Failed to create temporary database", cause) - } - - try f(dbName) finally { - if (spark.catalog.currentDatabase == dbName) { - spark.sql(s"USE $DEFAULT_DATABASE") - } - spark.sql(s"DROP DATABASE $dbName CASCADE") - } - } - - /** - * Drops database `dbName` after calling `f`. - */ - protected def withDatabase(dbNames: String*)(f: => Unit): Unit = { - try f finally { - dbNames.foreach { name => - spark.sql(s"DROP DATABASE IF EXISTS $name CASCADE") - } - spark.sql(s"USE $DEFAULT_DATABASE") - } - } - - /** - * Enables Locale `language` before executing `f`, then switches back to the default locale of JVM - * after `f` returns. - */ - protected def withLocale(language: String)(f: => Unit): Unit = { - val originalLocale = Locale.getDefault - try { - // Add Locale setting - Locale.setDefault(new Locale(language)) - f - } finally { - Locale.setDefault(originalLocale) - } - } - - /** - * Activates database `db` before executing `f`, then switches back to `default` database after - * `f` returns. - */ - protected def activateDatabase(db: String)(f: => Unit): Unit = { - spark.sessionState.catalog.setCurrentDatabase(db) - try f finally spark.sessionState.catalog.setCurrentDatabase("default") - } - - /** - * Strip Spark-side filtering in order to check if a datasource filters rows correctly. - */ - protected def stripSparkFilter(df: DataFrame): DataFrame = { - val schema = df.schema - val withoutFilters = df.queryExecution.sparkPlan.transform { - case FilterExec(_, child) => child - } - - spark.internalCreateDataFrame(withoutFilters.execute(), schema) - } - - /** - * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier - * way to construct `DataFrame` directly out of local data without relying on implicits. - */ - protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - Dataset.ofRows(spark, plan) - } - - - /** - * This method is used to make the given path qualified, when a path - * does not contain a scheme, this path will not be changed after the default - * FileSystem is changed. - */ - def makeQualifiedPath(path: String): URI = { - val hadoopPath = new Path(path) - val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) - fs.makeQualified(hadoopPath).toUri - } -} From 38a83c081b2f9e28bea6321994fc1a0a0c43f252 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Wed, 25 Oct 2017 10:42:15 -0400 Subject: [PATCH 09/15] Comment line length should be 100 --- .../scala/org/apache/spark/SharedSparkContext.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 8f20427cf9d67..214b1abe1a408 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -30,13 +30,11 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel var conf = new SparkConf(false) /** - * Initialize the [[SparkContext]]. Generally, this is just called from - * beforeAll; however, in test using styles other than FunSuite, there is - * often code that relies on the session between test group constructs and - * the actual tests, which may need this session. It is purely a semantic - * difference, but semantically, it makes more sense to call - * 'initializeContext' between a 'describe' and an 'it' call than it does to - * call 'beforeAll'. + * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in + * test using styles other than FunSuite, there is often code that relies on the session between + * test group constructs and the actual tests, which may need this session. It is purely a + * semantic difference, but semantically, it makes more sense to call 'initializeContext' between + * a 'describe' and an 'it' call than it does to call 'beforeAll'. */ protected def initializeContext(): Unit = { if (null == _sc) { From 241459a8a4c554877e381fe8306d086ab5b1b152 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Wed, 25 Oct 2017 10:43:51 -0400 Subject: [PATCH 10/15] Move SQLTestUtils object to the end of the file --- .../apache/spark/sql/test/SQLTestUtils.scala | 86 +++++++++---------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index e5695088a3462..b4248b74f50ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -131,49 +131,6 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with } } -private[sql] object SQLTestUtils { - - def compareAnswers( - sparkAnswer: Seq[Row], - expectedAnswer: Seq[Row], - sort: Boolean): Option[String] = { - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - // This function is copied from Catalyst's QueryTest - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } - if (sort) { - converted.sortBy(_.toString()) - } else { - converted - } - } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = - s""" - | == Results == - | ${sideBySide( - s"== Expected Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString()), - s"== Actual Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} - """.stripMargin - Some(errorMessage) - } else { - None - } - } -} - /** * Helper trait that can be extended by all external SQL test suites. * @@ -422,3 +379,46 @@ private[sql] trait SQLTestUtilsBase fs.makeQualified(hadoopPath).toUri } } + +private[sql] object SQLTestUtils { + + def compareAnswers( + sparkAnswer: Seq[Row], + expectedAnswer: Seq[Row], + sort: Boolean): Option[String] = { + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + if (sort) { + converted.sortBy(_.toString()) + } else { + converted + } + } + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | == Results == + | ${sideBySide( + s"== Expected Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Actual Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + Some(errorMessage) + } else { + None + } + } +} From 24fc4a324008b2acfcf5a2617eb7cc320565e83c Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Wed, 25 Oct 2017 11:00:07 -0400 Subject: [PATCH 11/15] fix scalastyle error (whitespace at end of line) --- core/src/test/scala/org/apache/spark/SharedSparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 214b1abe1a408..1aa1c421d792e 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -31,7 +31,7 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel /** * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in - * test using styles other than FunSuite, there is often code that relies on the session between + * test using styles other than FunSuite, there is often code that relies on the session between * test group constructs and the actual tests, which may need this session. It is purely a * semantic difference, but semantically, it makes more sense to call 'initializeContext' between * a 'describe' and an 'it' call than it does to call 'beforeAll'. From e4763d977cffbe7ef362a859c229b74b3cdf4ef3 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Wed, 25 Oct 2017 22:27:07 -0400 Subject: [PATCH 12/15] Remove extraneous curly brackets around empty PlanTest body --- .../scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index f584fdba348d3..82c5307d54360 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -trait PlanTest extends SparkFunSuite with PlanTestBase { -} +trait PlanTest extends SparkFunSuite with PlanTestBase /** * Provides helper methods for comparing plans, but without the overhead of From 6c0b0d569ae1d779fd9253da0c7e97d12634063c Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Wed, 25 Oct 2017 23:24:31 -0400 Subject: [PATCH 13/15] Remove extraneous beforeAll and brackets from SharedSQLContext --- .../scala/org/apache/spark/sql/test/SharedSQLContext.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 4eab6b9c43db6..4d578e21f5494 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,8 +17,4 @@ package org.apache.spark.sql.test -trait SharedSQLContext extends SQLTestUtils with SharedSparkSession { - protected override def beforeAll(): Unit = { - super.beforeAll() - } -} +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession From 565c598e89299b8c1473d76249ab732abebdb661 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Thu, 9 Nov 2017 01:39:30 -0500 Subject: [PATCH 14/15] Make sure no spark sessions are active outside tests --- .../org/apache/spark/sql/test/GenericFlatSpecSuite.scala | 4 ++-- .../scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala | 4 ++-- .../org/apache/spark/sql/test/GenericWordSpecSuite.scala | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala index 6179585a0d39a..990be789338b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala @@ -25,8 +25,8 @@ import org.scalatest.FlatSpec */ class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession { import testImplicits._ - initializeSession() - val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS "A Simple Dataset" should "have the specified number of elements" in { assert(8 === ds.count) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala index 15139ee8b3047..8656ec7d1b86c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala @@ -25,8 +25,8 @@ import org.scalatest.FunSpec */ class GenericFunSpecSuite extends FunSpec with SharedSparkSession { import testImplicits._ - initializeSession() - val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS describe("Simple Dataset") { it("should have the specified number of elements") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala index b6548bf95fec8..c451358cd7151 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala @@ -25,8 +25,8 @@ import org.scalatest.WordSpec */ class GenericWordSpecSuite extends WordSpec with SharedSparkSession { import testImplicits._ - initializeSession() - val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS "A Simple Dataset" when { "looked at as complete rows" should { From 12a1d37ec721a556592cae3c5aff129b6a0663d0 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Thu, 9 Nov 2017 17:31:34 -0500 Subject: [PATCH 15/15] Fix scalastyle errors --- .../org/apache/spark/sql/test/GenericFlatSpecSuite.scala | 4 +++- .../scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala | 4 +++- .../org/apache/spark/sql/test/GenericWordSpecSuite.scala | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala index 990be789338b5..14ac479e89754 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.test import org.scalatest.FlatSpec +import org.apache.spark.sql.Dataset + /** * The purpose of this suite is to make sure that generic FlatSpec-based scala * tests work with a shared spark session @@ -26,7 +28,7 @@ import org.scalatest.FlatSpec class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession { import testImplicits._ - def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS "A Simple Dataset" should "have the specified number of elements" in { assert(8 === ds.count) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala index 8656ec7d1b86c..e8971e36d112d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.test import org.scalatest.FunSpec +import org.apache.spark.sql.Dataset + /** * The purpose of this suite is to make sure that generic FunSpec-based scala * tests work with a shared spark session @@ -26,7 +28,7 @@ import org.scalatest.FunSpec class GenericFunSpecSuite extends FunSpec with SharedSparkSession { import testImplicits._ - def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS describe("Simple Dataset") { it("should have the specified number of elements") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala index c451358cd7151..44655a5345ca4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.test import org.scalatest.WordSpec +import org.apache.spark.sql.Dataset + /** * The purpose of this suite is to make sure that generic WordSpec-based scala * tests work with a shared spark session @@ -26,7 +28,7 @@ import org.scalatest.WordSpec class GenericWordSpecSuite extends WordSpec with SharedSparkSession { import testImplicits._ - def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS "A Simple Dataset" when { "looked at as complete rows" should {