Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.{BoundReference, Ascending, SortOrder}
import org.apache.spark.sql.catalyst.dsl.expressions._

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class SortSuite extends SparkPlanTest {
import TestSQLContext.implicits.localSeqToDataFrameHolder
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, perhaps SparkPlanTest should have an implicit alias to this method so each subclass does not have to worry about importing it.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, what's the right way to create such an implicit alias? I don't remember the trick offhand.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, does this work:

class SparkPlanTest extends SparkFunSuite {

  /**
   * Creates a DataFrame from a local Seq of Product.
   */
  implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = {
    TestSQLContext.implicits.localSeqToDataFrameHolder(data)
  }

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, that was what I was trying to describe.


test("basic sorting using ExternalSort") {

Expand All @@ -30,16 +34,14 @@ class SortSuite extends SparkPlanTest {
("World", 8)
)

val sortOrder = Seq(
SortOrder(BoundReference(0, StringType, nullable = false), Ascending),
SortOrder(BoundReference(1, IntegerType, nullable = false), Ascending)
)

checkAnswer(
input,
(child: SparkPlan) => new ExternalSort(sortOrder, global = false, child),
input.sorted
)
input.toDF("a", "b"),
ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan),
input.sorted)

checkAnswer(
input.toDF("a", "b"),
ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan),
input.sortBy(t => (t._2, t._1)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ import scala.util.control.NonFatal
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.SparkFunSuite

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util._

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.catalyst.util._

/**
* Base class for writing tests for individual physical operators. For an example of how this
Expand All @@ -48,6 +52,24 @@ class SparkPlanTest extends SparkFunSuite {
}
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
*/
protected def checkAnswer[A <: Product : TypeTag](
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedAnswer: Seq[A]): Unit = {
val expectedRows = expectedAnswer.map(Row.fromTuple)
SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
Expand Down Expand Up @@ -87,6 +109,23 @@ object SparkPlanTest {

val outputPlan = planFunction(input.queryExecution.sparkPlan)

// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
val resolvedPlan = outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).zipWithIndex.map {
case (a, i) =>
(a.name, BoundReference(i, a.dataType, a.nullable))
}.toMap

plan.transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.get(u).getOrElse {
sys.error(s"Invalid Test: Cannot resolve $u given input ${inputMap}")
}
}
}

def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
Expand All @@ -105,7 +144,7 @@ object SparkPlanTest {
}

val sparkAnswer: Seq[Row] = try {
outputPlan.executeCollect().toSeq
resolvedPlan.executeCollect().toSeq
} catch {
case NonFatal(e) =>
val errorMessage =
Expand Down