diff --git a/core/pom.xml b/core/pom.xml index bd739e53411a..ead33cd14d52 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -235,7 +235,7 @@ org.easymock - easymock + easymockclassextension test diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 5a8310090890..dc2db66df60e 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -25,7 +25,7 @@ import scala.language.postfixOps import scala.util.Random import org.scalatest.{BeforeAndAfter, FunSuite} -import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.{PatienceConfiguration, Eventually} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -76,7 +76,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo tester.assertCleanup() // Verify that shuffles can be re-executed after cleaning up - assert(rdd.collect().toList === collected) + assert(rdd.collect().toList.equals(collected)) } test("cleanup broadcast") { @@ -285,7 +285,7 @@ class CleanerTester( sc.cleaner.get.attachListener(cleanerListener) /** Assert that all the stuff has been cleaned up */ - def assertCleanup()(implicit waitTimeout: Eventually.Timeout) { + def assertCleanup()(implicit waitTimeout: PatienceConfiguration.Timeout) { try { eventually(waitTimeout, interval(100 millis)) { assert(isAllCleanedUp) diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala index 29d428aa7dc4..47df00050c1e 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala @@ -23,11 +23,11 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. - override def beforeAll(configMap: Map[String, Any]) { + override def beforeAll() { System.setProperty("spark.shuffle.use.netty", "true") } - override def afterAll(configMap: Map[String, Any]) { + override def afterAll() { System.setProperty("spark.shuffle.use.netty", "false") } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index fdbed45efec7..87bfce3470dd 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -275,8 +275,9 @@ class RDDSuite extends FunSuite with SharedSparkContext { // we can optionally shuffle to keep the upstream parallel val coalesced5 = data.coalesce(1, shuffle = true) - assert(coalesced5.dependencies.head.rdd.dependencies.head.rdd.asInstanceOf[ShuffledRDD[_, _, _]] != - null) + val isEquals = coalesced5.dependencies.head.rdd.dependencies.head.rdd. + asInstanceOf[ShuffledRDD[_, _, _]] != null + assert(isEquals) // when shuffling, we can increase the number of partitions val coalesced6 = data.coalesce(20, shuffle = true) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d172dd1ac8e1..7e901f8e9158 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -23,7 +23,7 @@ import scala.language.reflectiveCalls import akka.actor._ import akka.testkit.{ImplicitSender, TestKit, TestActorRef} -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.{BeforeAndAfter, FunSuiteLike} import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -37,7 +37,7 @@ class BuggyDAGEventProcessActor extends Actor { } } -class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with FunSuite +class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with FunSuiteLike with ImplicitSender with BeforeAndAfter with LocalSparkContext { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala index 6a5653ed2fb5..c1c605cdb487 100644 --- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -105,7 +105,8 @@ class TimeStampedHashMapSuite extends FunSuite { map("k1") = strongRef map("k2") = "v2" map("k3") = "v3" - assert(map("k1") === strongRef) + val isEquals = map("k1") == strongRef + assert(isEquals) // clear strong reference to "k1" strongRef = null diff --git a/pom.xml b/pom.xml index c4a1b5093ec4..1e464a2886e3 100644 --- a/pom.xml +++ b/pom.xml @@ -458,25 +458,31 @@ org.scalatest scalatest_${scala.binary.version} - 1.9.1 + 2.1.5 test org.easymock - easymock + easymockclassextension 3.1 test org.mockito mockito-all - 1.8.5 + 1.9.0 test org.scalacheck scalacheck_${scala.binary.version} - 1.10.0 + 1.11.3 + test + + + junit + junit + 4.10 test @@ -778,6 +784,7 @@ -unchecked -deprecation -feature + -language:postfixOps -Xms1024m diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index deafbc5aad28..c0e3bbaf9053 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -270,16 +270,17 @@ object SparkBuild extends Build { */ libraryDependencies ++= Seq( - "io.netty" % "netty-all" % "4.0.17.Final", - "org.eclipse.jetty" % "jetty-server" % jettyVersion, - "org.eclipse.jetty" % "jetty-util" % jettyVersion, - "org.eclipse.jetty" % "jetty-plus" % jettyVersion, - "org.eclipse.jetty" % "jetty-security" % jettyVersion, - "org.scalatest" %% "scalatest" % "1.9.1" % "test", - "org.scalacheck" %% "scalacheck" % "1.10.0" % "test", - "com.novocode" % "junit-interface" % "0.10" % "test", - "org.easymock" % "easymock" % "3.1" % "test", - "org.mockito" % "mockito-all" % "1.8.5" % "test" + "io.netty" % "netty-all" % "4.0.17.Final", + "org.eclipse.jetty" % "jetty-server" % jettyVersion, + "org.eclipse.jetty" % "jetty-util" % jettyVersion, + "org.eclipse.jetty" % "jetty-plus" % jettyVersion, + "org.eclipse.jetty" % "jetty-security" % jettyVersion, + "org.scalatest" %% "scalatest" % "2.1.5" % "test", + "org.scalacheck" %% "scalacheck" % "1.11.3" % "test", + "com.novocode" % "junit-interface" % "0.10" % "test", + "org.easymock" % "easymockclassextension" % "3.1" % "test", + "org.mockito" % "mockito-all" % "1.9.0" % "test", + "junit" % "junit" % "4.10" % "test" ), testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), @@ -476,7 +477,6 @@ object SparkBuild extends Build { // this non-deterministically. TODO: FIX THIS. parallelExecution in Test := false, libraryDependencies ++= Seq( - "org.scalatest" %% "scalatest" % "1.9.1" % "test", "com.typesafe" %% "scalalogging-slf4j" % "1.0.1" ) ) diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 95460aa20533..95e179383292 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -57,12 +57,14 @@ class ReplSuite extends FunSuite { } def assertContains(message: String, output: String) { - assert(output.contains(message), + val isContain = output.contains(message) + assert(isContain, "Interpreter output did not contain '" + message + "':\n" + output) } def assertDoesNotContain(message: String, output: String) { - assert(!output.contains(message), + val isContain = output.contains(message) + assert(!isContain, "Interpreter output contained '" + message + "':\n" + output) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index d05c9652753e..3299e86b8594 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference} import org.apache.spark.sql.catalyst.types.StringType /** @@ -26,23 +26,25 @@ import org.apache.spark.sql.catalyst.types.StringType */ abstract class Command extends LeafNode { self: Product => - def output: Seq[Attribute] = Seq.empty // TODO: SPARK-2081 should fix this + def output: Seq[Attribute] = Seq.empty } /** * Returned for commands supported by a given parser, but not catalyst. In general these are DDL * commands that are passed directly to another system. */ -case class NativeCommand(cmd: String) extends Command +case class NativeCommand(cmd: String) extends Command { + override def output = + Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)())) +} /** * Commands of the form "SET (key) (= value)". */ case class SetCommand(key: Option[String], value: Option[String]) extends Command { override def output = Seq( - AttributeReference("key", StringType, nullable = false)(), - AttributeReference("value", StringType, nullable = false)() - ) + BoundReference(0, AttributeReference("key", StringType, nullable = false)()), + BoundReference(1, AttributeReference("value", StringType, nullable = false)())) } /** @@ -50,11 +52,11 @@ case class SetCommand(key: Option[String], value: Option[String]) extends Comman * actually performing the execution. */ case class ExplainCommand(plan: LogicalPlan) extends Command { - override def output = Seq(AttributeReference("plan", StringType, nullable = false)()) + override def output = + Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)())) } /** * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command. */ case class CacheCommand(tableName: String, doCache: Boolean) extends Command - diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0cada785b663..1f67c80e5490 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -161,7 +161,7 @@ class FilterPushdownSuite extends OptimizerTest { comparePlans(optimized, correctAnswer) } - + test("joins: push down left outer join #1") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 38fc6b41f06c..378ff5453111 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.{ScalaReflection, dsl} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.plans.logical.{SetCommand, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.columnar.InMemoryRelation @@ -147,14 +147,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def sql(sqlText: String): SchemaRDD = { - val result = new SchemaRDD(this, parseSql(sqlText)) - // We force query optimization to happen right away instead of letting it happen lazily like - // when using the query DSL. This is so DDL commands behave as expected. This is only - // generates the RDD lineage for DML queries, but do not perform any execution. - result.queryExecution.toRdd - result - } + def sql(sqlText: String): SchemaRDD = new SchemaRDD(this, parseSql(sqlText)) /** Returns the specified table as a SchemaRDD */ def table(tableName: String): SchemaRDD = @@ -220,17 +213,21 @@ class SQLContext(@transient val sparkContext: SparkContext) * final desired output requires complex expressions to be evaluated or when columns can be * further eliminated out after filtering has been done. * + * The `prunePushedDownFilters` parameter is used to remove those filters that can be optimized + * away by the filter pushdown optimization. + * * The required attributes for both filtering and expression evaluation are passed to the * provided `scanBuilder` function so that it can avoid unnecessary column materialization. */ def pruneFilterProject( projectList: Seq[NamedExpression], filterPredicates: Seq[Expression], + prunePushedDownFilters: Seq[Expression] => Seq[Expression], scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { val projectSet = projectList.flatMap(_.references).toSet val filterSet = filterPredicates.flatMap(_.references).toSet - val filterCondition = filterPredicates.reduceLeftOption(And) + val filterCondition = prunePushedDownFilters(filterPredicates).reduceLeftOption(And) // Right now we still use a projection even if the only evaluation is applying an alias // to a column. Since this is a no-op, it could be avoided. However, using this @@ -255,8 +252,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] val planner = new SparkPlanner @transient - protected[sql] lazy val emptyResult = - sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) + protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) /** * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and @@ -276,22 +272,6 @@ class SQLContext(@transient val sparkContext: SparkContext) protected abstract class QueryExecution { def logical: LogicalPlan - def eagerlyProcess(plan: LogicalPlan): RDD[Row] = plan match { - case SetCommand(key, value) => - // Only this case needs to be executed eagerly. The other cases will - // be taken care of when the actual results are being extracted. - // In the case of HiveContext, sqlConf is overridden to also pass the - // pair into its HiveConf. - if (key.isDefined && value.isDefined) { - set(key.get, value.get) - } - // It doesn't matter what we return here, since this is only used - // to force the evaluation to happen eagerly. To query the results, - // one must use SchemaRDD operations to extract them. - emptyResult - case _ => executedPlan.execute() - } - lazy val analyzed = analyzer(logical) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... @@ -299,12 +279,7 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[Row] = { - logical match { - case s: SetCommand => eagerlyProcess(s) - case _ => executedPlan.execute() - } - } + lazy val toRdd: RDD[Row] = executedPlan.execute() protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } @@ -326,7 +301,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * TODO: We only support primitive types, add support for nested types. */ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { - val schema = rdd.first.map { case (fieldName, obj) => + val schema = rdd.first().map { case (fieldName, obj) => val dataType = obj.getClass match { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 7ad8edf5a5a6..821ac850ac3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -97,7 +97,7 @@ import java.util.{Map => JMap} @AlphaComponent class SchemaRDD( @transient val sqlContext: SQLContext, - @transient protected[spark] val logicalPlan: LogicalPlan) + @transient val baseLogicalPlan: LogicalPlan) extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike { def baseSchemaRDD = this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 3a895e15a450..656be965a8fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.SparkLogicalPlan /** * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) */ private[sql] trait SchemaRDDLike { @transient val sqlContext: SQLContext - @transient protected[spark] val logicalPlan: LogicalPlan + @transient val baseLogicalPlan: LogicalPlan private[sql] def baseSchemaRDD: SchemaRDD @@ -48,7 +49,17 @@ private[sql] trait SchemaRDDLike { */ @transient @DeveloperApi - lazy val queryExecution = sqlContext.executePlan(logicalPlan) + lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) + + @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { + // For various commands (like DDL) and queries with side effects, we force query optimization to + // happen right away to let these side effects take place eagerly. + case _: Command | _: InsertIntoTable | _: InsertIntoCreatedTable | _: WriteToFile => + queryExecution.toRdd + SparkLogicalPlan(queryExecution.executedPlan) + case _ => + baseLogicalPlan + } override def toString = s"""${super.toString} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 22f57b758dd0..aff6ffe9f347 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -37,7 +37,7 @@ import org.apache.spark.storage.StorageLevel */ class JavaSchemaRDD( @transient val sqlContext: SQLContext, - @transient protected[spark] val logicalPlan: LogicalPlan) + @transient val baseLogicalPlan: LogicalPlan) extends JavaRDDLike[Row, JavaRDD[Row]] with SchemaRDDLike { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1039be531520..2233216a6ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLConf, SQLContext, execution} +import org.apache.spark.sql.{SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -157,12 +157,36 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil - case PhysicalOperation(projectList, filters, relation: ParquetRelation) => - // TODO: Should be pushing down filters as well. + case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => + val prunePushedDownFilters = + if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { + (filters: Seq[Expression]) => { + filters.filter { filter => + // Note: filters cannot be pushed down to Parquet if they contain more complex + // expressions than simple "Attribute cmp Literal" comparisons. Here we remove + // all filters that have been pushed down. Note that a predicate such as + // "(A AND B) OR C" can result in "A OR C" being pushed down. + val recordFilter = ParquetFilters.createFilter(filter) + if (!recordFilter.isDefined) { + // First case: the pushdown did not result in any record filter. + true + } else { + // Second case: a record filter was created; here we are conservative in + // the sense that even if "A" was pushed and we check for "A AND B" we + // still want to keep "A AND B" in the higher-level filter, not just "B". + !ParquetFilters.findExpression(recordFilter.get, filter).isDefined + } + } + } + } else { + identity[Seq[Expression]] _ + } pruneFilterProject( projectList, filters, - ParquetTableScan(_, relation, None)(sparkContext)) :: Nil + prunePushedDownFilters, + ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil + case _ => Nil } } @@ -225,12 +249,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.SetCommand(key, value) => - Seq(execution.SetCommandPhysical(key, value, plan.output)(context)) + Seq(execution.SetCommand(key, value, plan.output)(context)) case logical.ExplainCommand(child) => val executedPlan = context.executePlan(child).executedPlan - Seq(execution.ExplainCommandPhysical(executedPlan, plan.output)(context)) + Seq(execution.ExplainCommand(executedPlan, plan.output)(context)) case logical.CacheCommand(tableName, cache) => - Seq(execution.CacheCommandPhysical(tableName, cache)(context)) + Seq(execution.CacheCommand(tableName, cache)(context)) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index be26d19e6686..0377290af592 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -22,45 +22,69 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute} +trait Command { + /** + * A concrete command should override this lazy field to wrap up any side effects caused by the + * command or any other computation that should be evaluated exactly once. The value of this field + * can be used as the contents of the corresponding RDD generated from the physical plan of this + * command. + * + * The `execute()` method of all the physical command classes should reference `sideEffectResult` + * so that the command can be executed eagerly right after the command query is created. + */ + protected[sql] lazy val sideEffectResult: Seq[Any] = Seq.empty[Any] +} + /** * :: DeveloperApi :: */ @DeveloperApi -case class SetCommandPhysical(key: Option[String], value: Option[String], output: Seq[Attribute]) - (@transient context: SQLContext) extends LeafNode { - def execute(): RDD[Row] = (key, value) match { - // Set value for key k; the action itself would - // have been performed in QueryExecution eagerly. - case (Some(k), Some(v)) => context.emptyResult +case class SetCommand( + key: Option[String], value: Option[String], output: Seq[Attribute])( + @transient context: SQLContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match { + // Set value for key k. + case (Some(k), Some(v)) => + context.set(k, v) + Array(k -> v) + // Query the value bound to key k. - case (Some(k), None) => - val resultString = context.getOption(k) match { - case Some(v) => s"$k=$v" - case None => s"$k is undefined" - } - context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](resultString))), 1) + case (Some(k), _) => + Array(k -> context.getOption(k).getOrElse("")) + // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => - val pairs = context.getAll - val rows = pairs.map { case (k, v) => - new GenericRow(Array[Any](s"$k=$v")) - }.toSeq - // Assume config parameters can fit into one split (machine) ;) - context.sparkContext.parallelize(rows, 1) - // The only other case is invalid semantics and is impossible. - case _ => context.emptyResult + context.getAll + + case _ => + throw new IllegalArgumentException() } + + def execute(): RDD[Row] = { + val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) } + context.sparkContext.parallelize(rows, 1) + } + + override def otherCopyArgs = context :: Nil } /** * :: DeveloperApi :: */ @DeveloperApi -case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute]) - (@transient context: SQLContext) extends UnaryNode { +case class ExplainCommand( + child: SparkPlan, output: Seq[Attribute])( + @transient context: SQLContext) + extends UnaryNode with Command { + + // Actually "EXPLAIN" command doesn't cause any side effect. + override protected[sql] lazy val sideEffectResult: Seq[String] = this.toString.split("\n") + def execute(): RDD[Row] = { - val planString = new GenericRow(Array[Any](child.toString)) - context.sparkContext.parallelize(Seq(planString)) + val explanation = sideEffectResult.mkString("\n") + context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](explanation))), 1) } override def otherCopyArgs = context :: Nil @@ -70,19 +94,20 @@ case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute]) * :: DeveloperApi :: */ @DeveloperApi -case class CacheCommandPhysical(tableName: String, doCache: Boolean)(@transient context: SQLContext) - extends LeafNode { +case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: SQLContext) + extends LeafNode with Command { - lazy val commandSideEffect = { + override protected[sql] lazy val sideEffectResult = { if (doCache) { context.cacheTable(tableName) } else { context.uncacheTable(tableName) } + Seq.empty[Any] } override def execute(): RDD[Row] = { - commandSideEffect + sideEffectResult context.emptyResult } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala new file mode 100644 index 000000000000..052b0a919671 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.conf.Configuration + +import parquet.filter._ +import parquet.filter.ColumnPredicates._ +import parquet.column.ColumnReader + +import com.google.common.io.BaseEncoding + +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkSqlSerializer + +object ParquetFilters { + val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" + // set this to false if pushdown should be disabled + val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.hints.parquetFilterPushdown" + + def createRecordFilter(filterExpressions: Seq[Expression]): UnboundRecordFilter = { + val filters: Seq[CatalystFilter] = filterExpressions.collect { + case (expression: Expression) if createFilter(expression).isDefined => + createFilter(expression).get + } + if (filters.length > 0) filters.reduce(AndRecordFilter.and) else null + } + + def createFilter(expression: Expression): Option[CatalystFilter] = { + def createEqualityFilter( + name: String, + literal: Literal, + predicate: CatalystPredicate) = literal.dataType match { + case BooleanType => + ComparisonFilter.createBooleanFilter(name, literal.value.asInstanceOf[Boolean], predicate) + case IntegerType => + ComparisonFilter.createIntFilter( + name, + (x: Int) => x == literal.value.asInstanceOf[Int], + predicate) + case LongType => + ComparisonFilter.createLongFilter( + name, + (x: Long) => x == literal.value.asInstanceOf[Long], + predicate) + case DoubleType => + ComparisonFilter.createDoubleFilter( + name, + (x: Double) => x == literal.value.asInstanceOf[Double], + predicate) + case FloatType => + ComparisonFilter.createFloatFilter( + name, + (x: Float) => x == literal.value.asInstanceOf[Float], + predicate) + case StringType => + ComparisonFilter.createStringFilter(name, literal.value.asInstanceOf[String], predicate) + } + def createLessThanFilter( + name: String, + literal: Literal, + predicate: CatalystPredicate) = literal.dataType match { + case IntegerType => + ComparisonFilter.createIntFilter( + name, + (x: Int) => x < literal.value.asInstanceOf[Int], + predicate) + case LongType => + ComparisonFilter.createLongFilter( + name, + (x: Long) => x < literal.value.asInstanceOf[Long], + predicate) + case DoubleType => + ComparisonFilter.createDoubleFilter( + name, + (x: Double) => x < literal.value.asInstanceOf[Double], + predicate) + case FloatType => + ComparisonFilter.createFloatFilter( + name, + (x: Float) => x < literal.value.asInstanceOf[Float], + predicate) + } + def createLessThanOrEqualFilter( + name: String, + literal: Literal, + predicate: CatalystPredicate) = literal.dataType match { + case IntegerType => + ComparisonFilter.createIntFilter( + name, + (x: Int) => x <= literal.value.asInstanceOf[Int], + predicate) + case LongType => + ComparisonFilter.createLongFilter( + name, + (x: Long) => x <= literal.value.asInstanceOf[Long], + predicate) + case DoubleType => + ComparisonFilter.createDoubleFilter( + name, + (x: Double) => x <= literal.value.asInstanceOf[Double], + predicate) + case FloatType => + ComparisonFilter.createFloatFilter( + name, + (x: Float) => x <= literal.value.asInstanceOf[Float], + predicate) + } + // TODO: combine these two types somehow? + def createGreaterThanFilter( + name: String, + literal: Literal, + predicate: CatalystPredicate) = literal.dataType match { + case IntegerType => + ComparisonFilter.createIntFilter( + name, + (x: Int) => x > literal.value.asInstanceOf[Int], + predicate) + case LongType => + ComparisonFilter.createLongFilter( + name, + (x: Long) => x > literal.value.asInstanceOf[Long], + predicate) + case DoubleType => + ComparisonFilter.createDoubleFilter( + name, + (x: Double) => x > literal.value.asInstanceOf[Double], + predicate) + case FloatType => + ComparisonFilter.createFloatFilter( + name, + (x: Float) => x > literal.value.asInstanceOf[Float], + predicate) + } + def createGreaterThanOrEqualFilter( + name: String, + literal: Literal, + predicate: CatalystPredicate) = literal.dataType match { + case IntegerType => + ComparisonFilter.createIntFilter( + name, (x: Int) => x >= literal.value.asInstanceOf[Int], + predicate) + case LongType => + ComparisonFilter.createLongFilter( + name, + (x: Long) => x >= literal.value.asInstanceOf[Long], + predicate) + case DoubleType => + ComparisonFilter.createDoubleFilter( + name, + (x: Double) => x >= literal.value.asInstanceOf[Double], + predicate) + case FloatType => + ComparisonFilter.createFloatFilter( + name, + (x: Float) => x >= literal.value.asInstanceOf[Float], + predicate) + } + + /** + * TODO: we currently only filter on non-nullable (Parquet REQUIRED) attributes until + * https://github.com/Parquet/parquet-mr/issues/371 + * has been resolved. + */ + expression match { + case p @ Or(left: Expression, right: Expression) + if createFilter(left).isDefined && createFilter(right).isDefined => { + // If either side of this Or-predicate is empty then this means + // it contains a more complex comparison than between attribute and literal + // (e.g., it contained a CAST). The only safe thing to do is then to disregard + // this disjunction, which could be contained in a conjunction. If it stands + // alone then it is also safe to drop it, since a Null return value of this + // function is interpreted as having no filters at all. + val leftFilter = createFilter(left).get + val rightFilter = createFilter(right).get + Some(new OrFilter(leftFilter, rightFilter)) + } + case p @ And(left: Expression, right: Expression) => { + // This treats nested conjunctions; since either side of the conjunction + // may contain more complex filter expressions we may actually generate + // strictly weaker filter predicates in the process. + val leftFilter = createFilter(left) + val rightFilter = createFilter(right) + (leftFilter, rightFilter) match { + case (None, Some(filter)) => Some(filter) + case (Some(filter), None) => Some(filter) + case (_, _) => + Some(new AndFilter(leftFilter.get, rightFilter.get)) + } + } + case p @ Equals(left: Literal, right: NamedExpression) if !right.nullable => + Some(createEqualityFilter(right.name, left, p)) + case p @ Equals(left: NamedExpression, right: Literal) if !left.nullable => + Some(createEqualityFilter(left.name, right, p)) + case p @ LessThan(left: Literal, right: NamedExpression) if !right.nullable => + Some(createLessThanFilter(right.name, left, p)) + case p @ LessThan(left: NamedExpression, right: Literal) if !left.nullable => + Some(createLessThanFilter(left.name, right, p)) + case p @ LessThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable => + Some(createLessThanOrEqualFilter(right.name, left, p)) + case p @ LessThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable => + Some(createLessThanOrEqualFilter(left.name, right, p)) + case p @ GreaterThan(left: Literal, right: NamedExpression) if !right.nullable => + Some(createGreaterThanFilter(right.name, left, p)) + case p @ GreaterThan(left: NamedExpression, right: Literal) if !left.nullable => + Some(createGreaterThanFilter(left.name, right, p)) + case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable => + Some(createGreaterThanOrEqualFilter(right.name, left, p)) + case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable => + Some(createGreaterThanOrEqualFilter(left.name, right, p)) + case _ => None + } + } + + /** + * Note: Inside the Hadoop API we only have access to `Configuration`, not to + * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey + * the actual filter predicate. + */ + def serializeFilterExpressions(filters: Seq[Expression], conf: Configuration): Unit = { + if (filters.length > 0) { + val serialized: Array[Byte] = SparkSqlSerializer.serialize(filters) + val encoded: String = BaseEncoding.base64().encode(serialized) + conf.set(PARQUET_FILTER_DATA, encoded) + } + } + + /** + * Note: Inside the Hadoop API we only have access to `Configuration`, not to + * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey + * the actual filter predicate. + */ + def deserializeFilterExpressions(conf: Configuration): Seq[Expression] = { + val data = conf.get(PARQUET_FILTER_DATA) + if (data != null) { + val decoded: Array[Byte] = BaseEncoding.base64().decode(data) + SparkSqlSerializer.deserialize(decoded) + } else { + Seq() + } + } + + /** + * Try to find the given expression in the tree of filters in order to + * determine whether it is safe to remove it from the higher level filters. Note + * that strictly speaking we could stop the search whenever an expression is found + * that contains this expression as subexpression (e.g., when searching for "a" + * and "(a or c)" is found) but we don't care about optimizations here since the + * filter tree is assumed to be small. + * + * @param filter The [[org.apache.spark.sql.parquet.CatalystFilter]] to expand + * and search + * @param expression The expression to look for + * @return An optional [[org.apache.spark.sql.parquet.CatalystFilter]] that + * contains the expression. + */ + def findExpression( + filter: CatalystFilter, + expression: Expression): Option[CatalystFilter] = filter match { + case f @ OrFilter(_, leftFilter, rightFilter, _) => + if (f.predicate == expression) { + Some(f) + } else { + val left = findExpression(leftFilter, expression) + if (left.isDefined) left else findExpression(rightFilter, expression) + } + case f @ AndFilter(_, leftFilter, rightFilter, _) => + if (f.predicate == expression) { + Some(f) + } else { + val left = findExpression(leftFilter, expression) + if (left.isDefined) left else findExpression(rightFilter, expression) + } + case f @ ComparisonFilter(_, _, predicate) => + if (predicate == expression) Some(f) else None + case _ => None + } +} + +abstract private[parquet] class CatalystFilter( + @transient val predicate: CatalystPredicate) extends UnboundRecordFilter + +private[parquet] case class ComparisonFilter( + val columnName: String, + private var filter: UnboundRecordFilter, + @transient override val predicate: CatalystPredicate) + extends CatalystFilter(predicate) { + override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { + filter.bind(readers) + } +} + +private[parquet] case class OrFilter( + private var filter: UnboundRecordFilter, + @transient val left: CatalystFilter, + @transient val right: CatalystFilter, + @transient override val predicate: Or) + extends CatalystFilter(predicate) { + def this(l: CatalystFilter, r: CatalystFilter) = + this( + OrRecordFilter.or(l, r), + l, + r, + Or(l.predicate, r.predicate)) + + override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { + filter.bind(readers) + } +} + +private[parquet] case class AndFilter( + private var filter: UnboundRecordFilter, + @transient val left: CatalystFilter, + @transient val right: CatalystFilter, + @transient override val predicate: And) + extends CatalystFilter(predicate) { + def this(l: CatalystFilter, r: CatalystFilter) = + this( + AndRecordFilter.and(l, r), + l, + r, + And(l.predicate, r.predicate)) + + override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { + filter.bind(readers) + } +} + +private[parquet] object ComparisonFilter { + def createBooleanFilter( + columnName: String, + value: Boolean, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + ColumnRecordFilter.column( + columnName, + ColumnPredicates.applyFunctionToBoolean( + new BooleanPredicateFunction { + def functionToApply(input: Boolean): Boolean = input == value + } + )), + predicate) + + def createStringFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + ColumnRecordFilter.column( + columnName, + ColumnPredicates.applyFunctionToString ( + new ColumnPredicates.PredicateFunction[String] { + def functionToApply(input: String): Boolean = input == value + } + )), + predicate) + + def createIntFilter( + columnName: String, + func: Int => Boolean, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + ColumnRecordFilter.column( + columnName, + ColumnPredicates.applyFunctionToInteger( + new IntegerPredicateFunction { + def functionToApply(input: Int) = func(input) + } + )), + predicate) + + def createLongFilter( + columnName: String, + func: Long => Boolean, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + ColumnRecordFilter.column( + columnName, + ColumnPredicates.applyFunctionToLong( + new LongPredicateFunction { + def functionToApply(input: Long) = func(input) + } + )), + predicate) + + def createDoubleFilter( + columnName: String, + func: Double => Boolean, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + ColumnRecordFilter.column( + columnName, + ColumnPredicates.applyFunctionToDouble( + new DoublePredicateFunction { + def functionToApply(input: Double) = func(input) + } + )), + predicate) + + def createFloatFilter( + columnName: String, + func: Float => Boolean, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + ColumnRecordFilter.column( + columnName, + ColumnPredicates.applyFunctionToFloat( + new FloatPredicateFunction { + def functionToApply(input: Float) = func(input) + } + )), + predicate) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index f825ca3c028e..65ba1246fbf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -27,26 +27,27 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat, FileOutputCommitter} -import parquet.hadoop.{ParquetInputFormat, ParquetOutputFormat} +import parquet.hadoop.{ParquetRecordReader, ParquetInputFormat, ParquetOutputFormat} +import parquet.hadoop.api.ReadSupport import parquet.hadoop.util.ContextUtil import parquet.io.InvalidRecordException import parquet.schema.MessageType -import org.apache.spark.{SerializableWritable, SparkContext, TaskContext} +import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} /** * Parquet table scan operator. Imports the file that backs the given - * [[ParquetRelation]] as a RDD[Row]. + * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``. */ case class ParquetTableScan( // note: output cannot be transient, see // https://issues.apache.org/jira/browse/SPARK-1367 output: Seq[Attribute], relation: ParquetRelation, - columnPruningPred: Option[Expression])( + columnPruningPred: Seq[Expression])( @transient val sc: SparkContext) extends LeafNode { @@ -62,18 +63,30 @@ case class ParquetTableScan( for (path <- fileList if !path.getName.startsWith("_")) { NewFileInputFormat.addInputPath(job, path) } + + // Store Parquet schema in `Configuration` conf.set( RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, ParquetTypesConverter.convertFromAttributes(output).toString) - // TODO: think about adding record filters - /* Comments regarding record filters: it would be nice to push down as much filtering - to Parquet as possible. However, currently it seems we cannot pass enough information - to materialize an (arbitrary) Catalyst [[Predicate]] inside Parquet's - ``FilteredRecordReader`` (via Configuration, for example). Simple - filter-rows-by-column-values however should be supported. - */ - sc.newAPIHadoopRDD(conf, classOf[ParquetInputFormat[Row]], classOf[Void], classOf[Row]) - .map(_._2) + + // Store record filtering predicate in `Configuration` + // Note 1: the input format ignores all predicates that cannot be expressed + // as simple column predicate filters in Parquet. Here we just record + // the whole pruning predicate. + // Note 2: you can disable filter predicate pushdown by setting + // "spark.sql.hints.parquetFilterPushdown" to false inside SparkConf. + if (columnPruningPred.length > 0 && + sc.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { + ParquetFilters.serializeFilterExpressions(columnPruningPred, conf) + } + + sc.newAPIHadoopRDD( + conf, + classOf[org.apache.spark.sql.parquet.FilteringParquetRowInputFormat], + classOf[Void], + classOf[Row]) + .map(_._2) + .filter(_ != null) // Parquet's record filters may produce null values } override def otherCopyArgs = sc :: Nil @@ -184,10 +197,19 @@ case class InsertIntoParquetTable( override def otherCopyArgs = sc :: Nil - // based on ``saveAsNewAPIHadoopFile`` in [[PairRDDFunctions]] - // TODO: Maybe PairRDDFunctions should use Product2 instead of Tuple2? - // .. then we could use the default one and could use [[MutablePair]] - // instead of ``Tuple2`` + /** + * Stores the given Row RDD as a Hadoop file. + * + * Note: We cannot use ``saveAsNewAPIHadoopFile`` from [[org.apache.spark.rdd.PairRDDFunctions]] + * together with [[org.apache.spark.util.MutablePair]] because ``PairRDDFunctions`` uses + * ``Tuple2`` and not ``Product2``. Also, we want to allow appending files to an existing + * directory and need to determine which was the largest written file index before starting to + * write. + * + * @param rdd The [[org.apache.spark.rdd.RDD]] to writer + * @param path The directory to write to. + * @param conf A [[org.apache.hadoop.conf.Configuration]]. + */ private def saveAsHadoopFile( rdd: RDD[Row], path: String, @@ -244,8 +266,10 @@ case class InsertIntoParquetTable( } } -// TODO: this will be able to append to directories it created itself, not necessarily -// to imported ones +/** + * TODO: this will be able to append to directories it created itself, not necessarily + * to imported ones. + */ private[parquet] class AppendingParquetOutputFormat(offset: Int) extends parquet.hadoop.ParquetOutputFormat[Row] { // override to accept existing directories as valid output directory @@ -262,6 +286,30 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) } } +/** + * We extend ParquetInputFormat in order to have more control over which + * RecordFilter we want to use. + */ +private[parquet] class FilteringParquetRowInputFormat + extends parquet.hadoop.ParquetInputFormat[Row] with Logging { + override def createRecordReader( + inputSplit: InputSplit, + taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { + val readSupport: ReadSupport[Row] = new RowReadSupport() + + val filterExpressions = + ParquetFilters.deserializeFilterExpressions(ContextUtil.getConfiguration(taskAttemptContext)) + if (filterExpressions.length > 0) { + logInfo(s"Pushing down predicates for RecordFilter: ${filterExpressions.mkString(", ")}") + new ParquetRecordReader[Row]( + readSupport, + ParquetFilters.createRecordFilter(filterExpressions)) + } else { + new ParquetRecordReader[Row](readSupport) + } + } +} + private[parquet] object FileSystemHelper { def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) @@ -278,7 +326,9 @@ private[parquet] object FileSystemHelper { fs.listStatus(path).map(_.getPath) } - // finds the maximum taskid in the output file names at the given path + /** + * Finds the maximum taskid in the output file names at the given path. + */ def findMaxTaskId(pathStr: String, conf: Configuration): Int = { val files = FileSystemHelper.listFiles(pathStr, conf) // filename pattern is part-r-.parquet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index f37976f7313c..46c717298564 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -19,15 +19,34 @@ package org.apache.spark.sql.parquet import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.Job +import parquet.example.data.{GroupWriter, Group} +import parquet.example.data.simple.SimpleGroup import parquet.hadoop.ParquetWriter -import parquet.hadoop.util.ContextUtil +import parquet.hadoop.api.WriteSupport +import parquet.hadoop.api.WriteSupport.WriteContext +import parquet.io.api.RecordConsumer import parquet.schema.{MessageType, MessageTypeParser} -import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.util.Utils +// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport +// with an empty configuration (it is after all not intended to be used in this way?) +// and members are private so we need to make our own in order to pass the schema +// to the writer. +private class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] { + var groupWriter: GroupWriter = null + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + groupWriter = new GroupWriter(recordConsumer, schema) + } + override def init(configuration: Configuration): WriteContext = { + new WriteContext(schema, new java.util.HashMap[String, String]()) + } + override def write(record: Group) { + groupWriter.write(record) + } +} + private[sql] object ParquetTestData { val testSchema = @@ -43,7 +62,7 @@ private[sql] object ParquetTestData { // field names for test assertion error messages val testSchemaFieldNames = Seq( "myboolean:Boolean", - "mtint:Int", + "myint:Int", "mystring:String", "mylong:Long", "myfloat:Float", @@ -58,6 +77,18 @@ private[sql] object ParquetTestData { |} """.stripMargin + val testFilterSchema = + """ + |message myrecord { + |required boolean myboolean; + |required int32 myint; + |required binary mystring; + |required int64 mylong; + |required float myfloat; + |required double mydouble; + |} + """.stripMargin + // field names for test assertion error messages val subTestSchemaFieldNames = Seq( "myboolean:Boolean", @@ -65,36 +96,57 @@ private[sql] object ParquetTestData { ) val testDir = Utils.createTempDir() + val testFilterDir = Utils.createTempDir() lazy val testData = new ParquetRelation(testDir.toURI.toString) def writeFile() = { testDir.delete val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet")) - val job = new Job() - val configuration: Configuration = ContextUtil.getConfiguration(job) val schema: MessageType = MessageTypeParser.parseMessageType(testSchema) + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) - val writeSupport = new RowWriteSupport() - writeSupport.setSchema(schema, configuration) - val writer = new ParquetWriter(path, writeSupport) for(i <- 0 until 15) { - val data = new Array[Any](6) + val record = new SimpleGroup(schema) if (i % 3 == 0) { - data.update(0, true) + record.add(0, true) } else { - data.update(0, false) + record.add(0, false) } if (i % 5 == 0) { - data.update(1, 5) + record.add(1, 5) + } + record.add(2, "abc") + record.add(3, i.toLong << 33) + record.add(4, 2.5F) + record.add(5, 4.5D) + writer.write(record) + } + writer.close() + } + + def writeFilterFile(records: Int = 200) = { + // for microbenchmark use: records = 300000000 + testFilterDir.delete + val path: Path = new Path(new Path(testFilterDir.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testFilterSchema) + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + + for(i <- 0 to records) { + val record = new SimpleGroup(schema) + if (i % 4 == 0) { + record.add(0, true) } else { - data.update(1, null) // optional + record.add(0, false) } - data.update(2, "abc") - data.update(3, i.toLong << 33) - data.update(4, 2.5F) - data.update(5, 4.5D) - writer.write(new GenericRow(data.toArray)) + record.add(1, i) + record.add(2, i.toString) + record.add(3, i.toLong) + record.add(4, i.toFloat + 0.5f) + record.add(5, i.toDouble + 0.5d) + writer.write(record) } writer.close() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c1fc99f07743..e9360b0fc791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -141,7 +141,7 @@ class SQLQuerySuite extends QueryTest { sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), Seq((2147483645.0,1),(2.0,2))) } - + test("count") { checkAnswer( sql("SELECT COUNT(*) FROM testData2"), @@ -332,7 +332,7 @@ class SQLQuerySuite extends QueryTest { (3, "C"), (4, "D"))) } - + test("system function upper()") { checkAnswer( sql("SELECT n,UPPER(l) FROM lowerCaseData"), @@ -349,7 +349,7 @@ class SQLQuerySuite extends QueryTest { (2, "ABC"), (3, null))) } - + test("system function lower()") { checkAnswer( sql("SELECT N,LOWER(L) FROM upperCaseData"), @@ -382,25 +382,25 @@ class SQLQuerySuite extends QueryTest { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Seq(Seq(s"$testKey=$testVal")) + Seq(Seq(testKey, testVal)) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Seq(s"$testKey=$testVal"), - Seq(s"${testKey + testKey}=${testVal + testVal}")) + Seq(testKey, testVal), + Seq(testKey + testKey, testVal + testVal)) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Seq(Seq(s"$testKey=$testVal")) + Seq(Seq(testKey, testVal)) ) checkAnswer( sql(s"SET $nonexistentKey"), - Seq(Seq(s"$nonexistentKey is undefined")) + Seq(Seq(nonexistentKey, "")) ) clear() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 64aacabe10ef..9810520bb9ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,25 +17,25 @@ package org.apache.spark.sql.parquet -import java.io.File - -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.hadoop.mapreduce.Job import parquet.hadoop.ParquetFileWriter -import parquet.schema.MessageTypeParser import parquet.hadoop.util.ContextUtil +import parquet.schema.MessageTypeParser import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.getTempFilePath -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.TestData +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.Equals +import org.apache.spark.sql.catalyst.types.IntegerType import org.apache.spark.util.Utils -import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, DataType} -import org.apache.spark.sql.{parquet, SchemaRDD} // Implicits import org.apache.spark.sql.test.TestSQLContext._ @@ -56,7 +56,7 @@ case class OptionalReflectData( doubleField: Option[Double], booleanField: Option[Boolean]) -class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll { +class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { import TestData._ TestData // Load test data tables. @@ -64,12 +64,16 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll { override def beforeAll() { ParquetTestData.writeFile() + ParquetTestData.writeFilterFile() testRDD = parquetFile(ParquetTestData.testDir.toString) testRDD.registerAsTable("testsource") + parquetFile(ParquetTestData.testFilterDir.toString) + .registerAsTable("testfiltersource") } override def afterAll() { Utils.deleteRecursively(ParquetTestData.testDir) + Utils.deleteRecursively(ParquetTestData.testFilterDir) // here we should also unregister the table?? } @@ -120,7 +124,7 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll { val scanner = new ParquetTableScan( ParquetTestData.testData.output, ParquetTestData.testData, - None)(TestSQLContext.sparkContext) + Seq())(TestSQLContext.sparkContext) val projected = scanner.pruneColumns(ParquetTypesConverter .convertToAttributes(MessageTypeParser .parseMessageType(ParquetTestData.subTestSchema))) @@ -196,7 +200,6 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll { assert(true) } - test("insert (appending) to same table via Scala API") { sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() @@ -239,5 +242,125 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll { Utils.deleteRecursively(file) assert(true) } -} + test("create RecordFilter for simple predicates") { + val attribute1 = new AttributeReference("first", IntegerType, false)() + val predicate1 = new Equals(attribute1, new Literal(1, IntegerType)) + val filter1 = ParquetFilters.createFilter(predicate1) + assert(filter1.isDefined) + assert(filter1.get.predicate == predicate1, "predicates do not match") + assert(filter1.get.isInstanceOf[ComparisonFilter]) + val cmpFilter1 = filter1.get.asInstanceOf[ComparisonFilter] + assert(cmpFilter1.columnName == "first", "column name incorrect") + + val predicate2 = new LessThan(attribute1, new Literal(4, IntegerType)) + val filter2 = ParquetFilters.createFilter(predicate2) + assert(filter2.isDefined) + assert(filter2.get.predicate == predicate2, "predicates do not match") + assert(filter2.get.isInstanceOf[ComparisonFilter]) + val cmpFilter2 = filter2.get.asInstanceOf[ComparisonFilter] + assert(cmpFilter2.columnName == "first", "column name incorrect") + + val predicate3 = new And(predicate1, predicate2) + val filter3 = ParquetFilters.createFilter(predicate3) + assert(filter3.isDefined) + assert(filter3.get.predicate == predicate3, "predicates do not match") + assert(filter3.get.isInstanceOf[AndFilter]) + + val predicate4 = new Or(predicate1, predicate2) + val filter4 = ParquetFilters.createFilter(predicate4) + assert(filter4.isDefined) + assert(filter4.get.predicate == predicate4, "predicates do not match") + assert(filter4.get.isInstanceOf[OrFilter]) + + val attribute2 = new AttributeReference("second", IntegerType, false)() + val predicate5 = new GreaterThan(attribute1, attribute2) + val badfilter = ParquetFilters.createFilter(predicate5) + assert(badfilter.isDefined === false) + } + + test("test filter by predicate pushdown") { + for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) { + println(s"testing field $myval") + val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") + assert( + query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result1 = query1.collect() + assert(result1.size === 50) + assert(result1(0)(1) === 100) + assert(result1(49)(1) === 149) + val query2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") + assert( + query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result2 = query2.collect() + assert(result2.size === 50) + if (myval == "myint" || myval == "mylong") { + assert(result2(0)(1) === 151) + assert(result2(49)(1) === 200) + } else { + assert(result2(0)(1) === 150) + assert(result2(49)(1) === 199) + } + } + for(myval <- Seq("myint", "mylong")) { + val query3 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190 OR $myval < 10") + assert( + query3.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result3 = query3.collect() + assert(result3.size === 20) + assert(result3(0)(1) === 0) + assert(result3(9)(1) === 9) + assert(result3(10)(1) === 191) + assert(result3(19)(1) === 200) + } + for(myval <- Seq("mydouble", "myfloat")) { + val result4 = + if (myval == "mydouble") { + val query4 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10.0") + assert( + query4.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + query4.collect() + } else { + // CASTs are problematic. Here myfloat will be casted to a double and it seems there is + // currently no way to specify float constants in SqlParser? + sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10").collect() + } + assert(result4.size === 20) + assert(result4(0)(1) === 0) + assert(result4(9)(1) === 9) + assert(result4(10)(1) === 191) + assert(result4(19)(1) === 200) + } + val query5 = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40") + assert( + query5.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val booleanResult = query5.collect() + assert(booleanResult.size === 10) + for(i <- 0 until 10) { + if (!booleanResult(i).getBoolean(0)) { + fail(s"Boolean value in result row $i not true") + } + if (booleanResult(i).getInt(1) != i * 4) { + fail(s"Int value in result row $i should be ${4*i}") + } + } + val query6 = sql("SELECT * FROM testfiltersource WHERE mystring = \"100\"") + assert( + query6.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val stringResult = query6.collect() + assert(stringResult.size === 1) + assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") + assert(stringResult(0).getInt(1) === 100) + } + + test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { + val query = sql(s"SELECT mystring FROM testfiltersource WHERE myint < 10") + assert(query.collect().size === 10) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 9cd13f6ae0d5..96e0ec513633 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -15,8 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql -package hive +package org.apache.spark.sql.hive import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.util.{ArrayList => JArrayList} @@ -32,12 +31,13 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} -import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.execution.{Command => PhysicalCommand} /** * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is @@ -71,14 +71,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** * Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD. */ - def hiveql(hqlQuery: String): SchemaRDD = { - val result = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) - // We force query optimization to happen right away instead of letting it happen lazily like - // when using the query DSL. This is so DDL commands behave as expected. This is only - // generates the RDD lineage for DML queries, but does not perform any execution. - result.queryExecution.toRdd - result - } + def hiveql(hqlQuery: String): SchemaRDD = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) /** An alias for `hiveql`. */ def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery) @@ -164,7 +157,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** * Runs the specified SQL query using Hive. */ - protected def runSqlHive(sql: String): Seq[String] = { + protected[sql] def runSqlHive(sql: String): Seq[String] = { val maxResults = 100000 val results = runHive(sql, 100000) // It is very confusing when you only get back some of the results... @@ -228,6 +221,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override val strategies: Seq[Strategy] = Seq( CommandStrategy(self), + HiveCommandStrategy(self), TakeOrdered, ParquetOperations, InMemoryScans, @@ -252,25 +246,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override lazy val optimizedPlan = optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) - override lazy val toRdd: RDD[Row] = { - def processCmd(cmd: String): RDD[Row] = { - val output = runSqlHive(cmd) - if (output.size == 0) { - emptyResult - } else { - val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]])) - sparkContext.parallelize(asRows, 1) - } - } - - logical match { - case s: SetCommand => eagerlyProcess(s) - case _ => analyzed match { - case NativeCommand(cmd) => processCmd(cmd) - case _ => executedPlan.execute().map(_.copy()) - } - } - } + override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, @@ -298,7 +274,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ))=> + case (seq: Seq[_], ArrayType(typ)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_,_], MapType(kType, vType)) => map.map { @@ -314,10 +290,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * Returns the result as a hive compatible sequence of strings. For native commands, the * execution is simply passed back to Hive. */ - def stringResult(): Seq[String] = analyzed match { - case NativeCommand(cmd) => runSqlHive(cmd) - case ExplainCommand(plan) => executePlan(plan).toString.split("\n") - case query => + def stringResult(): Seq[String] = executedPlan match { + case command: PhysicalCommand => + command.sideEffectResult.map(_.toString) + + case other => val result: Seq[Seq[Any]] = toRdd.collect().toSeq // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) @@ -328,8 +305,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def simpleString: String = logical match { - case _: NativeCommand => "" - case _: SetCommand => "" + case _: NativeCommand => "" + case _: SetCommand => "" case _ => executedPlan.toString } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d1aa8c868cb1..0ac0ee9071f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -70,9 +70,18 @@ private[hive] trait HiveStrategies { pruneFilterProject( projectList, otherPredicates, + identity[Seq[Expression]], HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil case _ => Nil } } + + case class HiveCommandStrategy(context: HiveContext) extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.NativeCommand(sql) => + NativeCommand(sql, plan.output)(context) :: Nil + case _ => Nil + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index d199097e06ed..9386008d02d5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -58,7 +58,6 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { // By clearing the port we force Spark to pick a new one. This allows us to rerun tests // without restarting the JVM. - System.clearProperty("spark.driver.port") System.clearProperty("spark.hostPort") override lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala index 29b4b9b006e4..a83923144916 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala @@ -32,14 +32,15 @@ import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Serializer} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ +import org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, DataType} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ -import org.apache.spark.{TaskContext, SparkException} import org.apache.spark.util.MutablePair +import org.apache.spark.{TaskContext, SparkException} /* Implicits */ import scala.collection.JavaConversions._ @@ -57,7 +58,7 @@ case class HiveTableScan( attributes: Seq[Attribute], relation: MetastoreRelation, partitionPruningPred: Option[Expression])( - @transient val sc: HiveContext) + @transient val context: HiveContext) extends LeafNode with HiveInspectors { @@ -75,7 +76,7 @@ case class HiveTableScan( } @transient - val hadoopReader = new HadoopTableReader(relation.tableDesc, sc) + val hadoopReader = new HadoopTableReader(relation.tableDesc, context) /** * The hive object inspector for this table, which can be used to extract values from the @@ -156,7 +157,7 @@ case class HiveTableScan( hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames) } - addColumnMetadataToConf(sc.hiveconf) + addColumnMetadataToConf(context.hiveconf) @transient def inputRdd = if (!relation.hiveQlTable.isPartitioned) { @@ -428,3 +429,26 @@ case class InsertIntoHiveTable( sc.sparkContext.makeRDD(Nil, 1) } } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class NativeCommand( + sql: String, output: Seq[Attribute])( + @transient context: HiveContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult: Seq[String] = context.runSqlHive(sql) + + override def execute(): RDD[spark.sql.Row] = { + if (sideEffectResult.size == 0) { + context.emptyResult + } else { + val rows = sideEffectResult.map(r => new GenericRow(Array[Any](r))) + context.sparkContext.parallelize(rows, 1) + } + } + + override def otherCopyArgs = context :: Nil +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 357c7e654bd2..24c929ff7430 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -24,6 +24,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} import org.apache.spark.sql.Logging import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand => LogicalNativeCommand} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.hive.test.TestHive @@ -141,7 +142,7 @@ abstract class HiveComparisonTest // Hack: Hive simply prints the result of a SET command to screen, // and does not return it as a query answer. case _: SetCommand => Seq("0") - case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") + case _: LogicalNativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") case _: ExplainCommand => answer case plan => if (isSorted(plan)) answer else answer.sorted } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 3581617c269a..ee194dbcb77b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -172,7 +172,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "case_sensitivity", // Flaky test, Hive sometimes returns different set of 10 rows. - "lateral_view_outer" + "lateral_view_outer", + + // After stop taking the `stringOrError` route, exceptions are thrown from these cases. + // See SPARK-2129 for details. + "join_view", + "mergejoins_mixed" ) /** @@ -476,7 +481,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "join_reorder3", "join_reorder4", "join_star", - "join_view", "lateral_view", "lateral_view_cp", "lateral_view_ppd", @@ -507,7 +511,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "merge1", "merge2", "mergejoins", - "mergejoins_mixed", "multigroupby_singlemr", "multi_insert_gby", "multi_insert_gby3", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6c239b02ed09..0d656c556965 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.Row -import org.apache.spark.sql.hive.test.TestHive._ +import scala.util.Try + import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.{SchemaRDD, execution, Row} /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. @@ -162,16 +164,60 @@ class HiveQuerySuite extends HiveComparisonTest { hql("SELECT * FROM src").toString } + private val explainCommandClassName = + classOf[execution.ExplainCommand].getSimpleName.stripSuffix("$") + + def isExplanation(result: SchemaRDD) = { + val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } + explanation.size == 1 && explanation.head.startsWith(explainCommandClassName) + } + test("SPARK-1704: Explain commands as a SchemaRDD") { hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + val rdd = hql("explain select key, count(value) from src group by key") - assert(rdd.collect().size == 1) - assert(rdd.toString.contains("ExplainCommand")) - assert(rdd.filter(row => row.toString.contains("ExplainCommand")).collect().size == 0, - "actual contents of the result should be the plans of the query to be explained") + assert(isExplanation(rdd)) + TestHive.reset() } + test("Query Hive native command execution result") { + val tableName = "test_native_commands" + + val q0 = hql(s"DROP TABLE IF EXISTS $tableName") + assert(q0.count() == 0) + + val q1 = hql(s"CREATE TABLE $tableName(key INT, value STRING)") + assert(q1.count() == 0) + + val q2 = hql("SHOW TABLES") + val tables = q2.select('result).collect().map { case Row(table: String) => table } + assert(tables.contains(tableName)) + + val q3 = hql(s"DESCRIBE $tableName") + assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { + q3.select('result).collect().map { case Row(fieldDesc: String) => + fieldDesc.split("\t").map(_.trim) + } + } + + val q4 = hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key") + assert(isExplanation(q4)) + + TestHive.reset() + } + + test("Exactly once semantics for DDL and command statements") { + val tableName = "test_exactly_once" + val q0 = hql(s"CREATE TABLE $tableName(key INT, value STRING)") + + // If the table was not created, the following assertion would fail + assert(Try(table(tableName)).isSuccess) + + // If the CREATE TABLE command got executed again, the following assertion would fail + assert(Try(q0.count()).isSuccess) + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" @@ -195,52 +241,69 @@ class HiveQuerySuite extends HiveComparisonTest { test("SET commands semantics for a HiveContext") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" - var testVal = "test.val.0" + val testVal = "test.val.0" val nonexistentKey = "nonexistent" - def fromRows(row: Array[Row]): Array[String] = row.map(_.getString(0)) + def rowsToPairs(rows: Array[Row]) = rows.map { case Row(key: String, value: String) => + key -> value + } clear() // "set" itself returns all config variables currently specified in SQLConf. - assert(hql("set").collect().size == 0) + assert(hql("SET").collect().size == 0) + + assertResult(Array(testKey -> testVal)) { + rowsToPairs(hql(s"SET $testKey=$testVal").collect()) + } - // "set key=val" - hql(s"SET $testKey=$testVal") - assert(fromRows(hql("SET").collect()) sameElements Array(s"$testKey=$testVal")) assert(hiveconf.get(testKey, "") == testVal) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(hql("SET").collect()) + } hql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(fromRows(hql("SET").collect()) sameElements - Array( - s"$testKey=$testVal", - s"${testKey + testKey}=${testVal + testVal}")) assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + assertResult(Array(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + rowsToPairs(hql("SET").collect()) + } // "set key" - assert(fromRows(hql(s"SET $testKey").collect()) sameElements - Array(s"$testKey=$testVal")) - assert(fromRows(hql(s"SET $nonexistentKey").collect()) sameElements - Array(s"$nonexistentKey is undefined")) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(hql(s"SET $testKey").collect()) + } + + assertResult(Array(nonexistentKey -> "")) { + rowsToPairs(hql(s"SET $nonexistentKey").collect()) + } // Assert that sql() should have the same effects as hql() by repeating the above using sql(). clear() - assert(sql("set").collect().size == 0) + assert(sql("SET").collect().size == 0) + + assertResult(Array(testKey -> testVal)) { + rowsToPairs(sql(s"SET $testKey=$testVal").collect()) + } - sql(s"SET $testKey=$testVal") - assert(fromRows(sql("SET").collect()) sameElements Array(s"$testKey=$testVal")) assert(hiveconf.get(testKey, "") == testVal) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(sql("SET").collect()) + } sql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(fromRows(sql("SET").collect()) sameElements - Array( - s"$testKey=$testVal", - s"${testKey + testKey}=${testVal + testVal}")) assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + assertResult(Array(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + rowsToPairs(sql("SET").collect()) + } - assert(fromRows(sql(s"SET $testKey").collect()) sameElements - Array(s"$testKey=$testVal")) - assert(fromRows(sql(s"SET $nonexistentKey").collect()) sameElements - Array(s"$nonexistentKey is undefined")) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(sql(s"SET $testKey").collect()) + } + + assertResult(Array(nonexistentKey -> "")) { + rowsToPairs(sql(s"SET $nonexistentKey").collect()) + } + + clear() } // Put tests that depend on specific Hive settings before these last two test, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 04925886c39e..ff6d86c8f81a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -92,9 +92,9 @@ class BasicOperationsSuite extends TestSuiteBase { assert(second.size === 5) assert(third.size === 5) - assert(first.flatten.toSet === (1 to 100).toSet) - assert(second.flatten.toSet === (101 to 200).toSet) - assert(third.flatten.toSet === (201 to 300).toSet) + assert(first.flatten.toSet.equals((1 to 100).toSet) ) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) } test("repartition (fewer partitions)") { @@ -111,9 +111,9 @@ class BasicOperationsSuite extends TestSuiteBase { assert(second.size === 2) assert(third.size === 2) - assert(first.flatten.toSet === (1 to 100).toSet) - assert(second.flatten.toSet === (101 to 200).toSet) - assert(third.flatten.toSet === (201 to 300).toSet) + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals( (101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) } test("groupByKey") {