From 91bd85e2637336f3c82252224acd5bcd51270f1c Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 28 Dec 2015 18:12:10 +0800 Subject: [PATCH 01/15] WIP: Converting resolved logical plan back to SQL --- .../spark/sql/catalyst/analysis/Catalog.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 4 + .../sql/catalyst/expressions/Expression.scala | 15 +++ .../sql/catalyst/expressions/SortOrder.scala | 14 ++- .../expressions/aggregate/Average.scala | 2 + .../expressions/aggregate/Count.scala | 5 + .../catalyst/expressions/aggregate/Max.scala | 2 + .../catalyst/expressions/aggregate/Sum.scala | 2 + .../expressions/aggregate/interfaces.scala | 6 + .../sql/catalyst/expressions/arithmetic.scala | 11 ++ .../expressions/complexTypeExtractors.scala | 2 + .../expressions/conditionalExpressions.scala | 6 + .../expressions/decimalExpressions.scala | 3 + .../sql/catalyst/expressions/literals.scala | 30 ++++- .../expressions/mathExpressions.scala | 2 + .../expressions/namedExpressions.scala | 12 ++ .../expressions/nullExpressions.scala | 4 + .../sql/catalyst/expressions/predicates.scala | 29 ++++- .../spark/sql/catalyst/plans/joinTypes.scala | 24 +++- .../plans/logical/basicOperators.scala | 4 +- .../spark/sql/catalyst/util/package.scala | 11 ++ .../apache/spark/sql/types/ArrayType.scala | 2 + .../org/apache/spark/sql/types/DataType.scala | 2 + .../org/apache/spark/sql/types/MapType.scala | 2 + .../apache/spark/sql/types/StructType.scala | 5 + .../spark/sql/types/UserDefinedType.scala | 2 + .../execution/HiveCompatibilitySuite.scala | 10 +- .../spark/sql/hive/QueryNormalizer.scala | 103 ++++++++++++++++++ .../apache/spark/sql/hive/SQLBuilder.scala | 102 +++++++++++++++++ .../hive/ExpressionSQLGenerationSuite.scala | 72 ++++++++++++ .../hive/LogicalPlanSQLGenerationSuite.scala | 55 ++++++++++ .../spark/sql/hive/SQLGenerationTest.scala | 83 ++++++++++++++ .../hive/execution/HiveComparisonTest.scala | 41 ++++++- 33 files changed, 652 insertions(+), 19 deletions(-) create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/QueryNormalizer.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLGenerationSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLGenerationSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLGenerationTest.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index e8b2fcf819bf6..a8f89ce6de457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -110,7 +110,9 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. - alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) + alias + .map(a => Subquery(a, tableWithQualifiers)) + .getOrElse(tableWithQualifiers) } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d82d3edae4e38..64ce03f2e128f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -931,6 +931,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { $evPrim = $result.copy(); """ } + + override def sql: Option[String] = { + child.sql.map(childSQL => s"CAST($childSQL AS ${dataType.sql})") + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6a9c12127d367..27c6b523d6c4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -223,6 +224,8 @@ abstract class Expression extends TreeNode[Expression] { protected def toCommentSafeString: String = this.toString .replace("*/", "\\*\\/") .replace("\\u", "\\\\u") + + def sql: Option[String] = None } @@ -356,6 +359,8 @@ abstract class UnaryExpression extends Expression { """ } } + + override def sql: Option[String] = child.sql.map(childSQL => s"($prettyName($childSQL))") } @@ -492,6 +497,11 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { TypeCheckResult.TypeCheckSuccess } } + + override def sql: Option[String] = for { + lhs <- left.sql + rhs <- right.sql + } yield s"($lhs $symbol $rhs)" } @@ -593,4 +603,9 @@ abstract class TernaryExpression extends Expression { """ } } + + override def sql: Option[String] = sequenceOption(children.map(_.sql)).map { + case Seq(child1, child2, child3) => + s"$prettyName($child1, $child2, $child3)" + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3add722da7816..1cb1b9da3049b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -24,9 +24,17 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator -abstract sealed class SortDirection -case object Ascending extends SortDirection -case object Descending extends SortDirection +abstract sealed class SortDirection { + def sql: String +} + +case object Ascending extends SortDirection { + override def sql: String = "ASC" +} + +case object Descending extends SortDirection { + override def sql: String = "DESC" +} /** * An expression that can be used to sort a tuple. This class extends expression primarily so that diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 94ac4bf09b90b..45fdc4f62c3d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -82,4 +82,6 @@ case class Average(child: Expression) extends DeclarativeAggregate { case _ => Cast(sum, resultType) / Cast(count, resultType) } + + override def argumentsSQL: Option[String] = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 663c69e799fbd..b953a4be3d320 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { @@ -59,6 +60,10 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override lazy val evaluateExpression = count override def defaultResult: Option[Literal] = Option(Literal(0L)) + + override def argumentsSQL: Option[String] = { + sequenceOption(children.map(_.sql)).map(_.mkString(", ")) + } } object Count { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 906003188d4ff..ab9032a20fbfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -56,4 +56,6 @@ case class Max(child: Expression) extends DeclarativeAggregate { } override lazy val evaluateExpression: AttributeReference = max + + override def argumentsSQL: Option[String] = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 08a67ea3df51d..a7a53e0aa215d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -77,4 +77,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate { } override lazy val evaluateExpression: Expression = sum + + override def argumentsSQL: Option[String] = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index b47f32d1768b9..a970451222fac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -94,6 +94,8 @@ private[sql] case class AggregateExpression( override def prettyString: String = aggregateFunction.prettyString override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" + + override def sql: Option[String] = aggregateFunction.sql } /** @@ -321,6 +323,10 @@ abstract class DeclarativeAggregate final lazy val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) + override def sql: Option[String] = argumentsSQL.map(sql => s"${prettyName.toUpperCase}($sql)") + + def argumentsSQL: Option[String] = None + /** * A helper class for representing an attribute used in merging two * aggregation buffers. When merging two buffers, `bufferLeft` and `bufferRight`, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 61a17fd7db0fe..be551cab7d9b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -54,6 +54,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp numeric.negate(input) } } + + override def sql: Option[String] = child.sql.map(sql => s"(-$sql)") } case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -67,6 +69,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects defineCodeGen(ctx, ev, c => c) protected override def nullSafeEval(input: Any): Any = input + + override def sql: Option[String] = child.sql.map(sql => s"(+$sql)") } /** @@ -91,6 +95,8 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) + + override def sql: Option[String] = child.sql.map(sql => s"(ABS($sql))") } abstract class BinaryArithmetic extends BinaryOperator { @@ -513,4 +519,9 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { val r = a % n if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r } + + override def sql: Option[String] = for { + leftSQL <- left.sql + rightSQL <- right.sql + } yield s"Pmod($leftSQL, $rightSQL)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9c73239f67ff2..d8bb8fe9e0fd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -130,6 +130,8 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] } }) } + + override def sql: Option[String] = child.sql.map(_ + s".`${childSchema(ordinal).name}`") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index f79c8676fb58c..c33608362c82f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -74,6 +74,12 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } override def toString: String = s"if ($predicate) $trueValue else $falseValue" + + override def sql: Option[String] = for { + predicateSQL <- predicate.sql + trueSQL <- trueValue.sql + falseSQL <- falseValue.sql + } yield s"(IF($predicateSQL, $trueSQL, $falseSQL))" } trait CaseWhenLike extends Expression { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index c54bcdd774021..188a5f4686113 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -73,6 +73,7 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" override def prettyName: String = "promote_precision" + override def sql: Option[String] = child.sql } /** @@ -107,4 +108,6 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary } override def toString: String = s"CheckOverflow($child, $dataType)" + + override def sql: Option[String] = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 672cc9c45e0af..4ed0d301141da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -21,9 +21,9 @@ import java.sql.{Date, Timestamp} import org.json4s.JsonAST._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types._ @@ -214,6 +214,34 @@ case class Literal protected (value: Any, dataType: DataType) } } } + + override def sql: Option[String] = Option((value, dataType) match { + case _ if value == null => + "NULL" + + case (v: UTF8String, StringType) => + "\"" + v.toString.replace("\"", "\\\"") + "\"" + + case (v: Byte, ByteType) => + s"CAST($v AS ${ByteType.simpleString.toUpperCase})" + + case (v: Short, ShortType) => + s"CAST($v AS ${ShortType.simpleString.toUpperCase})" + + case (v: Long, LongType) => + s"CAST($v AS ${LongType.simpleString.toUpperCase})" + + case (v: Float, FloatType) => + s"CAST($v AS ${FloatType.simpleString.toUpperCase})" + + case (v: Decimal, DecimalType.Fixed(precision, scale)) => + s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))" + + case (v: Int, DateType) => + s"DATE '${DateTimeUtils.toJavaDate(v)}'" + + case _ => value.toString + }) } // TODO: Specialize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 002f5929cc26b..e10f58cd2d318 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -70,6 +70,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } + + override def sql: Option[String] = child.sql.map(childSQL => s"$name($childSQL)") } abstract class UnaryLogExpression(f: Double => Double, name: String) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index eefd9c7482553..676e57c5c633a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -164,6 +164,12 @@ case class Alias(child: Expression, name: String)( explicitMetadata == a.explicitMetadata case _ => false } + + override def sql: Option[String] = { + val qualifiersString = + if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + child.sql.map(childSQL => s"$childSQL AS $qualifiersString`$name`") + } } /** @@ -271,6 +277,12 @@ case class AttributeReference( // Since the expression id is not in the first constructor it is missing from the default // tree string. override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" + + override def sql: Option[String] = { + val qualifiersString = + if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + Some(s"$qualifiersString`$name`") + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index df4747d4e6f7a..a3c9238285a9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -193,6 +193,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { ev.value = eval.isNull eval.code } + + override def sql: Option[String] = child.sql.map(childSQL => s"($childSQL IS NULL)") } @@ -212,6 +214,8 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { ev.value = s"(!(${eval.isNull}))" eval.code } + + override def sql: Option[String] = child.sql.map(childSQL => s"($childSQL IS NOT NULL)") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 304b438c84ba4..9a112f27c5c6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -101,6 +101,8 @@ case class Not(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"!($c)") } + + override def sql: Option[String] = child.sql.map(childSQL => s"(NOT $childSQL)") } @@ -176,6 +178,12 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } """ } + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + valueSQL = childrenSQL.head + listSQL = childrenSQL.tail + } yield s"($valueSQL IN (${listSQL.mkString(", ")}))" } /** @@ -226,6 +234,10 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } """ } + + override def sql: Option[String] = for { + valueSQL :: listSQL <- sequenceOption(children.map(_.sql)) + } yield s"($valueSQL IN (${listSQL.mkString(", ")}))" } case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { @@ -274,6 +286,11 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } """ } + + override def sql: Option[String] = for { + lhs <- left.sql + rhs <- right.sql + } yield s"($lhs AND $rhs)" } @@ -323,6 +340,11 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } """ } + + override def sql: Option[String] = for { + lhs <- left.sql + rhs <- right.sql + } yield s"($lhs OR $rhs)" } @@ -339,6 +361,11 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0") } } + + override def sql: Option[String] = for { + lhs <- left.sql + rhs <- right.sql + } yield s"($lhs $symbol $rhs)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 77dec7ca6e2b5..a5f6764aef7ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -37,14 +37,26 @@ object JoinType { } } -sealed abstract class JoinType +sealed abstract class JoinType { + def sql: String +} -case object Inner extends JoinType +case object Inner extends JoinType { + override def sql: String = "INNER" +} -case object LeftOuter extends JoinType +case object LeftOuter extends JoinType { + override def sql: String = "LEFT OUTER" +} -case object RightOuter extends JoinType +case object RightOuter extends JoinType { + override def sql: String = "RIGHT OUTER" +} -case object FullOuter extends JoinType +case object FullOuter extends JoinType { + override def sql: String = "FULL OUTER" +} -case object LeftSemi extends JoinType +case object LeftSemi extends JoinType { + override def sql: String = "LEFT SEMI" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 79759b5a37b34..dc583cccca6b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -422,7 +422,9 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { } } -case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { +case class Subquery(alias: String, child: LogicalPlan) + extends UnaryNode { + override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 71293475ca0f9..6bfaa9c14b839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -130,6 +130,17 @@ package object util { ret } + def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match { + case xs if xs.isEmpty => + Option(Seq.empty[T]) + + case xs => + for { + head <- xs.head + tail <- sequenceOption(xs.tail) + } yield head +: tail + } + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 6533622492d41..520e344361625 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -77,6 +77,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" + override def sql: String = s"ARRAY<${elementType.sql}>" + override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 136a97e066df7..92cf8d4c46bda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -65,6 +65,8 @@ abstract class DataType extends AbstractDataType { /** Readable string representation for the type with truncation */ private[sql] def simpleString(maxNumberFields: Int): String = simpleString + def sql: String = simpleString.toUpperCase + /** * Check if `this` and `other` are the same data type when ignoring nullability * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 00461e529ca0a..5474954af70e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -62,6 +62,8 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" + override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>" + override private[spark] def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 34382bf124eb0..9b5c86a8984be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -279,6 +279,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru s"struct<${fieldTypes.mkString(",")}>" } + override def sql: String = { + val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}") + s"STRUCT<${fieldTypes.mkString(", ")}>" + } + private[sql] override def simpleString(maxNumberFields: Int): String = { val builder = new StringBuilder val fieldTypes = fields.take(maxNumberFields).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 4305903616bd9..d7a2c23be8a9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -84,6 +84,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def acceptsType(dataType: DataType) = this.getClass == dataType.getClass + + override def sql: String = sqlType.sql } /** diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index bd1a52e5f3303..1770d0312671a 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -41,7 +41,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning - def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + def testCases: Seq[(String, File)] = { + hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + } override def beforeAll() { TestHive.cacheTables = true @@ -71,7 +73,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { } /** A list of tests deemed out of scope currently and thus completely disregarded. */ - override def blackList = Seq( + override def blackList: Seq[String] = Seq( // These tests use hooks that are not on the classpath and thus break all subsequent execution. "hook_order", "hook_context_cs", @@ -106,7 +108,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_merge", "alter_concatenate_indexed_table", "protectmode2", - //"describe_table", + // "describe_table", "describe_comment_nonascii", "create_merge_compressed", @@ -323,7 +325,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { * The set of tests that are believed to be working in catalyst. Tests not on whiteList or * blacklist are implicitly marked as ignored. */ - override def whiteList = Seq( + override def whiteList: Seq[String] = Seq( "add_part_exist", "add_part_multiple", "add_partition_no_whitelist", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/QueryNormalizer.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/QueryNormalizer.scala new file mode 100644 index 0000000000000..fd65a9fbf6728 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/QueryNormalizer.scala @@ -0,0 +1,103 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.optimizer.{CombineFilters, ProjectCollapsing} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} + +case class NamedRelation(databaseName: String, tableName: String, output: Seq[Attribute]) + extends LeafNode + +class QueryNormalizer(sqlContext: SQLContext) extends RuleExecutor[LogicalPlan] { + override protected val batches: Seq[Batch] = Seq( + Batch("Reorder Operators", FixedPoint(100), + ReorderPredicate, + ReorderLimit, + CombineFilters, + ProjectCollapsing + ), + + Batch("Fill Missing Operators", Once, + ProjectStar + ), + + Batch("Recover Scope Information", FixedPoint(100), + ReplaceWithNamedRelation + ) + ) + + object ReorderPredicate extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case project @ Project(_, filter @ Filter(_, child)) => + filter.copy(child = project.copy(child = child)) + + case agg @ Aggregate(_, _, filter @ Filter(_, child)) => + filter.copy(child = agg.copy(child = child)) + + case filter @ Filter(_, sort @ Sort(_, _, child)) => + sort.copy(child = filter.copy(child = child)) + } + } + + object ReorderLimit extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case agg @ Aggregate(_, _, limit @ Limit(_, child)) => + limit.copy(child = agg.copy(child = child)) + + case project @ Project(_, limit @ Limit(_, child)) => + limit.copy(child = project.copy(child = child)) + + case filter @ Filter(_, limit @ Limit(_, child)) => + limit.copy(child = filter.copy(child = child)) + + case sort @ Sort(_, _, limit @ Limit(_, child)) => + limit.copy(child = sort.copy(child = child)) + } + } + + object ProjectStar extends Rule[LogicalPlan] { + def projectStar(plan: LogicalPlan): LogicalPlan = { + sqlContext.analyzer.execute(Project(UnresolvedStar(None) :: Nil, plan)) + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case filter @ Filter(_, _: Aggregate | _: Project) => + filter + + case filter @ Filter(_, child) => + filter.copy(child = projectStar(child)) + + case limit @ Limit(_, child) => + limit.copy(child = projectStar(child)) + } + } + + object ReplaceWithNamedRelation extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case r @ MetastoreRelation(_, _, Some(alias)) => + Subquery(alias, NamedRelation(r.databaseName, r.tableName, r.output)) + + case r @ MetastoreRelation(_, _, None) => + NamedRelation(r.databaseName, r.tableName, r.output) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala new file mode 100644 index 0000000000000..a0837cdf392bd --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.sequenceOption + +class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { + def toSQL: Option[String] = toSQL(logicalPlan) + + private def toSQL(node: LogicalPlan): Option[String] = node match { + case plan @ Project(list, child) => + for { + listSQL <- sequenceOption(list.map(_.sql)) + childSQL <- toSQL(child) + from = child match { + case OneRowRelation => "" + case _ => " FROM " + } + } yield s"SELECT ${listSQL.mkString(", ")}$from$childSQL" + + case plan @ Aggregate(groupingExpressions, aggregateExpressions, child) => + for { + aggregateSQL <- sequenceOption(aggregateExpressions.map(_.sql)) + groupingSQL <- sequenceOption(groupingExpressions.map(_.sql)) + maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY " + maybeFrom = child match { + case OneRowRelation => "" + case _ => " FROM " + } + childSQL <- toSQL(child).map(maybeFrom + _) + } yield { + s"SELECT ${aggregateSQL.mkString(", ")}$childSQL$maybeGroupBy${groupingSQL.mkString(", ")}" + } + + case plan @ Limit(limit, child) => + for { + limitSQL <- limit.sql + childSQL <- toSQL(child) + } yield s"$childSQL LIMIT $limitSQL" + + case plan @ Filter(condition, child) => + for { + conditionSQL <- condition.sql + childSQL <- toSQL(child) + whereOrHaving = child match { + case _: Aggregate => "HAVING" + case _ => "WHERE" + } + } yield s"$childSQL $whereOrHaving $conditionSQL" + + case plan @ Union(left, right) => + for { + leftSQL <- toSQL(left) + rightSQL <- toSQL(right) + } yield s"$leftSQL UNION ALL $rightSQL" + + case plan @ Subquery(alias, child) => + toSQL(child).map(childSQL => s"($childSQL) AS $alias") + + case plan @ Join(left, right, joinType, condition) => + for { + leftSQL <- toSQL(left) + rightSQL <- toSQL(right) + joinTypeSQL = joinType.sql + conditionSQL = condition.flatMap(_.sql).map(" ON " + _).getOrElse("") + } yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL" + + case plan @ MetastoreRelation(database, table, alias) => + val aliasSQL = alias.map(a => s" `$a`").getOrElse("") + Some(s"`$database`.`$table`$aliasSQL") + + case plan @ Sort(orders, global, child) => + for { + childSQL <- toSQL(child) + ordersSQL <- sequenceOption(orders.map { case SortOrder(e, dir) => + e.sql.map(sql => s"$sql ${dir.sql}") + }) + orderOrSort = if (global) "ORDER" else "SORT" + } yield s"$childSQL $orderOrSort BY ${ordersSQL.mkString(", ")}" + + case _ => None + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLGenerationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLGenerationSuite.scala new file mode 100644 index 0000000000000..324ce3ec9cdf3 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLGenerationSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{If, Literal} + +class ExpressionSQLGenerationSuite extends SQLGenerationTest { + + test("literal") { + checkSQL(Literal("foo"), "\"foo\"") + checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"") + checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)") + checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)") + checkSQL(Literal(4: Int), "4") + checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)") + checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") + checkSQL(Literal(2.5D), "2.5") + // TODO tests for decimals + } + + test("binary comparisons") { + checkSQL('a.int === 'b.int, "(`a` = `b`)") + checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)") + checkSQL('a.int !== 'b.int, "(NOT (`a` = `b`))") + + checkSQL('a.int < 'b.int, "(`a` < `b`)") + checkSQL('a.int <= 'b.int, "(`a` <= `b`)") + checkSQL('a.int > 'b.int, "(`a` > `b`)") + checkSQL('a.int >= 'b.int, "(`a` >= `b`)") + + checkSQL('a.int in ('b.int, 'c.int), "(`a` IN (`b`, `c`))") + checkSQL('a.int in (1, 2), "(`a` IN (1, 2))") + + checkSQL('a.int.isNull, "(`a` IS NULL)") + checkSQL('a.int.isNotNull, "(`a` IS NOT NULL)") + } + + test("logical operators") { + checkSQL('a.boolean && 'b.boolean, "(`a` AND `b`)") + checkSQL('a.boolean || 'b.boolean, "(`a` OR `b`)") + checkSQL(!'a.boolean, "(NOT `a`)") + checkSQL(If('a.boolean, 'b.int, 'c.int), "(IF(`a`, `b`, `c`))") + } + + test("arithmetic expressions") { + checkSQL('a.int + 'b.int, "(`a` + `b`)") + checkSQL('a.int - 'b.int, "(`a` - `b`)") + checkSQL('a.int * 'b.int, "(`a` * `b`)") + checkSQL('a.int / 'b.int, "(`a` / `b`)") + checkSQL('a.int % 'b.int, "(`a` % `b`)") + + checkSQL(-'a.int, "(-`a`)") + checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))") + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLGenerationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLGenerationSuite.scala new file mode 100644 index 0000000000000..9f97d5de59ee9 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLGenerationSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestUtils + +class LogicalPlanSQLGenerationSuite extends SQLGenerationTest with SQLTestUtils { + import hiveContext.implicits._ + + protected override def beforeAll(): Unit = { + super.beforeAll() + + sqlContext.range(10).select('id alias "a").registerTempTable("t0") + sqlContext.range(10).select('id alias "b").registerTempTable("t1") + } + + protected override def afterAll(): Unit = { + sqlContext.dropTempTable("t0") + sqlContext.dropTempTable("t1") + + super.afterAll() + } + + test("single row project") { + checkSQL(OneRowRelation.select(lit(1)), "SELECT 1 AS `1`") + checkSQL(OneRowRelation.select(lit(1) as 'a), "SELECT 1 AS `a`") + } + + test("project with limit") { + checkSQL(OneRowRelation.select(lit(1)).limit(1), "SELECT 1 AS `1` LIMIT 1") + checkSQL(OneRowRelation.select(lit(1) as 'a).limit(1), "SELECT 1 AS `a` LIMIT 1") + } + + test("table lookup") { + checkSQL(sqlContext.table("t0"), "SELECT `t0`.`a` FROM `t0`") + checkSQL(sqlContext.table("t1").select('b alias "c"), "SELECT `t1`.`id` AS `c` FROM `t1`") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLGenerationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLGenerationTest.scala new file mode 100644 index 0000000000000..2f4220d40beda --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLGenerationTest.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.{DataFrame, QueryTest} + +abstract class SQLGenerationTest extends QueryTest with TestHiveSingleton { + protected def checkSQL(e: Expression, expectedSQL: String): Unit = { + val maybeSQL = e.sql + + if (maybeSQL.isEmpty) { + fail( + s"""Cannot convert the following expression to SQL form: + | + |${e.treeString} + """.stripMargin) + } + + try { + assert(maybeSQL.get === expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following expression: + | + |${e.prettyName} + | + |$cause + """.stripMargin) + } + } + + protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = { + val maybeSQL = new SQLBuilder(plan, hiveContext).toSQL + + if (maybeSQL.isEmpty) { + fail( + s"""Cannot convert the following logical query plan to SQL: + | + |${plan.treeString} + """.stripMargin) + } + + val actualSQL = maybeSQL.get + + try { + assert(actualSQL === expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following logical query plan: + | + |${plan.treeString} + | + |$cause + """.stripMargin) + } + + checkAnswer(sqlContext.sql(actualSQL), new DataFrame(sqlContext, plan)) + } + + protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { + checkSQL(df.queryExecution.analyzed, expectedSQL) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index d7e8ebc8d312f..23cdc3255cc57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -27,9 +27,10 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.{ExplainCommand, SetCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand +import org.apache.spark.sql.execution.{ExplainCommand, SetCommand} import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable, SQLBuilder} /** * Allows the creations of tests that execute the same query against both hive @@ -372,7 +373,43 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = new TestHive.QueryExecution(queryString) + val query = { + val originalQuery = new TestHive.QueryExecution(queryString) + val containsCommands = originalQuery.analyzed.collectFirst { + case _: Command => () + case _: LogicalInsertIntoHiveTable => () + }.nonEmpty + + if (containsCommands) { + originalQuery + } else { + new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql => + logInfo( + s""" + |### Running SQL generation round-trip test {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + | + |Generated SQL: + |$sql + |}}} + """.stripMargin.trim) + new TestHive.QueryExecution(sql) + }.getOrElse { + logInfo( + s""" + |### Cannot convert the following logical plan back to SQL {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + |}}} + """.stripMargin.trim) + originalQuery + } + } + } + try { (query, prepareAnswer(query, query.stringResult())) } catch { case e: Throwable => val errorMessage = From 7053757de3e31317e63c3286fd7eb17b694e0464 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 4 Jan 2016 19:48:01 +0800 Subject: [PATCH 02/15] Fixes some test failures --- .../expressions/aggregate/interfaces.scala | 19 +++- .../sql/catalyst/expressions/literals.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 33 ++++++ .../sql/catalyst/rules/RuleExecutor.scala | 2 +- .../spark/sql/hive/QueryNormalizer.scala | 103 ------------------ .../apache/spark/sql/hive/SQLBuilder.scala | 62 ++++++++++- .../hive/LogicalPlanSQLGenerationSuite.scala | 8 +- 7 files changed, 112 insertions(+), 117 deletions(-) delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/QueryNormalizer.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index a970451222fac..e0e5e57b359cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -93,9 +93,16 @@ private[sql] case class AggregateExpression( override def prettyString: String = aggregateFunction.prettyString - override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" + override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)" - override def sql: Option[String] = aggregateFunction.sql + override def sql: Option[String] = if (isDistinct) { + aggregateFunction.argumentsSQL.map { argsSQL => + val name = aggregateFunction.prettyName.toUpperCase + s"$name(DISTINCT $argsSQL)" + } + } else { + aggregateFunction.sql + } } /** @@ -165,6 +172,10 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) } + + override def sql: Option[String] = argumentsSQL.map(sql => s"${prettyName.toUpperCase}($sql)") + + def argumentsSQL: Option[String] = None } /** @@ -323,10 +334,6 @@ abstract class DeclarativeAggregate final lazy val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) - override def sql: Option[String] = argumentsSQL.map(sql => s"${prettyName.toUpperCase}($sql)") - - def argumentsSQL: Option[String] = None - /** * A helper class for representing an attribute used in merging two * aggregation buffers. When merging two buffers, `bufferLeft` and `bufferRight`, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 4ed0d301141da..adbca703db2a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -220,7 +220,7 @@ case class Literal protected (value: Any, dataType: DataType) "NULL" case (v: UTF8String, StringType) => - "\"" + v.toString.replace("\"", "\\\"") + "\"" + "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\"" case (v: Byte, ByteType) => s"CAST($v AS ${ByteType.simpleString.toUpperCase})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0b1c74293bb8b..2a552c8a3aefd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -333,6 +333,39 @@ object ProjectCollapsing extends Rule[LogicalPlan] { ) Project(cleanedProjection, child) } + + // TODO Eliminate duplicate code + // This clause is identical to the one above except that the inner operator is an `Aggregate` + // rather than a `Project`. + case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) => + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliasMap = AttributeMap(projectList2.collect { + case a: Alias => (a.toAttribute, a) + }) + + // We only collapse these two Projects if their overlapped expressions are all + // deterministic. + val hasNondeterministic = projectList1.exists(_.collect { + case a: Attribute if aliasMap.contains(a) => aliasMap(a).child + }.exists(!_.deterministic)) + + if (hasNondeterministic) { + p + } else { + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + // TODO: Fix TransformBase to avoid the cast below. + val substitutedProjection = projectList1.map(_.transform { + case a: Attribute => aliasMap.getOrElse(a, a) + }).asInstanceOf[Seq[NamedExpression]] + // collapse 2 projects may introduce unnecessary Aliases, trim them here. + val cleanedProjection = substitutedProjection.map(p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + ) + agg.copy(aggregateExpressions = cleanedProjection) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 62ea731ab5f38..9ebacb4680dc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -37,7 +37,7 @@ object RuleExecutor { val maxSize = map.keys.map(_.toString.length).max map.toSeq.sortBy(_._2).reverseMap { case (k, v) => s"${k.padTo(maxSize, " ").mkString} $v" - }.mkString("\n") + }.mkString("\n", "\n", "") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/QueryNormalizer.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/QueryNormalizer.scala deleted file mode 100644 index fd65a9fbf6728..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/QueryNormalizer.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.hive - -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.analysis.UnresolvedStar -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.optimizer.{CombineFilters, ProjectCollapsing} -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} - -case class NamedRelation(databaseName: String, tableName: String, output: Seq[Attribute]) - extends LeafNode - -class QueryNormalizer(sqlContext: SQLContext) extends RuleExecutor[LogicalPlan] { - override protected val batches: Seq[Batch] = Seq( - Batch("Reorder Operators", FixedPoint(100), - ReorderPredicate, - ReorderLimit, - CombineFilters, - ProjectCollapsing - ), - - Batch("Fill Missing Operators", Once, - ProjectStar - ), - - Batch("Recover Scope Information", FixedPoint(100), - ReplaceWithNamedRelation - ) - ) - - object ReorderPredicate extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case project @ Project(_, filter @ Filter(_, child)) => - filter.copy(child = project.copy(child = child)) - - case agg @ Aggregate(_, _, filter @ Filter(_, child)) => - filter.copy(child = agg.copy(child = child)) - - case filter @ Filter(_, sort @ Sort(_, _, child)) => - sort.copy(child = filter.copy(child = child)) - } - } - - object ReorderLimit extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case agg @ Aggregate(_, _, limit @ Limit(_, child)) => - limit.copy(child = agg.copy(child = child)) - - case project @ Project(_, limit @ Limit(_, child)) => - limit.copy(child = project.copy(child = child)) - - case filter @ Filter(_, limit @ Limit(_, child)) => - limit.copy(child = filter.copy(child = child)) - - case sort @ Sort(_, _, limit @ Limit(_, child)) => - limit.copy(child = sort.copy(child = child)) - } - } - - object ProjectStar extends Rule[LogicalPlan] { - def projectStar(plan: LogicalPlan): LogicalPlan = { - sqlContext.analyzer.execute(Project(UnresolvedStar(None) :: Nil, plan)) - } - - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(_, _: Aggregate | _: Project) => - filter - - case filter @ Filter(_, child) => - filter.copy(child = projectStar(child)) - - case limit @ Limit(_, child) => - limit.copy(child = projectStar(child)) - } - } - - object ReplaceWithNamedRelation extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case r @ MetastoreRelation(_, _, Some(alias)) => - Subquery(alias, NamedRelation(r.databaseName, r.tableName, r.output)) - - case r @ MetastoreRelation(_, _, None) => - NamedRelation(r.databaseName, r.tableName, r.output) - } - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index a0837cdf392bd..2cc2cf876bcf7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -17,14 +17,31 @@ package org.apache.spark.sql.hive +import java.util.concurrent.atomic.AtomicLong + import org.apache.spark.Logging import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Attribute, SortOrder} +import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.util.sequenceOption class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { - def toSQL: Option[String] = toSQL(logicalPlan) + def toSQL: Option[String] = { + val canonicalizedPlan = Canonicalizer.execute(logicalPlan) + + logDebug( + s"""Building SQL query string from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${canonicalizedPlan.treeString} + """.stripMargin) + + toSQL(canonicalizedPlan) + } private def toSQL(node: LogicalPlan): Option[String] = node match { case plan @ Project(list, child) => @@ -85,7 +102,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL" case plan @ MetastoreRelation(database, table, alias) => - val aliasSQL = alias.map(a => s" `$a`").getOrElse("") + val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("") Some(s"`$database`.`$table`$aliasSQL") case plan @ Sort(orders, global, child) => @@ -99,4 +116,43 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case _ => None } + + object Canonicalizer extends RuleExecutor[LogicalPlan] { + override protected def batches: Seq[Batch] = Seq( + Batch("Normalizer", FixedPoint(100), + // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over + // `Aggregate`s to perform type casting. This rule removes them. + ProjectCollapsing, + + // Used to handle other auxiliary `Project`s added by analyzer (e.g. + // `ResolveAggregateFunctions` rule) + RecoverScopingInfo + ) + ) + + object RecoverScopingInfo extends Rule[LogicalPlan] { + override def apply(tree: LogicalPlan): LogicalPlan = tree transform { + case plan @ Project( + list, _: Subquery | _: Join | _: Filter | _: MetastoreRelation | OneRowRelation + ) => + plan + + case Project(projectList, child) => + val alias = SQLBuilder.newSubqueryName + val childAttributes = child.outputSet + val aliasedProjectList = projectList.map(_.transform { + case a: Attribute if childAttributes.contains(a) => + a.withQualifiers(alias :: Nil) + }.asInstanceOf[NamedExpression]) + + Project(aliasedProjectList, Subquery(alias, child)) + } + } + } +} + +object SQLBuilder { + private val nextSubqueryId = new AtomicLong(0) + + private def newSubqueryName: String = s"subquery_${nextSubqueryId.getAndIncrement()}__" } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLGenerationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLGenerationSuite.scala index 9f97d5de59ee9..180955f1d7f94 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLGenerationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLGenerationSuite.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils +// All test cases in this test suite are ignored for now because currently `SQLBuilder` only handles +// resolved logical plans parsed directly from HiveQL query strings. class LogicalPlanSQLGenerationSuite extends SQLGenerationTest with SQLTestUtils { import hiveContext.implicits._ @@ -38,17 +40,17 @@ class LogicalPlanSQLGenerationSuite extends SQLGenerationTest with SQLTestUtils super.afterAll() } - test("single row project") { + ignore("single row project") { checkSQL(OneRowRelation.select(lit(1)), "SELECT 1 AS `1`") checkSQL(OneRowRelation.select(lit(1) as 'a), "SELECT 1 AS `a`") } - test("project with limit") { + ignore("project with limit") { checkSQL(OneRowRelation.select(lit(1)).limit(1), "SELECT 1 AS `1` LIMIT 1") checkSQL(OneRowRelation.select(lit(1) as 'a).limit(1), "SELECT 1 AS `a` LIMIT 1") } - test("table lookup") { + ignore("table lookup") { checkSQL(sqlContext.table("t0"), "SELECT `t0`.`a` FROM `t0`") checkSQL(sqlContext.table("t1").select('b alias "c"), "SELECT `t1`.`id` AS `c` FROM `t1`") } From 709483e92938aa338d7648bc85b112f88254a9e7 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 4 Jan 2016 20:39:03 +0800 Subject: [PATCH 03/15] Fixes more test failures --- .../sql/catalyst/parser/SparkSqlParser.g | 48 +++++++++---------- .../apache/spark/sql/hive/SQLBuilder.scala | 24 ++++++++-- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 4afce3090f739..ffb859662c8ea 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -1,9 +1,9 @@ /** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with + (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 @@ -580,7 +580,7 @@ import java.util.HashMap; return header; } - + @Override public String getErrorMessage(RecognitionException e, String[] tokenNames) { String msg = null; @@ -617,7 +617,7 @@ import java.util.HashMap; } return msg; } - + public void pushMsg(String msg, RecognizerSharedState state) { // ANTLR generated code does not wrap the @init code wit this backtracking check, // even if the matching @after has it. If we have parser rules with that are doing @@ -637,7 +637,7 @@ import java.util.HashMap; // counter to generate unique union aliases private int aliasCounter; private String generateUnionAlias() { - return "_u" + (++aliasCounter); + return "u_" + (++aliasCounter); } private char [] excludedCharForColumnName = {'.', ':'}; private boolean containExcludedCharForCreateTableColumnName(String input) { @@ -1233,7 +1233,7 @@ alterTblPartitionStatementSuffixSkewedLocation : KW_SET KW_SKEWED KW_LOCATION skewedLocations -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations) ; - + skewedLocations @init { pushMsg("skewed locations", state); } @after { popMsg(state); } @@ -1262,7 +1262,7 @@ alterStatementSuffixLocation -> ^(TOK_ALTERTABLE_LOCATION $newLoc) ; - + alterStatementSuffixSkewedby @init {pushMsg("alter skewed by statement", state);} @after{popMsg(state);} @@ -1334,10 +1334,10 @@ tabTypeExpr (identifier (DOT^ ( (KW_ELEM_TYPE) => KW_ELEM_TYPE - | + | (KW_KEY_TYPE) => KW_KEY_TYPE - | - (KW_VALUE_TYPE) => KW_VALUE_TYPE + | + (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier ))* )? @@ -1374,7 +1374,7 @@ descStatement analyzeStatement @init { pushMsg("analyze statement", state); } @after { popMsg(state); } - : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) + : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))? -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?) ; @@ -1387,7 +1387,7 @@ showStatement | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? -> ^(TOK_SHOWCOLUMNS tableName $db_name?) | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) - | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) + | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) | KW_SHOW KW_CREATE ( (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name) | @@ -1396,7 +1396,7 @@ showStatement | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec? -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?) | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?) - | KW_SHOW KW_LOCKS + | KW_SHOW KW_LOCKS ( (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?) | @@ -1509,7 +1509,7 @@ showCurrentRole setRole @init {pushMsg("set role", state);} @after {popMsg(state);} - : KW_SET KW_ROLE + : KW_SET KW_ROLE ( (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text]) | @@ -1964,7 +1964,7 @@ columnNameOrderList skewedValueElement @init { pushMsg("skewed value element", state); } @after { popMsg(state); } - : + : skewedColumnValues | skewedColumnValuePairList ; @@ -1978,8 +1978,8 @@ skewedColumnValuePairList skewedColumnValuePair @init { pushMsg("column value pair", state); } @after { popMsg(state); } - : - LPAREN colValues=skewedColumnValues RPAREN + : + LPAREN colValues=skewedColumnValues RPAREN -> ^(TOK_TABCOLVALUES $colValues) ; @@ -1999,11 +1999,11 @@ skewedColumnValue skewedValueLocationElement @init { pushMsg("skewed value location element", state); } @after { popMsg(state); } - : + : skewedColumnValue | skewedColumnValuePair ; - + columnNameOrder @init { pushMsg("column name order", state); } @after { popMsg(state); } @@ -2116,7 +2116,7 @@ unionType @after { popMsg(state); } : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList) ; - + setOperator @init { pushMsg("set operator", state); } @after { popMsg(state); } @@ -2168,7 +2168,7 @@ fromStatement[boolean topLevel] {adaptor.create(Identifier, generateUnionAlias())} ) ) - ^(TOK_INSERT + ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) ) @@ -2391,8 +2391,8 @@ setColumnsClause KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* ) ; -/* - UPDATE +/* + UPDATE
SET col1 = val1, col2 = val2... WHERE ... */ updateStatement diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 2cc2cf876bcf7..ca820da89c822 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -132,11 +132,29 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi object RecoverScopingInfo extends Rule[LogicalPlan] { override def apply(tree: LogicalPlan): LogicalPlan = tree transform { + // This branch handles aggregate function within HAVING clauses. For example: + // + // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255" + // + // This kind of query results in query plans of the following form because of analysis rule + // `ResolveAggregateFunctions`: + // + // Project ... + // +- Filter ... + // +- Aggregate ... + // +- MetastoreRelation default, src, None + case plan @ Project(_, Filter(_, _: Aggregate)) => + wrapChildWithSubquery(plan) + case plan @ Project( - list, _: Subquery | _: Join | _: Filter | _: MetastoreRelation | OneRowRelation - ) => - plan + _, _: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation + ) => plan + + case plan: Project => + wrapChildWithSubquery(plan) + } + def wrapChildWithSubquery(project: Project): Project = project match { case Project(projectList, child) => val alias = SQLBuilder.newSubqueryName val childAttributes = child.outputSet From af6110e66a37ea3acb74a63ee0151ca6425fa3b2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 4 Jan 2016 23:32:04 +0800 Subject: [PATCH 04/15] Fixes test failures --- .../scala/org/apache/spark/sql/hive/SQLBuilder.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index ca820da89c822..f21584763f2e5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -114,6 +114,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi orderOrSort = if (global) "ORDER" else "SORT" } yield s"$childSQL $orderOrSort BY ${ordersSQL.mkString(", ")}" + case OneRowRelation => + Some("") + case _ => None } @@ -121,7 +124,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi override protected def batches: Seq[Batch] = Seq( Batch("Normalizer", FixedPoint(100), // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over - // `Aggregate`s to perform type casting. This rule removes them. + // `Aggregate`s to perform type casting. This rule merges these `Project`s into + // `Aggregate`s. ProjectCollapsing, // Used to handle other auxiliary `Project`s added by analyzer (e.g. @@ -132,7 +136,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi object RecoverScopingInfo extends Rule[LogicalPlan] { override def apply(tree: LogicalPlan): LogicalPlan = tree transform { - // This branch handles aggregate function within HAVING clauses. For example: + // This branch handles aggregate functions within HAVING clauses. For example: // // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255" // @@ -146,8 +150,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case plan @ Project(_, Filter(_, _: Aggregate)) => wrapChildWithSubquery(plan) - case plan @ Project( - _, _: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation + case plan @ Project(_, + _: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation | _: Limit ) => plan case plan: Project => From c6d6429db7295e9967baa04ec73d8b665069b9a2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 5 Jan 2016 00:42:39 +0800 Subject: [PATCH 05/15] Prints statistics about SQL generation --- .../hive/execution/HiveCompatibilitySuite.scala | 2 ++ .../HiveWindowFunctionQuerySuite.scala | 1 + .../sql/hive/execution/HiveComparisonTest.scala | 17 +++++++++++++++++ .../sql/hive/execution/HiveQuerySuite.scala | 1 + 4 files changed, 21 insertions(+) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 1770d0312671a..afd2f611580fc 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -46,6 +46,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { } override def beforeAll() { + super.beforeAll() TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -70,6 +71,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // For debugging dump some statistics about how much time was spent in various optimizer rules. logWarning(RuleExecutor.dumpTimeSpent()) + super.afterAll() } /** A list of tests deemed out of scope currently and thus completely disregarded. */ diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 98bbdf0653c2a..bad3ca6da231f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -104,6 +104,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) TestHive.reset() + super.afterAll() } ///////////////////////////////////////////////////////////////////////////// diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 23cdc3255cc57..d1ee3a80d4f17 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -131,6 +131,21 @@ abstract class HiveComparisonTest new java.math.BigInteger(1, digest.digest).toString(16) } + /** Used for testing [[SQLBuilder]] */ + private var numConvertibleQueries: Int = 0 + private var numTotalQueries: Int = 0 + + override protected def afterAll(): Unit = { + logInfo( + s"""SQLBuiler statistics: + |- Total query number: $numTotalQueries + |- Number of convertible queries: $numConvertibleQueries + |- Percentage of convertible queries: ${numConvertibleQueries.toDouble / numTotalQueries} + """.stripMargin) + + super.afterAll() + } + protected def prepareAnswer( hiveQuery: TestHive.type#QueryExecution, answer: Seq[String]): Seq[String] = { @@ -383,7 +398,9 @@ abstract class HiveComparisonTest if (containsCommands) { originalQuery } else { + numTotalQueries += 1 new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql => + numConvertibleQueries += 1 logInfo( s""" |### Running SQL generation round-trip test {{{ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 98e22c2e2c1b0..5e9e77c4fb925 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -60,6 +60,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) sql("DROP TEMPORARY FUNCTION udtf_count2") + super.afterAll() } test("SPARK-4908: concurrent hive native commands") { From 018626b63c4e245cf6f8b7f0bb77078a3ebc2194 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 5 Jan 2016 19:51:12 +0800 Subject: [PATCH 06/15] Addresses comments and more tests --- .../sql/catalyst/expressions/literals.scala | 6 +- .../sql/catalyst/expressions/predicates.scala | 3 +- .../plans/logical/basicOperators.scala | 3 +- .../spark/sql/catalyst/util/package.scala | 3 + .../datasources/parquet/ParquetRelation.scala | 16 ++-- .../apache/spark/sql/hive/SQLBuilder.scala | 48 +++++++--- ....scala => ExpressionSQLBuilderSuite.scala} | 9 +- .../sql/hive/HiveQLSQLBuilderSuite.scala | 87 +++++++++++++++++++ ...scala => LogicalPlanSQLBuilderSuite.scala} | 2 +- ...erationTest.scala => SQLBuilderTest.scala} | 2 +- 10 files changed, 152 insertions(+), 27 deletions(-) rename sql/hive/src/test/scala/org/apache/spark/sql/hive/{ExpressionSQLGenerationSuite.scala => ExpressionSQLBuilderSuite.scala} (92%) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQLSQLBuilderSuite.scala rename sql/hive/src/test/scala/org/apache/spark/sql/hive/{LogicalPlanSQLGenerationSuite.scala => LogicalPlanSQLBuilderSuite.scala} (96%) rename sql/hive/src/test/scala/org/apache/spark/sql/hive/{SQLGenerationTest.scala => SQLBuilderTest.scala} (97%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index adbca703db2a1..544cdc62c919f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -217,9 +217,10 @@ case class Literal protected (value: Any, dataType: DataType) override def sql: Option[String] = Option((value, dataType) match { case _ if value == null => - "NULL" + if (dataType == NullType) "NULL" else s"CAST(NULL AS ${dataType.sql})" case (v: UTF8String, StringType) => + // Escapes all backslashes and double quotes. "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\"" case (v: Byte, ByteType) => @@ -240,6 +241,9 @@ case class Literal protected (value: Any, dataType: DataType) case (v: Int, DateType) => s"DATE '${DateTimeUtils.toJavaDate(v)}'" + case (v: Long, TimestampType) => + s"CAST('${DateTimeUtils.toJavaTimestamp(v)}' AS TIMESTAMP)" + case _ => value.toString }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 9a112f27c5c6e..2cb42515396b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -236,7 +236,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } override def sql: Option[String] = for { - valueSQL :: listSQL <- sequenceOption(children.map(_.sql)) + valueSQL <- child.sql + listSQL <- sequenceOption(hset.toSeq.map(Literal(_).sql)) } yield s"($valueSQL IN (${listSQL.mkString(", ")}))" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index dc583cccca6b8..64957db6b4013 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -422,8 +422,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { } } -case class Subquery(alias: String, child: LogicalPlan) - extends UnaryNode { +case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 6bfaa9c14b839..7a0d0de6328a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -130,6 +130,9 @@ package object util { ret } + /** + * Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`. + */ def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match { case xs if xs.isEmpty => Option(Seq.empty[T]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 4b375de05e9e3..ca8d010090401 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.net.URI -import java.util.{List => JList} import java.util.logging.{Logger => JLogger} +import java.util.{List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -32,24 +32,24 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.task.JobContextImpl -import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType +import org.apache.parquet.{Log => ApacheParquetLog} import org.slf4j.bridge.SLF4JBridgeHandler -import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser +import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} +import org.apache.spark.sql.execution.datasources.{PartitionSpec, _} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { @@ -147,6 +147,12 @@ private[sql] class ParquetRelation( .get(ParquetRelation.METASTORE_SCHEMA) .map(DataType.fromJson(_).asInstanceOf[StructType]) + // If this relation is converted from a Hive metastore table, this method returns the name of the + // original Hive metastore table. + private[sql] def metastoreTableName: Option[TableIdentifier] = { + parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map(SqlParser.parseTableIdentifier) + } + private lazy val metadataCache: MetadataCache = { val meta = new MetadataCache meta.refresh() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index f21584763f2e5..6ac09dfa46707 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -20,27 +20,45 @@ package org.apache.spark.sql.hive import java.util.concurrent.atomic.AtomicLong import org.apache.spark.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Attribute, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.util.sequenceOption +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.{DataFrame, SQLContext} class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { + def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) + def toSQL: Option[String] = { val canonicalizedPlan = Canonicalizer.execute(logicalPlan) + val maybeSQL = toSQL(canonicalizedPlan) + + if (maybeSQL.isDefined) { + logDebug( + s"""Built SQL query string successfully from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${canonicalizedPlan.treeString} + |# Built SQL query string: + |${maybeSQL.get} + """.stripMargin) + } else { + logDebug( + s"""Failed to build SQL query string from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${canonicalizedPlan.treeString} + """.stripMargin) + } - logDebug( - s"""Building SQL query string from given logical plan: - | - |# Original logical plan: - |${logicalPlan.treeString} - |# Canonicalized logical plan: - |${canonicalizedPlan.treeString} - """.stripMargin) - - toSQL(canonicalizedPlan) + maybeSQL } private def toSQL(node: LogicalPlan): Option[String] = node match { @@ -90,6 +108,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi rightSQL <- toSQL(right) } yield s"$leftSQL UNION ALL $rightSQL" + // ParquetRelation converted from Hive metastore table + case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) => + Some(s"`$alias`") + case plan @ Subquery(alias, child) => toSQL(child).map(childSQL => s"($childSQL) AS $alias") @@ -176,5 +198,5 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi object SQLBuilder { private val nextSubqueryId = new AtomicLong(0) - private def newSubqueryName: String = s"subquery_${nextSubqueryId.getAndIncrement()}__" + private def newSubqueryName: String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}" } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLGenerationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala similarity index 92% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLGenerationSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala index 324ce3ec9cdf3..4ab913f5a7d69 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLGenerationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.hive +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{If, Literal} -class ExpressionSQLGenerationSuite extends SQLGenerationTest { - +class ExpressionSQLBuilderSuite extends SQLBuilderTest { test("literal") { checkSQL(Literal("foo"), "\"foo\"") checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"") @@ -31,6 +32,9 @@ class ExpressionSQLGenerationSuite extends SQLGenerationTest { checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)") checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") checkSQL(Literal(2.5D), "2.5") + checkSQL( + Literal(Timestamp.valueOf("2016-01-01 00:00:00")), + "CAST('2016-01-01 00:00:00.0' AS TIMESTAMP)") // TODO tests for decimals } @@ -68,5 +72,4 @@ class ExpressionSQLGenerationSuite extends SQLGenerationTest { checkSQL(-'a.int, "(-`a`)") checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))") } - } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQLSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQLSQLBuilderSuite.scala new file mode 100644 index 0000000000000..022836631e365 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQLSQLBuilderSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.sql.Timestamp + +import org.apache.spark.sql.catalyst.expressions.{Literal, Alias} +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.functions._ + +class HiveQLSQLBuilderSuite extends SQLBuilderTest with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + sqlContext.range(10).write.saveAsTable("t0") + + sqlContext + .range(10) + .select('id as 'key, concat(lit("val_"), 'id) as 'value) + .write + .saveAsTable("t1") + } + + override protected def afterAll(): Unit = { + sql("DROP TABLE IF EXISTS t0") + } + + private def checkHiveQl(hiveQl: String): Unit = { + val df = sql(hiveQl) + val convertedSQL = new SQLBuilder(df).toSQL + + if (convertedSQL.isEmpty) { + fail( + s"""Cannot convert the following HiveQL query plan back to SQL query string: + | + |Original HiveQL query string: + |$hiveQl + | + |Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin) + } + + checkAnswer(sql(convertedSQL.get), df) + } + + test("in") { + checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)") + } + + test("aggregate function in having clause") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0") + } + + test("aggregate function in order by clause") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)") + } + + // TODO Fix name collision introduced by ResolveAggregateFunction analysis rule + // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into + // Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query + // execution since these aliases have different expression ID. But this introduces name collision + // when converting resolved plans back to SQL query strings as expression IDs are stripped. + ignore("aggregate function in order by clause with multiple order keys") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)") + } + + test("type widening in union") { + checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLGenerationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLBuilderSuite.scala similarity index 96% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLGenerationSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLBuilderSuite.scala index 180955f1d7f94..8d6eff5d286d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLGenerationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLBuilderSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.test.SQLTestUtils // All test cases in this test suite are ignored for now because currently `SQLBuilder` only handles // resolved logical plans parsed directly from HiveQL query strings. -class LogicalPlanSQLGenerationSuite extends SQLGenerationTest with SQLTestUtils { +class LogicalPlanSQLBuilderSuite extends SQLBuilderTest with SQLTestUtils { import hiveContext.implicits._ protected override def beforeAll(): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLGenerationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala similarity index 97% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLGenerationTest.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala index 2f4220d40beda..a6fa4e76bf910 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLGenerationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.{DataFrame, QueryTest} -abstract class SQLGenerationTest extends QueryTest with TestHiveSingleton { +abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { protected def checkSQL(e: Expression, expectedSQL: String): Unit = { val maybeSQL = e.sql From bec750bbc994f1adb686ee249f2bf56477e2323a Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 6 Jan 2016 02:34:07 +0800 Subject: [PATCH 07/15] Fixes test failures --- .../hive/execution/HiveComparisonTest.scala | 84 +++++++++++-------- 1 file changed, 47 insertions(+), 37 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index d1ee3a80d4f17..57358a07840e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -136,12 +136,19 @@ abstract class HiveComparisonTest private var numTotalQueries: Int = 0 override protected def afterAll(): Unit = { - logInfo( + logInfo({ + val percentage = if (numTotalQueries > 0) { + numConvertibleQueries.toDouble / numTotalQueries * 100 + } else { + 0D + } + s"""SQLBuiler statistics: |- Total query number: $numTotalQueries |- Number of convertible queries: $numConvertibleQueries - |- Percentage of convertible queries: ${numConvertibleQueries.toDouble / numTotalQueries} - """.stripMargin) + |- Percentage of convertible queries: $percentage% + """.stripMargin + }) super.afterAll() } @@ -388,46 +395,49 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = { - val originalQuery = new TestHive.QueryExecution(queryString) - val containsCommands = originalQuery.analyzed.collectFirst { - case _: Command => () - case _: LogicalInsertIntoHiveTable => () - }.nonEmpty - - if (containsCommands) { - originalQuery - } else { - numTotalQueries += 1 - new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql => - numConvertibleQueries += 1 - logInfo( - s""" - |### Running SQL generation round-trip test {{{ - |${originalQuery.analyzed.treeString} - |Original SQL: - |$queryString - | + var query: TestHive.QueryExecution = null + try { + query = { + val originalQuery = new TestHive.QueryExecution(queryString) + val containsCommands = originalQuery.analyzed.collectFirst { + case _: Command => () + case _: LogicalInsertIntoHiveTable => () + }.nonEmpty + + if (containsCommands) { + originalQuery + } else { + numTotalQueries += 1 + new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql => + numConvertibleQueries += 1 + logInfo( + s""" + |### Running SQL generation round-trip test {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + | |Generated SQL: - |$sql - |}}} + |$sql + |}}} """.stripMargin.trim) - new TestHive.QueryExecution(sql) - }.getOrElse { - logInfo( - s""" - |### Cannot convert the following logical plan back to SQL {{{ - |${originalQuery.analyzed.treeString} - |Original SQL: - |$queryString - |}}} + new TestHive.QueryExecution(sql) + }.getOrElse { + logInfo( + s""" + |### Cannot convert the following logical plan back to SQL {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + |}}} """.stripMargin.trim) - originalQuery + originalQuery + } } } - } - try { (query, prepareAnswer(query, query.stringResult())) } catch { + (query, prepareAnswer(query, query.stringResult())) + } catch { case e: Throwable => val errorMessage = s""" From 2b8d8efe8daca4d9ab0d4fdf089ad0f9c6b98862 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 6 Jan 2016 02:39:05 +0800 Subject: [PATCH 08/15] Fixes FilterPushdownSuite --- .../spark/sql/catalyst/optimizer/FilterPushdownSuite.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index b998636909a7d..f9f3bd55aa578 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -75,8 +75,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .groupBy('a)('a) - .select('a).analyze + .groupBy('a)('a).analyze comparePlans(optimized, correctAnswer) } @@ -91,8 +90,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .groupBy('a)('a as 'c) - .select('c).analyze + .groupBy('a)('a as 'c).analyze comparePlans(optimized, correctAnswer) } From 0f57a37c342ad1278f016ece3b021fc224d29025 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 7 Jan 2016 00:01:02 +0800 Subject: [PATCH 09/15] Bug fixes and better SQL generation coverage --- .../spark/sql/catalyst/expressions/Cast.scala | 6 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../expressions/aggregate/Average.scala | 2 - .../expressions/aggregate/Count.scala | 4 - .../catalyst/expressions/aggregate/Max.scala | 2 - .../catalyst/expressions/aggregate/Sum.scala | 2 - .../expressions/aggregate/interfaces.scala | 9 +- .../expressions/conditionalExpressions.scala | 42 ++++++++- .../expressions/datetimeExpressions.scala | 4 + .../sql/catalyst/expressions/literals.scala | 7 +- .../expressions/randomExpressions.scala | 2 + .../expressions/regexpExpressions.scala | 5 + .../expressions/stringExpressions.scala | 90 ++++++++++++++++++ .../org/apache/spark/sql/hive/HiveQl.scala | 3 +- .../apache/spark/sql/hive/SQLBuilder.scala | 94 +++++++++++++------ .../org/apache/spark/sql/hive/hiveUDFs.scala | 61 ++++++++---- .../sql/hive/HiveQLSQLBuilderSuite.scala | 50 +++++++++- 17 files changed, 312 insertions(+), 73 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 64ce03f2e128f..89203ad28cdda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -933,7 +933,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } override def sql: Option[String] = { - child.sql.map(childSQL => s"CAST($childSQL AS ${dataType.sql})") + if (foldable) { + Literal.create(eval(), dataType).sql + } else { + child.sql.map(childSQL => s"CAST($childSQL AS ${dataType.sql})") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 27c6b523d6c4b..3bb23fda9c4cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sequenceOption diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 45fdc4f62c3d8..94ac4bf09b90b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -82,6 +82,4 @@ case class Average(child: Expression) extends DeclarativeAggregate { case _ => Cast(sum, resultType) / Cast(count, resultType) } - - override def argumentsSQL: Option[String] = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index b953a4be3d320..2d56e3111dacf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -60,10 +60,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override lazy val evaluateExpression = count override def defaultResult: Option[Literal] = Option(Literal(0L)) - - override def argumentsSQL: Option[String] = { - sequenceOption(children.map(_.sql)).map(_.mkString(", ")) - } } object Count { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index ab9032a20fbfa..906003188d4ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -56,6 +56,4 @@ case class Max(child: Expression) extends DeclarativeAggregate { } override lazy val evaluateExpression: AttributeReference = max - - override def argumentsSQL: Option[String] = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index a7a53e0aa215d..08a67ea3df51d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -77,6 +77,4 @@ case class Sum(child: Expression) extends DeclarativeAggregate { } override lazy val evaluateExpression: Expression = sum - - override def argumentsSQL: Option[String] = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index e0e5e57b359cc..42204b911b151 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ @@ -106,7 +107,7 @@ private[sql] case class AggregateExpression( } /** - * AggregateFunction2 is the superclass of two aggregation function interfaces: + * AggregateFunction is the superclass of two aggregation function interfaces: * * - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of * initialize(), update(), and merge() functions that operate on Row-based aggregation buffers. @@ -175,7 +176,9 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu override def sql: Option[String] = argumentsSQL.map(sql => s"${prettyName.toUpperCase}($sql)") - def argumentsSQL: Option[String] = None + def argumentsSQL: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield childrenSQL.mkString(", ") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index c33608362c82f..37857de9b578f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils} import org.apache.spark.sql.types._ @@ -116,7 +116,7 @@ trait CaseWhenLike extends Expression { override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) + thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true) } } @@ -212,6 +212,25 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { case Seq(elseValue) => s" ELSE $elseValue" }.mkString } + + override def sql: Option[String] = { + sequenceOption(branches.map(_.sql)).map { + case branchesSQL => + val (cases, maybeElse) = if (branches.length % 2 == 0) { + (branchesSQL, None) + } else { + (branchesSQL.init, Some(branchesSQL.last)) + } + + val head = s"CASE " + val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" + val body = cases.grouped(2).map { + case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" + }.mkString(" ") + + head + body + tail + } + } } // scalastyle:off @@ -316,6 +335,25 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW case Seq(elseValue) => s" ELSE $elseValue" }.mkString } + + override def sql: Option[String] = for { + keySQL <- key.sql + branchesSQL <- sequenceOption(branches.map(_.sql)) + } yield { + val (cases, maybeElse) = if (branches.length % 2 == 0) { + (branchesSQL, None) + } else { + (branchesSQL.init, Some(branchesSQL.last)) + } + + val head = s"CASE $keySQL " + val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" + val body = cases.grouped(2).map { + case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" + }.mkString(" ") + + head + body + tail + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 3d65946a1bc65..0ce62f84a9aa0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -309,6 +309,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix def this(time: Expression) = { this(time, Literal("yyyy-MM-dd HH:mm:ss")) } + + override def prettyName: String = "to_unix_timestamp" } /** @@ -332,6 +334,8 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTi def this() = { this(CurrentTimestamp()) } + + override def prettyName: String = "unix_timestamp" } abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 544cdc62c919f..ab999674567c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -216,8 +216,11 @@ case class Literal protected (value: Any, dataType: DataType) } override def sql: Option[String] = Option((value, dataType) match { + case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => + "NULL" + case _ if value == null => - if (dataType == NullType) "NULL" else s"CAST(NULL AS ${dataType.sql})" + s"CAST(NULL AS ${dataType.sql})" case (v: UTF8String, StringType) => // Escapes all backslashes and double quotes. @@ -242,7 +245,7 @@ case class Literal protected (value: Any, dataType: DataType) s"DATE '${DateTimeUtils.toJavaDate(v)}'" case (v: Long, TimestampType) => - s"CAST('${DateTimeUtils.toJavaTimestamp(v)}' AS TIMESTAMP)" + s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" case _ => value.toString }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bde8cb9fe876..5d2d2b07b3615 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -49,6 +49,8 @@ abstract class RDG extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = DoubleType + + override def sql: Option[String] = Some(s"${prettyName.toUpperCase}($seed)") } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index adef6050c3565..e6e043047da6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -59,6 +59,11 @@ trait StringRegexExpression extends ImplicitCastInputTypes { matches(regex, input1.asInstanceOf[UTF8String].toString) } } + + override def sql: Option[String] = for { + leftSQL <- left.sql + rightSQL <- right.sql + } yield s"$leftSQL ${prettyName.toUpperCase} $rightSQL" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 50c8b9d59847e..4bb65efcf1b4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -61,6 +62,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } @@ -153,6 +158,10 @@ case class ConcatWs(children: Seq[Expression]) """ } } + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } trait String2StringExpression extends ImplicitCastInputTypes { @@ -165,6 +174,10 @@ trait String2StringExpression extends ImplicitCastInputTypes { protected override def nullSafeEval(input: Any): Any = convert(input.asInstanceOf[UTF8String]) + + override def sql: Option[String] = for { + childSQL <- child.sql + } yield s"${prettyName.toUpperCase}($childSQL)" } /** @@ -221,6 +234,10 @@ case class Contains(left: Expression, right: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") } + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } /** @@ -232,6 +249,10 @@ case class StartsWith(left: Expression, right: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") } + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } /** @@ -243,6 +264,10 @@ case class EndsWith(left: Expression, right: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } object StringTranslate { @@ -318,6 +343,10 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) override def children: Seq[Expression] = srcExpr :: matchingExpr :: replaceExpr :: Nil override def prettyName: String = "translate" + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } /** @@ -340,6 +369,12 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } override def dataType: DataType = IntegerType + + override def prettyName: String = "find_in_set" + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } /** @@ -412,6 +447,10 @@ case class StringInstr(str: Expression, substr: Expression) defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") } + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } /** @@ -613,6 +652,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC } override def prettyName: String = "format_string" + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } /** @@ -630,6 +673,10 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, str => s"$str.toTitleCase()") } + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } /** @@ -652,6 +699,10 @@ case class StringRepeat(str: Expression, times: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") } + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } /** @@ -687,6 +738,10 @@ case class StringSpace(child: Expression) } override def prettyName: String = "space" + + override def sql: Option[String] = for { + childSQL <- child.sql + } yield s"${prettyName.toUpperCase}($childSQL)" } /** @@ -745,6 +800,10 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") } } + + override def sql: Option[String] = for { + childSQL <- child.sql + } yield s"${prettyName.toUpperCase}($childSQL)" } /** @@ -763,6 +822,10 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres nullSafeCodeGen(ctx, ev, (left, right) => s"${ev.value} = $left.levenshteinDistance($right);") } + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } /** @@ -779,6 +842,10 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"$c.soundex()") } + + override def sql: Option[String] = for { + childSQL <- child.sql + } yield s"${prettyName.toUpperCase}($childSQL)" } /** @@ -810,6 +877,10 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp } """}) } + + override def sql: Option[String] = for { + childSQL <- child.sql + } yield s"${prettyName.toUpperCase}($childSQL)" } /** @@ -833,6 +904,9 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn """}) } + override def sql: Option[String] = for { + childSQL <- child.sql + } yield s"${prettyName.toUpperCase}($childSQL)" } /** @@ -852,6 +926,10 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast ${ev.value} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); """}) } + + override def sql: Option[String] = for { + childSQL <- child.sql + } yield s"${prettyName.toUpperCase}($childSQL)" } /** @@ -882,6 +960,10 @@ case class Decode(bin: Expression, charset: Expression) } """) } + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } /** @@ -911,6 +993,10 @@ case class Encode(value: Expression, charset: Expression) org.apache.spark.unsafe.Platform.throwException(e); }""") } + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } /** @@ -1026,4 +1112,8 @@ case class FormatNumber(x: Expression, d: Expression) } override def prettyName: String = "format_number" + + override def sql: Option[String] = for { + childrenSQL <- sequenceOption(children.map(_.sql)) + } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index bf3fe12d5c5d2..5b13dbe47370e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -668,7 +668,8 @@ private[hive] object HiveQl extends SparkQl with Logging { Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) + HiveGenericUDTF( + functionName, new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) case other => super.nodeToGenerator(node) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 6ac09dfa46707..8460229cac68f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import java.util.concurrent.atomic.AtomicLong import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} @@ -61,38 +61,59 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi maybeSQL } - private def toSQL(node: LogicalPlan): Option[String] = node match { - case plan @ Project(list, child) => - for { - listSQL <- sequenceOption(list.map(_.sql)) - childSQL <- toSQL(child) - from = child match { - case OneRowRelation => "" - case _ => " FROM " - } - } yield s"SELECT ${listSQL.mkString(", ")}$from$childSQL" + private def projectToSQL( + projectList: Seq[NamedExpression], + child: LogicalPlan, + isDistinct: Boolean): Option[String] = { + for { + listSQL <- sequenceOption(projectList.map(_.sql)) + childSQL <- toSQL(child) + from = child match { + case OneRowRelation => "" + case _ => " FROM " + } + distinct = if (isDistinct) " DISTINCT" else "" + } yield s"SELECT$distinct ${listSQL.mkString(", ")}$from$childSQL" + } - case plan @ Aggregate(groupingExpressions, aggregateExpressions, child) => - for { - aggregateSQL <- sequenceOption(aggregateExpressions.map(_.sql)) - groupingSQL <- sequenceOption(groupingExpressions.map(_.sql)) - maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY " - maybeFrom = child match { - case OneRowRelation => "" - case _ => " FROM " - } - childSQL <- toSQL(child).map(maybeFrom + _) - } yield { - s"SELECT ${aggregateSQL.mkString(", ")}$childSQL$maybeGroupBy${groupingSQL.mkString(", ")}" + private def aggregateToSQL( + groupingExprs: Seq[Expression], + aggExprs: Seq[Expression], + child: LogicalPlan): Option[String] = { + for { + aggSQL <- sequenceOption(aggExprs.map(_.sql)) + groupingSQL <- sequenceOption(groupingExprs.map(_.sql)) + maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY " + maybeFrom = child match { + case OneRowRelation => "" + case _ => " FROM " } + childSQL <- toSQL(child).map(maybeFrom + _) + } yield { + s"SELECT ${aggSQL.mkString(", ")}$childSQL$maybeGroupBy${groupingSQL.mkString(", ")}" + } + } - case plan @ Limit(limit, child) => + private def toSQL(node: LogicalPlan): Option[String] = node match { + case Distinct(Project(list, child)) => + projectToSQL(list, child, isDistinct = true) + + case Project(list, child) => + projectToSQL(list, child, isDistinct = false) + + case Aggregate(groupingExprs, aggExprs, Aggregate(_, _, Expand(_, _, child))) => + aggregateToSQL(groupingExprs, aggExprs, child) + + case Aggregate(groupingExprs, aggExprs, child) => + aggregateToSQL(groupingExprs, aggExprs, child) + + case Limit(limit, child) => for { limitSQL <- limit.sql childSQL <- toSQL(child) } yield s"$childSQL LIMIT $limitSQL" - case plan @ Filter(condition, child) => + case Filter(condition, child) => for { conditionSQL <- condition.sql childSQL <- toSQL(child) @@ -102,7 +123,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } } yield s"$childSQL $whereOrHaving $conditionSQL" - case plan @ Union(left, right) => + case Union(left, right) => for { leftSQL <- toSQL(left) rightSQL <- toSQL(right) @@ -112,10 +133,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) => Some(s"`$alias`") - case plan @ Subquery(alias, child) => + case Subquery(alias, child) => toSQL(child).map(childSQL => s"($childSQL) AS $alias") - case plan @ Join(left, right, joinType, condition) => + case Join(left, right, joinType, condition) => for { leftSQL <- toSQL(left) rightSQL <- toSQL(right) @@ -123,11 +144,18 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi conditionSQL = condition.flatMap(_.sql).map(" ON " + _).getOrElse("") } yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL" - case plan @ MetastoreRelation(database, table, alias) => + case MetastoreRelation(database, table, alias) => val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("") Some(s"`$database`.`$table`$aliasSQL") - case plan @ Sort(orders, global, child) => + case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) + if orders.map(_.child) == partitionExprs => + for { + partitionExprsSQL <- sequenceOption(partitionExprs.map(_.sql)) + childSQL <- toSQL(child) + } yield s"$childSQL CLUSTER BY ${partitionExprsSQL.mkString(", ")}" + + case Sort(orders, global, child) => for { childSQL <- toSQL(child) ordersSQL <- sequenceOption(orders.map { case SortOrder(e, dir) => @@ -136,6 +164,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi orderOrSort = if (global) "ORDER" else "SORT" } yield s"$childSQL $orderOrSort BY ${ordersSQL.mkString(", ")}" + case RepartitionByExpression(partitionExprs, child, _) => + for { + partitionExprsSQL <- sequenceOption(partitionExprs.map(_.sql)) + childSQL <- toSQL(child) + } yield s"$childSQL DISTRIBUTE BY ${partitionExprsSQL.mkString(", ")}" + case OneRowRelation => Some("") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index b1a6d0ab7df3c..7cb30f0e15680 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -17,30 +17,26 @@ package org.apache.spark.sql.hive -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.hive.ql.exec._ -import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} -import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions +import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.sequenceOption +import org.apache.spark.sql.catalyst.{InternalRow, analysis} import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client.ClientWrapper import org.apache.spark.sql.types._ @@ -75,19 +71,19 @@ private[hive] class HiveFunctionRegistry( try { if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUDF( - new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) + name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) + HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children) + HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveUDAFFunction( - new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) + name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udtf = HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) + val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children) udtf.elementTypes // Force it to check input data types. udtf } else { @@ -137,7 +133,8 @@ private[hive] class HiveFunctionRegistry( } } -private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveSimpleUDF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { override def deterministic: Boolean = isUDFDeterministic @@ -191,6 +188,12 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: Option[String] = { + sequenceOption(children.map(_.sql)).map { argsSQL => + s"$name(${argsSQL.mkString(", ")})" + } + } } // Adapter from Catalyst ExpressionResult to Hive DeferredObject @@ -205,7 +208,8 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp override def get(): AnyRef = wrap(func(), oi, dataType) } -private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveGenericUDF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { override def nullable: Boolean = true @@ -257,6 +261,12 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: Option[String] = { + sequenceOption(children.map(_.sql)).map { argsSQL => + s"$name(${argsSQL.mkString(", ")})" + } + } } /** @@ -271,6 +281,7 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr * user defined aggregations, which have clean semantics even in a partitioned execution. */ private[hive] case class HiveGenericUDTF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Generator with HiveInspectors with CodegenFallback { @@ -336,6 +347,12 @@ private[hive] case class HiveGenericUDTF( override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: Option[String] = { + sequenceOption(children.map(_.sql)).map { argsSQL => + s"$name(${argsSQL.mkString(", ")})" + } + } } /** @@ -343,6 +360,7 @@ private[hive] case class HiveGenericUDTF( * performance a lot. */ private[hive] case class HiveUDAFFunction( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression], isUDAFBridgeRequired: Boolean = false, @@ -427,5 +445,10 @@ private[hive] case class HiveUDAFFunction( override def supportsPartial: Boolean = false override val dataType: DataType = inspectorToDataType(returnInspector) -} + override def sql: Option[String] = { + sequenceOption(children.map(_.sql)).map { argsSQL => + s"$name(${argsSQL.mkString(", ")})" + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQLSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQLSQLBuilderSuite.scala index 022836631e365..aa24e64a16ec2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQLSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQLSQLBuilderSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.hive -import java.sql.Timestamp - -import org.apache.spark.sql.catalyst.expressions.{Literal, Alias} -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.functions._ @@ -35,10 +31,14 @@ class HiveQLSQLBuilderSuite extends SQLBuilderTest with SQLTestUtils { .select('id as 'key, concat(lit("val_"), 'id) as 'value) .write .saveAsTable("t1") + + sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") } override protected def afterAll(): Unit = { sql("DROP TABLE IF EXISTS t0") + sql("DROP TABLE IF EXISTS t1") + sql("DROP TABLE IF EXISTS t2") } private def checkHiveQl(hiveQl: String): Unit = { @@ -84,4 +84,46 @@ class HiveQLSQLBuilderSuite extends SQLBuilderTest with SQLTestUtils { test("type widening in union") { checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0") } + + test("case") { + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0") + } + + test("case with else") { + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM t0") + } + + test("case with key") { + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM t0") + } + + test("case with key and else") { + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM t0") + } + + test("select distinct without aggregate functions") { + checkHiveQl("SELECT DISTINCT id FROM t0") + } + + test("cluster by") { + checkHiveQl("SELECT id FROM t0 CLUSTER BY id") + } + + test("distribute by") { + checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id") + } + + test("distribute by with sort by") { + checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id SORT BY id") + } + + test("distinct aggregation") { + checkHiveQl("SELECT COUNT(DISTINCT id) FROM t0") + } + + // TODO Enable this + // Query plans transformed by DistinctAggregationRewriter are not recognized yet + ignore("distinct and non-distinct aggregation") { + checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a") + } } From c118a354a43f0f9eef83ed92b01f754936c9dfa6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 7 Jan 2016 11:22:03 +0800 Subject: [PATCH 10/15] Fixes test failures --- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 9 +++------ .../scala/org/apache/spark/sql/hive/SQLBuilder.scala | 3 --- .../spark/sql/hive/ExpressionSQLBuilderSuite.scala | 2 +- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 89203ad28cdda..d8c906e0ab444 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -932,12 +932,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { """ } - override def sql: Option[String] = { - if (foldable) { - Literal.create(eval(), dataType).sql - } else { - child.sql.map(childSQL => s"CAST($childSQL AS ${dataType.sql})") - } + override def sql: Option[String] = dataType match { + case _: ArrayType | _: MapType | _: StructType => None + case _ => child.sql.map(childSQL => s"CAST($childSQL AS ${dataType.sql})") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 8460229cac68f..2cd28533ded20 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -101,9 +101,6 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case Project(list, child) => projectToSQL(list, child, isDistinct = false) - case Aggregate(groupingExprs, aggExprs, Aggregate(_, _, Expand(_, _, child))) => - aggregateToSQL(groupingExprs, aggExprs, child) - case Aggregate(groupingExprs, aggExprs, child) => aggregateToSQL(groupingExprs, aggExprs, child) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala index 4ab913f5a7d69..3a6eb57add4e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala @@ -34,7 +34,7 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { checkSQL(Literal(2.5D), "2.5") checkSQL( Literal(Timestamp.valueOf("2016-01-01 00:00:00")), - "CAST('2016-01-01 00:00:00.0' AS TIMESTAMP)") + "TIMESTAMP('2016-01-01 00:00:00.0')") // TODO tests for decimals } From 2b0745411f0209af0580aeccda6b6751e3a01c9e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 7 Jan 2016 14:27:44 +0800 Subject: [PATCH 11/15] Addresses PR comments --- .../sql/catalyst/analysis/Analyzer.scala | 20 +----- .../expressions/aggregate/Count.scala | 1 - .../expressions/aggregate/interfaces.scala | 20 +++--- .../sql/catalyst/expressions/arithmetic.scala | 4 +- .../expressions/randomExpressions.scala | 3 +- .../expressions/stringExpressions.scala | 64 +++++++++---------- .../sql/catalyst/optimizer/Optimizer.scala | 21 +++++- .../apache/spark/sql/hive/SQLBuilder.scala | 6 +- .../sql/hive/LogicalPlanSQLBuilderSuite.scala | 57 ----------------- ...uite.scala => LogicalPlanToSQLSuite.scala} | 2 +- 10 files changed, 71 insertions(+), 127 deletions(-) delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLBuilderSuite.scala rename sql/hive/src/test/scala/org/apache/spark/sql/hive/{HiveQLSQLBuilderSuite.scala => LogicalPlanToSQLSuite.scala} (98%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e362b55d80cd1..8a33af8207350 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -86,8 +86,7 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic, - ComputeCurrentTime), + PullOutNondeterministic), Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, @@ -1229,23 +1228,6 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** - * Computes the current date and time to make sure we return the same result in a single query. - */ -object ComputeCurrentTime extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val dateExpr = CurrentDate() - val timeExpr = CurrentTimestamp() - val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) - val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) - - plan transformAllExpressions { - case CurrentDate() => currentDate - case CurrentTimestamp() => currentTime - } - } -} - /** * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 2d56e3111dacf..663c69e799fbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 42204b911b151..d94ce35fa3878 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -96,13 +96,15 @@ private[sql] case class AggregateExpression( override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)" - override def sql: Option[String] = if (isDistinct) { - aggregateFunction.argumentsSQL.map { argsSQL => - val name = aggregateFunction.prettyName.toUpperCase - s"$name(DISTINCT $argsSQL)" + override def sql: Option[String] = { + val name = aggregateFunction.prettyName + val argsSQL = sequenceOption(aggregateFunction.children.map(_.sql)) + + if (isDistinct) { + argsSQL.map(args => s"$name(DISTINCT ${args.mkString(", ")})") + } else { + argsSQL.map(args => s"$name(${args.mkString(", ")})") } - } else { - aggregateFunction.sql } } @@ -173,12 +175,6 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) } - - override def sql: Option[String] = argumentsSQL.map(sql => s"${prettyName.toUpperCase}($sql)") - - def argumentsSQL: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield childrenSQL.mkString(", ") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index be551cab7d9b4..0f5e4c949cb25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -96,7 +96,7 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes protected override def nullSafeEval(input: Any): Any = numeric.abs(input) - override def sql: Option[String] = child.sql.map(sql => s"(ABS($sql))") + override def sql: Option[String] = child.sql.map(sql => s"$prettyName($sql)") } abstract class BinaryArithmetic extends BinaryOperator { @@ -523,5 +523,5 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def sql: Option[String] = for { leftSQL <- left.sql rightSQL <- right.sql - } yield s"Pmod($leftSQL, $rightSQL)" + } yield s"$prettyName($leftSQL, $rightSQL)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 5d2d2b07b3615..f6cdd2bed2513 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -50,7 +50,8 @@ abstract class RDG extends LeafExpression with Nondeterministic { override def dataType: DataType = DoubleType - override def sql: Option[String] = Some(s"${prettyName.toUpperCase}($seed)") + // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed. + override def sql: Option[String] = Some(s"$prettyName($seed)") } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 4bb65efcf1b4c..2dd26dd5d7e46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -65,7 +65,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } @@ -161,7 +161,7 @@ case class ConcatWs(children: Seq[Expression]) override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } trait String2StringExpression extends ImplicitCastInputTypes { @@ -177,7 +177,7 @@ trait String2StringExpression extends ImplicitCastInputTypes { override def sql: Option[String] = for { childSQL <- child.sql - } yield s"${prettyName.toUpperCase}($childSQL)" + } yield s"$prettyName($childSQL)" } /** @@ -237,7 +237,7 @@ case class Contains(left: Expression, right: Expression) override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -252,7 +252,7 @@ case class StartsWith(left: Expression, right: Expression) override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -267,7 +267,7 @@ case class EndsWith(left: Expression, right: Expression) override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } object StringTranslate { @@ -317,24 +317,24 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac val termDict = ctx.freshName("dict") val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;") - ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;") - ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;") + ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") + ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") + ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { - s"${termDict} == null" + s"$termDict == null" } else { - s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})" + s"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)" } s"""if ($check) { // Not all of them is literal or matching or replace value changed - ${termLastMatching} = ${matching}.clone(); - ${termLastReplace} = ${replace}.clone(); - ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate - .buildDict(${termLastMatching}, ${termLastReplace}); + $termLastMatching = $matching.clone(); + $termLastReplace = $replace.clone(); + $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate + .buildDict($termLastMatching, $termLastReplace); } - ${ev.value} = ${src}.translate(${termDict}); + ${ev.value} = $src.translate($termDict); """ }) } @@ -346,7 +346,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -374,7 +374,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -450,7 +450,7 @@ case class StringInstr(str: Expression, substr: Expression) override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -655,7 +655,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -676,7 +676,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -702,7 +702,7 @@ case class StringRepeat(str: Expression, times: Expression) override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -741,7 +741,7 @@ case class StringSpace(child: Expression) override def sql: Option[String] = for { childSQL <- child.sql - } yield s"${prettyName.toUpperCase}($childSQL)" + } yield s"$prettyName($childSQL)" } /** @@ -803,7 +803,7 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy override def sql: Option[String] = for { childSQL <- child.sql - } yield s"${prettyName.toUpperCase}($childSQL)" + } yield s"$prettyName($childSQL)" } /** @@ -825,7 +825,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -845,7 +845,7 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT override def sql: Option[String] = for { childSQL <- child.sql - } yield s"${prettyName.toUpperCase}($childSQL)" + } yield s"$prettyName($childSQL)" } /** @@ -880,7 +880,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp override def sql: Option[String] = for { childSQL <- child.sql - } yield s"${prettyName.toUpperCase}($childSQL)" + } yield s"$prettyName($childSQL)" } /** @@ -906,7 +906,7 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn override def sql: Option[String] = for { childSQL <- child.sql - } yield s"${prettyName.toUpperCase}($childSQL)" + } yield s"$prettyName($childSQL)" } /** @@ -929,7 +929,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast override def sql: Option[String] = for { childSQL <- child.sql - } yield s"${prettyName.toUpperCase}($childSQL)" + } yield s"$prettyName($childSQL)" } /** @@ -963,7 +963,7 @@ case class Decode(bin: Expression, charset: Expression) override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -996,7 +996,7 @@ case class Encode(value: Expression, charset: Expression) override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -1115,5 +1115,5 @@ case class FormatNumber(x: Expression, d: Expression) override def sql: Option[String] = for { childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"${prettyName.toUpperCase}(${childrenSQL.mkString(", ")})" + } yield s"$prettyName(${childrenSQL.mkString(", ")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2a552c8a3aefd..7095078bb189b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -63,7 +63,9 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { RemoveDispensableExpressions, SimplifyFilters, SimplifyCasts, - SimplifyCaseConversionExpressions) :: + SimplifyCaseConversionExpressions, + // Nondeterministic + ComputeCurrentTime) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -1009,3 +1011,20 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { a.copy(groupingExpressions = newGrouping) } } + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 2cd28533ded20..f5a1ae5d0941b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -128,6 +128,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // ParquetRelation converted from Hive metastore table case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) => + // There seems to be a bug related to `ParquetConversions` analysis rule. The problem is + // that, the metastore database name and table name are not always propagated to converted + // `ParquetRelation` instances via data source options. Here we use subquery alias as a + // workaround. Some(s"`$alias`") case Subquery(alias, child) => @@ -175,7 +179,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi object Canonicalizer extends RuleExecutor[LogicalPlan] { override protected def batches: Seq[Batch] = Seq( - Batch("Normalizer", FixedPoint(100), + Batch("Canonicalizer", FixedPoint(100), // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over // `Aggregate`s to perform type casting. This rule merges these `Project`s into // `Aggregate`s. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLBuilderSuite.scala deleted file mode 100644 index 8d6eff5d286d8..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanSQLBuilderSuite.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils - -// All test cases in this test suite are ignored for now because currently `SQLBuilder` only handles -// resolved logical plans parsed directly from HiveQL query strings. -class LogicalPlanSQLBuilderSuite extends SQLBuilderTest with SQLTestUtils { - import hiveContext.implicits._ - - protected override def beforeAll(): Unit = { - super.beforeAll() - - sqlContext.range(10).select('id alias "a").registerTempTable("t0") - sqlContext.range(10).select('id alias "b").registerTempTable("t1") - } - - protected override def afterAll(): Unit = { - sqlContext.dropTempTable("t0") - sqlContext.dropTempTable("t1") - - super.afterAll() - } - - ignore("single row project") { - checkSQL(OneRowRelation.select(lit(1)), "SELECT 1 AS `1`") - checkSQL(OneRowRelation.select(lit(1) as 'a), "SELECT 1 AS `a`") - } - - ignore("project with limit") { - checkSQL(OneRowRelation.select(lit(1)).limit(1), "SELECT 1 AS `1` LIMIT 1") - checkSQL(OneRowRelation.select(lit(1) as 'a).limit(1), "SELECT 1 AS `a` LIMIT 1") - } - - ignore("table lookup") { - checkSQL(sqlContext.table("t0"), "SELECT `t0`.`a` FROM `t0`") - checkSQL(sqlContext.table("t1").select('b alias "c"), "SELECT `t1`.`id` AS `c` FROM `t1`") - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQLSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala similarity index 98% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQLSQLBuilderSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index aa24e64a16ec2..9753f64cd389d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQLSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.functions._ -class HiveQLSQLBuilderSuite extends SQLBuilderTest with SQLTestUtils { +class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { import testImplicits._ protected override def beforeAll(): Unit = { From 11807bdaaedc836da0a75ff2dc6c5efde9d62e9e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 7 Jan 2016 14:38:14 +0800 Subject: [PATCH 12/15] Adds ScalaDoc to newly introduced interfaces --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 4 ++++ .../main/scala/org/apache/spark/sql/hive/SQLBuilder.scala | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3bb23fda9c4cb..eb4864bcf2467 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -225,6 +225,10 @@ abstract class Expression extends TreeNode[Expression] { .replace("*/", "\\*\\/") .replace("\\u", "\\\\u") + /** + * Returns SQL representation of this expression if there is one (some expressions, e.g. + * `ScalaUDF`, don't have SQL representations). + */ def sql: Option[String] = None } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index f5a1ae5d0941b..42f64b652e382 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -29,6 +29,12 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.{DataFrame, SQLContext} +/** + * A builder class used to convert a resolved logical plan into a SQL query string. Note that this + * all resolved logical plan are convertible. They either don't have corresponding SQL + * representations (e.g. logical plans that operate on local Scala collections), or are simply not + * supported by this builder (yet). + */ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) From a304392602160ac5cf84b83eda60170842e85200 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 7 Jan 2016 21:58:37 +0800 Subject: [PATCH 13/15] Makes Expression.sql return String instead of Option[String] --- .../spark/sql/catalyst/expressions/Cast.scala | 9 +- .../sql/catalyst/expressions/Expression.scala | 24 ++--- .../catalyst/expressions/InputFileName.scala | 1 + .../MonotonicallyIncreasingID.scala | 4 + .../expressions/aggregate/interfaces.scala | 16 ++-- .../sql/catalyst/expressions/arithmetic.scala | 11 +-- .../expressions/complexTypeExtractors.scala | 2 +- .../expressions/conditionalExpressions.scala | 41 ++++----- .../expressions/datetimeExpressions.scala | 18 ++++ .../expressions/decimalExpressions.scala | 4 +- .../sql/catalyst/expressions/literals.scala | 4 +- .../expressions/mathExpressions.scala | 2 +- .../spark/sql/catalyst/expressions/misc.scala | 4 + .../expressions/namedExpressions.scala | 8 +- .../expressions/nullExpressions.scala | 6 +- .../sql/catalyst/expressions/predicates.scala | 39 ++++---- .../expressions/randomExpressions.scala | 2 +- .../expressions/regexpExpressions.scala | 5 +- .../expressions/stringExpressions.scala | 88 +------------------ .../apache/spark/sql/hive/SQLBuilder.scala | 61 ++++++------- .../org/apache/spark/sql/hive/hiveUDFs.scala | 25 ++---- .../sql/hive/LogicalPlanToSQLSuite.scala | 23 ++++- .../spark/sql/hive/SQLBuilderTest.scala | 13 +-- 23 files changed, 165 insertions(+), 245 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d8c906e0ab444..6f199cfc5d8cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -932,9 +932,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { """ } - override def sql: Option[String] = dataType match { - case _: ArrayType | _: MapType | _: StructType => None - case _ => child.sql.map(childSQL => s"CAST($childSQL AS ${dataType.sql})") + override def sql: String = dataType match { + // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this + // type of casting can only be introduced by the analyzer, and can be omitted when converting + // back to SQL query string. + case _: ArrayType | _: MapType | _: StructType => child.sql + case _ => s"CAST(${child.sql} AS ${dataType.sql})" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index eb4864bcf2467..d6219514b752b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -226,10 +226,13 @@ abstract class Expression extends TreeNode[Expression] { .replace("\\u", "\\\\u") /** - * Returns SQL representation of this expression if there is one (some expressions, e.g. - * `ScalaUDF`, don't have SQL representations). + * Returns SQL representation of this expression. For expressions that don't have a SQL + * representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`. */ - def sql: Option[String] = None + @throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation") + def sql: String = throw new UnsupportedOperationException( + s"Cannot map expression $this to its SQL representation" + ) } @@ -364,7 +367,7 @@ abstract class UnaryExpression extends Expression { } } - override def sql: Option[String] = child.sql.map(childSQL => s"($prettyName($childSQL))") + override def sql: String = s"($prettyName(${child.sql}))" } @@ -465,6 +468,8 @@ abstract class BinaryExpression extends Expression { """ } } + + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } @@ -502,10 +507,7 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { } } - override def sql: Option[String] = for { - lhs <- left.sql - rhs <- right.sql - } yield s"($lhs $symbol $rhs)" + override def sql: String = s"(${left.sql} $symbol ${right.sql})" } @@ -608,8 +610,8 @@ abstract class TernaryExpression extends Expression { } } - override def sql: Option[String] = sequenceOption(children.map(_.sql)).map { - case Seq(child1, child2, child3) => - s"$prettyName($child1, $child2, $child3)" + override def sql: String = { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index f33833c3918df..827dce8af100e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -49,4 +49,5 @@ case class InputFileName() extends LeafExpression with Nondeterministic { "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" } + override def sql: String = prettyName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index d0b78e15d99d1..94f8801dec369 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -78,4 +78,8 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with $countTerm++; """ } + + override def prettyName: String = "monotonically_increasing_id" + + override def sql: String = s"$prettyName()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index d94ce35fa3878..ddd99c51ab0c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -96,16 +96,7 @@ private[sql] case class AggregateExpression( override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)" - override def sql: Option[String] = { - val name = aggregateFunction.prettyName - val argsSQL = sequenceOption(aggregateFunction.children.map(_.sql)) - - if (isDistinct) { - argsSQL.map(args => s"$name(DISTINCT ${args.mkString(", ")})") - } else { - argsSQL.map(args => s"$name(${args.mkString(", ")})") - } - } + override def sql: String = aggregateFunction.sql(isDistinct) } /** @@ -175,6 +166,11 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) } + + def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else " " + s"$prettyName($distinct${children.map(_.sql).mkString(", ")})" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 0f5e4c949cb25..7bd851c059d0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -55,7 +55,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp } } - override def sql: Option[String] = child.sql.map(sql => s"(-$sql)") + override def sql: String = s"(-${child.sql})" } case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -70,7 +70,7 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects protected override def nullSafeEval(input: Any): Any = input - override def sql: Option[String] = child.sql.map(sql => s"(+$sql)") + override def sql: String = s"(+${child.sql})" } /** @@ -96,7 +96,7 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes protected override def nullSafeEval(input: Any): Any = numeric.abs(input) - override def sql: Option[String] = child.sql.map(sql => s"$prettyName($sql)") + override def sql: String = s"$prettyName(${child.sql})" } abstract class BinaryArithmetic extends BinaryOperator { @@ -520,8 +520,5 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r } - override def sql: Option[String] = for { - leftSQL <- left.sql - rightSQL <- right.sql - } yield s"$prettyName($leftSQL, $rightSQL)" + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index d8bb8fe9e0fd0..5bd97cc7467ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -131,7 +131,7 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] }) } - override def sql: Option[String] = child.sql.map(_ + s".`${childSchema(ordinal).name}`") + override def sql: String = child.sql + s".`${childSchema(ordinal).name}`" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 37857de9b578f..19da849d2bec9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -75,11 +75,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def toString: String = s"if ($predicate) $trueValue else $falseValue" - override def sql: Option[String] = for { - predicateSQL <- predicate.sql - trueSQL <- trueValue.sql - falseSQL <- falseValue.sql - } yield s"(IF($predicateSQL, $trueSQL, $falseSQL))" + override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" } trait CaseWhenLike extends Expression { @@ -213,23 +209,21 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { }.mkString } - override def sql: Option[String] = { - sequenceOption(branches.map(_.sql)).map { - case branchesSQL => - val (cases, maybeElse) = if (branches.length % 2 == 0) { - (branchesSQL, None) - } else { - (branchesSQL.init, Some(branchesSQL.last)) - } + override def sql: String = { + val branchesSQL = branches.map(_.sql) + val (cases, maybeElse) = if (branches.length % 2 == 0) { + (branchesSQL, None) + } else { + (branchesSQL.init, Some(branchesSQL.last)) + } - val head = s"CASE " - val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" - val body = cases.grouped(2).map { - case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" - }.mkString(" ") + val head = s"CASE " + val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" + val body = cases.grouped(2).map { + case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" + }.mkString(" ") - head + body + tail - } + head + body + tail } } @@ -336,10 +330,9 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW }.mkString } - override def sql: Option[String] = for { - keySQL <- key.sql - branchesSQL <- sequenceOption(branches.map(_.sql)) - } yield { + override def sql: String = { + val keySQL = key.sql + val branchesSQL = branches.map(_.sql) val (cases, maybeElse) = if (branches.length % 2 == 0) { (branchesSQL, None) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 0ce62f84a9aa0..17f1df06f2fad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -44,6 +44,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback { override def eval(input: InternalRow): Any = { DateTimeUtils.millisToDays(System.currentTimeMillis()) } + + override def prettyName: String = "current_date" } /** @@ -61,6 +63,8 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def eval(input: InternalRow): Any = { System.currentTimeMillis() * 1000L } + + override def prettyName: String = "current_timestamp" } /** @@ -85,6 +89,8 @@ case class DateAdd(startDate: Expression, days: Expression) s"""${ev.value} = $sd + $d;""" }) } + + override def prettyName: String = "date_add" } /** @@ -108,6 +114,8 @@ case class DateSub(startDate: Expression, days: Expression) s"""${ev.value} = $sd - $d;""" }) } + + override def prettyName: String = "date_sub" } case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { @@ -441,6 +449,8 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { """ } } + + override def prettyName: String = "unix_time" } /** @@ -455,6 +465,8 @@ case class FromUnixTime(sec: Expression, format: Expression) override def left: Expression = sec override def right: Expression = format + override def prettyName: String = "from_unixtime" + def this(unix: Expression) = { this(unix, Literal("yyyy-MM-dd HH:mm:ss")) } @@ -737,6 +749,8 @@ case class AddMonths(startDate: Expression, numMonths: Expression) s"""$dtu.dateAddMonths($sd, $m)""" }) } + + override def prettyName: String = "add_months" } /** @@ -762,6 +776,8 @@ case class MonthsBetween(date1: Expression, date2: Expression) s"""$dtu.monthsBetween($l, $r)""" }) } + + override def prettyName: String = "months_between" } /** @@ -827,6 +843,8 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, d => d) } + + override def prettyName: String = "to_date" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 188a5f4686113..5f8b544edb511 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -73,7 +73,7 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" override def prettyName: String = "promote_precision" - override def sql: Option[String] = child.sql + override def sql: String = child.sql } /** @@ -109,5 +109,5 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary override def toString: String = s"CheckOverflow($child, $dataType)" - override def sql: Option[String] = child.sql + override def sql: String = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index ab999674567c1..0eb915fdc1691 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -215,7 +215,7 @@ case class Literal protected (value: Any, dataType: DataType) } } - override def sql: Option[String] = Option((value, dataType) match { + override def sql: String = (value, dataType) match { case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL" @@ -248,7 +248,7 @@ case class Literal protected (value: Any, dataType: DataType) s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" case _ => value.toString - }) + } } // TODO: Specialize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index e10f58cd2d318..66d8631a846ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -71,7 +71,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } - override def sql: Option[String] = child.sql.map(childSQL => s"$name($childSQL)") + override def sql: String = s"$name(${child.sql})" } abstract class UnaryLogExpression(f: Double => Double, name: String) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 6697d463614d5..0f229db2b08a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -220,4 +220,8 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression final int ${ev.value} = ${unsafeRow.value}.hashCode($seed); """ } + + override def prettyName: String = "hash" + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 676e57c5c633a..eee708cb02f9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -165,10 +165,10 @@ case class Alias(child: Expression, name: String)( case _ => false } - override def sql: Option[String] = { + override def sql: String = { val qualifiersString = if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") - child.sql.map(childSQL => s"$childSQL AS $qualifiersString`$name`") + s"${child.sql} AS $qualifiersString`$name`" } } @@ -278,10 +278,10 @@ case class AttributeReference( // tree string. override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" - override def sql: Option[String] = { + override def sql: String = { val qualifiersString = if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") - Some(s"$qualifiersString`$name`") + s"$qualifiersString`$name`" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index a3c9238285a9d..89aec2b20fd0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -83,6 +83,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """ }.mkString("\n") } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } @@ -194,7 +196,7 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { eval.code } - override def sql: Option[String] = child.sql.map(childSQL => s"($childSQL IS NULL)") + override def sql: String = s"(${child.sql} IS NULL)" } @@ -215,7 +217,7 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { eval.code } - override def sql: Option[String] = child.sql.map(childSQL => s"($childSQL IS NOT NULL)") + override def sql: String = s"(${child.sql} IS NOT NULL)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2cb42515396b6..bca12a8d21023 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -102,7 +102,7 @@ case class Not(child: Expression) defineCodeGen(ctx, ev, c => s"!($c)") } - override def sql: Option[String] = child.sql.map(childSQL => s"(NOT $childSQL)") + override def sql: String = s"(NOT ${child.sql})" } @@ -179,11 +179,12 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate """ } - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - valueSQL = childrenSQL.head - listSQL = childrenSQL.tail - } yield s"($valueSQL IN (${listSQL.mkString(", ")}))" + override def sql: String = { + val childrenSQL = children.map(_.sql) + val valueSQL = childrenSQL.head + val listSQL = childrenSQL.tail.mkString(", ") + s"($valueSQL IN ($listSQL))" + } } /** @@ -235,10 +236,11 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with """ } - override def sql: Option[String] = for { - valueSQL <- child.sql - listSQL <- sequenceOption(hset.toSeq.map(Literal(_).sql)) - } yield s"($valueSQL IN (${listSQL.mkString(", ")}))" + override def sql: String = { + val valueSQL = child.sql + val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ") + s"($valueSQL IN ($listSQL))" + } } case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { @@ -288,10 +290,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with """ } - override def sql: Option[String] = for { - lhs <- left.sql - rhs <- right.sql - } yield s"($lhs AND $rhs)" + override def sql: String = s"(${left.sql} AND ${right.sql})" } @@ -342,10 +341,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P """ } - override def sql: Option[String] = for { - lhs <- left.sql - rhs <- right.sql - } yield s"($lhs OR $rhs)" + override def sql: String = s"(${left.sql} OR ${right.sql})" } @@ -362,11 +358,6 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0") } } - - override def sql: Option[String] = for { - lhs <- left.sql - rhs <- right.sql - } yield s"($lhs $symbol $rhs)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index f6cdd2bed2513..8de47e9ddc28d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -51,7 +51,7 @@ abstract class RDG extends LeafExpression with Nondeterministic { override def dataType: DataType = DoubleType // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed. - override def sql: Option[String] = Some(s"$prettyName($seed)") + override def sql: String = s"$prettyName($seed)" } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index e6e043047da6c..db266639b8560 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -60,10 +60,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes { } } - override def sql: Option[String] = for { - leftSQL <- left.sql - rightSQL <- right.sql - } yield s"$leftSQL ${prettyName.toUpperCase} $rightSQL" + override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 2dd26dd5d7e46..931f752b4dc1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -63,9 +63,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas """ } - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } @@ -159,9 +157,7 @@ case class ConcatWs(children: Seq[Expression]) } } - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } trait String2StringExpression extends ImplicitCastInputTypes { @@ -174,10 +170,6 @@ trait String2StringExpression extends ImplicitCastInputTypes { protected override def nullSafeEval(input: Any): Any = convert(input.asInstanceOf[UTF8String]) - - override def sql: Option[String] = for { - childSQL <- child.sql - } yield s"$prettyName($childSQL)" } /** @@ -234,10 +226,6 @@ case class Contains(left: Expression, right: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") } - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -249,10 +237,6 @@ case class StartsWith(left: Expression, right: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") } - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -264,10 +248,6 @@ case class EndsWith(left: Expression, right: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } object StringTranslate { @@ -343,10 +323,6 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) override def children: Seq[Expression] = srcExpr :: matchingExpr :: replaceExpr :: Nil override def prettyName: String = "translate" - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -371,10 +347,6 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def dataType: DataType = IntegerType override def prettyName: String = "find_in_set" - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -447,10 +419,6 @@ case class StringInstr(str: Expression, substr: Expression) defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") } - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -652,10 +620,6 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC } override def prettyName: String = "format_string" - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -673,10 +637,6 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, str => s"$str.toTitleCase()") } - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -699,10 +659,6 @@ case class StringRepeat(str: Expression, times: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") } - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -738,10 +694,6 @@ case class StringSpace(child: Expression) } override def prettyName: String = "space" - - override def sql: Option[String] = for { - childSQL <- child.sql - } yield s"$prettyName($childSQL)" } /** @@ -800,10 +752,6 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") } } - - override def sql: Option[String] = for { - childSQL <- child.sql - } yield s"$prettyName($childSQL)" } /** @@ -822,10 +770,6 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres nullSafeCodeGen(ctx, ev, (left, right) => s"${ev.value} = $left.levenshteinDistance($right);") } - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -842,10 +786,6 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"$c.soundex()") } - - override def sql: Option[String] = for { - childSQL <- child.sql - } yield s"$prettyName($childSQL)" } /** @@ -877,10 +817,6 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp } """}) } - - override def sql: Option[String] = for { - childSQL <- child.sql - } yield s"$prettyName($childSQL)" } /** @@ -903,10 +839,6 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn org.apache.commons.codec.binary.Base64.encodeBase64($child)); """}) } - - override def sql: Option[String] = for { - childSQL <- child.sql - } yield s"$prettyName($childSQL)" } /** @@ -926,10 +858,6 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast ${ev.value} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); """}) } - - override def sql: Option[String] = for { - childSQL <- child.sql - } yield s"$prettyName($childSQL)" } /** @@ -960,10 +888,6 @@ case class Decode(bin: Expression, charset: Expression) } """) } - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -993,10 +917,6 @@ case class Encode(value: Expression, charset: Expression) org.apache.spark.unsafe.Platform.throwException(e); }""") } - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } /** @@ -1112,8 +1032,4 @@ case class FormatNumber(x: Expression, d: Expression) } override def prettyName: String = "format_number" - - override def sql: Option[String] = for { - childrenSQL <- sequenceOption(children.map(_.sql)) - } yield s"$prettyName(${childrenSQL.mkString(", ")})" } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 42f64b652e382..1c910051faccf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.hive import java.util.concurrent.atomic.AtomicLong import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} -import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.{DataFrame, SQLContext} @@ -40,7 +39,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi def toSQL: Option[String] = { val canonicalizedPlan = Canonicalizer.execute(logicalPlan) - val maybeSQL = toSQL(canonicalizedPlan) + val maybeSQL = try { + toSQL(canonicalizedPlan) + } catch { case cause: UnsupportedOperationException => + logInfo(s"Failed to build SQL query string because: ${cause.getMessage}") + None + } if (maybeSQL.isDefined) { logDebug( @@ -72,31 +76,30 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi child: LogicalPlan, isDistinct: Boolean): Option[String] = { for { - listSQL <- sequenceOption(projectList.map(_.sql)) childSQL <- toSQL(child) - from = child match { - case OneRowRelation => "" + listSQL = projectList.map(_.sql).mkString(", ") + maybeFrom = child match { + case OneRowRelation => " " case _ => " FROM " } - distinct = if (isDistinct) " DISTINCT" else "" - } yield s"SELECT$distinct ${listSQL.mkString(", ")}$from$childSQL" + distinct = if (isDistinct) " DISTINCT " else " " + } yield s"SELECT$distinct$listSQL$maybeFrom$childSQL" } private def aggregateToSQL( groupingExprs: Seq[Expression], aggExprs: Seq[Expression], child: LogicalPlan): Option[String] = { - for { - aggSQL <- sequenceOption(aggExprs.map(_.sql)) - groupingSQL <- sequenceOption(groupingExprs.map(_.sql)) - maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY " - maybeFrom = child match { - case OneRowRelation => "" - case _ => " FROM " - } - childSQL <- toSQL(child).map(maybeFrom + _) - } yield { - s"SELECT ${aggSQL.mkString(", ")}$childSQL$maybeGroupBy${groupingSQL.mkString(", ")}" + val aggSQL = aggExprs.map(_.sql).mkString(", ") + val groupingSQL = groupingExprs.map(_.sql).mkString(", ") + val maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY " + val maybeFrom = child match { + case OneRowRelation => " " + case _ => " FROM " + } + + toSQL(child).map { childSQL => + s"SELECT $aggSQL$maybeFrom$childSQL$maybeGroupBy$groupingSQL" } } @@ -112,18 +115,18 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case Limit(limit, child) => for { - limitSQL <- limit.sql childSQL <- toSQL(child) + limitSQL = limit.sql } yield s"$childSQL LIMIT $limitSQL" case Filter(condition, child) => for { - conditionSQL <- condition.sql childSQL <- toSQL(child) whereOrHaving = child match { case _: Aggregate => "HAVING" case _ => "WHERE" } + conditionSQL = condition.sql } yield s"$childSQL $whereOrHaving $conditionSQL" case Union(left, right) => @@ -148,7 +151,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi leftSQL <- toSQL(left) rightSQL <- toSQL(right) joinTypeSQL = joinType.sql - conditionSQL = condition.flatMap(_.sql).map(" ON " + _).getOrElse("") + conditionSQL = condition.map(" ON " + _.sql).getOrElse("") } yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL" case MetastoreRelation(database, table, alias) => @@ -158,24 +161,22 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) if orders.map(_.child) == partitionExprs => for { - partitionExprsSQL <- sequenceOption(partitionExprs.map(_.sql)) childSQL <- toSQL(child) - } yield s"$childSQL CLUSTER BY ${partitionExprsSQL.mkString(", ")}" + partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") + } yield s"$childSQL CLUSTER BY $partitionExprsSQL" case Sort(orders, global, child) => for { childSQL <- toSQL(child) - ordersSQL <- sequenceOption(orders.map { case SortOrder(e, dir) => - e.sql.map(sql => s"$sql ${dir.sql}") - }) + ordersSQL = orders.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ") orderOrSort = if (global) "ORDER" else "SORT" - } yield s"$childSQL $orderOrSort BY ${ordersSQL.mkString(", ")}" + } yield s"$childSQL $orderOrSort BY $ordersSQL" case RepartitionByExpression(partitionExprs, child, _) => for { - partitionExprsSQL <- sequenceOption(partitionExprs.map(_.sql)) childSQL <- toSQL(child) - } yield s"$childSQL DISTRIBUTE BY ${partitionExprsSQL.mkString(", ")}" + partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") + } yield s"$childSQL DISTRIBUTE BY $partitionExprsSQL" case OneRowRelation => Some("") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 7cb30f0e15680..e76c18fa528f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -189,11 +189,7 @@ private[hive] case class HiveSimpleUDF( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - override def sql: Option[String] = { - sequenceOption(children.map(_.sql)).map { argsSQL => - s"$name(${argsSQL.mkString(", ")})" - } - } + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } // Adapter from Catalyst ExpressionResult to Hive DeferredObject @@ -262,11 +258,7 @@ private[hive] case class HiveGenericUDF( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - override def sql: Option[String] = { - sequenceOption(children.map(_.sql)).map { argsSQL => - s"$name(${argsSQL.mkString(", ")})" - } - } + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -348,11 +340,7 @@ private[hive] case class HiveGenericUDTF( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - override def sql: Option[String] = { - sequenceOption(children.map(_.sql)).map { argsSQL => - s"$name(${argsSQL.mkString(", ")})" - } - } + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -446,9 +434,8 @@ private[hive] case class HiveUDAFFunction( override val dataType: DataType = inspectorToDataType(returnInspector) - override def sql: Option[String] = { - sequenceOption(children.map(_.sql)).map { argsSQL => - s"$name(${argsSQL.mkString(", ")})" - } + override def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else " " + s"$name($distinct${children.map(_.sql).mkString(", ")})" } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 9753f64cd389d..0e81acf532a03 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -49,15 +49,32 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { fail( s"""Cannot convert the following HiveQL query plan back to SQL query string: | - |Original HiveQL query string: + |# Original HiveQL query string: |$hiveQl | - |Resolved query plan: + |# Resolved query plan: |${df.queryExecution.analyzed.treeString} """.stripMargin) } - checkAnswer(sql(convertedSQL.get), df) + val sqlString = convertedSQL.get + try { + checkAnswer(sql(sqlString), df) + } catch { case cause: Throwable => + fail( + s"""Failed to execute converted SQL string or got wrong answer: + | + |# Converted SQL query string: + |$sqlString + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin, + cause) + } } test("in") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala index a6fa4e76bf910..cf4a3fdd88806 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala @@ -24,18 +24,9 @@ import org.apache.spark.sql.{DataFrame, QueryTest} abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { protected def checkSQL(e: Expression, expectedSQL: String): Unit = { - val maybeSQL = e.sql - - if (maybeSQL.isEmpty) { - fail( - s"""Cannot convert the following expression to SQL form: - | - |${e.treeString} - """.stripMargin) - } - + val actualSQL = e.sql try { - assert(maybeSQL.get === expectedSQL) + assert(actualSQL === expectedSQL) } catch { case cause: Throwable => fail( From 2073e305f994d1f3b850fcf18e37002761979670 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 8 Jan 2016 01:05:36 +0800 Subject: [PATCH 14/15] Migrates test cases for ComputeCurrentTime --- .../sql/catalyst/analysis/AnalysisSuite.scala | 38 ----------- .../optimizer/ComputeCurrentTimeSuite.scala | 68 +++++++++++++++++++ 2 files changed, 68 insertions(+), 38 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index fa823e3021835..cf84855885a37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class AnalysisSuite extends AnalysisTest { @@ -238,43 +237,6 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } - test("analyzer should replace current_timestamp with literals") { - val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), - LocalRelation()) - - val min = System.currentTimeMillis() * 1000 - val plan = in.analyze.asInstanceOf[Project] - val max = (System.currentTimeMillis() + 1) * 1000 - - val lits = new scala.collection.mutable.ArrayBuffer[Long] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Long] - e - } - assert(lits.size == 2) - assert(lits(0) >= min && lits(0) <= max) - assert(lits(1) >= min && lits(1) <= max) - assert(lits(0) == lits(1)) - } - - test("analyzer should replace current_date with literals") { - val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) - - val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) - val plan = in.analyze.asInstanceOf[Project] - val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) - - val lits = new scala.collection.mutable.ArrayBuffer[Int] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Int] - e - } - assert(lits.size == 2) - assert(lits(0) >= min && lits(0) <= max) - assert(lits(1) >= min && lits(1) <= max) - assert(lits(0) == lits(1)) - } - test("SPARK-12102: Ignore nullablity when comparing two sides of case") { val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala new file mode 100644 index 0000000000000..10ed4e46ddd1c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +class ComputeCurrentTimeSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) + } + + test("analyzer should replace current_timestamp with literals") { + val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), + LocalRelation()) + + val min = System.currentTimeMillis() * 1000 + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val max = (System.currentTimeMillis() + 1) * 1000 + + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } + + test("analyzer should replace current_date with literals") { + val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) + + val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) + + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Int] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } +} From 97cd39e146ce1a4b49e3ca01b8a44906d7b19351 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 8 Jan 2016 12:17:17 -0800 Subject: [PATCH 15/15] Makes CreateCurrentTime a separate batch --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7095078bb189b..f8121a733a8d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -37,6 +37,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: + Batch("Compute Current Time", Once, + ComputeCurrentTime) :: Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: @@ -63,9 +65,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { RemoveDispensableExpressions, SimplifyFilters, SimplifyCasts, - SimplifyCaseConversionExpressions, - // Nondeterministic - ComputeCurrentTime) :: + SimplifyCaseConversionExpressions) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100),