From 0fd3a07a29a58e763b6e0a0ad26ca860efdbb979 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 29 Oct 2014 19:19:24 -0700 Subject: [PATCH 01/17] Draft of data sources API --- .../apache/spark/sql/catalyst/package.scala | 4 + .../spark/sql/catalyst/types/dataTypes.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 14 +- .../spark/sql/execution/ExistingRDD.scala | 6 - .../spark/sql/execution/SparkStrategies.scala | 28 +++- .../apache/spark/sql/execution/commands.scala | 35 ++++- .../apache/spark/sql/json/JSONRelation.scala | 52 ++++++++ .../apache/spark/sql/sources/DataSource.scala | 21 +++ .../spark/sql/sources/LogicalRelation.scala | 56 ++++++++ .../org/apache/spark/sql/sources/ddl.scala | 110 +++++++++++++++ .../apache/spark/sql/sources/package.scala | 95 +++++++++++++ .../apache/spark/sql/CachedTableSuite.scala | 12 -- .../org/apache/spark/sql/QueryTest.scala | 30 ++++- .../org/apache/spark/sql/json/JsonSuite.scala | 26 ++++ .../spark/sql/sources/DataSourceTest.scala | 34 +++++ .../spark/sql/sources/FilteredScanSuite.scala | 117 ++++++++++++++++ .../spark/sql/sources/PrunedScanSuite.scala | 102 ++++++++++++++ .../spark/sql/sources/TableScanSuite.scala | 126 ++++++++++++++++++ .../apache/spark/sql/hive/HiveContext.scala | 2 +- 19 files changed, 844 insertions(+), 28 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/DataSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index bdd07bbeb2230..a38079ced34b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +/** + * Catalyst is a library for manipulating relational query plans. All classes in catalyst are + * considered an internal API to Spark SQL and are subject to change between minor releases. + */ package object catalyst { /** * A JVM-global lock that should be used to prevent thread safety issues when using things in diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index b9cf37d53ffd2..932614720fa8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -389,7 +389,7 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT * @param dataType The data type of this field. * @param nullable Indicates if values of this field can be `null` values. */ -case class StructField(name: String, dataType: DataType, nullable: Boolean) { +case class StructField(name: String, dataType: DataType, nullable: Boolean = true) { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a41a500c9a5d0..c4a564d5c514d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.sources.{BaseRelation, DDLParser, LogicalRelation} /** * :: AlphaComponent :: @@ -69,13 +70,19 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer + @transient + protected[sql] val ddlParser = new DDLParser + @transient protected[sql] val sqlParser = { val fallback = new catalyst.SqlParser new catalyst.SparkSQLParser(fallback(_)) } - protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = { + ddlParser(sql).getOrElse(sqlParser(sql)) + } + protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -105,6 +112,10 @@ class SQLContext(@transient val sparkContext: SparkContext) LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self)) } + implicit def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { + logicalPlanToSparkQuery(LogicalRelation(baseRelation)) + } + /** * :: DeveloperApi :: * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. @@ -295,6 +306,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val strategies: Seq[Strategy] = CommandStrategy(self) :: + DataSources :: TakeOrdered :: HashAggregation :: LeftSemiJoin :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 04c51a1ee4b97..d64c5af89ec99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -50,12 +50,6 @@ object RDDConversions { } } } - - /* - def toLogicalPlan[A <: Product : TypeTag](productRdd: RDD[A]): LogicalPlan = { - LogicalRDD(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) - } - */ } case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 79e4ddb8c4f5d..efbcbebe6e66f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, execution} +import org.apache.spark.sql.{sources, SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -252,6 +252,31 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object DataSources extends Strategy { + import sources._ + + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FilteredScan)) => + pruneFilterProject( + projectList, + filters, + identity[Seq[Expression]], // All filters still need to be evaluated + a => PhysicalRDD(a, t.buildScan(a.map(l.attributeMap), filters))) :: Nil + + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedScan)) => + pruneFilterProject( + projectList, + filters, + identity[Seq[Expression]], // All filters still need to be evaluated. + a => PhysicalRDD(a, t.buildScan(a.map(l.attributeMap)))) :: Nil + + case l @ LogicalRelation(t: TableScan) => + PhysicalRDD(l.output, t.buildScan()) :: Nil + + case _ => Nil + } + } + // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def numPartitions = self.numPartitions @@ -304,6 +329,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case r: RunnableCommand => ExecutedCommand(r) :: Nil case logical.SetCommand(kv) => Seq(execution.SetCommand(kv, plan.output)(context)) case logical.ExplainCommand(logicalPlan, extended) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 5859eba408ee1..53f2f58393a05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,10 +21,12 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} +import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.{Row, SQLConf, SQLContext} +import org.apache.spark.sql.{SQLConf, SQLContext} +// TODO: DELETE ME... trait Command { this: SparkPlan => @@ -44,6 +46,35 @@ trait Command { override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) } +// TODO: Replace command with runnable command. +trait RunnableCommand extends logical.Command { + self: Product => + + def output: Seq[Attribute] + def run(sqlContext: SQLContext): Seq[Row] +} + +case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { + /** + * A concrete command should override this lazy field to wrap up any side effects caused by the + * command or any other computation that should be evaluated exactly once. The value of this field + * can be used as the contents of the corresponding RDD generated from the physical plan of this + * command. + * + * The `execute()` method of all the physical command classes should reference `sideEffectResult` + * so that the command can be executed eagerly right after the command query is created. + */ + protected[sql] lazy val sideEffectResult: Seq[Row] = cmd.run(sqlContext) + + override def output = cmd.output + + def children = Nil + + override def executeCollect(): Array[Row] = sideEffectResult.toArray + + override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) +} + /** * :: DeveloperApi :: */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala new file mode 100644 index 0000000000000..bf7ba34de8d66 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.sources._ + +private[sql] class DefaultSource extends RelationProvider { + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val fileName = + parameters.getOrElse("fileName", sys.error(s"Option 'fileName' not specified")) + + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + + JSONRelation(fileName, samplingRatio)(sqlContext) + } +} + +private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)( + @transient val sqlContext: SQLContext) + extends BaseRelation with TableScan { + + private def baseRDD = sqlContext.sparkContext.textFile(fileName) + + override val schema = + JsonRDD.inferSchema( + baseRDD, + samplingRatio, + sqlContext.columnNameOfCorruptRecord) + + override def buildScan() = + JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord) +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSource.scala new file mode 100644 index 0000000000000..f03ff136c98b9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSource.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources + +import org.apache.spark.sql.SQLContext + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala new file mode 100644 index 0000000000000..338a555a4016b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources + +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Expression, Attribute} +import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LogicalPlan} +import org.apache.spark.sql.catalyst.trees.LeafNode + +/** + * Used to link a [[BaseRelation]] in to a logical query plan. + */ +private[sql] case class LogicalRelation(relation: BaseRelation) + extends LogicalPlan + with LeafNode[LogicalPlan] + with MultiInstanceRelation { + + val output = relation.schema.toAttributes + + // Logical Relations are distinct if they have different output for the sake of transformations. + override def equals(other: Any) = other match { + case l @ LogicalRelation(otherRelation) => relation == otherRelation && output == l.output + case _ => false + } + + override def sameResult(otherPlan: LogicalPlan) = otherPlan match { + case LogicalRelation(otherRelation) => relation == otherRelation + case _ => false + } + + @transient override lazy val statistics = Statistics( + // TODO: Allow datasources to provide statistics as well. + sizeInBytes = BigInt(relation.sqlContext.defaultSizeInBytes) + ) + + /** Used to lookup original attribute capitalization */ + val attributeMap = AttributeMap(output.map(o => (o, o))) + + def newInstance() = LogicalRelation(relation).asInstanceOf[this.type] + + override def simpleString = s"Relation[${output.mkString(",")}] $relation" +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala new file mode 100644 index 0000000000000..05463208c156d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.util.Utils + +import scala.language.implicitConversions +import scala.util.parsing.combinator.lexical.StdLexical +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.PackratParsers + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.SqlLexical + +/** + * A parser for foreign DDL commands. + */ +private[sql] class DDLParser extends StandardTokenParsers with PackratParsers with Logging { + + def apply(input: String): Option[LogicalPlan] = { + phrase(ddl)(new lexical.Scanner(input)) match { + case Success(r, x) => Some(r) + case x => + logDebug(s"Not recognized as DDL: $x") + None + } + } + + protected case class Keyword(str: String) + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + protected val CREATE = Keyword("CREATE") + protected val TEMPORARY = Keyword("TEMPORARY") + protected val TABLE = Keyword("TABLE") + protected val USING = Keyword("USING") + protected val OPTIONS = Keyword("OPTIONS") + + // Use reflection to find the reserved words defined in this class. + protected val reservedWords = + this.getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new SqlLexical(reservedWords) + + protected lazy val ddl: Parser[LogicalPlan] = createTable + + /** + * CREATE FOREIGN TEMPORARY TABLE avroTable + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") + */ + protected lazy val createTable: Parser[LogicalPlan] = + CREATE ~> TEMPORARY ~> TABLE ~> ident ~ + USING ~ className ~ + OPTIONS ~ options ^^ { + case tableName ~ _ ~ provider ~ _ ~ opts => + CreateTableUsing(tableName, provider, opts) + } + + protected lazy val options: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} + + protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } +} + +private[sql] case class CreateTableUsing( + tableName: String, + provider: String, + options: Map[String, String]) extends RunnableCommand { + + def run(sqlContext: SQLContext) = { + val loader = Utils.getContextOrSparkClassLoader + val clazz: Class[_] = try Utils.getContextOrSparkClassLoader.loadClass(provider) catch { + case cnf: java.lang.ClassNotFoundException => + try Utils.getContextOrSparkClassLoader.loadClass(provider + ".DefaultSource") catch { + case cnf: java.lang.ClassNotFoundException => + sys.error(s"Failed to load class for data source: $provider") + } + } + val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider] + val relation = dataSource.createRelation(sqlContext, options) + + sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) + Seq.empty + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala new file mode 100644 index 0000000000000..5b61df3e5d199 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} + +/** + * A set of APIs for adding data sources to Spark SQL. + */ +package object sources { + + /** + * Implemented by objects that produce relations for a specific kind of data source. When + * Spark SQL is given a DDL operation with a USING clause specified, this interface is used to + * pass in the parameters specified by a user. + * + * Users may specify the fully qualified class name of a given data source. When that class is + * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for + * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the + * data source 'org.apache.spark.sql.json.DefaultSource' + */ + trait RelationProvider { + /** Returns a new base relation with the given parameters. */ + def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation + } + + /** + * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must + * be able to produce the schema of their data in the form of a [[StructType]] In order to be + * executed, a BaseRelation must also mix in at least one of the Scan traits. + * + * BaseRelations must also define a equality function that only returns true when the two + * instances will return the same data. This equality function is used when determining when + * it is safe to substitute cached results for a given relation. + */ + @DeveloperApi + abstract class BaseRelation { + def sqlContext: SQLContext + def schema: StructType + } + + /** + * Mixed into a BaseRelation that can produce all of its tuples as an RDD of Row objects. + */ + @DeveloperApi + trait TableScan { + self: BaseRelation => + + def buildScan(): RDD[Row] + } + + /** + * Mixed into a BaseRelation that can eliminate unneeded columns before producing an RDD + * containing all of its tuples as Row objects. + */ + @DeveloperApi + trait PrunedScan { + self: BaseRelation => + + def buildScan(requiredColumns: Seq[Attribute]): RDD[Row] + } + + /** + * Mixed into a BaseRelation that can eliminate unneeded columns and filter using selected + * predicates before producing an RDD containing all matching tuples as Row objects. + * + * The pushed down filters are currently purely an optimization as they will all be evaluated + * again. This means it is safe to use them with methods that produce false positives such + * as filtering partitions based on a bloom filter. + */ + @DeveloperApi + trait FilteredScan { + self: BaseRelation => + + def buildScan( + requiredColumns: Seq[Attribute], + filters: Seq[Expression]): RDD[Row] + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1a5d87d5240e9..44a2961b27eda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -27,18 +27,6 @@ case class BigData(s: String) class CachedTableSuite extends QueryTest { TestData // Load test tables. - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } - def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan executedPlan.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 042f61f5a4113..3d9f0cbf80fe7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.columnar.InMemoryRelation class QueryTest extends PlanTest { + /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer @@ -78,11 +80,31 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${convertedAnswer.size} ==" +: - prepareAnswer(convertedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + s"== Correct Answer - ${convertedAnswer.size} ==" +: + prepareAnswer(convertedAnswer).map(_.toString), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} """.stripMargin) } } + + def sqlTest(sqlString: String, expectedAnswer: Any)(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } + + /** Asserts that a given SchemaRDD will be executed using the given number of cached results. */ + def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index ce6184f5d8c9d..cb01345c3c282 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -545,6 +545,32 @@ class JsonSuite extends QueryTest { ) } + test("Loading a JSON dataset from a text file with SQL") { + val file = getTempFilePath("json") + val path = file.toString + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + sql( + s""" + |CREATE TEMPORARY TABLE jsonTableSQL + |USING org.apache.spark.sql.json + |OPTIONS ( + | fileName '$path' + |) + """.stripMargin) + + checkAnswer( + sql("select * from jsonTableSQL"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } + test("Applying schemas") { val file = getTempFilePath("json") val path = file.toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala new file mode 100644 index 0000000000000..9626252e742e5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -0,0 +1,34 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.test.TestSQLContext +import org.scalatest.BeforeAndAfter + +abstract class DataSourceTest extends QueryTest with BeforeAndAfter { + // Case sensitivity is not configurable yet, but we want to test some edge cases. + // TODO: Remove when it is configurable + implicit val caseInsensisitiveContext = new SQLContext(TestSQLContext.sparkContext) { + @transient + override protected[sql] lazy val analyzer: Analyzer = + new Analyzer(catalog, functionRegistry, caseSensitive = false) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala new file mode 100644 index 0000000000000..3de45ab60300b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -0,0 +1,117 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.catalyst.expressions.{Row => _, _} +import org.apache.spark.sql._ + +class FilteredScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleFilteredScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends BaseRelation with FilteredScan { + + override def schema = + StructType( + StructField("a", IntegerType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: Nil) + + override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]) = { + val rowBuilders = requiredColumns.map(_.name).map { + case "a" => (i: Int) => Seq(i) + case "b" => (i: Int) => Seq(i * 2) + } + + val filter = filters.collect { + case Seq(EqualTo(a: AttributeReference, l: Literal)) if a.name == "a" => + (i: Int) => i == l.value + case Seq(EqualTo(l: Literal, a: AttributeReference)) if a.name == "a" => + (i: Int) => i == l.value + }.headOption.getOrElse((_: Int) => true) + + sqlContext.sparkContext.parallelize(from to to).filter(filter).map(i => + Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) + } +} + +class FilteredScanSuite extends DataSourceTest { + + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenFiltered + |USING org.apache.spark.sql.sources.FilteredScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTenFiltered", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT b, a FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2, i)).toSeq) + + sqlTest( + "SELECT a FROM oneToTenFiltered", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a * 2 FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT A AS b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT x.b, y.a FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b", + (1 to 5).map(i => Row(i * 4, i)).toSeq) + + sqlTest( + "SELECT x.a, y.b FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b", + (2 to 10 by 2).map(i => Row(i, i)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE a = 1", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE b = 2", + Seq(1).map(i => Row(i, i * 2)).toSeq) +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala new file mode 100644 index 0000000000000..276d3815bf744 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -0,0 +1,102 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.catalyst.expressions.{Row => _, _} +import org.apache.spark.sql._ + +class PrunedScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends BaseRelation with PrunedScan { + + override def schema = + StructType( + StructField("a", IntegerType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: Nil) + + override def buildScan(requiredColumns: Seq[Attribute]) = { + val rowBuilders = requiredColumns.map(_.name).map { + case "a" => (i: Int) => Seq(i) + case "b" => (i: Int) => Seq(i * 2) + } + + sqlContext.sparkContext.parallelize(from to to).map(i => + Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) + } +} + +class PrunedScanSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenPruned + |USING org.apache.spark.sql.sources.PrunedScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT b, a FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2, i)).toSeq) + + sqlTest( + "SELECT a FROM oneToTenPruned", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT b FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a * 2 FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT A AS b FROM oneToTenPruned", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT x.b, y.a FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", + (1 to 5).map(i => Row(i * 4, i)).toSeq) + + sqlTest( + "SELECT x.a, y.b FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", + (2 to 10 by 2).map(i => Row(i, i)).toSeq) + +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala new file mode 100644 index 0000000000000..5eb30827e7314 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -0,0 +1,126 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.catalyst.expressions.{Row => _, _} +import org.apache.spark.sql._ + +class DefaultSource extends SimpleScanSource + +class SimpleScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends BaseRelation with TableScan { + + override def schema = + StructType(StructField("i", IntegerType, nullable = false) :: Nil) + + override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) +} + +class TableScanSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTen + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTen", + (1 to 10).map(Row(_)).toSeq) + + sqlTest( + "SELECT i FROM oneToTen", + (1 to 10).map(Row(_)).toSeq) + + sqlTest( + "SELECT i FROM oneToTen WHERE i < 5", + (1 to 4).map(Row(_)).toSeq) + + sqlTest( + "SELECT i * 2 FROM oneToTen", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1", + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + + test("Caching") { + // Cached Query Execution + cacheTable("oneToTen") + assertCached(sql("SELECT * FROM oneToTen")) + checkAnswer( + sql("SELECT * FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT i FROM oneToTen")) + checkAnswer( + sql("SELECT i FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT i FROM oneToTen WHERE i < 5")) + checkAnswer( + sql("SELECT i FROM oneToTen WHERE i < 5"), + (1 to 4).map(Row(_)).toSeq) + + assertCached(sql("SELECT i * 2 FROM oneToTen")) + checkAnswer( + sql("SELECT i * 2 FROM oneToTen"), + (1 to 10).map(i => Row(i * 2)).toSeq) + + assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + checkAnswer( + sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + // Verify uncaching + uncacheTable("oneToTen") + assertCached(sql("SELECT * FROM oneToTen"), 0) + } + + test("defaultSource") { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenDef + |USING org.apache.spark.sql.sources + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM oneToTenDef"), + (1 to 10).map(Row(_)).toSeq) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fad4091d48a89..29693d04c8a40 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -95,7 +95,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { if (dialect == "sql") { super.sql(sqlText) } else if (dialect == "hiveql") { - new SchemaRDD(this, HiveQl.parseSql(sqlText)) + new SchemaRDD(this, ddlParser(sqlText).getOrElse(HiveQl.parseSql(sqlText))) } else { sys.error(s"Unsupported SQL dialect: $dialect. Try 'sql' or 'hiveql'") } From 2957875057c9d4e79c1aacf518e85716fa29245f Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 29 Oct 2014 19:22:14 -0700 Subject: [PATCH 02/17] add override --- .../main/scala/org/apache/spark/sql/execution/commands.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 53f2f58393a05..e658e6fc4d5d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -68,7 +68,7 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { override def output = cmd.output - def children = Nil + override def children = Nil override def executeCollect(): Array[Row] = sideEffectResult.toArray From 360cb30728f2b01acced3fc8b696f1f8ebcb5891 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 30 Oct 2014 13:39:44 -0700 Subject: [PATCH 03/17] style and java api --- .../scala/org/apache/spark/sql/api/java/JavaSQLContext.scala | 5 +++++ .../main/scala/org/apache/spark/sql/json/JSONRelation.scala | 2 +- .../scala/org/apache/spark/sql/sources/LogicalRelation.scala | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 082ae03eef03f..2b46dbd3a805f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation} import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} @@ -39,6 +40,10 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc)) + def baseRelationToSchemaRDD(baseRelation: BaseRelation): JavaSchemaRDD = { + new JavaSchemaRDD(sqlContext, LogicalRelation(baseRelation)) + } + /** * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index bf7ba34de8d66..0df655b94e4f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -49,4 +49,4 @@ private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)( override def buildScan() = JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord) -} \ No newline at end of file +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala index 338a555a4016b..26a8c4d3df06e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -53,4 +53,4 @@ private[sql] case class LogicalRelation(relation: BaseRelation) def newInstance() = LogicalRelation(relation).asInstanceOf[this.type] override def simpleString = s"Relation[${output.mkString(",")}] $relation" -} \ No newline at end of file +} From de3b68c4d28c7cf603c886b99476a2510ce307e8 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 30 Oct 2014 17:12:23 -0700 Subject: [PATCH 04/17] Remove empty file --- .../apache/spark/sql/sources/DataSource.scala | 21 ------------------- 1 file changed, 21 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/DataSource.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSource.scala deleted file mode 100644 index f03ff136c98b9..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSource.scala +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.sources - -import org.apache.spark.sql.SQLContext - - From 3e06776fa141c727081aa5ceb906efb2e6961c20 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 30 Oct 2014 17:24:58 -0700 Subject: [PATCH 05/17] remove line wraps --- .../scala/org/apache/spark/sql/json/JSONRelation.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 0df655b94e4f0..c15ce124d9614 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -25,11 +25,8 @@ private[sql] class DefaultSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - val fileName = - parameters.getOrElse("fileName", sys.error(s"Option 'fileName' not specified")) - - val samplingRatio = - parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + val fileName = parameters.getOrElse("fileName", sys.error("Option 'fileName' not specified")) + val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) JSONRelation(fileName, samplingRatio)(sqlContext) } From 0d74bcf965d36dc01e03ea28662c5cf18aec4b0b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 30 Oct 2014 17:26:54 -0700 Subject: [PATCH 06/17] Add documention on object life cycle --- .../src/main/scala/org/apache/spark/sql/sources/package.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala index 5b61df3e5d199..7ac558edfa0e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala @@ -34,6 +34,8 @@ package object sources { * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the * data source 'org.apache.spark.sql.json.DefaultSource' + * + * A new instance of this class with be instantiated each time a DDL call is made. */ trait RelationProvider { /** Returns a new base relation with the given parameters. */ From 34f836ad799bc470a58c2603ddc4fc2b5fadc121 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 30 Oct 2014 17:27:37 -0700 Subject: [PATCH 07/17] Make @DeveloperApi --- .../src/main/scala/org/apache/spark/sql/sources/package.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala index 7ac558edfa0e6..fac989de03de4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala @@ -37,6 +37,7 @@ package object sources { * * A new instance of this class with be instantiated each time a DDL call is made. */ + @DeveloperApi trait RelationProvider { /** Returns a new base relation with the given parameters. */ def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation From b06914625d709f0f1ad7a1be8ce6a2cff38644f1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 30 Oct 2014 22:09:08 -0700 Subject: [PATCH 08/17] traits => abstract classes --- .../org/apache/spark/sql/json/JSONRelation.scala | 2 +- .../apache/spark/sql/sources/LogicalRelation.scala | 2 +- .../scala/org/apache/spark/sql/sources/package.scala | 12 +++--------- .../apache/spark/sql/sources/FilteredScanSuite.scala | 2 +- .../apache/spark/sql/sources/PrunedScanSuite.scala | 2 +- .../apache/spark/sql/sources/TableScanSuite.scala | 2 +- 6 files changed, 8 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index c15ce124d9614..1ee10c11e94ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -34,7 +34,7 @@ private[sql] class DefaultSource extends RelationProvider { private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)( @transient val sqlContext: SQLContext) - extends BaseRelation with TableScan { + extends TableScan { private def baseRDD = sqlContext.sparkContext.textFile(fileName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala index 26a8c4d3df06e..b183b6573e32a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -29,7 +29,7 @@ private[sql] case class LogicalRelation(relation: BaseRelation) with LeafNode[LogicalPlan] with MultiInstanceRelation { - val output = relation.schema.toAttributes + override val output = relation.schema.toAttributes // Logical Relations are distinct if they have different output for the sake of transformations. override def equals(other: Any) = other match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala index fac989de03de4..96f71f222f7b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala @@ -62,9 +62,7 @@ package object sources { * Mixed into a BaseRelation that can produce all of its tuples as an RDD of Row objects. */ @DeveloperApi - trait TableScan { - self: BaseRelation => - + abstract class TableScan extends BaseRelation { def buildScan(): RDD[Row] } @@ -73,9 +71,7 @@ package object sources { * containing all of its tuples as Row objects. */ @DeveloperApi - trait PrunedScan { - self: BaseRelation => - + abstract class PrunedScan extends BaseRelation { def buildScan(requiredColumns: Seq[Attribute]): RDD[Row] } @@ -88,9 +84,7 @@ package object sources { * as filtering partitions based on a bloom filter. */ @DeveloperApi - trait FilteredScan { - self: BaseRelation => - + abstract class FilteredScan extends BaseRelation { def buildScan( requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 3de45ab60300b..0724e37a1572e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -29,7 +29,7 @@ class FilteredScanSource extends RelationProvider { } case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends BaseRelation with FilteredScan { + extends FilteredScan { override def schema = StructType( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 276d3815bf744..48caf502490d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -29,7 +29,7 @@ class PrunedScanSource extends RelationProvider { } case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends BaseRelation with PrunedScan { + extends PrunedScan { override def schema = StructType( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 5eb30827e7314..a665ac07a0c1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -31,7 +31,7 @@ class SimpleScanSource extends RelationProvider { } case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends BaseRelation with TableScan { + extends TableScan { override def schema = StructType(StructField("i", IntegerType, nullable = false) :: Nil) From 22963efddc59099a14d49228654db8ba0760f87b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 30 Oct 2014 22:17:36 -0700 Subject: [PATCH 09/17] package objects compile wierdly... --- .../apache/spark/sql/sources/interfaces.scala | 86 +++++++++++++++++++ .../apache/spark/sql/sources/package.scala | 72 +--------------- 2 files changed, 87 insertions(+), 71 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala new file mode 100644 index 0000000000000..6e829647f0dae --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -0,0 +1,86 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.spark.sql.sources + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} + +/** + * Implemented by objects that produce relations for a specific kind of data source. When + * Spark SQL is given a DDL operation with a USING clause specified, this interface is used to + * pass in the parameters specified by a user. + * + * Users may specify the fully qualified class name of a given data source. When that class is + * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for + * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the + * data source 'org.apache.spark.sql.json.DefaultSource' + * + * A new instance of this class with be instantiated each time a DDL call is made. + */ +@DeveloperApi +trait RelationProvider { + /** Returns a new base relation with the given parameters. */ + def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation +} + +/** + * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must + * be able to produce the schema of their data in the form of a [[StructType]] In order to be + * executed, a BaseRelation must also mix in at least one of the Scan traits. + * + * BaseRelations must also define a equality function that only returns true when the two + * instances will return the same data. This equality function is used when determining when + * it is safe to substitute cached results for a given relation. + */ +@DeveloperApi +abstract class BaseRelation { + def sqlContext: SQLContext + def schema: StructType +} + +/** + * Mixed into a BaseRelation that can produce all of its tuples as an RDD of Row objects. + */ +@DeveloperApi +abstract class TableScan extends BaseRelation { + def buildScan(): RDD[Row] +} + +/** + * Mixed into a BaseRelation that can eliminate unneeded columns before producing an RDD + * containing all of its tuples as Row objects. + */ +@DeveloperApi +abstract class PrunedScan extends BaseRelation { + def buildScan(requiredColumns: Seq[Attribute]): RDD[Row] +} + +/** + * Mixed into a BaseRelation that can eliminate unneeded columns and filter using selected + * predicates before producing an RDD containing all matching tuples as Row objects. + * + * The pushed down filters are currently purely an optimization as they will all be evaluated + * again. This means it is safe to use them with methods that produce false positives such + * as filtering partitions based on a bloom filter. + */ +@DeveloperApi +abstract class FilteredScan extends BaseRelation { + def buildScan( + requiredColumns: Seq[Attribute], + filters: Seq[Expression]): RDD[Row] +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala index 96f71f222f7b5..8393c510f4f6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala @@ -16,77 +16,7 @@ */ package org.apache.spark.sql -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} - /** * A set of APIs for adding data sources to Spark SQL. */ -package object sources { - - /** - * Implemented by objects that produce relations for a specific kind of data source. When - * Spark SQL is given a DDL operation with a USING clause specified, this interface is used to - * pass in the parameters specified by a user. - * - * Users may specify the fully qualified class name of a given data source. When that class is - * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for - * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the - * data source 'org.apache.spark.sql.json.DefaultSource' - * - * A new instance of this class with be instantiated each time a DDL call is made. - */ - @DeveloperApi - trait RelationProvider { - /** Returns a new base relation with the given parameters. */ - def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation - } - - /** - * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must - * be able to produce the schema of their data in the form of a [[StructType]] In order to be - * executed, a BaseRelation must also mix in at least one of the Scan traits. - * - * BaseRelations must also define a equality function that only returns true when the two - * instances will return the same data. This equality function is used when determining when - * it is safe to substitute cached results for a given relation. - */ - @DeveloperApi - abstract class BaseRelation { - def sqlContext: SQLContext - def schema: StructType - } - - /** - * Mixed into a BaseRelation that can produce all of its tuples as an RDD of Row objects. - */ - @DeveloperApi - abstract class TableScan extends BaseRelation { - def buildScan(): RDD[Row] - } - - /** - * Mixed into a BaseRelation that can eliminate unneeded columns before producing an RDD - * containing all of its tuples as Row objects. - */ - @DeveloperApi - abstract class PrunedScan extends BaseRelation { - def buildScan(requiredColumns: Seq[Attribute]): RDD[Row] - } - - /** - * Mixed into a BaseRelation that can eliminate unneeded columns and filter using selected - * predicates before producing an RDD containing all matching tuples as Row objects. - * - * The pushed down filters are currently purely an optimization as they will all be evaluated - * again. This means it is safe to use them with methods that produce false positives such - * as filtering partitions based on a bloom filter. - */ - @DeveloperApi - abstract class FilteredScan extends BaseRelation { - def buildScan( - requiredColumns: Seq[Attribute], - filters: Seq[Expression]): RDD[Row] - } -} +package object sources From 5545491aef2984e9563eb6c8ee0dbacec752820c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 31 Oct 2014 12:16:57 -0700 Subject: [PATCH 10/17] Address comments --- .../org/apache/spark/sql/json/JSONRelation.scala | 2 +- .../spark/sql/sources/LogicalRelation.scala | 8 +++----- .../scala/org/apache/spark/sql/sources/ddl.scala | 12 +++++------- .../apache/spark/sql/sources/interfaces.scala | 16 ++++++++-------- .../org/apache/spark/sql/hive/HiveContext.scala | 1 + 5 files changed, 18 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 1ee10c11e94ce..0e18ce63bfe50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.sources._ private[sql] class DefaultSource extends RelationProvider { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala index b183b6573e32a..82a2cf8402f8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Expression, Attribute} -import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LogicalPlan} -import org.apache.spark.sql.catalyst.trees.LeafNode +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LeafNode, LogicalPlan} /** * Used to link a [[BaseRelation]] in to a logical query plan. */ private[sql] case class LogicalRelation(relation: BaseRelation) - extends LogicalPlan - with LeafNode[LogicalPlan] + extends LeafNode with MultiInstanceRelation { override val output = relation.schema.toAttributes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 05463208c156d..9168ca2fc6fec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -72,15 +72,13 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") */ protected lazy val createTable: Parser[LogicalPlan] = - CREATE ~> TEMPORARY ~> TABLE ~> ident ~ - USING ~ className ~ - OPTIONS ~ options ^^ { - case tableName ~ _ ~ provider ~ _ ~ opts => + CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { + case tableName ~ provider ~ opts => CreateTableUsing(tableName, provider, opts) } protected lazy val options: Parser[Map[String, String]] = - "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} @@ -94,9 +92,9 @@ private[sql] case class CreateTableUsing( def run(sqlContext: SQLContext) = { val loader = Utils.getContextOrSparkClassLoader - val clazz: Class[_] = try Utils.getContextOrSparkClassLoader.loadClass(provider) catch { + val clazz: Class[_] = try loader.loadClass(provider) catch { case cnf: java.lang.ClassNotFoundException => - try Utils.getContextOrSparkClassLoader.loadClass(provider + ".DefaultSource") catch { + try loader.loadClass(provider + ".DefaultSource") catch { case cnf: java.lang.ClassNotFoundException => sys.error(s"Failed to load class for data source: $provider") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6e829647f0dae..18756a8da922e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, StructType} import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} /** @@ -40,8 +41,9 @@ trait RelationProvider { /** * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must - * be able to produce the schema of their data in the form of a [[StructType]] In order to be - * executed, a BaseRelation must also mix in at least one of the Scan traits. + * be able to produce the schema of their data in the form of a [[StructType]] Concrete + * implementation should inherit from one of the descendant `Scan` classes, which also define + * abstract methods for execution. * * BaseRelations must also define a equality function that only returns true when the two * instances will return the same data. This equality function is used when determining when @@ -54,7 +56,7 @@ abstract class BaseRelation { } /** - * Mixed into a BaseRelation that can produce all of its tuples as an RDD of Row objects. + * A BaseRelation that can produce all of its tuples as an RDD of Row objects. */ @DeveloperApi abstract class TableScan extends BaseRelation { @@ -62,7 +64,7 @@ abstract class TableScan extends BaseRelation { } /** - * Mixed into a BaseRelation that can eliminate unneeded columns before producing an RDD + * A BaseRelation that can eliminate unneeded columns before producing an RDD * containing all of its tuples as Row objects. */ @DeveloperApi @@ -71,7 +73,7 @@ abstract class PrunedScan extends BaseRelation { } /** - * Mixed into a BaseRelation that can eliminate unneeded columns and filter using selected + * A BaseRelation that can eliminate unneeded columns and filter using selected * predicates before producing an RDD containing all matching tuples as Row objects. * * The pushed down filters are currently purely an optimization as they will all be evaluated @@ -80,7 +82,5 @@ abstract class PrunedScan extends BaseRelation { */ @DeveloperApi abstract class FilteredScan extends BaseRelation { - def buildScan( - requiredColumns: Seq[Attribute], - filters: Seq[Expression]): RDD[Row] + def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] } \ No newline at end of file diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 29693d04c8a40..37ba92aca2f96 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -326,6 +326,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hiveContext = self override val strategies: Seq[Strategy] = Seq( + DataSources, CommandStrategy(self), HiveCommandStrategy(self), TakeOrdered, From 7d948aed0d0cc97a1a175f672ba2ca37f0c3e538 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 31 Oct 2014 14:27:58 -0700 Subject: [PATCH 11/17] Fix equality of AttributeReference. --- .../spark/sql/catalyst/expressions/namedExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index fe13a661f6f7a..4da1d894eacfd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -119,7 +119,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea extends Attribute with trees.LeafNode[Expression] { override def equals(other: Any) = other match { - case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType + case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType case _ => false } From 70da6d941b222385aa437bf7bfe6cabddad09cf8 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 31 Oct 2014 14:29:35 -0700 Subject: [PATCH 12/17] Modify API to ease binary compatibility and interop with Java --- .../sql/catalyst/planning/QueryPlanner.scala | 20 ++-- .../org/apache/spark/sql/SQLContext.scala | 4 +- .../spark/sql/execution/SparkStrategies.scala | 27 +----- .../scala/org/apache/spark/sql/package.scala | 8 ++ .../sql/sources/DataSourceStrategy.scala | 96 +++++++++++++++++++ .../apache/spark/sql/sources/filters.scala | 22 +++++ .../apache/spark/sql/sources/interfaces.scala | 8 +- .../spark/sql/sources/FilteredScanSuite.scala | 52 ++++++++-- .../spark/sql/sources/PrunedScanSuite.scala | 5 +- .../spark/sql/sources/TableScanSuite.scala | 1 - .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../spark/sql/hive/HiveStrategies.scala | 2 +- 12 files changed, 194 insertions(+), 55 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 5839c9f7c43ef..51b5699affed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -21,6 +21,15 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode +/** + * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can + * be used for execution. If this strategy does not apply to the give logical operation then an + * empty list should be returned. + */ +abstract class GenericStrategy[PhysicalPlan <: TreeNode[PhysicalPlan]] extends Logging { + def apply(plan: LogicalPlan): Seq[PhysicalPlan] +} + /** * Abstract class for transforming [[plans.logical.LogicalPlan LogicalPlan]]s into physical plans. * Child classes are responsible for specifying a list of [[Strategy]] objects that each of which @@ -35,16 +44,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNode */ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { /** A list of execution strategies that can be used by the planner */ - def strategies: Seq[Strategy] - - /** - * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can - * be used for execution. If this strategy does not apply to the give logical operation then an - * empty list should be returned. - */ - abstract protected class Strategy extends Logging { - def apply(plan: LogicalPlan): Seq[PhysicalPlan] - } + def strategies: Seq[GenericStrategy[PhysicalPlan]] /** * Returns a placeholder for a physical plan that executes `plan`. This placeholder will be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c4a564d5c514d..99069e10c0a32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{BaseRelation, DDLParser, LogicalRelation} +import org.apache.spark.sql.sources.{DataSourceStrategy, BaseRelation, DDLParser, LogicalRelation} /** * :: AlphaComponent :: @@ -306,7 +306,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val strategies: Seq[Strategy] = CommandStrategy(self) :: - DataSources :: + DataSourceStrategy :: TakeOrdered :: HashAggregation :: LeftSemiJoin :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index efbcbebe6e66f..2cd3063bc3097 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{sources, SQLContext, execution} +import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -252,31 +252,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object DataSources extends Strategy { - import sources._ - - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FilteredScan)) => - pruneFilterProject( - projectList, - filters, - identity[Seq[Expression]], // All filters still need to be evaluated - a => PhysicalRDD(a, t.buildScan(a.map(l.attributeMap), filters))) :: Nil - - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedScan)) => - pruneFilterProject( - projectList, - filters, - identity[Seq[Expression]], // All filters still need to be evaluated. - a => PhysicalRDD(a, t.buildScan(a.map(l.attributeMap)))) :: Nil - - case l @ LogicalRelation(t: TableScan) => - PhysicalRDD(l.output, t.buildScan()) :: Nil - - case _ => Nil - } - } - // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def numPartitions = self.numPartitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index e98d151286818..fec27463f4820 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.SparkPlan /** * Allows the execution of relational queries, including those expressed in SQL using Spark. @@ -414,4 +415,11 @@ package object sql { */ @DeveloperApi val StructField = catalyst.types.StructField + + /** + * Converts a logical plan into zero or more SparkPlans. + */ + @DeveloperApi + type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala new file mode 100644 index 0000000000000..dae3720287087 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan + +/** + * A Strategy for planning scans over data sources defined using the sources API. + */ +private[sql] object DataSourceStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FilteredScan)) => + pruneFilterProject( + l, + projectList, + filters, + (a, f) => t.buildScan(a, f)) :: Nil + + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedScan)) => + pruneFilterProject( + l, + projectList, + filters, + (a, _) => t.buildScan(a)) :: Nil + + case l @ LogicalRelation(t: TableScan) => + execution.PhysicalRDD(l.output, t.buildScan()) :: Nil + + case _ => Nil + } + + protected def pruneFilterProject( + relation: LogicalRelation, + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + scanBuilder: (Array[String], Array[Filter]) => RDD[Row]) = { + + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) + val filterCondition = filterPredicates.reduceLeftOption(And) + + val pushedFilters = selectFilters(filterPredicates.map { _ transform { + case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. + }}).toArray + + if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && + filterSet.subsetOf(projectSet)) { + // When it is possible to just use column pruning to get the right projection and + // when the columns of this projection are enough to evaluate all filter conditions, + // just do a scan followed by a filter, with no extra project. + val requestedColumns = + projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. + .map(relation.attributeMap) // Match original case of attributes. + .map(_.name) + .toArray + + val scan = + execution.PhysicalRDD( + projectList.map(_.toAttribute), + scanBuilder(requestedColumns, pushedFilters)) + filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) + } else { + val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq + val columnNames = requestedColumns.map(_.name).toArray + + val scan = execution.PhysicalRDD(requestedColumns, scanBuilder(columnNames, pushedFilters)) + execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) + } + } + + protected def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { + case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala new file mode 100644 index 0000000000000..17c978a304764 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +abstract sealed class Filter + +case class EqualTo(attribute: String, value: Any) extends Filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 18756a8da922e..82a6075b223c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -42,7 +42,7 @@ trait RelationProvider { /** * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must * be able to produce the schema of their data in the form of a [[StructType]] Concrete - * implementation should inherit from one of the descendant `Scan` classes, which also define + * implementation should inherit from one of the descendant `Scan` classes, which define various * abstract methods for execution. * * BaseRelations must also define a equality function that only returns true when the two @@ -69,7 +69,7 @@ abstract class TableScan extends BaseRelation { */ @DeveloperApi abstract class PrunedScan extends BaseRelation { - def buildScan(requiredColumns: Seq[Attribute]): RDD[Row] + def buildScan(requiredColumns: Array[String]): RDD[Row] } /** @@ -82,5 +82,5 @@ abstract class PrunedScan extends BaseRelation { */ @DeveloperApi abstract class FilteredScan extends BaseRelation { - def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] -} \ No newline at end of file + def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 0724e37a1572e..db62e150e4e61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.sources -import org.apache.spark.sql.catalyst.expressions.{Row => _, _} +import scala.language.existentials + import org.apache.spark.sql._ class FilteredScanSource extends RelationProvider { @@ -36,17 +37,16 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL StructField("a", IntegerType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: Nil) - override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]) = { - val rowBuilders = requiredColumns.map(_.name).map { + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { + val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) } + FiltersPushed.list = filters + val filter = filters.collect { - case Seq(EqualTo(a: AttributeReference, l: Literal)) if a.name == "a" => - (i: Int) => i == l.value - case Seq(EqualTo(l: Literal, a: AttributeReference)) if a.name == "a" => - (i: Int) => i == l.value + case EqualTo("a", v) => (a: Int) => a == v }.headOption.getOrElse((_: Int) => true) sqlContext.sparkContext.parallelize(from to to).filter(filter).map(i => @@ -54,6 +54,11 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL } } +// A hack for better error messages when filter pushdown fails. +object FiltersPushed { + var list: Seq[Filter] = Nil +} + class FilteredScanSuite extends DataSourceTest { import caseInsensisitiveContext._ @@ -110,8 +115,41 @@ class FilteredScanSuite extends DataSourceTest { "SELECT * FROM oneToTenFiltered WHERE a = 1", Seq(1).map(i => Row(i, i * 2)).toSeq) + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE A = 1", + Seq(1).map(i => Row(i, i * 2)).toSeq) + sqlTest( "SELECT * FROM oneToTenFiltered WHERE b = 2", Seq(1).map(i => Row(i, i * 2)).toSeq) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5", 10) + testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) + + def testPushDown(sqlString: String, expectedCount: Int): Unit = { + test(s"PushDown Returns $expectedCount: $sqlString") { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawCount = rawPlan.execute().count() + + if (rawCount != expectedCount) { + fail( + s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 48caf502490d5..a7fac84636a1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.sources -import org.apache.spark.sql.catalyst.expressions.{Row => _, _} import org.apache.spark.sql._ class PrunedScanSource extends RelationProvider { @@ -36,8 +35,8 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo StructField("a", IntegerType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: Nil) - override def buildScan(requiredColumns: Seq[Attribute]) = { - val rowBuilders = requiredColumns.map(_.name).map { + override def buildScan(requiredColumns: Array[String]) = { + val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index a665ac07a0c1b..b254b0620c779 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.sources -import org.apache.spark.sql.catalyst.expressions.{Row => _, _} import org.apache.spark.sql._ class DefaultSource extends SimpleScanSource diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 37ba92aca2f96..91b16803ef5ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -21,6 +21,8 @@ import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.sql.{Date, Timestamp} import java.util.{ArrayList => JArrayList} +import org.apache.spark.sql.sources.DataSourceStrategy + import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} @@ -326,7 +328,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hiveContext = self override val strategies: Seq[Strategy] = Seq( - DataSources, + DataSourceStrategy, CommandStrategy(self), HiveCommandStrategy(self), TakeOrdered, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 3207ad81d9571..989740c8d43b6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan} import org.apache.spark.sql.hive import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy} import scala.collection.JavaConversions._ From a70d602c445d8df21cdfba9f4499d6df36d48ada Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 31 Oct 2014 15:01:05 -0700 Subject: [PATCH 13/17] Fix style, more tests, FilteredSuite => PrunedFilteredSuite --- .../sql/sources/DataSourceStrategy.scala | 7 ++-- .../apache/spark/sql/sources/interfaces.scala | 2 +- .../spark/sql/sources/FilteredScanSuite.scala | 6 ++-- .../spark/sql/sources/PrunedScanSuite.scala | 36 +++++++++++++++++++ 4 files changed, 44 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index dae3720287087..08fbeb99282d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.SparkPlan */ private[sql] object DataSourceStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FilteredScan)) => + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedFilteredScan)) => pruneFilterProject( l, projectList, @@ -65,7 +65,8 @@ private[sql] object DataSourceStrategy extends Strategy { case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. }}).toArray - if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && + if (projectList.map(_.toAttribute) == projectList && + projectSet.size == projectList.size && filterSet.subsetOf(projectSet)) { // When it is possible to just use column pruning to get the right projection and // when the columns of this projection are enough to evaluate all filter conditions, @@ -93,4 +94,4 @@ private[sql] object DataSourceStrategy extends Strategy { protected def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) } -} \ No newline at end of file +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 82a6075b223c3..ac3bf9d8e1a21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -81,6 +81,6 @@ abstract class PrunedScan extends BaseRelation { * as filtering partitions based on a bloom filter. */ @DeveloperApi -abstract class FilteredScan extends BaseRelation { +abstract class PrunedFilteredScan extends BaseRelation { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index db62e150e4e61..b20445eb8bc1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -30,7 +30,7 @@ class FilteredScanSource extends RelationProvider { } case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends FilteredScan { + extends PrunedFilteredScan { override def schema = StructType( @@ -146,8 +146,8 @@ class FilteredScanSuite extends DataSourceTest { if (rawCount != expectedCount) { fail( s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" + - s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + - queryExecution) + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index a7fac84636a1d..fee2e22611cdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -77,6 +77,10 @@ class PrunedScanSuite extends DataSourceTest { "SELECT a FROM oneToTenPruned", (1 to 10).map(i => Row(i)).toSeq) + sqlTest( + "SELECT a, a FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i)).toSeq) + sqlTest( "SELECT b FROM oneToTenPruned", (1 to 10).map(i => Row(i * 2)).toSeq) @@ -97,5 +101,37 @@ class PrunedScanSuite extends DataSourceTest { "SELECT x.a, y.b FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", (2 to 10 by 2).map(i => Row(i, i)).toSeq) + testPruning("SELECT * FROM oneToTenPruned", "a", "b") + testPruning("SELECT a, b FROM oneToTenPruned", "a", "b") + testPruning("SELECT b, a FROM oneToTenPruned", "b", "a") + testPruning("SELECT b, b FROM oneToTenPruned", "b") + testPruning("SELECT a FROM oneToTenPruned", "a") + testPruning("SELECT b FROM oneToTenPruned", "b") + + def testPruning(sqlString: String, expectedColumns: String*): Unit = { + test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawColumns = rawPlan.output.map(_.name) + val rawOutput = rawPlan.execute().first() + + if (rawColumns != expectedColumns) { + fail( + s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + + if (rawOutput.size != expectedColumns.size) { + fail(s"Wrong output row. Got $rawOutput\n$queryExecution") + } + } + } + } From e3e690e28c31e1cf4852f7d484fb9fcdb7fa0ea1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 31 Oct 2014 15:17:01 -0700 Subject: [PATCH 14/17] Add hook for extraStrategies --- .../main/scala/org/apache/spark/sql/SQLContext.scala | 11 ++++++++++- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 99069e10c0a32..314198347f96d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -295,6 +295,14 @@ class SQLContext(@transient val sparkContext: SparkContext) def table(tableName: String): SchemaRDD = new SchemaRDD(this, catalog.lookupRelation(None, tableName)) + /** + * :: DeveloperApi :: + * Allows extra strategies to be injected into the query planner at runtime. Note this API + * should be consider experimental and is not intended to be stable across releases. + */ + @DeveloperApi + var extraStrategies: Seq[Strategy] = Nil + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext @@ -305,6 +313,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def numPartitions = self.numShufflePartitions val strategies: Seq[Strategy] = + extraStrategies ++ ( CommandStrategy(self) :: DataSourceStrategy :: TakeOrdered :: @@ -315,7 +324,7 @@ class SQLContext(@transient val sparkContext: SparkContext) ParquetOperations :: BasicOperators :: CartesianProduct :: - BroadcastNestedLoopJoin :: Nil + BroadcastNestedLoopJoin :: Nil) /** * Used to build table scan operators where complex projection and filtering are done using diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 91b16803ef5ea..dafd05943c8a3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -327,7 +327,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hivePlanner = new SparkPlanner with HiveStrategies { val hiveContext = self - override val strategies: Seq[Strategy] = Seq( + override val strategies: Seq[Strategy] = extraStrategies ++ Seq( DataSourceStrategy, CommandStrategy(self), HiveCommandStrategy(self), From 5b4790168910a84e1926f2d638d871770151a62e Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 1 Nov 2014 16:11:04 -0700 Subject: [PATCH 15/17] Remove sealed, more filter types --- .../sql/sources/DataSourceStrategy.scala | 15 ++++++++++ .../apache/spark/sql/sources/filters.scala | 6 +++- .../spark/sql/sources/FilteredScanSuite.scala | 29 ++++++++++++++++--- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 08fbeb99282d7..9b8c6a56b94b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -93,5 +93,20 @@ private[sql] object DataSourceStrategy extends Strategy { protected def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) + case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v) + + case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v) + case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v) + + case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v) + case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + GreaterThanOrEqual(a.name, v) + case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + LessThanOrEqual(a.name, v) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 17c978a304764..e72a2aeb8f310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.sources -abstract sealed class Filter +abstract class Filter case class EqualTo(attribute: String, value: Any) extends Filter +case class GreaterThan(attribute: String, value: Any) extends Filter +case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter +case class LessThan(attribute: String, value: Any) extends Filter +case class LessThanOrEqual(attribute: String, value: Any) extends Filter diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index b20445eb8bc1d..8b2f1591d5bf3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -45,11 +45,17 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL FiltersPushed.list = filters - val filter = filters.collect { + val filterFunctions = filters.collect { case EqualTo("a", v) => (a: Int) => a == v - }.headOption.getOrElse((_: Int) => true) + case LessThan("a", v: Int) => (a: Int) => a < v + case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v + case GreaterThan("a", v: Int) => (a: Int) => a > v + case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v + } + + def eval(a: Int) = !filterFunctions.map(_(a)).contains(false) - sqlContext.sparkContext.parallelize(from to to).filter(filter).map(i => + sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) } } @@ -128,8 +134,23 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5", 10) testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) def testPushDown(sqlString: String, expectedCount: Int): Unit = { From 1d41bb50e8f5bc889b9019716462db8358be78f8 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 1 Nov 2014 16:13:38 -0700 Subject: [PATCH 16/17] unify argument names --- .../src/main/scala/org/apache/spark/sql/json/JSONRelation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 0e18ce63bfe50..fc70c183437f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -25,7 +25,7 @@ private[sql] class DefaultSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - val fileName = parameters.getOrElse("fileName", sys.error("Option 'fileName' not specified")) + val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) JSONRelation(fileName, samplingRatio)(sqlContext) From ab2c31f86352580b8fce64a397d4f07c1b41df59 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 1 Nov 2014 16:42:39 -0700 Subject: [PATCH 17/17] fix test --- .../src/test/scala/org/apache/spark/sql/json/JsonSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index cb01345c3c282..f5a12a55a60e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -555,7 +555,7 @@ class JsonSuite extends QueryTest { |CREATE TEMPORARY TABLE jsonTableSQL |USING org.apache.spark.sql.json |OPTIONS ( - | fileName '$path' + | path '$path' |) """.stripMargin)