From a2b5408d2635e62738741cb5dbeaa5159487d6d2 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Jul 2014 22:04:20 -0700 Subject: [PATCH 01/28] WIP: Code generation with scala reflection. --- project/SparkBuild.scala | 3 + .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../catalyst/expressions/BoundAttribute.scala | 1 + .../catalyst/expressions/GeneratedRow.scala | 26 + .../sql/catalyst/expressions/Projection.scala | 19 +- .../spark/sql/catalyst/expressions/Row.scala | 28 ++ .../sql/catalyst/expressions/ScalaUdf.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 444 ++++++++++++++++++ .../codegen/GenerateMutableProjection.scala | 83 ++++ .../codegen/GenerateOrdering.scala | 84 ++++ .../codegen/GeneratePredicate.scala | 55 +++ .../codegen/GenerateProjection.scala | 224 +++++++++ .../expressions/codegen/package.scala | 82 ++++ .../sql/catalyst/expressions/package.scala | 12 +- .../expressions/stringOperations.scala | 10 + .../sql/catalyst/rules/RuleExecutor.scala | 4 +- .../spark/sql/catalyst/types/dataTypes.scala | 16 +- .../sql/catalyst/CodeGenerationSuite.scala | 0 .../ExpressionEvaluationSuite.scala | 39 +- .../GeneratedEvaluationSuite.scala | 112 +++++ .../org/apache/spark/sql/SQLContext.scala | 2 + .../spark/sql/execution/Aggregate.scala | 7 +- .../apache/spark/sql/execution/Exchange.scala | 6 +- .../apache/spark/sql/execution/Generate.scala | 5 +- .../spark/sql/execution/SparkPlan.scala | 18 +- .../spark/sql/execution/SparkStrategies.scala | 14 +- .../spark/sql/execution/aggregates.scala | 175 +++++++ .../spark/sql/execution/basicOperators.scala | 15 +- .../apache/spark/sql/execution/joins.scala | 27 +- .../sql/parquet/ParquetTableOperations.scala | 1 + .../org/apache/spark/sql/hive/HiveQl.scala | 6 +- .../hive/execution/ScriptTransformation.scala | 2 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 4 +- .../sql/hive/execution/HiveQuerySuite.scala | 26 +- 35 files changed, 1493 insertions(+), 64 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GeneratedRow.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CodeGenerationSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 599714233c18..d5f571cf94c0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -495,6 +495,9 @@ object SparkBuild extends Build { // assumptions about the the expression ids being contiguous. Running tests in parallel breaks // this non-deterministically. TODO: FIX THIS. parallelExecution in Test := false, + addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.0-M8" cross CrossVersion.full), + libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v ), + libraryDependencies += "org.scalamacros" %% "quasiquotes" % "2.0.0-M8", libraryDependencies ++= Seq( "com.typesafe" %% "scalalogging-slf4j" % "1.0.1" ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c7188469bfb8..9571a57e8f5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -158,8 +158,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool */ object ImplicitGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Project(Seq(Alias(g: Generator, _)), child) => - Generate(g, join = false, outer = false, None, child) + case Project(Seq(Alias(g: Generator, alias)), child) => + Generate(g, join = false, outer = false, Some(alias), child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 76ddeba9cb31..8e011acaab7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -259,7 +259,7 @@ trait HiveTypeCoercion { case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) // Turn true into 1, and false into 0 if casting boolean into other types. case Cast(e, dataType) if e.dataType == BooleanType => - Cast(If(e, Literal(1), Literal(0)), dataType) + If(e, Cast(Literal(1), dataType), Cast(Literal(0), dataType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 9ce1f0105646..5d6c95eba391 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -62,6 +62,7 @@ class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { plan.transform { case n: NoBind => n.asInstanceOf[TreeNode] case leafNode if leafNode.children.isEmpty => leafNode + case nb: NoBind => nb.asInstanceOf[TreeNode] case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e => bindReference(e, unaryNode.children.head.output) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GeneratedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GeneratedRow.scala new file mode 100644 index 000000000000..2147921d0d9d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GeneratedRow.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst._ + +import org.apache.spark.sql.catalyst.types._ + +object CodeGeneration + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 2c71d2c7b356..53cae54eb8e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -22,7 +22,7 @@ package org.apache.spark.sql.catalyst.expressions * new row. If the schema of the input row is specified, then the given expression will be bound to * that schema. */ -class Projection(expressions: Seq[Expression]) extends (Row => Row) { +class InterpretedProjection(expressions: Seq[Expression]) extends (Row => Row) { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) @@ -40,7 +40,7 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) { } /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of th + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the * new row. If the schema of the input row is specified, then the given expression will be bound to * that schema. * @@ -50,14 +50,19 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) { * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` * and hold on to the returned [[Row]] before calling `next()`. */ -case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) { +case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) private[this] val exprArray = expressions.toArray - private[this] val mutableRow = new GenericMutableRow(exprArray.size) + private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size) def currentValue: Row = mutableRow + def target(row: MutableRow): MutableProjection = { + mutableRow = row + this + } + def apply(input: Row): Row = { var i = 0 while (i < exprArray.length) { @@ -76,6 +81,12 @@ class JoinedRow extends Row { private[this] var row1: Row = _ private[this] var row2: Row = _ + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ def apply(r1: Row, r2: Row): Row = { row1 = r1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 74ae723686cf..1a2ac7285b98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -180,6 +180,34 @@ class GenericRow(protected[catalyst] val values: Array[Any]) extends Row { values(i).asInstanceOf[String] } + override def hashCode(): Int = { + var result: Int = 37 + + var i = 0 + while (i < values.length) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + apply(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } + def copy() = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 5e089f7618e0..acddf5e9c700 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -29,6 +29,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi override def eval(input: Row): Any = { children.size match { + case 0 => function.asInstanceOf[() => Any]() case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) case 2 => function.asInstanceOf[(Any, Any) => Any]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala new file mode 100644 index 000000000000..a7afd3932dfc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.language.existentials + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + +/** + * A base class for generators of byte code that performs expression evaluation. Includes helpers + * for refering to Catalyst types and building trees that perform evaluation of individual + * expressions. + */ +abstract class CodeGenerator extends Logging { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + import scala.tools.reflect.ToolBox + + val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() + + // TODO: Use typetags? + val rowType = tq"org.apache.spark.sql.catalyst.expressions.Row" + val mutableRowType = tq"org.apache.spark.sql.catalyst.expressions.MutableRow" + val genericRowType = tq"org.apache.spark.sql.catalyst.expressions.GenericRow" + val genericMutableRowType = tq"org.apache.spark.sql.catalyst.expressions.GenericMutableRow" + + val projectionType = tq"org.apache.spark.sql.catalyst.expressions.Projection" + val mutableProjectionType = tq"org.apache.spark.sql.catalyst.expressions.MutableProjection" + + private val curId = new java.util.concurrent.atomic.AtomicInteger() + private val javaSeperator = "$" + + /** + * Returns a term name that is unique within this instance of a `CodeGenerator`. + * + * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` + * function.) + */ + protected def freshName(prefix: String): TermName = { + newTermName(s"$prefix$javaSeperator${curId.getAndIncrement}") + } + + /** + * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input. + * + * @param code The sequence of statements required to evaluate the expression. + * @param nullTerm A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not + * valid if `nullTerm` is set to `false`. + * @param objectTerm An possibly boxed version of the result of evaluating this expression. + */ + protected case class EvaluatedExpression( + code: Seq[Tree], + nullTerm: TermName, + primitiveTerm: TermName, + objectTerm: TermName) { + + def withObjectTerm = ??? + } + + /** + * Given an expression tree returns the code required to determine both if the result is NULL + * as well as the code required to compute the value. + */ + def expressionEvaluator(e: Expression): EvaluatedExpression = { + val primitiveTerm = freshName("primitiveTerm") + val nullTerm = freshName("nullTerm") + val objectTerm = freshName("objectTerm") + + implicit class Evaluate1(e: Expression) { + def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = { + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(dataType)} + else + ${f(eval.primitiveTerm)} + """.children + } + } + + implicit class Evaluate2(expressions: (Expression, Expression)) { + + /** + * Short hand for generating binary evaluation code, which depends on two sub-evaluations of + * the same type. If either of the sub-expressions is null, the results of this computation + * is assumed to be null. + * + * @param f a function from two primitive term names to a tree that evaluates them. + */ + def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] = + evaluateAs(expressions._1.dataType)(f) + + def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { + require(expressions._1.dataType == expressions._2.dataType, + s"${expressions._1.dataType} != ${expressions._2.dataType}") + + val eval1 = expressionEvaluator(expressions._1) + val eval2 = expressionEvaluator(expressions._2) + val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) + + eval1.code ++ eval2.code ++ + q""" + val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm} + val $primitiveTerm: ${termForType(resultType)} = + if($nullTerm) { + ${defaultPrimitive(resultType)} + } else { + $resultCode.asInstanceOf[${termForType(resultType)}] + } + """.children : Seq[Tree] + } + } + + val inputTuple = newTermName(s"i") + + // TODO: Skip generation of null handling code when expression are not nullable. + val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { + case b @ BoundReference(ordinal, _) => + q""" + val $nullTerm: Boolean = $inputTuple.isNullAt($ordinal) + val $primitiveTerm: ${termForType(b.dataType)} = + if($nullTerm) + ${defaultPrimitive(e.dataType)} + else + ${getColumn(inputTuple, b.dataType, ordinal)} + """.children + + case expressions.Literal(null, dataType) => + q""" + val $nullTerm = true + val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}] + """.children + + case expressions.Literal(value: Boolean, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: String, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + case expressions.Literal(value: Int, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + case expressions.Literal(value: Long, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case Cast(e @ BinaryType(), StringType) => + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(StringType)} + else + new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + """.children + + case Cast(child @ NumericType(), IntegerType) => + child.castOrNull(c => q"$c.toInt", IntegerType) + + case Cast(child @ NumericType(), LongType) => + child.castOrNull(c => q"$c.toLong", LongType) + + case Cast(child @ NumericType(), DoubleType) => + child.castOrNull(c => q"$c.toDouble", DoubleType) + + case Cast(child @ NumericType(), FloatType) => + child.castOrNull(c => q"$c.toFloat", IntegerType) + + case Cast(e, StringType) => + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(StringType)} + else + ${eval.primitiveTerm}.toString + """.children + + case EqualTo(e1, e2) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } + + case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) => + val eval = expressionEvaluator(e1) + + val checks = list.map { + case expressions.Literal(v: String, dataType) => + q"if(${eval.primitiveTerm} == $v) return true" + case expressions.Literal(v: Int, dataType) => + q"if(${eval.primitiveTerm} == $v) return true" + } + + val funcName = newTermName(s"isIn${curId.getAndIncrement()}") + + q""" + def $funcName: Boolean = { + ..${eval.code} + if(${eval.nullTerm}) return false + ..$checks + return false + } + val $nullTerm = false + val $primitiveTerm = $funcName + """.children + + case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } + case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" } + case LessThan(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" } + case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" } + + case And(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false + + if ((!${eval1.nullTerm} && !${eval1.primitiveTerm}) || + (!${eval2.nullTerm} && !${eval2.primitiveTerm})) { + $nullTerm = false + $primitiveTerm = false + } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else { + $nullTerm = false + $primitiveTerm = true + } + """.children + + case Or(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false + + if ((!${eval1.nullTerm} && ${eval1.primitiveTerm}) || + (!${eval2.nullTerm} && ${eval2.primitiveTerm})) { + $nullTerm = false + $primitiveTerm = true + } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else { + $nullTerm = false + $primitiveTerm = false + } + """.children + + case Not(child) => + // Uh, bad function name... + child.castOrNull(c => q"!$c", BooleanType) + + case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } + case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } + case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } + case Divide(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 / $eval2" } + + case IsNotNull(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm} + """.children + + case IsNull(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm} + """.children + + case c @ Coalesce(children) => + q""" + var $nullTerm = true + var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)} + """.children ++ + children.map { c => + val eval = expressionEvaluator(c) + q""" + if($nullTerm) { + ..${eval.code} + if(!${eval.nullTerm}) { + $nullTerm = false + $primitiveTerm = ${eval.primitiveTerm} + } + } + """ + } + + case i @ expressions.If(condition, trueValue, falseValue) => + val condEval = expressionEvaluator(condition) + val trueEval = expressionEvaluator(trueValue) + val falseEval = expressionEvaluator(falseValue) + + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)} + ..${condEval.code} + if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { + ..${trueEval.code} + $nullTerm = ${trueEval.nullTerm} + $primitiveTerm = ${trueEval.primitiveTerm} + } else { + ..${falseEval.code} + $nullTerm = ${falseEval.nullTerm} + $primitiveTerm = ${falseEval.primitiveTerm} + } + """.children + + case SubString(str, start, end) => + val stringEval = expressionEvaluator(str) + val startEval = expressionEvaluator(start) + val endEval = expressionEvaluator(end) + + stringEval.code ++ startEval.code ++ endEval.code ++ + q""" + var $nullTerm = ${stringEval.nullTerm} + var $primitiveTerm: String = + if($nullTerm) { + null + } else { + val len = + if(${endEval.primitiveTerm} <= ${stringEval.primitiveTerm}.length) + ${endEval.primitiveTerm} + else + ${stringEval.primitiveTerm}.length + ${stringEval.primitiveTerm}.substring(${startEval.primitiveTerm}, len) + } + """.children + } + + // If there was no match in the partial function above, we fall back on calling the interpreted + // expression evaluator. + val code: Seq[Tree] = + primitiveEvaluation.lift.apply(e) + .getOrElse { + log.debug(s"No rules to generate $e") + val tree = reify { e } + q""" + val $objectTerm = $tree.eval(i) + val $nullTerm = $objectTerm == null + val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] + """.children + } + + EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) + } + + protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { + dataType match { + case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" + case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" + } + } + + protected def setColumn( + destinationRow: TermName, + dataType: DataType, + ordinal: Int, + value: TermName) = { + dataType match { + case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case _ => q"$destinationRow.update($ordinal, $value)" + } + } + + protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") + protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") + + protected def primitiveForType(dt: DataType) = dt match { + case IntegerType => "Int" + case LongType => "Long" + case ShortType => "Short" + case ByteType => "Byte" + case DoubleType => "Double" + case FloatType => "Float" + case BooleanType => "Boolean" + case StringType => "String" + } + + protected def defaultPrimitive(dt: DataType) = dt match { + case BooleanType => ru.Literal(Constant(false)) + case FloatType => ru.Literal(Constant(-1.0.toFloat)) + case StringType => ru.Literal(Constant("")) + case ShortType => ru.Literal(Constant(-1.toShort)) + case LongType => ru.Literal(Constant(1L)) + case ByteType => ru.Literal(Constant(-1.toByte)) + case DoubleType => ru.Literal(Constant(-1.toDouble)) + case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed. + case IntegerType => ru.Literal(Constant(-1)) + case _ => ru.Literal(Constant(null)) + } + + protected def termForType(dt: DataType) = dt match { + case n: NativeType => n.tag + case _ => typeTag[Any] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala new file mode 100644 index 000000000000..09e9af387c9b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new + * input [[Row]] for a fixed set of [[Expression Expressions]]. + */ +object GenerateMutableProjection extends CodeGenerator { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + // TODO: Should be weak references... bounded in size. + val projectionCache = new collection.mutable.HashMap[Seq[Expression], () => MutableProjection] + + def apply(expressions: Seq[Expression], inputSchema: Seq[Attribute]): (() => MutableProjection) = + apply(expressions.map(BindReferences.bindReference(_, inputSchema))) + + // TODO: Safe to fire up multiple instances of the compiler? + def apply(expressions: Seq[Expression]): () => MutableProjection = + CodeGeneration.synchronized { + val cleanedExpressions = expressions.map(ExpressionCanonicalizer(_)) + projectionCache.getOrElseUpdate(cleanedExpressions, createProjection(cleanedExpressions)) + } + + val mutableRowName = newTermName("mutableRow") + + def createProjection(expressions: Seq[Expression]): (() => MutableProjection) = { + val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) => + val evaluationCode = expressionEvaluator(e) + + evaluationCode.code :+ + q""" + if(${evaluationCode.nullTerm}) + mutableRow.setNullAt($i) + else + ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)} + """ + } + + val code = + q""" + () => { new $mutableProjectionType { + + private[this] var $mutableRowName: $mutableRowType = + new $genericMutableRowType(${expressions.size}) + + def target(row: $mutableRowType): $mutableProjectionType = { + $mutableRowName = row + this + } + + /* Provide immutable access to the last projected row. */ + def currentValue: $rowType = mutableRow + + def apply(i: $rowType): $rowType = { + ..$projectionCode + mutableRow + } + } } + """ + + log.debug(s"code for ${expressions.mkString(",")}:\n$code") + toolBox.eval(code).asInstanceOf[() => MutableProjection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala new file mode 100644 index 000000000000..27cb6779ade1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of + * [[Expression Expressions]]. + */ +object GenerateOrdering extends CodeGenerator { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + // TODO: Should be weak references... bounded in size. + val orderingCache = new collection.mutable.HashMap[Seq[SortOrder], Ordering[Row]] + + // TODO: Safe to fire up multiple instances of the compiler? + def apply(ordering: Seq[SortOrder]): Ordering[Row] = CodeGeneration.synchronized { + val cleanedExpression = ordering.map(ExpressionCanonicalizer(_)).asInstanceOf[Seq[SortOrder]] + orderingCache.getOrElseUpdate(cleanedExpression, createOrdering(cleanedExpression)) + } + + def createOrdering(ordering: Seq[SortOrder]): Ordering[Row] = { + val a = newTermName("a") + val b = newTermName("b") + val comparisons = ordering.zipWithIndex.map { case (order, i) => + val evalA = expressionEvaluator(order.child) + val evalB = expressionEvaluator(order.child) + + q""" + i = $a + ..${evalA.code} + i = $b + ..${evalB.code} + if (${evalA.nullTerm} && ${evalB.nullTerm}) { + // Nothing + } else if (${evalA.nullTerm}) { + return ${if (order.direction == Ascending) q"-1" else q"1"} + } else if (${evalB.nullTerm}) { + return ${if (order.direction == Ascending) q"1" else q"-1"} + } else { + i = a + val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} + if(comp != 0) return comp.toInt + } + """ + } + + val q"class $orderingName extends $orderingType { ..$body }" = reify { + class SpecificOrdering extends Ordering[Row] { + val o = ordering + } + }.tree.children.head + + val code = q""" + class $orderingName extends $orderingType { + ..$body + def compare(a: $rowType, b: $rowType): Int = { + var i: $rowType = null // Holds current row being evaluated. + ..$comparisons + return 0 + } + } + new $orderingName() + """ + toolBox.eval(code).asInstanceOf[Ordering[Row]] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala new file mode 100644 index 000000000000..3748e9a27a84 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. + */ +object GeneratePredicate extends CodeGenerator { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + // TODO: Should be weak references... bounded in size. + val predicateCache = new collection.mutable.HashMap[Expression, (Row) => Boolean] + + // TODO: Safe to fire up multiple instances of the compiler? + def apply(predicate: Expression): (Row => Boolean) = CodeGeneration.synchronized { + val cleanedExpression = ExpressionCanonicalizer(predicate) + predicateCache.getOrElseUpdate(cleanedExpression, createPredicate(cleanedExpression)) + } + + def apply(predicate: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = + apply(BindReferences.bindReference(predicate, inputSchema)) + + def createPredicate(predicate: Expression): ((Row) => Boolean) = { + val cEval = expressionEvaluator(predicate) + + val code = + q""" + (i: $rowType) => { + ..${cEval.code} + if (${cEval.nullTerm}) false else ${cEval.primitiveTerm} + } + """ + + log.debug(s"Generated predicate '$predicate':\n$code") + toolBox.eval(code).asInstanceOf[Row => Boolean] + } +} \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala new file mode 100644 index 000000000000..8b7ae7e4dbd7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + + +/** + * Generates bytecode that produces a new [[Row]] object based on a fixed set of input + * [[Expression Expressions]] and a given input [[Row]]. The returned [[Row]] object is custom + * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. + */ +object GenerateProjection extends CodeGenerator { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + // TODO: Should be weak references... bounded in size. + val projectionCache = new collection.mutable.HashMap[Seq[Expression], Projection] + + def apply(expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = + apply(expressions.map(BindReferences.bindReference(_, inputSchema))) + + // TODO: Safe to fire up multiple instances of the compiler? + def apply(expressions: Seq[Expression]): Projection = CodeGeneration.synchronized { + val cleanedExpressions = expressions.map(ExpressionCanonicalizer(_)) + projectionCache.getOrElseUpdate(cleanedExpressions, createProjection(cleanedExpressions)) + } + + // Make Mutablility optional... + def createProjection(expressions: Seq[Expression]): Projection = { + val tupleLength = ru.Literal(Constant(expressions.length)) + val lengthDef = q"final val length = $tupleLength" + + /* TODO: Configurable... + val nullFunctions = + q""" + private final val nullSet = new org.apache.spark.util.collection.BitSet(length) + final def setNullAt(i: Int) = nullSet.set(i) + final def isNullAt(i: Int) = nullSet.get(i) + """ + */ + + val nullFunctions = + q""" + private[this] var nullBits = new Array[Boolean](${expressions.size}) + final def setNullAt(i: Int) = { nullBits(i) = true } + final def isNullAt(i: Int) = nullBits(i) + """.children + + val tupleElements = expressions.zipWithIndex.flatMap { + case (e, i) => + val elementName = newTermName(s"c$i") + val evaluatedExpression = expressionEvaluator(e) + val iLit = ru.Literal(Constant(i)) + + q""" + var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _ + { + ..${evaluatedExpression.code} + if(${evaluatedExpression.nullTerm}) + setNullAt($iLit) + else + $elementName = ${evaluatedExpression.primitiveTerm} + } + """.children : Seq[Tree] + } + + val iteratorFunction = { + val allColumns = (0 until expressions.size).map { i => + val iLit = ru.Literal(Constant(i)) + q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + } + q"final def iterator = Iterator[Any](..$allColumns)" + } + + val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" + val applyFunction = { + val cases = (0 until expressions.size).map { i => + val ordinal = ru.Literal(Constant(i)) + val elementName = newTermName(s"c$i") + val iLit = ru.Literal(Constant(i)) + + q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }" + } + q"final def apply(i: Int): Any = { ..$cases; $accessorFailure }" + } + + val updateFunction = { + val cases = expressions.zipWithIndex.map {case (e, i) => + val ordinal = ru.Literal(Constant(i)) + val elementName = newTermName(s"c$i") + val iLit = ru.Literal(Constant(i)) + + q""" + if(i == $ordinal) { + if(value == null) { + setNullAt(i) + } else { + $elementName = value.asInstanceOf[${termForType(e.dataType)}] + return + } + }""" + } + q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" + } + + val specificAccessorFunctions = NativeType.all.map { dataType => + val ifStatements = expressions.zipWithIndex.flatMap { + case (e, i) if e.dataType == dataType => + val elementName = newTermName(s"c$i") + // TODO: The string of ifs gets pretty inefficient as the row grows in size. + // TODO: Optional null checks? + q"if(i == $i) return $elementName" :: Nil + case _ => Nil + } + + q""" + final def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { + ..$ifStatements; + $accessorFailure + }""" + } + + val specificMutatorFunctions = NativeType.all.map { dataType => + val ifStatements = expressions.zipWithIndex.flatMap { + case (e, i) if e.dataType == dataType => + val elementName = newTermName(s"c$i") + // TODO: The string of ifs gets pretty inefficient as the row grows in size. + // TODO: Optional null checks? + q"if(i == $i) { $elementName = value; return }" :: Nil + case _ => Nil + } + + q""" + final def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { + ..$ifStatements; + $accessorFailure + }""" + } + + val hashValues = expressions.zipWithIndex.map { case (e,i) => + val elementName = newTermName(s"c$i") + val nonNull = e.dataType match { + case BooleanType => q"if ($elementName) 0 else 1" + case ByteType | ShortType | IntegerType => q"$elementName.toInt" + case LongType => q"($elementName ^ ($elementName >>> 32)).toInt" + case FloatType => q"java.lang.Float.floatToIntBits($elementName)" + case DoubleType => + q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }" + case _ => q"$elementName.hashCode" + } + q"if (isNullAt($i)) 0 else $nonNull" + } + + val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree) + + val hashCodeFunction = + q""" + override def hashCode(): Int = { + var result: Int = 37 + ..$hashUpdates + result + } + """ + + val columnChecks = (0 until expressions.size).map { i => + val elementName = newTermName(s"c$i") + q"if (this.$elementName != specificType.$elementName) return false" + } + + val equalsFunction = + q""" + override def equals(other: Any): Boolean = other match { + case specificType: SpecificRow => + ..$columnChecks + return true + case other => super.equals(other) + } + """ + + val classBody = + nullFunctions ++ ( + lengthDef +: + iteratorFunction +: + applyFunction +: + updateFunction +: + equalsFunction +: + hashCodeFunction +: + (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) + + val code = q""" + final class SpecificRow(i: $rowType) extends $mutableRowType { + ..$classBody + + // Not safe! + final def copy() = scala.sys.error("Not implemented") + + final def getStringBuilder(ordinal: Int): StringBuilder = ??? + } + + new $projectionType { def apply(r: $rowType) = new SpecificRow(r) } + """ + + log.debug( + s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}") + toolBox.eval(code).asInstanceOf[Projection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala new file mode 100644 index 000000000000..f26d3d3f7b02 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.rules +import org.apache.spark.sql.catalyst.util + +/** + * A collection of generators that build custom bytecode at runtime for performing the evaluation + * of catalyst expression. + */ +package object codegen { + + /** + * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala + * 2.10. + */ + protected val globalLock = new Object() + + /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ + object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { + val batches = + Batch("CleanExpressions", FixedPoint(20), CleanExpressions) :: Nil + + object CleanExpressions extends rules.Rule[Expression] { + def apply(e: Expression): Expression = e transform { + case BoundReference(o, a) => + BoundReference(o, AttributeReference("a", a.dataType, a.nullable)(exprId = ExprId(0))) + case Alias(c, _) => c + } + } + } + + /** + * :: DeveloperApi :: + * Dumps the bytecode from a class to the screen using javap. + */ + @DeveloperApi + object DumpByteCode { + import scala.sys.process._ + val dumpDirectory = util.getTempFilePath("sparkSqlByteCode") + dumpDirectory.mkdir() + + def apply(obj: Any): Unit = { + val generatedClass = obj.getClass + val classLoader = + generatedClass + .getClassLoader + .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader] + val generatedBytes = classLoader.classBytes(generatedClass.getName) + + val packageDir = new java.io.File(dumpDirectory, generatedClass.getPackage.getName) + if (!packageDir.exists()) { packageDir.mkdir() } + + val classFile = + new java.io.File(packageDir, generatedClass.getName.split("\\.").last + ".class") + + val outfile = new java.io.FileOutputStream(classFile) + outfile.write(generatedBytes) + outfile.close() + + println( + s"javap -p -v -classpath ${dumpDirectory.getCanonicalPath} ${generatedClass.getName}".!!) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index b6f2451b52e1..c609e6b92904 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -47,4 +47,14 @@ package org.apache.spark.sql.catalyst * ==Evaluation== * The result of expressions can be evaluated using the `Expression.apply(Row)` method. */ -package object expressions +package object expressions { + + abstract class Projection extends (Row => Row) + + abstract class MutableProjection extends Projection { + def currentValue: Row + + /* Updates the target of this projection to a new MutableRow */ + def target(row: MutableRow): MutableProjection + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 347471cebdc7..ccfe3f1a6db8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -23,6 +23,16 @@ import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.sql.catalyst.types.StringType import org.apache.spark.sql.catalyst.types.BooleanType +case class SubString(string: Expression, start: Expression, end: Expression) extends Expression { + def children = string :: start :: end :: Nil + def references = children.flatMap(_.references).toSet + def dataType = StringType + def nullable = string.nullable + + override def eval(input: Row) = ??? + + override def toString = s"substr($string, $start, $end" +} trait StringRegexExpression { self: BinaryExpression => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index e32adb76fe14..144c8530564c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -72,7 +72,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { } iteration += 1 if (iteration > batch.strategy.maxIterations) { - logger.info(s"Max iterations ($iteration) reached for batch ${batch.name}") + if (iteration != 2) { + logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + } continue = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index bb77bccf8617..38bb91974760 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, import org.apache.spark.util.Utils /** - * + * Utility functions for working with DataTypes */ object DataType extends RegexParsers { protected lazy val primitiveType: Parser[DataType] = @@ -97,6 +97,13 @@ abstract class DataType { case object NullType extends DataType +object NativeType { + def all = Seq( + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + + def unapply(dt: DataType): Boolean = all.contains(dt) +} + trait PrimitiveType extends DataType { override def isPrimitive = true } @@ -145,6 +152,13 @@ abstract class NumericType extends NativeType with PrimitiveType { val numeric: Numeric[JvmType] } +object NumericType { + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[NumericType] => true + case _ => false + } +} + /** Matcher for any expressions that evaluate to [[IntegralType]]s */ object IntegralType { def unapply(a: Expression): Boolean = a match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CodeGenerationSuite.scala new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 84d72814778b..56840d3491b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -29,7 +29,11 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ class ExpressionEvaluationSuite extends FunSuite { test("literals") { - assert((Literal(1) + Literal(1)).eval(null) === 2) + checkEvaluation(Literal(1), 1) + checkEvaluation(Literal(true), true) + checkEvaluation(Literal(0L), 0L) + checkEvaluation(Literal("test"), "test") + checkEvaluation(Literal(1) + Literal(1), 2) } /** @@ -61,10 +65,8 @@ class ExpressionEvaluationSuite extends FunSuite { test("3VL Not") { notTrueTable.foreach { case (v, answer) => - val expr = ! Literal(v, BooleanType) - val result = expr.eval(null) - if (result != answer) - fail(s"$expr should not evaluate to $result, expected: $answer") } + checkEvaluation(!Literal(v, BooleanType), answer) + } } booleanLogicTest("AND", _ && _, @@ -127,6 +129,13 @@ class ExpressionEvaluationSuite extends FunSuite { } } + test("IN") { + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal(null, StringType).like("a"), null) checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) @@ -232,21 +241,21 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) - checkEvaluation("23" cast DoubleType, 23) + checkEvaluation("23" cast DoubleType, 23d) checkEvaluation("23" cast IntegerType, 23) - checkEvaluation("23" cast FloatType, 23) - checkEvaluation("23" cast DecimalType, 23) - checkEvaluation("23" cast ByteType, 23) - checkEvaluation("23" cast ShortType, 23) + checkEvaluation("23" cast FloatType, 23f) + checkEvaluation("23" cast DecimalType, 23: BigDecimal) + checkEvaluation("23" cast ByteType, 23.toByte) + checkEvaluation("23" cast ShortType, 23.toShort) checkEvaluation("2012-12-11" cast DoubleType, null) checkEvaluation(Literal(123) cast IntegerType, 123) - checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24) + checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) - checkEvaluation(Literal(23f) + Cast(true, FloatType), 24) - checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24) - checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24) - checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24) + checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) + checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal) + checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) + checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala new file mode 100644 index 000000000000..438735990c11 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ + +/** + * Overrides our expression evaluation tests to use code generation for evaluation. + */ +class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { + val generator = new CodeGenerator() {} + + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val plan = try { + GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val evaluated = generator.expressionEvaluator(expression) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if(actual != expected) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + + test("multithreaded eval") { + import scala.concurrent._ + import ExecutionContext.Implicits.global + import scala.concurrent.duration._ + + val futures = (1 to 20).map { _ => + future { + GeneratePredicate(EqualTo(Literal(1), Literal(1))) + GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil) + } + } + + futures.foreach(Await.result(_, 10 seconds)) + } +} + +/** + * Overrides our expression evaluation tests to use generated code on mutable rows. + */ +class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { + val generator = new CodeGenerator() {} + + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + lazy val evaluated = generator.expressionEvaluator(expression) + + val plan = try { + GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](expected)) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |${evaluated.code.mkString("\n")} + """.stripMargin) + } + if (actual != expectedRow) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4abd89955bd2..de656a44919d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -308,9 +308,11 @@ class SQLContext(@transient val sparkContext: SparkContext) } /** + * :: DeveloperApi :: * The primary workflow for executing relational queries using Spark. Designed to allow easy * access to the intermediate phases of query execution for developers. */ + @DeveloperApi protected abstract class QueryExecution { def logical: LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index c1ced8bfa404..411ae09cee3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -138,7 +138,7 @@ case class Aggregate( i += 1 } } - val resultProjection = new Projection(resultExpressions, computedSchema) + val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) val aggregateResults = new GenericMutableRow(computedAggregates.length) var i = 0 @@ -152,7 +152,7 @@ case class Aggregate( } else { child.execute().mapPartitions { iter => val hashTable = new HashMap[Row, Array[AggregateFunction]] - val groupingProjection = new MutableProjection(groupingExpressions, childOutput) + val groupingProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) var currentRow: Row = null while (iter.hasNext) { @@ -175,7 +175,8 @@ case class Aggregate( private[this] val hashTableIter = hashTable.entrySet().iterator() private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) private[this] val resultProjection = - new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2)) + new InterpretedMutableProjection( + resultExpressions, computedSchema ++ namedGroups.map(_._2)) private[this] val joinedRow = new JoinedRow override final def hasNext: Boolean = hashTableIter.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 00010ef6e798..3d01a29ec43d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{NoBind, RowOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -42,7 +42,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. val rdd = child.execute().mapPartitions { iter => - val hashExpressions = new MutableProjection(expressions, child.output) + @transient val hashExpressions = + newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[Row, Row]() iter.map(r => mutablePair.update(hashExpressions(r), r)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 47b3d00262db..aaba658b56b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -56,10 +56,11 @@ case class Generate( val nullValues = Seq.fill(generator.output.size)(Literal(null)) // Used to produce rows with no matches when outer = true. val outerProjection = - new Projection(child.output ++ nullValues, child.output) + new InterpretedProjection(child.output ++ nullValues, child.output) val joinProjection = - new Projection(child.output ++ generator.output, child.output ++ generator.output) + new InterpretedProjection( + child.output ++ generator.output, child.output ++ generator.output) val joinedRow = new JoinedRow iter.flatMap {row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 27dc091b8581..de4a3bf5f916 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -22,7 +22,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Logging, Row} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.BaseRelation import org.apache.spark.sql.catalyst.plans.physical._ @@ -51,8 +52,19 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { */ def executeCollect(): Array[Row] = execute().map(_.copy()).collect() - protected def buildRow(values: Seq[Any]): Row = - new GenericRow(values.toArray) + def newProjection(expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = + GenerateProjection(expressions) + + def newMutableProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): () => MutableProjection = { + GenerateMutableProjection(expressions) + } + + + def newPredicate(expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { + GeneratePredicate(expression, inputSchema) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7080074a69c0..5a88e0085c27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -49,6 +49,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * evaluated by matching hash keys. */ object HashJoin extends Strategy with PredicateHelper { + private[this] def broadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -171,12 +172,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) - def convertToCatalyst(a: Any): Any = a match { - case s: Seq[Any] => s.map(convertToCatalyst) - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other - } - object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => @@ -265,10 +260,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => - val dataAsRdd = - sparkContext.parallelize(data.map(r => - new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)) - execution.ExistingRdd(output, dataAsRdd) :: Nil + ExistingRdd( + output, + ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child))(sqlContext) :: Nil case Unions(unionChildren) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala new file mode 100644 index 000000000000..ee62623ecbae --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkContext +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.types._ + +/** + * Attempt to rewrite aggregate to be more efficient. + * + * @param partial if true then aggregation is done partially on local data without shuffling to + * ensure all values where `groupingExpressions` are equal are present. + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param child the input data source. + */ +case class HashAggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan)(@transient sc: SparkContext) + extends UnaryNode with NoBind { + + override def requiredChildDistribution = + if (partial) { + UnspecifiedDistribution :: Nil + } else { + if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def otherCopyArgs = sc :: Nil + + def output = aggregateExpressions.map(_.toAttribute) + + def execute() = { + val aggregatesToCompute = aggregateExpressions.flatMap { a => + a.collect { case agg: AggregateExpression => agg } + } + + // Move these into expressions... have fall back that uses standard aggregate interface. + val computeFunctions = aggregatesToCompute.map { + case c @ Count(expr) => + val currentCount = AttributeReference("currentCount", LongType, true)() + val initialValue = Literal(0L) + val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val result = currentCount + + AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case Sum(expr) => + val currentSum = AttributeReference("currentSum", expr.dataType, true)() + val initialValue = Cast(Literal(0L), expr.dataType) + + // Coalasce avoids double calculation... + // but really, common sub expression elimination would be better.... + val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) + val result = currentSum + + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case a @ Average(expr) => + val currentCount = AttributeReference("currentCount", LongType, true)() + val currentSum = AttributeReference("currentSum", expr.dataType, true)() + val initialCount = Literal(0L) + val initialSum = Cast(Literal(0L), expr.dataType) + val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) + + val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType)) + + AggregateEvaluation( + currentCount :: currentSum :: Nil, + initialCount :: initialSum :: Nil, + updateCount :: updateSum :: Nil, + result + ) + + /* + case otherAggregate => + val ref = + AttributeReference("aggregateFunction", otherAggregate.dataType, otherAggregate.nullable) + + AggregateEvaluation( + ref :: Nil, + ScalaUdf(() => otherAggregate.newInstance, NullType, ), + ) + */ + } + + @transient val computationSchema = computeFunctions.flatMap(_.schema) + @transient lazy val newComputationBuffer = + newProjection(computeFunctions.flatMap(_.initialValues), child.output) + @transient lazy val updateProjectionBuilder = + newMutableProjection( + computeFunctions.flatMap(_.update), + computeFunctions.flatMap(_.schema) ++ child.output) + @transient lazy val groupProjection = newProjection(groupingExpressions, child.output) + + @transient val resultMap = aggregatesToCompute.zip(computeFunctions).map { + case (agg, func) => agg.id -> func.result + }.toMap + + val namedGroups = groupingExpressions.zipWithIndex.map { + case (ne: NamedExpression, _) => (ne, ne) + case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) + } + val groupMap = namedGroups.map {case (k,v) => k -> v.toAttribute }.toMap + + @transient val resultExpressions = aggregateExpressions.map(_.transform { + case e: Expression if resultMap.contains(e.id) => resultMap(e.id) + case e: Expression if groupMap.contains(e) => groupMap(e) + }) + @transient lazy val resultProjectionBuilder = + newMutableProjection( + resultExpressions, + (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + + child.execute().mapPartitions { iter => + // TODO: Skip hashmap for no grouping exprs... + @transient val buffers = new java.util.HashMap[Row, MutableRow]() + @transient val updateProjection = updateProjectionBuilder() + @transient val joinedRow = new JoinedRow + + var currentRow: Row = null + while(iter.hasNext) { + currentRow = iter.next() + val currentGroup = groupProjection(currentRow) + var currentBuffer = buffers.get(currentGroup) + if(currentBuffer == null) { + currentBuffer = newComputationBuffer(EmptyRow).asInstanceOf[MutableRow] + buffers.put(currentGroup, currentBuffer) + } + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) + } + + @transient val resultIterator = buffers.entrySet.iterator() + @transient val resultProjection = resultProjectionBuilder() + new Iterator[Row] { + def hasNext = resultIterator.hasNext + def next() = { + val currentGroup = resultIterator.next() + resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) + } + } + } + } +} + +case class AggregateEvaluation( + schema: Seq[Attribute], + initialValues: Seq[Expression], + update: Seq[Expression], + result: Expression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 97abd636ab5f..7f7984a41efe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,9 +37,11 @@ import org.apache.spark.util.MutablePair case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override def output = projectList.map(_.toAttribute) - override def execute() = child.execute().mapPartitions { iter => - @transient val reusableProjection = new MutableProjection(projectList) - iter.map(reusableProjection) + @transient lazy val buildProjection = newMutableProjection(projectList, child.output) + + def execute() = child.execute().mapPartitions { iter => + val resuableProjection = buildProjection() + iter.map(resuableProjection) } } @@ -50,8 +52,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output = child.output - override def execute() = child.execute().mapPartitions { iter => - iter.filter(condition.eval(_).asInstanceOf[Boolean]) + @transient lazy val conditionEvaluator = newPredicate(condition, child.output) + + def execute() = child.execute().mapPartitions { iter => + iter.filter(conditionEvaluator) } } @@ -170,6 +174,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) @transient lazy val ordering = new RowOrdering(sortOrder) + // TODO: Is this copying for no reason? override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) // TODO: Terminal split should be implemented differently from non-terminal split. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 7d1f11caae83..8d3d0dc307fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -38,6 +38,8 @@ case object BuildLeft extends BuildSide case object BuildRight extends BuildSide trait HashJoin { + self: SparkPlan => + val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] val buildSide: BuildSide @@ -56,9 +58,9 @@ trait HashJoin { def output = left.output ++ right.output - @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) + @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) @transient lazy val streamSideKeyGenerator = - () => new MutableProjection(streamedKeys, streamedPlan.output) + newMutableProjection(streamedKeys, streamedPlan.output) def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { // TODO: Use Spark's HashMap implementation. @@ -300,8 +302,16 @@ case class LeftSemiJoinBNL( case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { def output = left.output ++ right.output - def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map { - case (l: Row, r: Row) => buildRow(l ++ r) + def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.cartesian(rightResults).mapPartitions { iter => + val joinedRow = new JoinedRow + iter.map { + case (l: Row, r: Row) => joinedRow(l, r) + } + } } } @@ -352,6 +362,7 @@ case class BroadcastNestedLoopJoin( // TODO: Use Spark's BitSet. val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val rightNulls = new GenericMutableRow(right.output.size) streamedIter.foreach { streamedRow => var i = 0 @@ -361,7 +372,7 @@ case class BroadcastNestedLoopJoin( // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matchedRows += buildRow(streamedRow ++ broadcastedRow) + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() matched = true includedBroadcastTuples += i } @@ -369,7 +380,7 @@ case class BroadcastNestedLoopJoin( } if (!matched && (joinType == LeftOuter || joinType == FullOuter)) { - matchedRows += buildRow(streamedRow ++ Array.fill(right.output.size)(null)) + matchedRows += joinedRow(streamedRow, rightNulls).copy() } } Iterator((matchedRows, includedBroadcastTuples)) @@ -383,13 +394,13 @@ case class BroadcastNestedLoopJoin( streamedPlusMatches.map(_._2).reduce(_ ++ _) } + val leftNulls = new GenericMutableRow(left.output.size) val rightOuterMatches: Seq[Row] = if (joinType == RightOuter || joinType == FullOuter) { broadcastedRelation.value.zipWithIndex.filter { case (row, i) => !allIncludedBroadcastTuples.contains(i) }.map { - // TODO: Use projection. - case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row) + case (row, _) => new JoinedRow(leftNulls, row) } } else { Vector() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ade823b51c9c..1dfaa2505f44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -302,6 +302,7 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) */ private[parquet] class FilteringParquetRowInputFormat extends parquet.hadoop.ParquetInputFormat[Row] with Logging { + override def createRecordReader( inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b70104dd5be5..d0c51c391ecc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -840,6 +840,7 @@ private[hive] object HiveQl { /* Case insensitive matches */ val COUNT = "(?i)COUNT".r + val SUBSTR = "(?i)SUBSTR".r val AVG = "(?i)AVG".r val SUM = "(?i)SUM".r val MAX = "(?i)MAX".r @@ -852,10 +853,10 @@ private[hive] object HiveQl { val NOT = "(?i)NOT".r val TRUE = "(?i)TRUE".r val FALSE = "(?i)FALSE".r + val IN = "(?i)IN".r val LIKE = "(?i)LIKE".r val RLIKE = "(?i)RLIKE".r val REGEXP = "(?i)REGEXP".r - val IN = "(?i)IN".r val DIV = "(?i)DIV".r val BETWEEN = "(?i)BETWEEN".r val WHEN = "(?i)WHEN".r @@ -949,13 +950,14 @@ private[hive] object HiveQl { IsNull(nodeToExpr(child)) case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => In(nodeToExpr(value), list.map(nodeToExpr)) + case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: str :: start :: end :: Nil) => + SubString(nodeToExpr(str), nodeToExpr(start), nodeToExpr(end)) case Token("TOK_FUNCTION", Token(BETWEEN(), Nil) :: Token("KW_FALSE", Nil) :: target :: minValue :: maxValue :: Nil) => - val targetExpression = nodeToExpr(target) And( GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 8258ee5fef0e..7f71fe4e32f4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -67,7 +67,7 @@ case class ScriptTransformation( } } readerThread.start() - val outputProjection = new Projection(input) + val outputProjection = new InterpretedProjection(input) iter .map(outputProjection) // TODO: Use SerDe diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 9b105308ab7c..7aaa8bfb7aa9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -476,7 +476,7 @@ private[hive] case class HiveGenericUdtf( override def eval(input: Row): TraversableOnce[Row] = { outputInspectors // Make sure initialized. - val inputProjection = new Projection(children) + val inputProjection = new InterpretedProjection(children) val collector = new UDTFCollector function.setCollector(collector) @@ -530,7 +530,7 @@ private[hive] case class HiveUdafFunction( override def eval(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector) @transient - val inputProjection = new Projection(exprs) + val inputProjection = new InterpretedProjection(exprs) def update(input: Row): Unit = { val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index a623d29b5397..15e5e3da8a47 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,17 +19,41 @@ package org.apache.spark.sql.hive.execution import scala.util.Try +import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SchemaRDD, Row} case class TestData(a: Int, b: String) +/** + * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + */ +class HiveCacheSuite extends HiveComparisonTest { + //cache("src", "key" :: "value" :: Nil, "value" :: Nil) + + println(catalog.lookupRelation(None, "src")) + + println(executeSql("SELECT SUM(key) FROM src GROUP BY value")) + + createQueryTest("Simple Average", + "SELECT SUM(key) FROM src GROUP BY value") +} + /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("single case", + """SELECT case when true then 1 else 2 end FROM src""") + + createQueryTest("double case", + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else 2 end FROM src""") + + createQueryTest("case else null", + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else null end FROM src""") + test("CREATE TABLE AS runs once") { hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, From 92e74a45a40461a99faaa41bc7b85bfef549a6ce Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Jul 2014 23:00:42 -0700 Subject: [PATCH 02/28] add overrides --- .../apache/spark/sql/catalyst/expressions/Projection.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 53cae54eb8e8..1bb0367a2f24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -58,12 +58,12 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size) def currentValue: Row = mutableRow - def target(row: MutableRow): MutableProjection = { + override def target(row: MutableRow): MutableProjection = { mutableRow = row this } - def apply(input: Row): Row = { + override def apply(input: Row): Row = { var i = 0 while (i < exprArray.length) { mutableRow(i) = exprArray(i).eval(input) From efad14f96a6c3d7f6ca38fbf9067f4e1eb6b5c34 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Jul 2014 23:01:45 -0700 Subject: [PATCH 03/28] Remove some half finished functions. --- .../expressions/codegen/CodeGenerator.scala | 23 ++----------------- .../expressions/stringOperations.scala | 11 --------- .../org/apache/spark/sql/hive/HiveQl.scala | 2 -- 3 files changed, 2 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a7afd3932dfc..131bf900ee0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -216,6 +216,7 @@ abstract class CodeGenerator extends Logging { case EqualTo(e1, e2) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } + /* TODO: Fix null semantics. case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) => val eval = expressionEvaluator(e1) @@ -238,6 +239,7 @@ abstract class CodeGenerator extends Logging { val $nullTerm = false val $primitiveTerm = $funcName """.children + */ case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } @@ -352,27 +354,6 @@ abstract class CodeGenerator extends Logging { $primitiveTerm = ${falseEval.primitiveTerm} } """.children - - case SubString(str, start, end) => - val stringEval = expressionEvaluator(str) - val startEval = expressionEvaluator(start) - val endEval = expressionEvaluator(end) - - stringEval.code ++ startEval.code ++ endEval.code ++ - q""" - var $nullTerm = ${stringEval.nullTerm} - var $primitiveTerm: String = - if($nullTerm) { - null - } else { - val len = - if(${endEval.primitiveTerm} <= ${stringEval.primitiveTerm}.length) - ${endEval.primitiveTerm} - else - ${stringEval.primitiveTerm}.length - ${stringEval.primitiveTerm}.substring(${startEval.primitiveTerm}, len) - } - """.children } // If there was no match in the partial function above, we fall back on calling the interpreted diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index ccfe3f1a6db8..85d5c589ae07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -23,17 +23,6 @@ import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.sql.catalyst.types.StringType import org.apache.spark.sql.catalyst.types.BooleanType -case class SubString(string: Expression, start: Expression, end: Expression) extends Expression { - def children = string :: start :: end :: Nil - def references = children.flatMap(_.references).toSet - def dataType = StringType - def nullable = string.nullable - - override def eval(input: Row) = ??? - - override def toString = s"substr($string, $start, $end" -} - trait StringRegexExpression { self: BinaryExpression => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index d0c51c391ecc..b840f058d300 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -950,8 +950,6 @@ private[hive] object HiveQl { IsNull(nodeToExpr(child)) case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => In(nodeToExpr(value), list.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: str :: start :: end :: Nil) => - SubString(nodeToExpr(str), nodeToExpr(start), nodeToExpr(end)) case Token("TOK_FUNCTION", Token(BETWEEN(), Nil) :: Token("KW_FALSE", Nil) :: From f623ffd9e29a1da6e75b4c871d93e2930519efea Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Jul 2014 23:02:05 -0700 Subject: [PATCH 04/28] Quiet logging from test suite. --- .../spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 4896f1b955f0..e2ae0d25db1a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -27,9 +27,9 @@ class CombiningLimitsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("Combine Limit", FixedPoint(2), + Batch("Combine Limit", FixedPoint(10), CombineLimits) :: - Batch("Constant Folding", FixedPoint(3), + Batch("Constant Folding", FixedPoint(10), NullPropagation, ConstantFolding, BooleanSimplification) :: Nil From 0e889e8357d82c67205acd89a940e3d5206652ff Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Jul 2014 23:02:24 -0700 Subject: [PATCH 05/28] Use typeOf instead tq --- .../expressions/codegen/CodeGenerator.scala | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 131bf900ee0f..84f36dd98a29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -37,14 +37,13 @@ abstract class CodeGenerator extends Logging { val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() - // TODO: Use typetags? - val rowType = tq"org.apache.spark.sql.catalyst.expressions.Row" - val mutableRowType = tq"org.apache.spark.sql.catalyst.expressions.MutableRow" - val genericRowType = tq"org.apache.spark.sql.catalyst.expressions.GenericRow" - val genericMutableRowType = tq"org.apache.spark.sql.catalyst.expressions.GenericMutableRow" - - val projectionType = tq"org.apache.spark.sql.catalyst.expressions.Projection" - val mutableProjectionType = tq"org.apache.spark.sql.catalyst.expressions.MutableProjection" + val rowType = typeOf[Row] + val mutableRowType = typeOf[MutableRow] + val genericRowType = typeOf[GenericRow] + val genericMutableRowType = typeOf[GenericMutableRow] + + val projectionType = typeOf[Projection] + val mutableProjectionType = typeOf[MutableProjection] private val curId = new java.util.concurrent.atomic.AtomicInteger() private val javaSeperator = "$" From d81f998a106f99a2515d40134a30b88e323c1d75 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Jul 2014 23:02:36 -0700 Subject: [PATCH 06/28] include schema for binding. --- .../main/scala/org/apache/spark/sql/execution/SparkPlan.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index de4a3bf5f916..89a537d341ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -53,12 +53,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { def executeCollect(): Array[Row] = execute().map(_.copy()).collect() def newProjection(expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = - GenerateProjection(expressions) + GenerateProjection(expressions, inputSchema) def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { - GenerateMutableProjection(expressions) + GenerateMutableProjection(expressions, inputSchema) } From 00933769e3bd5a97d7eb24a9a9cb141a6975a599 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Jul 2014 23:02:49 -0700 Subject: [PATCH 07/28] Comment / indenting cleanup. --- .../catalyst/expressions/codegen/CodeGenerator.scala | 11 ++++------- .../spark/sql/catalyst/expressions/package.scala | 2 +- .../spark/sql/catalyst/rules/RuleExecutor.scala | 1 + 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 84f36dd98a29..cfc1a0596fbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -69,13 +69,10 @@ abstract class CodeGenerator extends Logging { * @param objectTerm An possibly boxed version of the result of evaluating this expression. */ protected case class EvaluatedExpression( - code: Seq[Tree], - nullTerm: TermName, - primitiveTerm: TermName, - objectTerm: TermName) { - - def withObjectTerm = ??? - } + code: Seq[Tree], + nullTerm: TermName, + primitiveTerm: TermName, + objectTerm: TermName) /** * Given an expression tree returns the code required to determine both if the result is NULL diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index c609e6b92904..91658bcf7e17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -54,7 +54,7 @@ package object expressions { abstract class MutableProjection extends Projection { def currentValue: Row - /* Updates the target of this projection to a new MutableRow */ + /** Updates the target of this projection to a new MutableRow */ def target(row: MutableRow): MutableProjection } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 144c8530564c..e300bdbececb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -72,6 +72,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { } iteration += 1 if (iteration > batch.strategy.maxIterations) { + // Only log if this is a rule that is supposed to run more than once. if (iteration != 2) { logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") } From 675e6790a61d518b33c22573bfd352fc58c9aed5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Jul 2014 23:15:26 -0700 Subject: [PATCH 08/28] Upgrade paradise. --- project/SparkBuild.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d5f571cf94c0..58c3fc490db2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -495,9 +495,9 @@ object SparkBuild extends Build { // assumptions about the the expression ids being contiguous. Running tests in parallel breaks // this non-deterministically. TODO: FIX THIS. parallelExecution in Test := false, - addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.0-M8" cross CrossVersion.full), + addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v ), - libraryDependencies += "org.scalamacros" %% "quasiquotes" % "2.0.0-M8", + libraryDependencies += "org.scalamacros" %% "quasiquotes" % "2.0.1", libraryDependencies ++= Seq( "com.typesafe" %% "scalalogging-slf4j" % "1.0.1" ) From e742640b9a5598e902134959e71e7423a5364edc Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Jul 2014 23:16:08 -0700 Subject: [PATCH 09/28] Remove unneeded changes and code. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +-- .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../catalyst/expressions/BoundAttribute.scala | 1 - .../catalyst/expressions/GeneratedRow.scala | 26 ------------------- .../codegen/GenerateMutableProjection.scala | 2 +- .../codegen/GenerateOrdering.scala | 2 +- .../codegen/GeneratePredicate.scala | 2 +- .../codegen/GenerateProjection.scala | 2 +- .../expressions/codegen/package.scala | 2 +- .../apache/spark/sql/execution/Generate.scala | 5 ++-- .../sql/parquet/ParquetTableOperations.scala | 1 - .../org/apache/spark/sql/hive/HiveQl.scala | 4 +-- .../sql/hive/execution/HiveQuerySuite.scala | 14 ---------- 13 files changed, 12 insertions(+), 55 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GeneratedRow.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9571a57e8f5d..c7188469bfb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -158,8 +158,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool */ object ImplicitGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Project(Seq(Alias(g: Generator, alias)), child) => - Generate(g, join = false, outer = false, Some(alias), child) + case Project(Seq(Alias(g: Generator, _)), child) => + Generate(g, join = false, outer = false, None, child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8e011acaab7e..76ddeba9cb31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -259,7 +259,7 @@ trait HiveTypeCoercion { case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) // Turn true into 1, and false into 0 if casting boolean into other types. case Cast(e, dataType) if e.dataType == BooleanType => - If(e, Cast(Literal(1), dataType), Cast(Literal(0), dataType)) + Cast(If(e, Literal(1), Literal(0)), dataType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 5d6c95eba391..9ce1f0105646 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -62,7 +62,6 @@ class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { plan.transform { case n: NoBind => n.asInstanceOf[TreeNode] case leafNode if leafNode.children.isEmpty => leafNode - case nb: NoBind => nb.asInstanceOf[TreeNode] case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e => bindReference(e, unaryNode.children.head.output) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GeneratedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GeneratedRow.scala deleted file mode 100644 index 2147921d0d9d..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GeneratedRow.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst._ - -import org.apache.spark.sql.catalyst.types._ - -object CodeGeneration - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 09e9af387c9b..f1f8eb79e401 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -35,7 +35,7 @@ object GenerateMutableProjection extends CodeGenerator { // TODO: Safe to fire up multiple instances of the compiler? def apply(expressions: Seq[Expression]): () => MutableProjection = - CodeGeneration.synchronized { + globalLock.synchronized { val cleanedExpressions = expressions.map(ExpressionCanonicalizer(_)) projectionCache.getOrElseUpdate(cleanedExpressions, createProjection(cleanedExpressions)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 27cb6779ade1..03804e4dc1a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -31,7 +31,7 @@ object GenerateOrdering extends CodeGenerator { val orderingCache = new collection.mutable.HashMap[Seq[SortOrder], Ordering[Row]] // TODO: Safe to fire up multiple instances of the compiler? - def apply(ordering: Seq[SortOrder]): Ordering[Row] = CodeGeneration.synchronized { + def apply(ordering: Seq[SortOrder]): Ordering[Row] = globalLock.synchronized { val cleanedExpression = ordering.map(ExpressionCanonicalizer(_)).asInstanceOf[Seq[SortOrder]] orderingCache.getOrElseUpdate(cleanedExpression, createOrdering(cleanedExpression)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 3748e9a27a84..18031fa98e21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -30,7 +30,7 @@ object GeneratePredicate extends CodeGenerator { val predicateCache = new collection.mutable.HashMap[Expression, (Row) => Boolean] // TODO: Safe to fire up multiple instances of the compiler? - def apply(predicate: Expression): (Row => Boolean) = CodeGeneration.synchronized { + def apply(predicate: Expression): (Row => Boolean) = globalLock.synchronized { val cleanedExpression = ExpressionCanonicalizer(predicate) predicateCache.getOrElseUpdate(cleanedExpression, createPredicate(cleanedExpression)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 8b7ae7e4dbd7..c94227454bb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -37,7 +37,7 @@ object GenerateProjection extends CodeGenerator { apply(expressions.map(BindReferences.bindReference(_, inputSchema))) // TODO: Safe to fire up multiple instances of the compiler? - def apply(expressions: Seq[Expression]): Projection = CodeGeneration.synchronized { + def apply(expressions: Seq[Expression]): Projection = globalLock.synchronized { val cleanedExpressions = expressions.map(ExpressionCanonicalizer(_)) projectionCache.getOrElseUpdate(cleanedExpressions, createProjection(cleanedExpressions)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index f26d3d3f7b02..af2aecf110c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -31,7 +31,7 @@ package object codegen { * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala * 2.10. */ - protected val globalLock = new Object() + protected[codegen] val globalLock = new Object() /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index aaba658b56b5..4aed7f91ce15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -56,11 +56,10 @@ case class Generate( val nullValues = Seq.fill(generator.output.size)(Literal(null)) // Used to produce rows with no matches when outer = true. val outerProjection = - new InterpretedProjection(child.output ++ nullValues, child.output) + newProjection(child.output ++ nullValues, child.output) val joinProjection = - new InterpretedProjection( - child.output ++ generator.output, child.output ++ generator.output) + newProjection(child.output ++ generator.output, child.output ++ generator.output) val joinedRow = new JoinedRow iter.flatMap {row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 1dfaa2505f44..ade823b51c9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -302,7 +302,6 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) */ private[parquet] class FilteringParquetRowInputFormat extends parquet.hadoop.ParquetInputFormat[Row] with Logging { - override def createRecordReader( inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b840f058d300..b70104dd5be5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -840,7 +840,6 @@ private[hive] object HiveQl { /* Case insensitive matches */ val COUNT = "(?i)COUNT".r - val SUBSTR = "(?i)SUBSTR".r val AVG = "(?i)AVG".r val SUM = "(?i)SUM".r val MAX = "(?i)MAX".r @@ -853,10 +852,10 @@ private[hive] object HiveQl { val NOT = "(?i)NOT".r val TRUE = "(?i)TRUE".r val FALSE = "(?i)FALSE".r - val IN = "(?i)IN".r val LIKE = "(?i)LIKE".r val RLIKE = "(?i)RLIKE".r val REGEXP = "(?i)REGEXP".r + val IN = "(?i)IN".r val DIV = "(?i)DIV".r val BETWEEN = "(?i)BETWEEN".r val WHEN = "(?i)WHEN".r @@ -956,6 +955,7 @@ private[hive] object HiveQl { target :: minValue :: maxValue :: Nil) => + val targetExpression = nodeToExpr(target) And( GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 15e5e3da8a47..15c40e131d93 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -26,20 +26,6 @@ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(a: Int, b: String) -/** - * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. - */ -class HiveCacheSuite extends HiveComparisonTest { - //cache("src", "key" :: "value" :: Nil, "value" :: Nil) - - println(catalog.lookupRelation(None, "src")) - - println(executeSql("SELECT SUM(key) FROM src GROUP BY value")) - - createQueryTest("Simple Average", - "SELECT SUM(key) FROM src GROUP BY value") -} - /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. */ From fc522d5c10eccd26e2c2a69103e6ac06ef8aa901 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 9 Jul 2014 00:17:16 -0700 Subject: [PATCH 10/28] Hook generated aggregation in to the planner. --- .../sql/catalyst/planning/patterns.scala | 56 +++++++++++++ .../sql/catalyst/CodeGenerationSuite.scala | 0 .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 83 +++++++++---------- .../spark/sql/execution/aggregates.scala | 13 +-- .../spark/sql/execution/PlannerSuite.scala | 8 +- 6 files changed, 110 insertions(+), 52 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CodeGenerationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 026692abe067..0660fd9223fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -104,6 +104,62 @@ object PhysicalOperation extends PredicateHelper { } } +object PartialAggregation { + type ReturnType = + (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => + // Collect all aggregate expressions. + val allAggregates = + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + // Collect all aggregate expressions that can be computed partially. + val partialAggregates = + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + + // Only do partial aggregation if supported by all aggregate expressions. + if (allAggregates.size == partialAggregates.size) { + // Create a map of expressions to their partial evaluations for all aggregate expressions. + val partialEvaluations: Map[Long, SplitEvaluation] = + partialAggregates.map(a => (a.id, a.asPartial)).toMap + + // We need to pass all grouping expressions though so the grouping can happen a second + // time. However some of them might be unnamed so we alias them allowing them to be + // referenced in the second aggregation. + val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) + }.toMap + + // Replace aggregations with a new expression that computes the result from the already + // computed partial evaluations and grouping values. + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + case e: Expression if partialEvaluations.contains(e.id) => + partialEvaluations(e.id).finalEvaluation + case e: Expression if namedGroupingExpressions.contains(e) => + namedGroupingExpressions(e).toAttribute + }).asInstanceOf[Seq[NamedExpression]] + + val partialComputation = + (namedGroupingExpressions.values ++ + partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + + val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq + + Some( + (namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child)) + } else { + None + } + case _ => None + } +} + + /** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CodeGenerationSuite.scala deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index de656a44919d..e8772e0c8cbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -239,7 +239,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val strategies: Seq[Strategy] = CommandStrategy(self) :: TakeOrdered :: - PartialAggregation :: + HashAggregation :: LeftSemiJoin :: HashJoin :: InMemoryScans :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5a88e0085c27..335b8fa4c198 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -95,58 +95,57 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object PartialAggregation extends Strategy { + object HashAggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => - // Collect all aggregate expressions. - val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a }) - // Collect all aggregate expressions that can be computed partially. - val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p }) + // Aggregations that can be performed in two phases, before and after the shuffle. - // Only do partial aggregation if supported by all aggregate expressions. - if (allAggregates.size == partialAggregates.size) { - // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[Long, SplitEvaluation] = - partialAggregates.map(a => (a.id, a.asPartial)).toMap - - // We need to pass all grouping expressions though so the grouping can happen a second - // time. However some of them might be unnamed so we alias them allowing them to be - // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - }.toMap - - // Replace aggregations with a new expression that computes the result from the already - // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { - case e: Expression if partialEvaluations.contains(e.id) => - partialEvaluations(e.id).finalEvaluation - case e: Expression if namedGroupingExpressions.contains(e) => - namedGroupingExpressions(e).toAttribute - }).asInstanceOf[Seq[NamedExpression]] - - val partialComputation = - (namedGroupingExpressions.values ++ - partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq - - // Construct two phased aggregation. - execution.Aggregate( + // Where all aggregates can be codegened. + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) + if canBeCodeGened( + allAggregates(partialComputation) ++ + allAggregates(rewrittenAggregateExpressions))=> + execution.HashAggregate( partial = false, - namedGroupingExpressions.values.map(_.toAttribute).toSeq, + namedGroupingAttributes, rewrittenAggregateExpressions, - execution.Aggregate( + execution.HashAggregate( partial = true, groupingExpressions, partialComputation, planLater(child))(sqlContext))(sqlContext) :: Nil - } else { - Nil - } + + + // Where some aggregate can not be codegened + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) => + execution.Aggregate( + partial = false, + namedGroupingAttributes, + rewrittenAggregateExpressions, + execution.Aggregate( + partial = true, + groupingExpressions, + partialComputation, + planLater(child))(sqlContext))(sqlContext) :: Nil case _ => Nil } + + def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists { + case _: Sum | _: Count => false + case _ => true + } + + def allAggregates(exprs: Seq[Expression]) = + exprs.flatMap(_.collect { case a: AggregateExpression => a }) } object BroadcastNestedLoopJoin extends Strategy { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala index ee62623ecbae..8bcef1803c8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -33,12 +34,14 @@ import org.apache.spark.sql.catalyst.types._ * @param child the input data source. */ case class HashAggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan)(@transient sc: SparkContext) + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan)(@transient sqlContext: SQLContext) extends UnaryNode with NoBind { + private def sc = sqlContext.sparkContext + override def requiredChildDistribution = if (partial) { UnspecifiedDistribution :: Nil @@ -50,7 +53,7 @@ case class HashAggregate( } } - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil def output = aggregateExpressions.map(_.toAttribute) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 215618e852eb..76b172447144 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -39,22 +39,22 @@ class PlannerSuite extends FunSuite { test("count is partially aggregated") { val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed - val planned = PartialAggregation(query).head - val aggregations = planned.collect { case a: Aggregate => a } + val planned = HashAggregation(query).head + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } assert(aggregations.size === 2) } test("count distinct is not partially aggregated") { val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed - val planned = PartialAggregation(query) + val planned = HashAggregation(query) assert(planned.isEmpty) } test("mixed aggregates are not partially aggregated") { val query = testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed - val planned = PartialAggregation(query) + val planned = HashAggregation(query) assert(planned.isEmpty) } } From 9d67d850f6910209c23aa4b46a1154bc9a1db802 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 9 Jul 2014 00:25:01 -0700 Subject: [PATCH 11/28] Fix hive planner --- .../src/main/scala/org/apache/spark/sql/hive/HiveContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7aedfcd74189..473b0756db21 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -229,7 +229,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveTableScans, DataSinks, Scripts, - PartialAggregation, + HashAggregation, LeftSemiJoin, HashJoin, BasicOperators, From ca6cc6bf489b3b2b5546638eaff26ea8e9bf64ae Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 9 Jul 2014 19:42:40 -0700 Subject: [PATCH 12/28] WIP --- .../GeneratedEvaluationSuite.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 9 +- .../spark/sql/execution/aggregates.scala | 100 +++++++++--------- 3 files changed, 58 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index 438735990c11..8335269a0345 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -66,7 +66,7 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { } } - futures.foreach(Await.result(_, 10 seconds)) + futures.foreach(Await.result(_, 10.seconds)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 335b8fa4c198..0fca15a08b52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -99,7 +99,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Aggregations that can be performed in two phases, before and after the shuffle. - // Where all aggregates can be codegened. + // Cases where all aggregates can be codegened. case PartialAggregation( namedGroupingAttributes, rewrittenAggregateExpressions, @@ -109,18 +109,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if canBeCodeGened( allAggregates(partialComputation) ++ allAggregates(rewrittenAggregateExpressions))=> - execution.HashAggregate( + execution.GeneratedAggregate( partial = false, namedGroupingAttributes, rewrittenAggregateExpressions, - execution.HashAggregate( + execution.GeneratedAggregate( partial = true, groupingExpressions, partialComputation, planLater(child))(sqlContext))(sqlContext) :: Nil - // Where some aggregate can not be codegened + // Cases where some aggregate can not be codegened case PartialAggregation( namedGroupingAttributes, rewrittenAggregateExpressions, @@ -136,6 +136,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { groupingExpressions, partialComputation, planLater(child))(sqlContext))(sqlContext) :: Nil + case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala index 8bcef1803c8a..8f9e20ba682a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala @@ -18,14 +18,24 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkContext +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.types._ +case class AggregateEvaluation( + schema: Seq[Attribute], + initialValues: Seq[Expression], + update: Seq[Expression], + result: Expression) + /** - * Attempt to rewrite aggregate to be more efficient. + * :: DeveloperApi :: + * Alternate version of aggregation that leverages projection and thus code generation. + * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto + * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported. * * @param partial if true then aggregation is done partially on local data without shuffling to * ensure all values where `groupingExpressions` are equal are present. @@ -33,7 +43,8 @@ import org.apache.spark.sql.catalyst.types._ * @param aggregateExpressions expressions that are computed for each group. * @param child the input data source. */ -case class HashAggregate( +@DeveloperApi +case class GeneratedAggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], @@ -59,12 +70,11 @@ case class HashAggregate( def execute() = { val aggregatesToCompute = aggregateExpressions.flatMap { a => - a.collect { case agg: AggregateExpression => agg } + a.collect { case agg: AggregateExpression => agg} } - // Move these into expressions... have fall back that uses standard aggregate interface. val computeFunctions = aggregatesToCompute.map { - case c @ Count(expr) => + case c@Count(expr) => val currentCount = AttributeReference("currentCount", LongType, true)() val initialValue = Literal(0L) val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) @@ -83,7 +93,7 @@ case class HashAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case a @ Average(expr) => + case a@Average(expr) => val currentCount = AttributeReference("currentCount", LongType, true)() val currentSum = AttributeReference("currentSum", expr.dataType, true)() val initialCount = Literal(0L) @@ -99,29 +109,11 @@ case class HashAggregate( updateCount :: updateSum :: Nil, result ) - - /* - case otherAggregate => - val ref = - AttributeReference("aggregateFunction", otherAggregate.dataType, otherAggregate.nullable) - - AggregateEvaluation( - ref :: Nil, - ScalaUdf(() => otherAggregate.newInstance, NullType, ), - ) - */ } - @transient val computationSchema = computeFunctions.flatMap(_.schema) - @transient lazy val newComputationBuffer = - newProjection(computeFunctions.flatMap(_.initialValues), child.output) - @transient lazy val updateProjectionBuilder = - newMutableProjection( - computeFunctions.flatMap(_.update), - computeFunctions.flatMap(_.schema) ++ child.output) - @transient lazy val groupProjection = newProjection(groupingExpressions, child.output) + val computationSchema = computeFunctions.flatMap(_.schema) - @transient val resultMap = aggregatesToCompute.zip(computeFunctions).map { + val resultMap = aggregatesToCompute.zip(computeFunctions).map { case (agg, func) => agg.id -> func.result }.toMap @@ -129,39 +121,57 @@ case class HashAggregate( case (ne: NamedExpression, _) => (ne, ne) case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) } - val groupMap = namedGroups.map {case (k,v) => k -> v.toAttribute }.toMap - @transient val resultExpressions = aggregateExpressions.map(_.transform { + val groupMap = namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap + + val resultExpressions = aggregateExpressions.map(_.transform { case e: Expression if resultMap.contains(e.id) => resultMap(e.id) case e: Expression if groupMap.contains(e) => groupMap(e) }) - @transient lazy val resultProjectionBuilder = - newMutableProjection( - resultExpressions, - (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) child.execute().mapPartitions { iter => - // TODO: Skip hashmap for no grouping exprs... - @transient val buffers = new java.util.HashMap[Row, MutableRow]() - @transient val updateProjection = updateProjectionBuilder() - @transient val joinedRow = new JoinedRow + // Builds a new custom class for holding the results of aggregation for a group. + val newAggregationBuffer = + newProjection(computeFunctions.flatMap(_.initialValues), child.output) + + // A projection that is used to update the aggregate values for a group given a new tuple. + // This projection should be targeted at the current values for the group and then applied + // to a joined row of the current values with the new input row. + val updateProjection = + newMutableProjection( + computeFunctions.flatMap(_.update), + computeFunctions.flatMap(_.schema) ++ child.output)() + + // A projection that computes the group given an input tuple. + val groupProjection = newProjection(groupingExpressions, child.output) + + // A projection that produces the final result, given a computation. + val resultProjectionBuilder = + newMutableProjection( + resultExpressions, + (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + + val buffers = new java.util.HashMap[Row, MutableRow]() + val joinedRow = new JoinedRow var currentRow: Row = null - while(iter.hasNext) { + while (iter.hasNext) { currentRow = iter.next() val currentGroup = groupProjection(currentRow) var currentBuffer = buffers.get(currentGroup) - if(currentBuffer == null) { - currentBuffer = newComputationBuffer(EmptyRow).asInstanceOf[MutableRow] + if (currentBuffer == null) { + currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] buffers.put(currentGroup, currentBuffer) } updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) } - @transient val resultIterator = buffers.entrySet.iterator() - @transient val resultProjection = resultProjectionBuilder() new Iterator[Row] { + private[this] val resultIterator = buffers.entrySet.iterator() + private[this] val resultProjection = resultProjectionBuilder() + def hasNext = resultIterator.hasNext + def next() = { val currentGroup = resultIterator.next() resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) @@ -169,10 +179,4 @@ case class HashAggregate( } } } -} - -case class AggregateEvaluation( - schema: Seq[Attribute], - initialValues: Seq[Expression], - update: Seq[Expression], - result: Expression) +} \ No newline at end of file From 4220f1e82cea75cc52baec7ffad1ba3da74bde12 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 10 Jul 2014 22:10:28 -0700 Subject: [PATCH 13/28] Better config, docs, etc. --- .../sql/catalyst/expressions/Projection.scala | 2 +- .../sql/catalyst/expressions/predicates.scala | 3 + .../scala/org/apache/spark/sql/SQLConf.scala | 12 +++ .../org/apache/spark/sql/SQLContext.scala | 2 + .../apache/spark/sql/execution/Generate.scala | 3 + .../spark/sql/execution/SparkPlan.scala | 22 +++++- .../spark/sql/execution/SparkStrategies.scala | 6 +- .../spark/sql/execution/aggregates.scala | 76 +++++++++++-------- 8 files changed, 88 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 1bb0367a2f24..4e170ee8f0f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -22,7 +22,7 @@ package org.apache.spark.sql.catalyst.expressions * new row. If the schema of the input row is specified, then the given expression will be bound to * that schema. */ -class InterpretedProjection(expressions: Seq[Expression]) extends (Row => Row) { +class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b63406b94a4a..87737fbfebcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -23,6 +23,9 @@ import org.apache.spark.sql.catalyst.types.BooleanType object InterpretedPredicate { + def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = + apply(BindReferences.bindReference(expression, inputSchema)) + def apply(expression: Expression): (Row => Boolean) = { (r: Row) => expression.eval(r).asInstanceOf[Boolean] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 95ed0f28507f..4a6bf6bdd8af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -35,6 +35,18 @@ trait SQLConf { /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt + /** + * When set to true, Spark SQL will use the scala compiler at runtime to generate custom bytecode + * that evaluates expressions found in queries. In general this custom code runs much faster + * than interpreted evaluation, but there are significant start-up costs due to compilation. + * As a result codegen is only benificial when queries run for a long time, or when the same + * expressions are used multiple times. + * + * Defaults to false as this feature is currently experimental. + */ + private[spark] def codegenEnabled: Boolean = + if (get("spark.sql.codegen", "true") == "true") true else false + /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to * a broadcast value during the physical executions of join operations. Setting this to 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e8772e0c8cbb..33b8cff336cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -234,6 +234,8 @@ class SQLContext(@transient val sparkContext: SparkContext) val sqlContext: SQLContext = self + def codegenEnabled = self.codegenEnabled + def numPartitions = self.numShufflePartitions val strategies: Seq[Strategy] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 4aed7f91ce15..76e84a79069a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -50,6 +50,9 @@ case class Generate( override def output = if (join) child.output ++ generatorOutput else generatorOutput + /** Codegenned rows are not serializable... */ + override val codegenEnabled = false + override def execute() = { if (join) { child.execute().mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 89a537d341ca..d8bb748c3690 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Logging, Row} +import org.apache.spark.sql.{SQLContext, Logging, Row} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ @@ -35,6 +35,8 @@ import org.apache.spark.sql.catalyst.plans.physical._ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { self: Product => + val codegenEnabled = true + // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! @@ -53,17 +55,29 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { def executeCollect(): Array[Row] = execute().map(_.copy()).collect() def newProjection(expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = - GenerateProjection(expressions, inputSchema) + if (codegenEnabled) { + GenerateProjection(expressions, inputSchema) + } else { + new InterpretedProjection(expressions, inputSchema) + } def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { - GenerateMutableProjection(expressions, inputSchema) + if(codegenEnabled) { + GenerateMutableProjection(expressions, inputSchema) + } else { + () => new InterpretedMutableProjection(expressions, inputSchema) + } } def newPredicate(expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { - GeneratePredicate(expression, inputSchema) + if (codegenEnabled) { + GeneratePredicate(expression, inputSchema) + } else { + InterpretedPredicate(expression, inputSchema) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0fca15a08b52..d0ff88908b1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, execution} +import org.apache.spark.sql.{SQLConf, SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -108,7 +108,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { child) if canBeCodeGened( allAggregates(partialComputation) ++ - allAggregates(rewrittenAggregateExpressions))=> + allAggregates(rewrittenAggregateExpressions)) && + codegenEnabled => execution.GeneratedAggregate( partial = false, namedGroupingAttributes, @@ -119,7 +120,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { partialComputation, planLater(child))(sqlContext))(sqlContext) :: Nil - // Cases where some aggregate can not be codegened case PartialAggregation( namedGroupingAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala index 8f9e20ba682a..84fd9a91e66b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.types._ @@ -51,8 +49,6 @@ case class GeneratedAggregate( child: SparkPlan)(@transient sqlContext: SQLContext) extends UnaryNode with NoBind { - private def sc = sqlContext.sparkContext - override def requiredChildDistribution = if (partial) { UnspecifiedDistribution :: Nil @@ -66,16 +62,16 @@ case class GeneratedAggregate( override def otherCopyArgs = sqlContext :: Nil - def output = aggregateExpressions.map(_.toAttribute) + override def output = aggregateExpressions.map(_.toAttribute) - def execute() = { + override def execute() = { val aggregatesToCompute = aggregateExpressions.flatMap { a => a.collect { case agg: AggregateExpression => agg} } val computeFunctions = aggregatesToCompute.map { - case c@Count(expr) => - val currentCount = AttributeReference("currentCount", LongType, true)() + case c @ Count(expr) => + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() val initialValue = Literal(0L) val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) val result = currentCount @@ -83,7 +79,7 @@ case class GeneratedAggregate( AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) case Sum(expr) => - val currentSum = AttributeReference("currentSum", expr.dataType, true)() + val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() val initialValue = Cast(Literal(0L), expr.dataType) // Coalasce avoids double calculation... @@ -93,9 +89,9 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case a@Average(expr) => - val currentCount = AttributeReference("currentCount", LongType, true)() - val currentSum = AttributeReference("currentSum", expr.dataType, true)() + case a @ Average(expr) => + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() + val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() val initialCount = Literal(0L) val initialSum = Cast(Literal(0L), expr.dataType) val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) @@ -131,50 +127,70 @@ case class GeneratedAggregate( child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. + @transient val newAggregationBuffer = newProjection(computeFunctions.flatMap(_.initialValues), child.output) // A projection that is used to update the aggregate values for a group given a new tuple. // This projection should be targeted at the current values for the group and then applied // to a joined row of the current values with the new input row. + @transient val updateProjection = newMutableProjection( computeFunctions.flatMap(_.update), computeFunctions.flatMap(_.schema) ++ child.output)() // A projection that computes the group given an input tuple. + @transient val groupProjection = newProjection(groupingExpressions, child.output) // A projection that produces the final result, given a computation. + @transient val resultProjectionBuilder = newMutableProjection( resultExpressions, (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) - val buffers = new java.util.HashMap[Row, MutableRow]() val joinedRow = new JoinedRow - var currentRow: Row = null - while (iter.hasNext) { - currentRow = iter.next() - val currentGroup = groupProjection(currentRow) - var currentBuffer = buffers.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - buffers.put(currentGroup, currentBuffer) + if (groupingExpressions.isEmpty) { + // TODO: Codegening anything other than the updateProjection is probably over kill. + val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + updateProjection.target(buffer)(joinedRow(buffer, currentRow)) + } + + val resultProjection = resultProjectionBuilder() + Iterator(resultProjection(buffer)) + } else { + val buffers = new java.util.HashMap[Row, MutableRow]() + + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + val currentGroup = groupProjection(currentRow) + var currentBuffer = buffers.get(currentGroup) + if (currentBuffer == null) { + currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + buffers.put(currentGroup, currentBuffer) + } + // Target the projection at the current aggregation buffer and then project the updated + // values. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) } - updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) - } - new Iterator[Row] { - private[this] val resultIterator = buffers.entrySet.iterator() - private[this] val resultProjection = resultProjectionBuilder() + new Iterator[Row] { + private[this] val resultIterator = buffers.entrySet.iterator() + private[this] val resultProjection = resultProjectionBuilder() - def hasNext = resultIterator.hasNext + def hasNext = resultIterator.hasNext - def next() = { - val currentGroup = resultIterator.next() - resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) + def next() = { + val currentGroup = resultIterator.next() + resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) + } } } } From bc88ecdd9d5d120c09283da2a26a45be7748c642 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 11 Jul 2014 15:58:35 -0700 Subject: [PATCH 14/28] Style --- .../sql/catalyst/expressions/codegen/GeneratePredicate.scala | 2 +- .../execution/{aggregates.scala => GeneratedAggregate.scala} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{aggregates.scala => GeneratedAggregate.scala} (99%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 18031fa98e21..4f18ac376d8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -52,4 +52,4 @@ object GeneratePredicate extends CodeGenerator { log.debug(s"Generated predicate '$predicate':\n$code") toolBox.eval(code).asInstanceOf[Row => Boolean] } -} \ No newline at end of file +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 84fd9a91e66b..63e9d5e74502 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -195,4 +195,4 @@ case class GeneratedAggregate( } } } -} \ No newline at end of file +} From be2cd6b875378b259244c99cb1a64194219c6a89 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 14 Jul 2014 13:15:44 -0700 Subject: [PATCH 15/28] WIP: Remove old method for reference binding, more work on configuration. --- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../catalyst/expressions/BoundAttribute.scala | 50 ++-------- .../expressions/codegen/CodeGenerator.scala | 93 ++++++++++++------- .../codegen/GenerateMutableProjection.scala | 53 +++++------ .../codegen/GenerateOrdering.scala | 15 ++- .../codegen/GeneratePredicate.scala | 17 +--- .../codegen/GenerateProjection.scala | 18 ++-- .../expressions/codegen/package.scala | 2 - .../sql/catalyst/plans/logical/commands.scala | 14 +-- .../ExpressionEvaluationSuite.scala | 16 ++-- .../GeneratedEvaluationSuite.scala | 14 +-- .../org/apache/spark/sql/SQLContext.scala | 20 +++- .../spark/sql/execution/Aggregate.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 4 +- .../apache/spark/sql/execution/Generate.scala | 2 +- .../sql/execution/GeneratedAggregate.scala | 8 +- .../spark/sql/execution/SparkPlan.scala | 13 ++- .../spark/sql/execution/basicOperators.scala | 10 +- 18 files changed, 174 insertions(+), 179 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 1b503b957d14..c32f0e9b9f07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -172,7 +172,7 @@ package object dsl { // Protobuf terminology def required = a.withNullability(false) - def at(ordinal: Int) = BoundReference(ordinal, a) + def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 9ce1f0105646..e50b724bff30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.trees + import org.apache.spark.sql.Logging /** @@ -28,61 +30,27 @@ import org.apache.spark.sql.Logging * to be retrieved more efficiently. However, since operations like column pruning can change * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ -case class BoundReference(ordinal: Int, baseReference: Attribute) - extends Attribute with trees.LeafNode[Expression] { +case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) + extends Expression with trees.LeafNode[Expression] { type EvaluatedType = Any - override def nullable = baseReference.nullable - override def dataType = baseReference.dataType - override def exprId = baseReference.exprId - override def qualifiers = baseReference.qualifiers - override def name = baseReference.name + def references = Set.empty - override def newInstance = BoundReference(ordinal, baseReference.newInstance) - override def withNullability(newNullability: Boolean) = - BoundReference(ordinal, baseReference.withNullability(newNullability)) - override def withQualifiers(newQualifiers: Seq[String]) = - BoundReference(ordinal, baseReference.withQualifiers(newQualifiers)) - - override def toString = s"$baseReference:$ordinal" + override def toString = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) } -/** - * Used to denote operators that do their own binding of attributes internally. - */ -trait NoBind { self: trees.TreeNode[_] => } - -class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { - import BindReferences._ - - def apply(plan: TreeNode): TreeNode = { - plan.transform { - case n: NoBind => n.asInstanceOf[TreeNode] - case leafNode if leafNode.children.isEmpty => leafNode - case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e => - bindReference(e, unaryNode.children.head.output) - } - } - } -} - object BindReferences extends Logging { def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) if (ordinal == -1) { - // TODO: This fallback is required because some operators (such as ScriptTransform) - // produce new attributes that can't be bound. Likely the right thing to do is remove - // this rule and require all operators to explicitly bind to the input schema that - // they specify. - logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") - a + sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") } else { - BoundReference(ordinal, a) + BoundReference(ordinal, a.dataType, a.nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cfc1a0596fbb..8bf676b803a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import com.google.common.cache.{CacheLoader, CacheBuilder} + import scala.language.existentials import org.apache.spark.Logging @@ -26,27 +28,53 @@ import org.apache.spark.sql.catalyst.types._ /** * A base class for generators of byte code that performs expression evaluation. Includes helpers - * for refering to Catalyst types and building trees that perform evaluation of individual + * for referring to Catalyst types and building trees that perform evaluation of individual * expressions. */ -abstract class CodeGenerator extends Logging { +abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ import scala.tools.reflect.ToolBox - val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() + protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() - val rowType = typeOf[Row] - val mutableRowType = typeOf[MutableRow] - val genericRowType = typeOf[GenericRow] - val genericMutableRowType = typeOf[GenericMutableRow] + protected val rowType = typeOf[Row] + protected val mutableRowType = typeOf[MutableRow] + protected val genericRowType = typeOf[GenericRow] + protected val genericMutableRowType = typeOf[GenericMutableRow] - val projectionType = typeOf[Projection] - val mutableProjectionType = typeOf[MutableProjection] + protected val projectionType = typeOf[Projection] + protected val mutableProjectionType = typeOf[MutableProjection] private val curId = new java.util.concurrent.atomic.AtomicInteger() - private val javaSeperator = "$" + private val javaSeparator = "$" + + /** + * Generates a class for a given input expression. Called when there is not a cached code + * already available. + */ + protected def create(in: InType): OutType + + /** Canonicalizes an input expression. */ + protected def canonicalize(in: InType): InType + + /** Binds an input expression to a given input schema */ + protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + + protected val cache = CacheBuilder.newBuilder() + .maximumSize(1000) + .build( + new CacheLoader[InType, OutType]() { + override def load(in: InType): OutType = globalLock.synchronized { + create(in) + } + }) + + def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType= + apply(bind(expressions, inputSchema)) + + def apply(expressions: InType): OutType = cache.get(canonicalize(expressions)) /** * Returns a term name that is unique within this instance of a `CodeGenerator`. @@ -55,7 +83,7 @@ abstract class CodeGenerator extends Logging { * function.) */ protected def freshName(prefix: String): TermName = { - newTermName(s"$prefix$javaSeperator${curId.getAndIncrement}") + newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}") } /** @@ -66,7 +94,7 @@ abstract class CodeGenerator extends Logging { * to null. * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not * valid if `nullTerm` is set to `false`. - * @param objectTerm An possibly boxed version of the result of evaluating this expression. + * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ protected case class EvaluatedExpression( code: Seq[Tree], @@ -87,7 +115,7 @@ abstract class CodeGenerator extends Logging { def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = { val eval = expressionEvaluator(e) eval.code ++ - q""" + q""" val $nullTerm = ${eval.nullTerm} val $primitiveTerm = if($nullTerm) @@ -119,7 +147,7 @@ abstract class CodeGenerator extends Logging { val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) eval1.code ++ eval2.code ++ - q""" + q""" val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm} val $primitiveTerm: ${termForType(resultType)} = if($nullTerm) { @@ -135,14 +163,15 @@ abstract class CodeGenerator extends Logging { // TODO: Skip generation of null handling code when expression are not nullable. val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { - case b @ BoundReference(ordinal, _) => + case b @ BoundReference(ordinal, dataType, nullable) => + val nullValue = if (nullable) q"$inputTuple.isNullAt($ordinal)" else q"false" q""" - val $nullTerm: Boolean = $inputTuple.isNullAt($ordinal) - val $primitiveTerm: ${termForType(b.dataType)} = + val $nullTerm: Boolean = $nullValue + val $primitiveTerm: ${termForType(dataType)} = if($nullTerm) - ${defaultPrimitive(e.dataType)} + ${defaultPrimitive(dataType)} else - ${getColumn(inputTuple, b.dataType, ordinal)} + ${getColumn(inputTuple, dataType, ordinal)} """.children case expressions.Literal(null, dataType) => @@ -162,11 +191,13 @@ abstract class CodeGenerator extends Logging { val $nullTerm = ${value == null} val $primitiveTerm: ${termForType(dataType)} = $value """.children + case expressions.Literal(value: Int, dataType) => q""" val $nullTerm = ${value == null} val $primitiveTerm: ${termForType(dataType)} = $value """.children + case expressions.Literal(value: Long, dataType) => q""" val $nullTerm = ${value == null} @@ -176,7 +207,7 @@ abstract class CodeGenerator extends Logging { case Cast(e @ BinaryType(), StringType) => val eval = expressionEvaluator(e) eval.code ++ - q""" + q""" val $nullTerm = ${eval.nullTerm} val $primitiveTerm = if($nullTerm) @@ -200,7 +231,7 @@ abstract class CodeGenerator extends Logging { case Cast(e, StringType) => val eval = expressionEvaluator(e) eval.code ++ - q""" + q""" val $nullTerm = ${eval.nullTerm} val $primitiveTerm = if($nullTerm) @@ -251,7 +282,7 @@ abstract class CodeGenerator extends Logging { val eval2 = expressionEvaluator(e2) eval1.code ++ eval2.code ++ - q""" + q""" var $nullTerm = false var $primitiveTerm: ${termForType(BooleanType)} = false @@ -272,7 +303,7 @@ abstract class CodeGenerator extends Logging { val eval2 = expressionEvaluator(e2) eval1.code ++ eval2.code ++ - q""" + q""" var $nullTerm = false var $primitiveTerm: ${termForType(BooleanType)} = false @@ -360,10 +391,10 @@ abstract class CodeGenerator extends Logging { log.debug(s"No rules to generate $e") val tree = reify { e } q""" - val $objectTerm = $tree.eval(i) - val $nullTerm = $objectTerm == null - val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] - """.children + val $objectTerm = $tree.eval(i) + val $nullTerm = $objectTerm == null + val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] + """.children } EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) @@ -377,10 +408,10 @@ abstract class CodeGenerator extends Logging { } protected def setColumn( - destinationRow: TermName, - dataType: DataType, - ordinal: Int, - value: TermName) = { + destinationRow: TermName, + dataType: DataType, + ordinal: Int, + value: TermName) = { dataType match { case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" case _ => q"$destinationRow.update($ordinal, $value)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index f1f8eb79e401..a419fd7ecb39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -23,31 +23,24 @@ import org.apache.spark.sql.catalyst.expressions._ * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new * input [[Row]] for a fixed set of [[Expression Expressions]]. */ -object GenerateMutableProjection extends CodeGenerator { +object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ - // TODO: Should be weak references... bounded in size. - val projectionCache = new collection.mutable.HashMap[Seq[Expression], () => MutableProjection] - - def apply(expressions: Seq[Expression], inputSchema: Seq[Attribute]): (() => MutableProjection) = - apply(expressions.map(BindReferences.bindReference(_, inputSchema))) + val mutableRowName = newTermName("mutableRow") - // TODO: Safe to fire up multiple instances of the compiler? - def apply(expressions: Seq[Expression]): () => MutableProjection = - globalLock.synchronized { - val cleanedExpressions = expressions.map(ExpressionCanonicalizer(_)) - projectionCache.getOrElseUpdate(cleanedExpressions, createProjection(cleanedExpressions)) - } + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer(_)) - val mutableRowName = newTermName("mutableRow") + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) - def createProjection(expressions: Seq[Expression]): (() => MutableProjection) = { + protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) => val evaluationCode = expressionEvaluator(e) evaluationCode.code :+ - q""" + q""" if(${evaluationCode.nullTerm}) mutableRow.setNullAt($i) else @@ -57,25 +50,25 @@ object GenerateMutableProjection extends CodeGenerator { val code = q""" - () => { new $mutableProjectionType { + () => { new $mutableProjectionType { - private[this] var $mutableRowName: $mutableRowType = - new $genericMutableRowType(${expressions.size}) + private[this] var $mutableRowName: $mutableRowType = + new $genericMutableRowType(${expressions.size}) - def target(row: $mutableRowType): $mutableProjectionType = { - $mutableRowName = row - this - } + def target(row: $mutableRowType): $mutableProjectionType = { + $mutableRowName = row + this + } - /* Provide immutable access to the last projected row. */ - def currentValue: $rowType = mutableRow + /* Provide immutable access to the last projected row. */ + def currentValue: $rowType = mutableRow - def apply(i: $rowType): $rowType = { - ..$projectionCode - mutableRow - } - } } - """ + def apply(i: $rowType): $rowType = { + ..$projectionCode + mutableRow + } + } } + """ log.debug(s"code for ${expressions.mkString(",")}:\n$code") toolBox.eval(code).asInstanceOf[() => MutableProjection] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 03804e4dc1a0..9809fb536a3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -23,20 +23,17 @@ import org.apache.spark.sql.catalyst.expressions._ * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of * [[Expression Expressions]]. */ -object GenerateOrdering extends CodeGenerator { +object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] { import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ - // TODO: Should be weak references... bounded in size. - val orderingCache = new collection.mutable.HashMap[Seq[SortOrder], Ordering[Row]] + protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = + in.map(ExpressionCanonicalizer(_).asInstanceOf[SortOrder]) - // TODO: Safe to fire up multiple instances of the compiler? - def apply(ordering: Seq[SortOrder]): Ordering[Row] = globalLock.synchronized { - val cleanedExpression = ordering.map(ExpressionCanonicalizer(_)).asInstanceOf[Seq[SortOrder]] - orderingCache.getOrElseUpdate(cleanedExpression, createOrdering(cleanedExpression)) - } + protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = + in.map(BindReferences.bindReference(_, inputSchema)) - def createOrdering(ordering: Seq[SortOrder]): Ordering[Row] = { + protected def create(ordering: Seq[SortOrder]): Ordering[Row] = { val a = newTermName("a") val b = newTermName("b") val comparisons = ordering.zipWithIndex.map { case (order, i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 4f18ac376d8f..2a0935c790cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -22,23 +22,16 @@ import org.apache.spark.sql.catalyst.expressions._ /** * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. */ -object GeneratePredicate extends CodeGenerator { +object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ - // TODO: Should be weak references... bounded in size. - val predicateCache = new collection.mutable.HashMap[Expression, (Row) => Boolean] + protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer(in) - // TODO: Safe to fire up multiple instances of the compiler? - def apply(predicate: Expression): (Row => Boolean) = globalLock.synchronized { - val cleanedExpression = ExpressionCanonicalizer(predicate) - predicateCache.getOrElseUpdate(cleanedExpression, createPredicate(cleanedExpression)) - } - - def apply(predicate: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = - apply(BindReferences.bindReference(predicate, inputSchema)) + protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = + BindReferences.bindReference(in, inputSchema) - def createPredicate(predicate: Expression): ((Row) => Boolean) = { + protected def create(predicate: Expression): ((Row) => Boolean) = { val cEval = expressionEvaluator(predicate) val code = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c94227454bb5..45bfba819827 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -26,24 +26,18 @@ import org.apache.spark.sql.catalyst.types._ * [[Expression Expressions]] and a given input [[Row]]. The returned [[Row]] object is custom * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. */ -object GenerateProjection extends CodeGenerator { +object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ - // TODO: Should be weak references... bounded in size. - val projectionCache = new collection.mutable.HashMap[Seq[Expression], Projection] + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer(_)) - def apply(expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = - apply(expressions.map(BindReferences.bindReference(_, inputSchema))) - - // TODO: Safe to fire up multiple instances of the compiler? - def apply(expressions: Seq[Expression]): Projection = globalLock.synchronized { - val cleanedExpressions = expressions.map(ExpressionCanonicalizer(_)) - projectionCache.getOrElseUpdate(cleanedExpressions, createProjection(cleanedExpressions)) - } + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) // Make Mutablility optional... - def createProjection(expressions: Seq[Expression]): Projection = { + protected def create(expressions: Seq[Expression]): Projection = { val tupleLength = ru.Literal(Constant(expressions.length)) val lengthDef = q"final val length = $tupleLength" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index af2aecf110c6..846d67b53801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -40,8 +40,6 @@ package object codegen { object CleanExpressions extends rules.Rule[Expression] { def apply(e: Expression): Expression = e transform { - case BoundReference(o, a) => - BoundReference(o, AttributeReference("a", a.dataType, a.nullable)(exprId = ExprId(0))) case Alias(c, _) => c } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 1d5f033f0d27..15069726c2c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -35,7 +35,7 @@ abstract class Command extends LeafNode { */ case class NativeCommand(cmd: String) extends Command { override def output = - Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)())) + Seq(AttributeReference("result", StringType, nullable = false)()) } /** @@ -43,8 +43,8 @@ case class NativeCommand(cmd: String) extends Command { */ case class SetCommand(key: Option[String], value: Option[String]) extends Command { override def output = Seq( - BoundReference(0, AttributeReference("key", StringType, nullable = false)()), - BoundReference(1, AttributeReference("value", StringType, nullable = false)())) + AttributeReference("key", StringType, nullable = false)(), + AttributeReference("value", StringType, nullable = false)()) } /** @@ -53,7 +53,7 @@ case class SetCommand(key: Option[String], value: Option[String]) extends Comman */ case class ExplainCommand(plan: LogicalPlan) extends Command { override def output = - Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)())) + Seq(AttributeReference("plan", StringType, nullable = false)()) } /** @@ -72,7 +72,7 @@ case class DescribeCommand( isExtended: Boolean) extends Command { override def output = Seq( // Column names are based on Hive. - BoundReference(0, AttributeReference("col_name", StringType, nullable = false)()), - BoundReference(1, AttributeReference("data_type", StringType, nullable = false)()), - BoundReference(2, AttributeReference("comment", StringType, nullable = false)())) + AttributeReference("col_name", StringType, nullable = false)(), + AttributeReference("data_type", StringType, nullable = false)(), + AttributeReference("comment", StringType, nullable = false)()) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 56840d3491b1..8134f376bb95 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -400,21 +400,21 @@ class ExpressionEvaluationSuite extends FunSuite { val typeMap = MapType(StringType, StringType) val typeArray = ArrayType(StringType) - checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), + checkEvaluation(GetItem(BoundReference(3, typeMap, nullable = true), Literal("aa")), "bb", row) checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row) checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), + checkEvaluation(GetItem(BoundReference(3, typeMap, nullable = true), Literal(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), + checkEvaluation(GetItem(BoundReference(4, typeArray, nullable = true), Literal(1)), "bb", row) checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row) checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row) - checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), + checkEvaluation(GetItem(BoundReference(4, typeArray, nullable = true), Literal(null, IntegerType)), null, row) - checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row) + checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) checkEvaluation(GetField(Literal(null, typeS), "a"), null, row) val typeS_notNullable = StructType( @@ -422,10 +422,8 @@ class ExpressionEvaluationSuite extends FunSuite { :: StructField("b", StringType, nullable = false) :: Nil ) - assert(GetField(BoundReference(2, - AttributeReference("c", typeS)()), "a").nullable === true) - assert(GetField(BoundReference(2, - AttributeReference("c", typeS_notNullable, nullable = false)()), "a").nullable === false) + assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) + assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) assert(GetField(Literal(null, typeS), "a").nullable === true) assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index 8335269a0345..224c2a059a2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -25,8 +25,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ * Overrides our expression evaluation tests to use code generation for evaluation. */ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { - val generator = new CodeGenerator() {} - override def checkEvaluation( expression: Expression, expected: Any, @@ -35,7 +33,7 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)() } catch { case e: Throwable => - val evaluated = generator.expressionEvaluator(expression) + val evaluated = GenerateProjection.expressionEvaluator(expression) fail( s""" |Code generation of $expression failed: @@ -74,13 +72,11 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { * Overrides our expression evaluation tests to use generated code on mutable rows. */ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { - val generator = new CodeGenerator() {} - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - lazy val evaluated = generator.expressionEvaluator(expression) + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + lazy val evaluated = GenerateProjection.expressionEvaluator(expression) val plan = try { GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 33b8cff336cf..d904e521989b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies @@ -299,14 +299,23 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) /** - * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and - * inserting shuffle operations as needed. + * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed. */ @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = Batch("Add exchange", Once, AddExchange(self)) :: - Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil + Batch("CodeGen", Once, TurnOnCodeGen) :: Nil + } + + protected object TurnOnCodeGen extends Rule[SparkPlan] { + def apply(plan: SparkPlan): SparkPlan = { + if (self.codegenEnabled) { + plan.foreach(p => println(p.simpleString)) + plan.foreach(_._codegenEnabled = true) + } + plan + } } /** @@ -341,6 +350,9 @@ class SQLContext(@transient val sparkContext: SparkContext) |${stringOrError(optimizedPlan)} |== Physical Plan == |${stringOrError(executedPlan)} + |Code Generation: ${executedPlan.codegenEnabled} + |== RDD == + |${stringOrError(toRdd.toDebugString)} """.stripMargin.trim } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 411ae09cee3d..9e04e0845218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -43,7 +43,7 @@ case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], child: SparkPlan)(@transient sqlContext: SQLContext) - extends UnaryNode with NoBind { + extends UnaryNode { override def requiredChildDistribution = if (partial) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 3d01a29ec43d..392a7f3be390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{NoBind, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair * :: DeveloperApi :: */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind { +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { override def outputPartitioning = newPartitioning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 76e84a79069a..b81150278f4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -51,7 +51,7 @@ case class Generate( if (join) child.output ++ generatorOutput else generatorOutput /** Codegenned rows are not serializable... */ - override val codegenEnabled = false + override def codegenEnabled = false override def execute() = { if (join) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 63e9d5e74502..394f748cfb33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -47,7 +47,9 @@ case class GeneratedAggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], child: SparkPlan)(@transient sqlContext: SQLContext) - extends UnaryNode with NoBind { + extends UnaryNode { + + println(s"new $codegenEnabled") override def requiredChildDistribution = if (partial) { @@ -65,6 +67,7 @@ case class GeneratedAggregate( override def output = aggregateExpressions.map(_.toAttribute) override def execute() = { + println(s"codegen: $codegenEnabled") val aggregatesToCompute = aggregateExpressions.flatMap { a => a.collect { case agg: AggregateExpression => agg} } @@ -157,6 +160,8 @@ case class GeneratedAggregate( // TODO: Codegening anything other than the updateProjection is probably over kill. val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] var currentRow: Row = null + println(codegenEnabled) + while (iter.hasNext) { currentRow = iter.next() updateProjection.target(buffer)(joinedRow(buffer, currentRow)) @@ -167,6 +172,7 @@ case class GeneratedAggregate( } else { val buffers = new java.util.HashMap[Row, MutableRow]() + println(codegenEnabled) var currentRow: Row = null while (iter.hasNext) { currentRow = iter.next() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index d8bb748c3690..371d4307f7c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -35,7 +35,10 @@ import org.apache.spark.sql.catalyst.plans.physical._ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { self: Product => - val codegenEnabled = true + def codegenEnabled = _codegenEnabled + + /** Will be set to true during planning if code generation should be used for this operator. */ + private[sql] var _codegenEnabled = false // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ @@ -79,6 +82,14 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { InterpretedPredicate(expression, inputSchema) } } + + def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { + if (codegenEnabled) { + GenerateOrdering(order, inputSchema) + } else { + new RowOrdering(order, inputSchema) + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 7f7984a41efe..7e9478843bf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -194,15 +194,13 @@ case class Sort( override def requiredChildDistribution = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - @transient - lazy val ordering = new RowOrdering(sortOrder) override def execute() = attachTree(this, "sort") { - // TODO: Optimize sorting operation? child.execute() - .mapPartitions( - iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator, - preservesPartitioning = true) + .mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + iterator.map(_.copy()).toArray.sorted(ordering).iterator + }, preservesPartitioning = true) } override def output = child.output From d2ad5c54c4c309fda49ece99d744839693d6f879 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 22 Jul 2014 00:04:40 -0700 Subject: [PATCH 16/28] Refactor putting SQLContext into SparkPlan. Fix ordering, other test cases. --- .../codegen/GenerateOrdering.scala | 25 +++++++++-- .../org/apache/spark/sql/SQLContext.scala | 18 +++----- .../spark/sql/execution/Aggregate.scala | 4 +- .../apache/spark/sql/execution/Generate.scala | 8 ++-- .../sql/execution/GeneratedAggregate.scala | 9 +--- .../spark/sql/execution/SparkPlan.scala | 44 +++++++++++++++---- .../spark/sql/execution/SparkStrategies.scala | 26 +++++------ .../spark/sql/execution/basicOperators.scala | 19 +++----- .../spark/sql/execution/debug/package.scala | 8 ++-- .../apache/spark/sql/execution/joins.scala | 19 +++----- .../sql/parquet/ParquetTableOperations.scala | 12 ++--- .../org/apache/spark/sql/QueryTest.scala | 1 + .../apache/spark/sql/execution/TgfSuite.scala | 2 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 4 +- 14 files changed, 104 insertions(+), 95 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 9809fb536a3c..4211998f7511 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import com.typesafe.scalalogging.slf4j.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{StringType, NumericType} /** * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of * [[Expression Expressions]]. */ -object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] { +object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging { import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ @@ -40,6 +42,22 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] { val evalA = expressionEvaluator(order.child) val evalB = expressionEvaluator(order.child) + val compare = order.child.dataType match { + case _: NumericType => + q""" + val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} + if(comp != 0) { + return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"} + } + """ + case StringType => + if (order.direction == Ascending) { + q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})""" + } else { + q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})""" + } + } + q""" i = $a ..${evalA.code} @@ -52,9 +70,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] { } else if (${evalB.nullTerm}) { return ${if (order.direction == Ascending) q"1" else q"-1"} } else { - i = a - val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} - if(comp != 0) return comp.toInt + $compare } """ } @@ -76,6 +92,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] { } new $orderingName() """ + logger.debug(s"Generated Ordering: $code") toolBox.eval(code).asInstanceOf[Ordering[Row]] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d904e521989b..d64f5cba2cb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -304,18 +304,7 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: - Batch("CodeGen", Once, TurnOnCodeGen) :: Nil - } - - protected object TurnOnCodeGen extends Rule[SparkPlan] { - def apply(plan: SparkPlan): SparkPlan = { - if (self.codegenEnabled) { - plan.foreach(p => println(p.simpleString)) - plan.foreach(_._codegenEnabled = true) - } - plan - } + Batch("Add exchange", Once, AddExchange(self)) :: Nil } /** @@ -330,7 +319,10 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val analyzed = analyzer(logical) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... - lazy val sparkPlan = planner(optimizedPlan).next() + lazy val sparkPlan = { + SparkPlan.currentContext.set(self) + planner(optimizedPlan).next() + } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 9e04e0845218..463a1d32d7fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -42,7 +42,7 @@ case class Aggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], - child: SparkPlan)(@transient sqlContext: SQLContext) + child: SparkPlan) extends UnaryNode { override def requiredChildDistribution = @@ -56,8 +56,6 @@ case class Aggregate( } } - override def otherCopyArgs = sqlContext :: Nil - // HACK: Generators don't correctly preserve their output through serializations so we grab // out child's output attributes statically here. private[this] val childOutput = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index b81150278f4f..a28ce5869dad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -51,9 +51,11 @@ case class Generate( if (join) child.output ++ generatorOutput else generatorOutput /** Codegenned rows are not serializable... */ - override def codegenEnabled = false + override val codegenEnabled = false override def execute() = { + val boundGenerator = BindReferences.bindReference(generator, child.output) + if (join) { child.execute().mapPartitions { iter => val nullValues = Seq.fill(generator.output.size)(Literal(null)) @@ -66,7 +68,7 @@ case class Generate( val joinedRow = new JoinedRow iter.flatMap {row => - val outputRows = generator.eval(row) + val outputRows = boundGenerator.eval(row) if (outer && outputRows.isEmpty) { outerProjection(row) :: Nil } else { @@ -75,7 +77,7 @@ case class Generate( } } } else { - child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row))) + child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row))) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 394f748cfb33..5c4b615a14bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -46,11 +46,9 @@ case class GeneratedAggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], - child: SparkPlan)(@transient sqlContext: SQLContext) + child: SparkPlan) extends UnaryNode { - println(s"new $codegenEnabled") - override def requiredChildDistribution = if (partial) { UnspecifiedDistribution :: Nil @@ -62,12 +60,9 @@ case class GeneratedAggregate( } } - override def otherCopyArgs = sqlContext :: Nil - override def output = aggregateExpressions.map(_.toAttribute) override def execute() = { - println(s"codegen: $codegenEnabled") val aggregatesToCompute = aggregateExpressions.flatMap { a => a.collect { case agg: AggregateExpression => agg} } @@ -160,7 +155,6 @@ case class GeneratedAggregate( // TODO: Codegening anything other than the updateProjection is probably over kill. val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] var currentRow: Row = null - println(codegenEnabled) while (iter.hasNext) { currentRow = iter.next() @@ -172,7 +166,6 @@ case class GeneratedAggregate( } else { val buffers = new java.util.HashMap[Row, MutableRow]() - println(codegenEnabled) var currentRow: Row = null while (iter.hasNext) { currentRow = iter.next() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 371d4307f7c5..5e5ce21bea65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Logging, Row} +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ @@ -28,17 +29,35 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.BaseRelation import org.apache.spark.sql.catalyst.plans.physical._ + +object SparkPlan { + protected[sql] val currentContext = new ThreadLocal[SQLContext]() +} + /** * :: DeveloperApi :: */ @DeveloperApi -abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { +abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { self: Product => - def codegenEnabled = _codegenEnabled + /** + * A handle to the SQL Context that was used to create this plan. Since many operators need + * access to the sqlContext for RDD operations or configuration this field is automatically + * populated by the query planning infrastructure. + */ + @transient + protected val sqlContext = SparkPlan.currentContext.get() - /** Will be set to true during planning if code generation should be used for this operator. */ - private[sql] var _codegenEnabled = false + protected def sparkContext = sqlContext.sparkContext + + def logger = log + + val codegenEnabled: Boolean = if(sqlContext != null) { + sqlContext.codegenEnabled + } else { + false + } // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ @@ -57,16 +76,22 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { */ def executeCollect(): Array[Row] = execute().map(_.copy()).collect() - def newProjection(expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = + protected def newProjection( + expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { + log.debug( + s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled) { GenerateProjection(expressions, inputSchema) } else { new InterpretedProjection(expressions, inputSchema) } + } - def newMutableProjection( + protected def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug( + s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if(codegenEnabled) { GenerateMutableProjection(expressions, inputSchema) } else { @@ -75,7 +100,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { } - def newPredicate(expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { + protected def newPredicate( + expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { if (codegenEnabled) { GeneratePredicate(expression, inputSchema) } else { @@ -83,7 +109,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { } } - def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { + protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { if (codegenEnabled) { GenerateOrdering(order, inputSchema) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d0ff88908b1c..53ad6b3ed811 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -39,7 +39,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition)(sqlContext) :: Nil + planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -58,7 +58,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { condition: Option[Expression], side: BuildSide) = { val broadcastHashJoin = execution.BroadcastHashJoin( - leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext) + leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } @@ -118,7 +118,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { partial = true, groupingExpressions, partialComputation, - planLater(child))(sqlContext))(sqlContext) :: Nil + planLater(child))) :: Nil // Cases where some aggregate can not be codegened case PartialAggregation( @@ -135,7 +135,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { partial = true, groupingExpressions, partialComputation, - planLater(child))(sqlContext))(sqlContext) :: Nil + planLater(child))) :: Nil case _ => Nil } @@ -153,7 +153,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => execution.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil + planLater(left), planLater(right), joinType, condition) :: Nil case _ => Nil } } @@ -175,7 +175,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => - execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil + execution.TakeOrdered(limit, order, planLater(child)) :: Nil case _ => Nil } } @@ -187,9 +187,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val relation = ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) // Note: overwrite=false because otherwise the metadata we just created will be deleted - InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil + InsertIntoParquetTable(relation, planLater(child), overwrite=false) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => - InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil + InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { @@ -218,7 +218,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, prunePushedDownFilters, - ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil + ParquetTableScan(_, relation, filters)) :: Nil case _ => Nil } @@ -243,7 +243,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Distinct(child) => execution.Aggregate( - partial = false, child.output, child.output, planLater(child))(sqlContext) :: Nil + partial = false, child.output, child.output, planLater(child)) :: Nil case logical.Sort(sortExprs, child) => // This sort is a global sort. Its requiredDistribution will be an OrderedDistribution. execution.Sort(sortExprs, global = true, planLater(child)):: Nil @@ -256,7 +256,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil + execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => @@ -264,9 +264,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { output, ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => - execution.Limit(limit, planLater(child))(sqlContext) :: Nil + execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => - execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil + execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left,right) => execution.Except(planLater(left),planLater(right)) :: Nil case logical.Intersect(left, right) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 7e9478843bf8..31dcf01eb6f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -76,12 +76,10 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: * :: DeveloperApi :: */ @DeveloperApi -case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan { +case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output = children.head.output - override def execute() = sqlContext.sparkContext.union(children.map(_.execute())) - - override def otherCopyArgs = sqlContext :: Nil + override def execute() = sparkContext.union(children.map(_.execute())) } /** @@ -93,13 +91,11 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex * repartition all the data to a single partition to compute the global limit. */ @DeveloperApi -case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext) +case class Limit(limit: Int, child: SparkPlan) extends UnaryNode { // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: // partition local limit -> exchange into one partition -> partition local limit again - override def otherCopyArgs = sqlContext :: Nil - override def output = child.output /** @@ -165,21 +161,18 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. */ @DeveloperApi -case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) - (@transient sqlContext: SQLContext) extends UnaryNode { - override def otherCopyArgs = sqlContext :: Nil +case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def output = child.output - @transient - lazy val ordering = new RowOrdering(sortOrder) + val ordering = new RowOrdering(sortOrder, child.output) // TODO: Is this copying for no reason? override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1) + override def execute() = sparkContext.makeRDD(executeCollect(), 1) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c6fbd6d2f693..5ef46c32d44b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -41,13 +41,13 @@ package object debug { */ @DeveloperApi implicit class DebugQuery(query: SchemaRDD) { - def debug(implicit sc: SparkContext): Unit = { + def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[Long]() val debugPlan = plan transform { case s: SparkPlan if !visited.contains(s.id) => visited += s.id - DebugNode(sc, s) + DebugNode(s) } println(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { @@ -57,9 +57,7 @@ package object debug { } } - private[sql] case class DebugNode( - @transient sparkContext: SparkContext, - child: SparkPlan) extends UnaryNode { + private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { def references = Set.empty def output = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 8d3d0dc307fe..6d85e0658e59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -219,9 +219,8 @@ case class BroadcastHashJoin( rightKeys: Seq[Expression], buildSide: BuildSide, left: SparkPlan, - right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin { + right: SparkPlan) extends BinaryNode with HashJoin { - override def otherCopyArgs = sqlContext :: Nil override def outputPartitioning: Partitioning = left.outputPartitioning @@ -230,7 +229,7 @@ case class BroadcastHashJoin( @transient lazy val broadcastFuture = future { - sqlContext.sparkContext.broadcast(buildPlan.executeCollect()) + sparkContext.broadcast(buildPlan.executeCollect()) } def execute() = { @@ -250,14 +249,11 @@ case class BroadcastHashJoin( @DeveloperApi case class LeftSemiJoinBNL( streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sqlContext :: Nil - def output = left.output /** The Streamed Relation */ @@ -273,7 +269,7 @@ case class LeftSemiJoinBNL( def execute() = { val broadcastedRelation = - sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow @@ -321,14 +317,11 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod @DeveloperApi case class BroadcastNestedLoopJoin( streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression]) - (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sqlContext :: Nil - override def output = { joinType match { case LeftOuter => @@ -355,7 +348,7 @@ case class BroadcastNestedLoopJoin( def execute() = { val broadcastedRelation = - sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => val matchedRows = new ArrayBuffer[Row] @@ -407,7 +400,7 @@ case class BroadcastNestedLoopJoin( } // TODO: Breaks lineage. - sqlContext.sparkContext.union( - streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches)) + sparkContext.union( + streamedPlusMatches.flatMap(_._1), sparkContext.makeRDD(rightOuterMatches)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ade823b51c9c..78b1094abd80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -48,8 +48,7 @@ case class ParquetTableScan( // https://issues.apache.org/jira/browse/SPARK-1367 output: Seq[Attribute], relation: ParquetRelation, - columnPruningPred: Seq[Expression])( - @transient val sqlContext: SQLContext) + columnPruningPred: Seq[Expression]) extends LeafNode { override def execute(): RDD[Row] = { @@ -94,8 +93,6 @@ case class ParquetTableScan( .filter(_ != null) // Parquet's record filters may produce null values } - override def otherCopyArgs = sqlContext :: Nil - /** * Applies a (candidate) projection. * @@ -105,7 +102,7 @@ case class ParquetTableScan( def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { val success = validateProjection(prunedAttributes) if (success) { - ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext) + ParquetTableScan(prunedAttributes, relation, columnPruningPred) } else { sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") this @@ -152,8 +149,7 @@ case class ParquetTableScan( case class InsertIntoParquetTable( relation: ParquetRelation, child: SparkPlan, - overwrite: Boolean = false)( - @transient val sqlContext: SQLContext) + overwrite: Boolean = false) extends UnaryNode with SparkHadoopMapReduceUtil { /** @@ -205,8 +201,6 @@ case class InsertIntoParquetTable( override def output = child.output - override def otherCopyArgs = sqlContext :: Nil - /** * Stores the given Row RDD as a Hadoop file. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8e1e1971d968..1fd8d27b34c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -45,6 +45,7 @@ class QueryTest extends PlanTest { |${rdd.queryExecution} |== Exception == |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} """.stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala index e55648b8ed15..2cab5e0c44d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.test.TestSQLContext._ * Note: this is only a rough example of how TGFs can be expressed, the final version will likely * involve a lot more sugar for cleaner use in Scala/Java/etc. */ -case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator { +case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator { def children = input protected def makeOutput() = 'nameAndAge.string :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 8fa143e2deca..af83599871db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import org.apache.spark.sql.execution.SparkPlan import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} import parquet.hadoop.ParquetFileWriter @@ -201,10 +202,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Projection of simple Parquet file") { + SparkPlan.currentContext.set(TestSQLContext) val scanner = new ParquetTableScan( ParquetTestData.testData.output, ParquetTestData.testData, - Seq())(TestSQLContext) + Seq()) val projected = scanner.pruneColumns(ParquetTypesConverter .convertToAttributes(MessageTypeParser .parseMessageType(ParquetTestData.subTestSchema))) From 4771fabbcaa661045c1707766354e7360e9a2843 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 22 Jul 2014 17:21:42 -0700 Subject: [PATCH 17/28] Docs, more test coverage. --- .../sql/catalyst/expressions/Projection.scala | 20 ++++++-------- .../spark/sql/catalyst/expressions/Row.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 27 +++++++++++-------- .../expressions/codegen/package.scala | 2 +- .../sql/catalyst/expressions/package.scala | 16 +++++++++++ .../apache/spark/sql/catalyst/package.scala | 11 ++++++++ .../spark/sql/catalyst/types/dataTypes.scala | 10 ++----- .../apache/spark/sql/execution/Generate.scala | 7 ++--- .../hive/execution/ScriptTransformation.scala | 2 +- ...se null-0-8ef2f741400830ef889a9dd0c817fe3d | 1 + ...le case-0-f513687d17dcb18546fefa75000a52f2 | 1 + ...le case-0-c264e319c52f1840a32959d552b99e73 | 1 + .../sql/hive/execution/HiveQuerySuite.scala | 6 ++--- 13 files changed, 66 insertions(+), 39 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala create mode 100644 sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d create mode 100644 sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 create mode 100644 sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 4e170ee8f0f9..8fc589697443 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions + /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of the - * new row. If the schema of the input row is specified, then the given expression will be bound to - * that schema. + * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. + * @param expressions a sequence of expressions that determine the value of each column of the + * output row. */ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = @@ -40,15 +41,10 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { } /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of the - * new row. If the schema of the input row is specified, then the given expression will be bound to - * that schema. - * - * In contrast to a normal projection, a MutableProjection reuses the same underlying row object - * each time an input row is added. This significantly reduces the cost of calculating the - * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` - * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` - * and hold on to the returned [[Row]] before calling `next()`. + * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified + * expressions. + * @param expressions a sequence of expressions that determine the value of each column of the + * output row. */ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 1a2ac7285b98..effcfd0841ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -180,6 +180,7 @@ class GenericRow(protected[catalyst] val values: Array[Any]) extends Row { values(i).asInstanceOf[String] } + // Custom hashCode function that matches the efficient code generated version. override def hashCode(): Int = { var result: Int = 37 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 8bf676b803a2..d1c863fe48fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ /** - * A base class for generators of byte code that performs expression evaluation. Includes helpers - * for referring to Catalyst types and building trees that perform evaluation of individual + * A base class for generators of byte code to perform expression evaluation. Includes a set of + * helpers for referring to Catalyst types and building trees that perform evaluation of individual * expressions. */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { @@ -51,12 +51,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin private val javaSeparator = "$" /** - * Generates a class for a given input expression. Called when there is not a cached code + * Generates a class for a given input expression. Called when there is not cached code * already available. */ protected def create(in: InType): OutType - /** Canonicalizes an input expression. */ + /** + * Canonicalizes an input expression. Used to avoid double caching expressions that differ only + * cosmetically. + */ protected def canonicalize(in: InType): InType /** Binds an input expression to a given input schema */ @@ -103,8 +106,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin objectTerm: TermName) /** - * Given an expression tree returns the code required to determine both if the result is NULL - * as well as the code required to compute the value. + * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that + * can be used to determine the result of evaluating the expression on an input row. */ def expressionEvaluator(e: Expression): EvaluatedExpression = { val primitiveTerm = freshName("primitiveTerm") @@ -130,7 +133,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** * Short hand for generating binary evaluation code, which depends on two sub-evaluations of - * the same type. If either of the sub-expressions is null, the results of this computation + * the same type. If either of the sub-expressions is null, the result of this computation * is assumed to be null. * * @param f a function from two primitive term names to a tree that evaluates them. @@ -139,8 +142,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin evaluateAs(expressions._1.dataType)(f) def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { - require(expressions._1.dataType == expressions._2.dataType, - s"${expressions._1.dataType} != ${expressions._2.dataType}") + // Right now some timestamp tests fail if we enforce this... + if (expressions._1.dataType != expressions._2.dataType) + log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") val eval1 = expressionEvaluator(expressions._1) val eval2 = expressionEvaluator(expressions._2) @@ -164,7 +168,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin // TODO: Skip generation of null handling code when expression are not nullable. val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { case b @ BoundReference(ordinal, dataType, nullable) => - val nullValue = if (nullable) q"$inputTuple.isNullAt($ordinal)" else q"false" + val nullValue = q"$inputTuple.isNullAt($ordinal)" q""" val $nullTerm: Boolean = $nullValue val $primitiveTerm: ${termForType(dataType)} = @@ -228,7 +232,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case Cast(child @ NumericType(), FloatType) => child.castOrNull(c => q"$c.toFloat", IntegerType) - case Cast(e, StringType) => + // Special handling required for timestamps in hive test cases. + case Cast(e, StringType) if e.dataType != TimestampType => val eval = expressionEvaluator(e) eval.code ++ q""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 846d67b53801..80c7dfd376c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -31,7 +31,7 @@ package object codegen { * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala * 2.10. */ - protected[codegen] val globalLock = new Object() + protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 91658bcf7e17..38ee66cccd29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -49,8 +49,24 @@ package org.apache.spark.sql.catalyst */ package object expressions { + /** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound + * to that schema. + */ abstract class Projection extends (Row => Row) + /** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound + * to that schema. + * + * In contrast to a normal projection, a MutableProjection reuses the same underlying row object + * each time an input row is added. This significantly reduces the cost of calculating the + * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` + * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` + * and hold on to the returned [[Row]] before calling `next()`. + */ abstract class MutableProjection extends Projection { def currentValue: Row diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala new file mode 100644 index 000000000000..64f7ca7895aa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -0,0 +1,11 @@ +package org.apache.spark.sql + + +package object catalyst { + /** + * A JVM-global lock that should be used to prevent thread safety issues when using things in + * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for + * 2.10.* builds. See SI-6240 for more details. + */ + protected[catalyst] object ScalaReflectionLock +} \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index a1884a785e5e..2f4639cb1a07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -23,18 +23,12 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} import scala.util.parsing.combinator.RegexParsers +import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.util.Utils /** - * A JVM-global lock that should be used to prevent thread safety issues when using things in - * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for - * 2.10.* builds. See SI-6240 for more details. - */ -protected[catalyst] object ScalaReflectionLock - -/** - * Utility functions for working with DataTypes + * Utility functions for working with DataTypes. */ object DataType extends RegexParsers { protected lazy val primitiveType: Parser[DataType] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index a28ce5869dad..62051cc42539 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -47,15 +47,16 @@ case class Generate( } } - override def output = + // This must be a val since the generator output expr ids are not preserved by serialization. + override val output = if (join) child.output ++ generatorOutput else generatorOutput + val boundGenerator = BindReferences.bindReference(generator, child.output) + /** Codegenned rows are not serializable... */ override val codegenEnabled = false override def execute() = { - val boundGenerator = BindReferences.bindReference(generator, child.output) - if (join) { child.execute().mapPartitions { iter => val nullValues = Seq.fill(generator.output.size)(Literal(null)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 7f71fe4e32f4..0c8f676e9c5c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -67,7 +67,7 @@ case class ScriptTransformation( } } readerThread.start() - val outputProjection = new InterpretedProjection(input) + val outputProjection = new InterpretedProjection(input, child.output) iter .map(outputProjection) // TODO: Use SerDe diff --git a/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d new file mode 100644 index 000000000000..00750edc07d6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d @@ -0,0 +1 @@ +3 diff --git a/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 new file mode 100644 index 000000000000..00750edc07d6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 @@ -0,0 +1 @@ +3 diff --git a/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 new file mode 100644 index 000000000000..d00491fd7e5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fedcae665cc4..d04b282b196b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -32,13 +32,13 @@ case class TestData(a: Int, b: String) class HiveQuerySuite extends HiveComparisonTest { createQueryTest("single case", - """SELECT case when true then 1 else 2 end FROM src""") + """SELECT case when true then 1 else 2 end FROM src LIMIT 1""") createQueryTest("double case", - """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else 2 end FROM src""") + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else 2 end FROM src LIMIT 1""") createQueryTest("case else null", - """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else null end FROM src""") + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else null end FROM src LIMIT 1""") test("CREATE TABLE AS runs once") { hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() From 033abc6fc4560c76d31aa55080271b797a0198e7 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 22 Jul 2014 23:57:35 -0700 Subject: [PATCH 18/28] off by default --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 8f532b2615f0..8eca796a940f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -47,7 +47,7 @@ trait SQLConf { * Defaults to false as this feature is currently experimental. */ private[spark] def codegenEnabled: Boolean = - if (get("spark.sql.codegen", "true") == "true") true else false + if (get("spark.sql.codegen", "false") == "true") true else false /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to From 1ec2d6ef6f9b62bb55f206fc606dbdd4ce2436d5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 23 Jul 2014 16:14:43 -0700 Subject: [PATCH 19/28] Address comments --- pom.xml | 10 ++++++++++ project/SparkBuild.scala | 4 +--- sql/catalyst/pom.xml | 9 +++++++++ .../expressions/codegen/CodeGenerator.scala | 5 +++-- .../apache/spark/sql/catalyst/package.scala | 20 +++++++++++++++++-- .../sql/catalyst/planning/patterns.scala | 15 ++++++++++++++ .../spark/sql/catalyst/types/dataTypes.scala | 5 +---- .../scala/org/apache/spark/sql/SQLConf.scala | 2 +- 8 files changed, 58 insertions(+), 12 deletions(-) diff --git a/pom.xml b/pom.xml index 4e2d64a83364..aaa157a4e776 100644 --- a/pom.xml +++ b/pom.xml @@ -113,6 +113,7 @@ spark 2.10.4 2.10 + 2.0.1 0.18.1 shaded-protobuf org.spark-project.akka @@ -818,6 +819,15 @@ -target ${java.version} + + + + org.scalamacros + paradise_${scala.version} + ${scala.macros.version} + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c64525e1b6c3..069d3db15dd9 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -187,9 +187,7 @@ object SparkBuild extends PomBuild { object Catalyst { lazy val settings = Seq( - addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), - libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v), - libraryDependencies += "org.scalamacros" %% "quasiquotes" % "2.0.1") + addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full)) } object SQL { diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 6decde3fcd62..bd35b8d2fd9b 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -36,10 +36,19 @@ + + org.scala-lang + scala-compiler + org.scala-lang scala-reflect + + org.scalamacros + quasiquotes_${scala.binary.version} + ${scala.macros.version} + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d1c863fe48fd..86e536a0a77c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -142,9 +142,10 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin evaluateAs(expressions._1.dataType)(f) def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { - // Right now some timestamp tests fail if we enforce this... - if (expressions._1.dataType != expressions._2.dataType) + // TODO: Right now some timestamp tests fail if we enforce this... + if (expressions._1.dataType != expressions._2.dataType) { log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") + } val eval1 = expressionEvaluator(expressions._1) val eval2 = expressionEvaluator(expressions._2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index 64f7ca7895aa..3b3e206055cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -1,5 +1,21 @@ -package org.apache.spark.sql +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql package object catalyst { /** @@ -8,4 +24,4 @@ package object catalyst { * 2.10.* builds. See SI-6240 for more details. */ protected[catalyst] object ScalaReflectionLock -} \ No newline at end of file +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 0660fd9223fe..418f8686bfe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -104,6 +104,21 @@ object PhysicalOperation extends PredicateHelper { } } +/** + * Matches a logical aggregation that can be performed on distributed data in two steps. The first + * operates on the data in each partition performing partial aggregation for each group. The second + * occurs after the shuffle and completes the aggregation. + * + * This pattern will only match if all aggregate expressions can be computed partially and will + * return the rewritten aggregation expressions for both phases. + * + * The returned values for this match are as follows: + * - Grouping attributes for the final aggregation. + * - Aggregates for the final aggregation. + * - Grouping expressions for the partial aggregation. + * - Partial aggregate expressions. + * - Input to the aggregation. + */ object PartialAggregation { type ReturnType = (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 2f4639cb1a07..71808f76d632 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -154,10 +154,7 @@ abstract class NumericType extends NativeType with PrimitiveType { } object NumericType { - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType.isInstanceOf[NumericType] => true - case _ => false - } + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] } /** Matcher for any expressions that evaluate to [[IntegralType]]s */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 8eca796a940f..d5e8c622cb52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -38,7 +38,7 @@ trait SQLConf { private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt /** - * When set to true, Spark SQL will use the scala compiler at runtime to generate custom bytecode + * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster * than interpreted evaluation, but there are significant start-up costs due to compilation. * As a result codegen is only benificial when queries run for a long time, or when the same From 0672e8a6153b7b82948471533147e4338a735703 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 25 Jul 2014 18:20:37 -0700 Subject: [PATCH 20/28] Address comments. --- .../apache/spark/sql/catalyst/expressions/BoundAttribute.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index e50b724bff30..a3ebec8082cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -35,7 +35,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) type EvaluatedType = Any - def references = Set.empty + override def references = Set.empty override def toString = s"input[$ordinal]" From 1a61293849beba628be56ff114e6ef502c90b597 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 26 Jul 2014 13:07:01 -0700 Subject: [PATCH 21/28] Address review comments. --- .../expressions/codegen/CodeGenerator.scala | 28 ++++++--- .../sql/catalyst/expressions/package.scala | 2 +- .../GeneratedEvaluationSuite.scala | 39 ------------ .../GeneratedMutableEvaluationSuite.scala | 61 +++++++++++++++++++ .../sql/execution/GeneratedAggregate.scala | 14 ++--- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../apache/spark/sql/execution/joins.scala | 4 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 3 +- 8 files changed, 92 insertions(+), 61 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 86e536a0a77c..5b398695bf56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -65,6 +65,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** Binds an input expression to a given input schema */ protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + /** + * A cache of generated classes. + * + * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most + * fundamental difference is that a ConcurrentMap persists all elements that are added to it until + * they are explicitly removed. A Cache on the other hand is generally configured to evict entries + * automatically, in order to constrain its memory footprint + */ protected val cache = CacheBuilder.newBuilder() .maximumSize(1000) .build( @@ -74,9 +82,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } }) - def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType= + /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ + def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType = apply(bind(expressions, inputSchema)) + /** Generates the requested evaluator given already bound expression(s). */ def apply(expressions: InType): OutType = cache.get(canonicalize(expressions)) /** @@ -233,7 +243,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case Cast(child @ NumericType(), FloatType) => child.castOrNull(c => q"$c.toFloat", IntegerType) - // Special handling required for timestamps in hive test cases. + // Special handling required for timestamps in hive test cases since the toString function + // does not match the expected output. case Cast(e, StringType) if e.dataType != TimestampType => val eval = expressionEvaluator(e) eval.code ++ @@ -355,9 +366,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin var $nullTerm = true var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)} """.children ++ - children.map { c => - val eval = expressionEvaluator(c) - q""" + children.map { c => + val eval = expressionEvaluator(c) + q""" if($nullTerm) { ..${eval.code} if(!${eval.nullTerm}) { @@ -365,8 +376,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin $primitiveTerm = ${eval.primitiveTerm} } } - """ - } + """ + } case i @ expressions.If(condition, trueValue, falseValue) => val condEval = expressionEvaluator(condition) @@ -392,8 +403,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin // If there was no match in the partial function above, we fall back on calling the interpreted // expression evaluator. val code: Seq[Tree] = - primitiveEvaluation.lift.apply(e) - .getOrElse { + primitiveEvaluation.lift.apply(e).getOrElse { log.debug(s"No rules to generate $e") val tree = reify { e } q""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 38ee66cccd29..55d95991c5f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -70,7 +70,7 @@ package object expressions { abstract class MutableProjection extends Projection { def currentValue: Row - /** Updates the target of this projection to a new MutableRow */ + /** Uses the given row to store the output of the projection. */ def target(row: MutableRow): MutableProjection } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index 224c2a059a2e..245a2e148030 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -67,42 +67,3 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { futures.foreach(Await.result(_, 10.seconds)) } } - -/** - * Overrides our expression evaluation tests to use generated code on mutable rows. - */ -class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - lazy val evaluated = GenerateProjection.expressionEvaluator(expression) - - val plan = try { - GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil) - } catch { - case e: Throwable => - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code.mkString("\n")} - |$e - """.stripMargin) - } - - val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](expected)) - if (actual.hashCode() != expectedRow.hashCode()) { - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |${evaluated.code.mkString("\n")} - """.stripMargin) - } - if (actual != expectedRow) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } -} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala new file mode 100644 index 000000000000..887aabb1d5fb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ + +/** + * Overrides our expression evaluation tests to use generated code on mutable rows. + */ +class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + lazy val evaluated = GenerateProjection.expressionEvaluator(expression) + + val plan = try { + GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](expected)) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |${evaluated.code.mkString("\n")} + """.stripMargin) + } + if (actual != expectedRow) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 5c4b615a14bc..e55d6b0fabca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -107,7 +107,7 @@ case class GeneratedAggregate( val computationSchema = computeFunctions.flatMap(_.schema) - val resultMap = aggregatesToCompute.zip(computeFunctions).map { + val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map { case (agg, func) => agg.id -> func.result }.toMap @@ -116,8 +116,11 @@ case class GeneratedAggregate( case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) } - val groupMap = namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap + val groupMap: Map[Expression, Attribute] = + namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap + // The set of expressions that produce the final output given the aggregation buffer and the + // grouping expressions. val resultExpressions = aggregateExpressions.map(_.transform { case e: Expression if resultMap.contains(e.id) => resultMap(e.id) case e: Expression if groupMap.contains(e) => groupMap(e) @@ -125,25 +128,21 @@ case class GeneratedAggregate( child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. - @transient val newAggregationBuffer = newProjection(computeFunctions.flatMap(_.initialValues), child.output) // A projection that is used to update the aggregate values for a group given a new tuple. // This projection should be targeted at the current values for the group and then applied // to a joined row of the current values with the new input row. - @transient val updateProjection = newMutableProjection( computeFunctions.flatMap(_.update), computeFunctions.flatMap(_.schema) ++ child.output)() // A projection that computes the group given an input tuple. - @transient val groupProjection = newProjection(groupingExpressions, child.output) // A projection that produces the final result, given a computation. - @transient val resultProjectionBuilder = newMutableProjection( resultExpressions, @@ -155,10 +154,11 @@ case class GeneratedAggregate( // TODO: Codegening anything other than the updateProjection is probably over kill. val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] var currentRow: Row = null + updateProjection.target(buffer) while (iter.hasNext) { currentRow = iter.next() - updateProjection.target(buffer)(joinedRow(buffer, currentRow)) + updateProjection(joinedRow(buffer, currentRow)) } val resultProjection = resultProjectionBuilder() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 813502969bf7..66e78f5efe0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -187,7 +187,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val relation = ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) // Note: overwrite=false because otherwise the metadata we just created will be deleted - InsertIntoParquetTable(relation, planLater(child), overwrite=false) :: Nil + InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 6d85e0658e59..2750ddbce896 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -304,9 +304,7 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod leftResults.cartesian(rightResults).mapPartitions { iter => val joinedRow = new JoinedRow - iter.map { - case (l: Row, r: Row) => joinedRow(l, r) - } + iter.map(r => joinedRow(r._1, r._2)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 4295233f7f56..561f5b4a4996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.parquet -import org.apache.spark.sql.execution.SparkPlan import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} import parquet.hadoop.ParquetFileWriter @@ -26,6 +25,7 @@ import parquet.schema.MessageTypeParser import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job + import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils From 3587460ea12b6a9248757ce05655d6c1797a82da Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 27 Jul 2014 14:58:43 -0700 Subject: [PATCH 22/28] Drop unused string builder function. --- .../apache/spark/sql/catalyst/expressions/Row.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index effcfd0841ee..7470cb861b83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -88,15 +88,6 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) - - /** - * Experimental - * - * Returns a mutable string builder for the specified column. A given row should return the - * result of any mutations made to the returned buffer next time getString is called for the same - * column. - */ - def getStringBuilder(ordinal: Int): StringBuilder } /** @@ -216,8 +207,6 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { /** No-arg constructor for serialization. */ def this() = this(0) - def getStringBuilder(ordinal: Int): StringBuilder = ??? - override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value } override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value } From 64b2ee19f3e6b63a984ad1859136931c5e0a1831 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 27 Jul 2014 14:58:57 -0700 Subject: [PATCH 23/28] Implement copy --- .../expressions/codegen/GenerateProjection.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 45bfba819827..77fa02c13de3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -188,6 +188,11 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } """ + val copyFunction = + q""" + final def copy() = new $genericRowType(this.toArray) + """ + val classBody = nullFunctions ++ ( lengthDef +: @@ -196,16 +201,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { updateFunction +: equalsFunction +: hashCodeFunction +: + copyFunction +: (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) val code = q""" final class SpecificRow(i: $rowType) extends $mutableRowType { ..$classBody - - // Not safe! - final def copy() = scala.sys.error("Not implemented") - - final def getStringBuilder(ordinal: Int): StringBuilder = ??? } new $projectionType { def apply(r: $rowType) = new SpecificRow(r) } From 3cd773ef8508fd269af2e8cd20030b8fafb997c6 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 27 Jul 2014 14:59:27 -0700 Subject: [PATCH 24/28] Allow codegen for Generate. --- .../main/scala/org/apache/spark/sql/execution/Generate.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 62051cc42539..c386fd121c5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -53,9 +53,6 @@ case class Generate( val boundGenerator = BindReferences.bindReference(generator, child.output) - /** Codegenned rows are not serializable... */ - override val codegenEnabled = false - override def execute() = { if (join) { child.execute().mapPartitions { iter => From 533fdfd582183a8f470750c106ec0a256715137c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 27 Jul 2014 14:59:51 -0700 Subject: [PATCH 25/28] More logging of expression rewriting for GeneratedAggregate. --- .../sql/execution/GeneratedAggregate.scala | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index e55d6b0fabca..bdfa44adf602 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -128,25 +128,28 @@ case class GeneratedAggregate( child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. - val newAggregationBuffer = - newProjection(computeFunctions.flatMap(_.initialValues), child.output) + val initialValues = computeFunctions.flatMap(_.initialValues) + val newAggregationBuffer = newProjection(initialValues, child.output) + logger.info(s"Initial values: ${initialValues.mkString(",")}") + + // A projection that computes the group given an input tuple. + val groupProjection = newProjection(groupingExpressions, child.output) + logger.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") // A projection that is used to update the aggregate values for a group given a new tuple. // This projection should be targeted at the current values for the group and then applied // to a joined row of the current values with the new input row. - val updateProjection = - newMutableProjection( - computeFunctions.flatMap(_.update), - computeFunctions.flatMap(_.schema) ++ child.output)() - - // A projection that computes the group given an input tuple. - val groupProjection = newProjection(groupingExpressions, child.output) + val updateExpressions = computeFunctions.flatMap(_.update) + val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output + val updateProjection = newMutableProjection(updateExpressions, updateSchema)() + logger.info(s"Update Expressions: ${updateExpressions.mkString(",")}") // A projection that produces the final result, given a computation. val resultProjectionBuilder = newMutableProjection( resultExpressions, (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + logger.info(s"Result Projection: ${resultExpressions.mkString(",")}") val joinedRow = new JoinedRow From ef8d42bf4b620c5999a32a19a27859bb3a9d808c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 27 Jul 2014 14:59:58 -0700 Subject: [PATCH 26/28] comments --- .../main/scala/org/apache/spark/sql/execution/SparkPlan.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 991421bdec25..cd691a3e2720 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -53,7 +53,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def logger = log - val codegenEnabled: Boolean = if(sqlContext != null) { + // sqlContext will be null when we are being deserialized on the slaves. In this instance + // the value of codegenEnabled will be set by the desserializer after the constructor has run. + val codegenEnabled: Boolean = if (sqlContext != null) { sqlContext.codegenEnabled } else { false From fed3634816a8ad439d68fecf5494505c296cf5b2 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 27 Jul 2014 15:00:12 -0700 Subject: [PATCH 27/28] Inspectors are not serializable. --- .../src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 96b41bd2c563..7582b4743d40 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -251,8 +251,10 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val function: GenericUDTF = createFunction() + @transient protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) + @transient protected lazy val outputInspectors = { val structInspector = function.initialize(inputInspectors.toArray) structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector) From 67b1c48f88de5b1310451451dab610a22f6c4556 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 29 Jul 2014 13:31:11 -0700 Subject: [PATCH 28/28] Use conf variable in SQLConf object --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 67b028d38a79..210b9e8ab8a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -48,7 +48,7 @@ trait SQLConf { * Defaults to false as this feature is currently experimental. */ private[spark] def codegenEnabled: Boolean = - if (get("spark.sql.codegen", "false") == "true") true else false + if (get(CODEGEN_ENABLED, "false") == "true") true else false /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -107,6 +107,7 @@ object SQLConf { val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" + val CODEGEN_ENABLED = "spark.sql.codegen" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"