diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 4afce3090f739..ffb859662c8ea 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -1,9 +1,9 @@ /** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with + (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 @@ -580,7 +580,7 @@ import java.util.HashMap; return header; } - + @Override public String getErrorMessage(RecognitionException e, String[] tokenNames) { String msg = null; @@ -617,7 +617,7 @@ import java.util.HashMap; } return msg; } - + public void pushMsg(String msg, RecognizerSharedState state) { // ANTLR generated code does not wrap the @init code wit this backtracking check, // even if the matching @after has it. If we have parser rules with that are doing @@ -637,7 +637,7 @@ import java.util.HashMap; // counter to generate unique union aliases private int aliasCounter; private String generateUnionAlias() { - return "_u" + (++aliasCounter); + return "u_" + (++aliasCounter); } private char [] excludedCharForColumnName = {'.', ':'}; private boolean containExcludedCharForCreateTableColumnName(String input) { @@ -1233,7 +1233,7 @@ alterTblPartitionStatementSuffixSkewedLocation : KW_SET KW_SKEWED KW_LOCATION skewedLocations -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations) ; - + skewedLocations @init { pushMsg("skewed locations", state); } @after { popMsg(state); } @@ -1262,7 +1262,7 @@ alterStatementSuffixLocation -> ^(TOK_ALTERTABLE_LOCATION $newLoc) ; - + alterStatementSuffixSkewedby @init {pushMsg("alter skewed by statement", state);} @after{popMsg(state);} @@ -1334,10 +1334,10 @@ tabTypeExpr (identifier (DOT^ ( (KW_ELEM_TYPE) => KW_ELEM_TYPE - | + | (KW_KEY_TYPE) => KW_KEY_TYPE - | - (KW_VALUE_TYPE) => KW_VALUE_TYPE + | + (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier ))* )? @@ -1374,7 +1374,7 @@ descStatement analyzeStatement @init { pushMsg("analyze statement", state); } @after { popMsg(state); } - : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) + : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))? -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?) ; @@ -1387,7 +1387,7 @@ showStatement | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? -> ^(TOK_SHOWCOLUMNS tableName $db_name?) | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) - | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) + | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) | KW_SHOW KW_CREATE ( (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name) | @@ -1396,7 +1396,7 @@ showStatement | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec? -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?) | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?) - | KW_SHOW KW_LOCKS + | KW_SHOW KW_LOCKS ( (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?) | @@ -1509,7 +1509,7 @@ showCurrentRole setRole @init {pushMsg("set role", state);} @after {popMsg(state);} - : KW_SET KW_ROLE + : KW_SET KW_ROLE ( (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text]) | @@ -1964,7 +1964,7 @@ columnNameOrderList skewedValueElement @init { pushMsg("skewed value element", state); } @after { popMsg(state); } - : + : skewedColumnValues | skewedColumnValuePairList ; @@ -1978,8 +1978,8 @@ skewedColumnValuePairList skewedColumnValuePair @init { pushMsg("column value pair", state); } @after { popMsg(state); } - : - LPAREN colValues=skewedColumnValues RPAREN + : + LPAREN colValues=skewedColumnValues RPAREN -> ^(TOK_TABCOLVALUES $colValues) ; @@ -1999,11 +1999,11 @@ skewedColumnValue skewedValueLocationElement @init { pushMsg("skewed value location element", state); } @after { popMsg(state); } - : + : skewedColumnValue | skewedColumnValuePair ; - + columnNameOrder @init { pushMsg("column name order", state); } @after { popMsg(state); } @@ -2116,7 +2116,7 @@ unionType @after { popMsg(state); } : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList) ; - + setOperator @init { pushMsg("set operator", state); } @after { popMsg(state); } @@ -2168,7 +2168,7 @@ fromStatement[boolean topLevel] {adaptor.create(Identifier, generateUnionAlias())} ) ) - ^(TOK_INSERT + ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) ) @@ -2391,8 +2391,8 @@ setColumnsClause KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* ) ; -/* - UPDATE +/* + UPDATE
SET col1 = val1, col2 = val2... WHERE ... */ updateStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e362b55d80cd1..8a33af8207350 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -86,8 +86,7 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic, - ComputeCurrentTime), + PullOutNondeterministic), Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, @@ -1229,23 +1228,6 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** - * Computes the current date and time to make sure we return the same result in a single query. - */ -object ComputeCurrentTime extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val dateExpr = CurrentDate() - val timeExpr = CurrentTimestamp() - val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) - val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) - - plan transformAllExpressions { - case CurrentDate() => currentDate - case CurrentTimestamp() => currentTime - } - } -} - /** * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index e8b2fcf819bf6..a8f89ce6de457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -110,7 +110,9 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. - alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) + alias + .map(a => Subquery(a, tableWithQualifiers)) + .getOrElse(tableWithQualifiers) } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d82d3edae4e38..6f199cfc5d8cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -931,6 +931,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { $evPrim = $result.copy(); """ } + + override def sql: String = dataType match { + // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this + // type of casting can only be introduced by the analyzer, and can be omitted when converting + // back to SQL query string. + case _: ArrayType | _: MapType | _: StructType => child.sql + case _ => s"CAST(${child.sql} AS ${dataType.sql})" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6a9c12127d367..d6219514b752b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -223,6 +224,15 @@ abstract class Expression extends TreeNode[Expression] { protected def toCommentSafeString: String = this.toString .replace("*/", "\\*\\/") .replace("\\u", "\\\\u") + + /** + * Returns SQL representation of this expression. For expressions that don't have a SQL + * representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`. + */ + @throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation") + def sql: String = throw new UnsupportedOperationException( + s"Cannot map expression $this to its SQL representation" + ) } @@ -356,6 +366,8 @@ abstract class UnaryExpression extends Expression { """ } } + + override def sql: String = s"($prettyName(${child.sql}))" } @@ -456,6 +468,8 @@ abstract class BinaryExpression extends Expression { """ } } + + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } @@ -492,6 +506,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { TypeCheckResult.TypeCheckSuccess } } + + override def sql: String = s"(${left.sql} $symbol ${right.sql})" } @@ -593,4 +609,9 @@ abstract class TernaryExpression extends Expression { """ } } + + override def sql: String = { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index f33833c3918df..827dce8af100e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -49,4 +49,5 @@ case class InputFileName() extends LeafExpression with Nondeterministic { "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" } + override def sql: String = prettyName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index d0b78e15d99d1..94f8801dec369 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -78,4 +78,8 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with $countTerm++; """ } + + override def prettyName: String = "monotonically_increasing_id" + + override def sql: String = s"$prettyName()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3add722da7816..1cb1b9da3049b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -24,9 +24,17 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator -abstract sealed class SortDirection -case object Ascending extends SortDirection -case object Descending extends SortDirection +abstract sealed class SortDirection { + def sql: String +} + +case object Ascending extends SortDirection { + override def sql: String = "ASC" +} + +case object Descending extends SortDirection { + override def sql: String = "DESC" +} /** * An expression that can be used to sort a tuple. This class extends expression primarily so that diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index b47f32d1768b9..ddd99c51ab0c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ @@ -93,11 +94,13 @@ private[sql] case class AggregateExpression( override def prettyString: String = aggregateFunction.prettyString - override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" + override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)" + + override def sql: String = aggregateFunction.sql(isDistinct) } /** - * AggregateFunction2 is the superclass of two aggregation function interfaces: + * AggregateFunction is the superclass of two aggregation function interfaces: * * - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of * initialize(), update(), and merge() functions that operate on Row-based aggregation buffers. @@ -163,6 +166,11 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) } + + def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else " " + s"$prettyName($distinct${children.map(_.sql).mkString(", ")})" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 61a17fd7db0fe..7bd851c059d0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -54,6 +54,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp numeric.negate(input) } } + + override def sql: String = s"(-${child.sql})" } case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -67,6 +69,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects defineCodeGen(ctx, ev, c => c) protected override def nullSafeEval(input: Any): Any = input + + override def sql: String = s"(+${child.sql})" } /** @@ -91,6 +95,8 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) + + override def sql: String = s"$prettyName(${child.sql})" } abstract class BinaryArithmetic extends BinaryOperator { @@ -513,4 +519,6 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { val r = a % n if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r } + + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9c73239f67ff2..5bd97cc7467ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -130,6 +130,8 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] } }) } + + override def sql: String = child.sql + s".`${childSchema(ordinal).name}`" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index f79c8676fb58c..19da849d2bec9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils} import org.apache.spark.sql.types._ @@ -74,6 +74,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } override def toString: String = s"if ($predicate) $trueValue else $falseValue" + + override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" } trait CaseWhenLike extends Expression { @@ -110,7 +112,7 @@ trait CaseWhenLike extends Expression { override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) + thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true) } } @@ -206,6 +208,23 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { case Seq(elseValue) => s" ELSE $elseValue" }.mkString } + + override def sql: String = { + val branchesSQL = branches.map(_.sql) + val (cases, maybeElse) = if (branches.length % 2 == 0) { + (branchesSQL, None) + } else { + (branchesSQL.init, Some(branchesSQL.last)) + } + + val head = s"CASE " + val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" + val body = cases.grouped(2).map { + case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" + }.mkString(" ") + + head + body + tail + } } // scalastyle:off @@ -310,6 +329,24 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW case Seq(elseValue) => s" ELSE $elseValue" }.mkString } + + override def sql: String = { + val keySQL = key.sql + val branchesSQL = branches.map(_.sql) + val (cases, maybeElse) = if (branches.length % 2 == 0) { + (branchesSQL, None) + } else { + (branchesSQL.init, Some(branchesSQL.last)) + } + + val head = s"CASE $keySQL " + val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" + val body = cases.grouped(2).map { + case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" + }.mkString(" ") + + head + body + tail + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 3d65946a1bc65..17f1df06f2fad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -44,6 +44,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback { override def eval(input: InternalRow): Any = { DateTimeUtils.millisToDays(System.currentTimeMillis()) } + + override def prettyName: String = "current_date" } /** @@ -61,6 +63,8 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def eval(input: InternalRow): Any = { System.currentTimeMillis() * 1000L } + + override def prettyName: String = "current_timestamp" } /** @@ -85,6 +89,8 @@ case class DateAdd(startDate: Expression, days: Expression) s"""${ev.value} = $sd + $d;""" }) } + + override def prettyName: String = "date_add" } /** @@ -108,6 +114,8 @@ case class DateSub(startDate: Expression, days: Expression) s"""${ev.value} = $sd - $d;""" }) } + + override def prettyName: String = "date_sub" } case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { @@ -309,6 +317,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix def this(time: Expression) = { this(time, Literal("yyyy-MM-dd HH:mm:ss")) } + + override def prettyName: String = "to_unix_timestamp" } /** @@ -332,6 +342,8 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTi def this() = { this(CurrentTimestamp()) } + + override def prettyName: String = "unix_timestamp" } abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { @@ -437,6 +449,8 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { """ } } + + override def prettyName: String = "unix_time" } /** @@ -451,6 +465,8 @@ case class FromUnixTime(sec: Expression, format: Expression) override def left: Expression = sec override def right: Expression = format + override def prettyName: String = "from_unixtime" + def this(unix: Expression) = { this(unix, Literal("yyyy-MM-dd HH:mm:ss")) } @@ -733,6 +749,8 @@ case class AddMonths(startDate: Expression, numMonths: Expression) s"""$dtu.dateAddMonths($sd, $m)""" }) } + + override def prettyName: String = "add_months" } /** @@ -758,6 +776,8 @@ case class MonthsBetween(date1: Expression, date2: Expression) s"""$dtu.monthsBetween($l, $r)""" }) } + + override def prettyName: String = "months_between" } /** @@ -823,6 +843,8 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, d => d) } + + override def prettyName: String = "to_date" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index c54bcdd774021..5f8b544edb511 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -73,6 +73,7 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" override def prettyName: String = "promote_precision" + override def sql: String = child.sql } /** @@ -107,4 +108,6 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary } override def toString: String = s"CheckOverflow($child, $dataType)" + + override def sql: String = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 672cc9c45e0af..0eb915fdc1691 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -21,9 +21,9 @@ import java.sql.{Date, Timestamp} import org.json4s.JsonAST._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types._ @@ -214,6 +214,41 @@ case class Literal protected (value: Any, dataType: DataType) } } } + + override def sql: String = (value, dataType) match { + case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => + "NULL" + + case _ if value == null => + s"CAST(NULL AS ${dataType.sql})" + + case (v: UTF8String, StringType) => + // Escapes all backslashes and double quotes. + "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\"" + + case (v: Byte, ByteType) => + s"CAST($v AS ${ByteType.simpleString.toUpperCase})" + + case (v: Short, ShortType) => + s"CAST($v AS ${ShortType.simpleString.toUpperCase})" + + case (v: Long, LongType) => + s"CAST($v AS ${LongType.simpleString.toUpperCase})" + + case (v: Float, FloatType) => + s"CAST($v AS ${FloatType.simpleString.toUpperCase})" + + case (v: Decimal, DecimalType.Fixed(precision, scale)) => + s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))" + + case (v: Int, DateType) => + s"DATE '${DateTimeUtils.toJavaDate(v)}'" + + case (v: Long, TimestampType) => + s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" + + case _ => value.toString + } } // TODO: Specialize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 002f5929cc26b..66d8631a846ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -70,6 +70,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } + + override def sql: String = s"$name(${child.sql})" } abstract class UnaryLogExpression(f: Double => Double, name: String) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 6697d463614d5..0f229db2b08a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -220,4 +220,8 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression final int ${ev.value} = ${unsafeRow.value}.hashCode($seed); """ } + + override def prettyName: String = "hash" + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index eefd9c7482553..eee708cb02f9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -164,6 +164,12 @@ case class Alias(child: Expression, name: String)( explicitMetadata == a.explicitMetadata case _ => false } + + override def sql: String = { + val qualifiersString = + if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + s"${child.sql} AS $qualifiersString`$name`" + } } /** @@ -271,6 +277,12 @@ case class AttributeReference( // Since the expression id is not in the first constructor it is missing from the default // tree string. override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" + + override def sql: String = { + val qualifiersString = + if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + s"$qualifiersString`$name`" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index df4747d4e6f7a..89aec2b20fd0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -83,6 +83,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """ }.mkString("\n") } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } @@ -193,6 +195,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { ev.value = eval.isNull eval.code } + + override def sql: String = s"(${child.sql} IS NULL)" } @@ -212,6 +216,8 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { ev.value = s"(!(${eval.isNull}))" eval.code } + + override def sql: String = s"(${child.sql} IS NOT NULL)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 304b438c84ba4..bca12a8d21023 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -101,6 +101,8 @@ case class Not(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"!($c)") } + + override def sql: String = s"(NOT ${child.sql})" } @@ -176,6 +178,13 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } """ } + + override def sql: String = { + val childrenSQL = children.map(_.sql) + val valueSQL = childrenSQL.head + val listSQL = childrenSQL.tail.mkString(", ") + s"($valueSQL IN ($listSQL))" + } } /** @@ -226,6 +235,12 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } """ } + + override def sql: String = { + val valueSQL = child.sql + val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ") + s"($valueSQL IN ($listSQL))" + } } case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { @@ -274,6 +289,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } """ } + + override def sql: String = s"(${left.sql} AND ${right.sql})" } @@ -323,6 +340,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } """ } + + override def sql: String = s"(${left.sql} OR ${right.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bde8cb9fe876..8de47e9ddc28d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -49,6 +49,9 @@ abstract class RDG extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = DoubleType + + // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed. + override def sql: String = s"$prettyName($seed)" } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index adef6050c3565..db266639b8560 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -59,6 +59,8 @@ trait StringRegexExpression extends ImplicitCastInputTypes { matches(regex, input1.asInstanceOf[UTF8String].toString) } } + + override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 50c8b9d59847e..931f752b4dc1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -61,6 +62,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } @@ -153,6 +156,8 @@ case class ConcatWs(children: Seq[Expression]) """ } } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } trait String2StringExpression extends ImplicitCastInputTypes { @@ -292,24 +297,24 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac val termDict = ctx.freshName("dict") val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;") - ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;") - ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;") + ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") + ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") + ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { - s"${termDict} == null" + s"$termDict == null" } else { - s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})" + s"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)" } s"""if ($check) { // Not all of them is literal or matching or replace value changed - ${termLastMatching} = ${matching}.clone(); - ${termLastReplace} = ${replace}.clone(); - ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate - .buildDict(${termLastMatching}, ${termLastReplace}); + $termLastMatching = $matching.clone(); + $termLastReplace = $replace.clone(); + $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate + .buildDict($termLastMatching, $termLastReplace); } - ${ev.value} = ${src}.translate(${termDict}); + ${ev.value} = $src.translate($termDict); """ }) } @@ -340,6 +345,8 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } override def dataType: DataType = IntegerType + + override def prettyName: String = "find_in_set" } /** @@ -832,7 +839,6 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn org.apache.commons.codec.binary.Base64.encodeBase64($child)); """}) } - } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0b1c74293bb8b..f8121a733a8d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -37,6 +37,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: + Batch("Compute Current Time", Once, + ComputeCurrentTime) :: Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: @@ -333,6 +335,39 @@ object ProjectCollapsing extends Rule[LogicalPlan] { ) Project(cleanedProjection, child) } + + // TODO Eliminate duplicate code + // This clause is identical to the one above except that the inner operator is an `Aggregate` + // rather than a `Project`. + case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) => + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliasMap = AttributeMap(projectList2.collect { + case a: Alias => (a.toAttribute, a) + }) + + // We only collapse these two Projects if their overlapped expressions are all + // deterministic. + val hasNondeterministic = projectList1.exists(_.collect { + case a: Attribute if aliasMap.contains(a) => aliasMap(a).child + }.exists(!_.deterministic)) + + if (hasNondeterministic) { + p + } else { + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + // TODO: Fix TransformBase to avoid the cast below. + val substitutedProjection = projectList1.map(_.transform { + case a: Attribute => aliasMap.getOrElse(a, a) + }).asInstanceOf[Seq[NamedExpression]] + // collapse 2 projects may introduce unnecessary Aliases, trim them here. + val cleanedProjection = substitutedProjection.map(p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + ) + agg.copy(aggregateExpressions = cleanedProjection) + } } } @@ -976,3 +1011,20 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { a.copy(groupingExpressions = newGrouping) } } + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 77dec7ca6e2b5..a5f6764aef7ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -37,14 +37,26 @@ object JoinType { } } -sealed abstract class JoinType +sealed abstract class JoinType { + def sql: String +} -case object Inner extends JoinType +case object Inner extends JoinType { + override def sql: String = "INNER" +} -case object LeftOuter extends JoinType +case object LeftOuter extends JoinType { + override def sql: String = "LEFT OUTER" +} -case object RightOuter extends JoinType +case object RightOuter extends JoinType { + override def sql: String = "RIGHT OUTER" +} -case object FullOuter extends JoinType +case object FullOuter extends JoinType { + override def sql: String = "FULL OUTER" +} -case object LeftSemi extends JoinType +case object LeftSemi extends JoinType { + override def sql: String = "LEFT SEMI" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 79759b5a37b34..64957db6b4013 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -423,6 +423,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 62ea731ab5f38..9ebacb4680dc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -37,7 +37,7 @@ object RuleExecutor { val maxSize = map.keys.map(_.toString.length).max map.toSeq.sortBy(_._2).reverseMap { case (k, v) => s"${k.padTo(maxSize, " ").mkString} $v" - }.mkString("\n") + }.mkString("\n", "\n", "") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 71293475ca0f9..7a0d0de6328a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -130,6 +130,20 @@ package object util { ret } + /** + * Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`. + */ + def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match { + case xs if xs.isEmpty => + Option(Seq.empty[T]) + + case xs => + for { + head <- xs.head + tail <- sequenceOption(xs.tail) + } yield head +: tail + } + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 6533622492d41..520e344361625 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -77,6 +77,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" + override def sql: String = s"ARRAY<${elementType.sql}>" + override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 136a97e066df7..92cf8d4c46bda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -65,6 +65,8 @@ abstract class DataType extends AbstractDataType { /** Readable string representation for the type with truncation */ private[sql] def simpleString(maxNumberFields: Int): String = simpleString + def sql: String = simpleString.toUpperCase + /** * Check if `this` and `other` are the same data type when ignoring nullability * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 00461e529ca0a..5474954af70e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -62,6 +62,8 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" + override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>" + override private[spark] def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 34382bf124eb0..9b5c86a8984be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -279,6 +279,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru s"struct<${fieldTypes.mkString(",")}>" } + override def sql: String = { + val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}") + s"STRUCT<${fieldTypes.mkString(", ")}>" + } + private[sql] override def simpleString(maxNumberFields: Int): String = { val builder = new StringBuilder val fieldTypes = fields.take(maxNumberFields).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 4305903616bd9..d7a2c23be8a9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -84,6 +84,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def acceptsType(dataType: DataType) = this.getClass == dataType.getClass + + override def sql: String = sqlType.sql } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index fa823e3021835..cf84855885a37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class AnalysisSuite extends AnalysisTest { @@ -238,43 +237,6 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } - test("analyzer should replace current_timestamp with literals") { - val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), - LocalRelation()) - - val min = System.currentTimeMillis() * 1000 - val plan = in.analyze.asInstanceOf[Project] - val max = (System.currentTimeMillis() + 1) * 1000 - - val lits = new scala.collection.mutable.ArrayBuffer[Long] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Long] - e - } - assert(lits.size == 2) - assert(lits(0) >= min && lits(0) <= max) - assert(lits(1) >= min && lits(1) <= max) - assert(lits(0) == lits(1)) - } - - test("analyzer should replace current_date with literals") { - val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) - - val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) - val plan = in.analyze.asInstanceOf[Project] - val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) - - val lits = new scala.collection.mutable.ArrayBuffer[Int] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Int] - e - } - assert(lits.size == 2) - assert(lits(0) >= min && lits(0) <= max) - assert(lits(1) >= min && lits(1) <= max) - assert(lits(0) == lits(1)) - } - test("SPARK-12102: Ignore nullablity when comparing two sides of case") { val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala new file mode 100644 index 0000000000000..10ed4e46ddd1c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +class ComputeCurrentTimeSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) + } + + test("analyzer should replace current_timestamp with literals") { + val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), + LocalRelation()) + + val min = System.currentTimeMillis() * 1000 + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val max = (System.currentTimeMillis() + 1) * 1000 + + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } + + test("analyzer should replace current_date with literals") { + val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) + + val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) + + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Int] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index b998636909a7d..f9f3bd55aa578 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -75,8 +75,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .groupBy('a)('a) - .select('a).analyze + .groupBy('a)('a).analyze comparePlans(optimized, correctAnswer) } @@ -91,8 +90,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .groupBy('a)('a as 'c) - .select('c).analyze + .groupBy('a)('a as 'c).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 4b375de05e9e3..ca8d010090401 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.net.URI -import java.util.{List => JList} import java.util.logging.{Logger => JLogger} +import java.util.{List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -32,24 +32,24 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.task.JobContextImpl -import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType +import org.apache.parquet.{Log => ApacheParquetLog} import org.slf4j.bridge.SLF4JBridgeHandler -import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser +import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} +import org.apache.spark.sql.execution.datasources.{PartitionSpec, _} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { @@ -147,6 +147,12 @@ private[sql] class ParquetRelation( .get(ParquetRelation.METASTORE_SCHEMA) .map(DataType.fromJson(_).asInstanceOf[StructType]) + // If this relation is converted from a Hive metastore table, this method returns the name of the + // original Hive metastore table. + private[sql] def metastoreTableName: Option[TableIdentifier] = { + parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map(SqlParser.parseTableIdentifier) + } + private lazy val metadataCache: MetadataCache = { val meta = new MetadataCache meta.refresh() diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index bd1a52e5f3303..afd2f611580fc 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -41,9 +41,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning - def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + def testCases: Seq[(String, File)] = { + hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + } override def beforeAll() { + super.beforeAll() TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -68,10 +71,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // For debugging dump some statistics about how much time was spent in various optimizer rules. logWarning(RuleExecutor.dumpTimeSpent()) + super.afterAll() } /** A list of tests deemed out of scope currently and thus completely disregarded. */ - override def blackList = Seq( + override def blackList: Seq[String] = Seq( // These tests use hooks that are not on the classpath and thus break all subsequent execution. "hook_order", "hook_context_cs", @@ -106,7 +110,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_merge", "alter_concatenate_indexed_table", "protectmode2", - //"describe_table", + // "describe_table", "describe_comment_nonascii", "create_merge_compressed", @@ -323,7 +327,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { * The set of tests that are believed to be working in catalyst. Tests not on whiteList or * blacklist are implicitly marked as ignored. */ - override def whiteList = Seq( + override def whiteList: Seq[String] = Seq( "add_part_exist", "add_part_multiple", "add_partition_no_whitelist", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 98bbdf0653c2a..bad3ca6da231f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -104,6 +104,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) TestHive.reset() + super.afterAll() } ///////////////////////////////////////////////////////////////////////////// diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index bf3fe12d5c5d2..5b13dbe47370e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -668,7 +668,8 @@ private[hive] object HiveQl extends SparkQl with Logging { Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) + HiveGenericUDTF( + functionName, new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) case other => super.nodeToGenerator(node) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala new file mode 100644 index 0000000000000..1c910051faccf --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.{DataFrame, SQLContext} + +/** + * A builder class used to convert a resolved logical plan into a SQL query string. Note that this + * all resolved logical plan are convertible. They either don't have corresponding SQL + * representations (e.g. logical plans that operate on local Scala collections), or are simply not + * supported by this builder (yet). + */ +class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { + def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) + + def toSQL: Option[String] = { + val canonicalizedPlan = Canonicalizer.execute(logicalPlan) + val maybeSQL = try { + toSQL(canonicalizedPlan) + } catch { case cause: UnsupportedOperationException => + logInfo(s"Failed to build SQL query string because: ${cause.getMessage}") + None + } + + if (maybeSQL.isDefined) { + logDebug( + s"""Built SQL query string successfully from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${canonicalizedPlan.treeString} + |# Built SQL query string: + |${maybeSQL.get} + """.stripMargin) + } else { + logDebug( + s"""Failed to build SQL query string from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${canonicalizedPlan.treeString} + """.stripMargin) + } + + maybeSQL + } + + private def projectToSQL( + projectList: Seq[NamedExpression], + child: LogicalPlan, + isDistinct: Boolean): Option[String] = { + for { + childSQL <- toSQL(child) + listSQL = projectList.map(_.sql).mkString(", ") + maybeFrom = child match { + case OneRowRelation => " " + case _ => " FROM " + } + distinct = if (isDistinct) " DISTINCT " else " " + } yield s"SELECT$distinct$listSQL$maybeFrom$childSQL" + } + + private def aggregateToSQL( + groupingExprs: Seq[Expression], + aggExprs: Seq[Expression], + child: LogicalPlan): Option[String] = { + val aggSQL = aggExprs.map(_.sql).mkString(", ") + val groupingSQL = groupingExprs.map(_.sql).mkString(", ") + val maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY " + val maybeFrom = child match { + case OneRowRelation => " " + case _ => " FROM " + } + + toSQL(child).map { childSQL => + s"SELECT $aggSQL$maybeFrom$childSQL$maybeGroupBy$groupingSQL" + } + } + + private def toSQL(node: LogicalPlan): Option[String] = node match { + case Distinct(Project(list, child)) => + projectToSQL(list, child, isDistinct = true) + + case Project(list, child) => + projectToSQL(list, child, isDistinct = false) + + case Aggregate(groupingExprs, aggExprs, child) => + aggregateToSQL(groupingExprs, aggExprs, child) + + case Limit(limit, child) => + for { + childSQL <- toSQL(child) + limitSQL = limit.sql + } yield s"$childSQL LIMIT $limitSQL" + + case Filter(condition, child) => + for { + childSQL <- toSQL(child) + whereOrHaving = child match { + case _: Aggregate => "HAVING" + case _ => "WHERE" + } + conditionSQL = condition.sql + } yield s"$childSQL $whereOrHaving $conditionSQL" + + case Union(left, right) => + for { + leftSQL <- toSQL(left) + rightSQL <- toSQL(right) + } yield s"$leftSQL UNION ALL $rightSQL" + + // ParquetRelation converted from Hive metastore table + case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) => + // There seems to be a bug related to `ParquetConversions` analysis rule. The problem is + // that, the metastore database name and table name are not always propagated to converted + // `ParquetRelation` instances via data source options. Here we use subquery alias as a + // workaround. + Some(s"`$alias`") + + case Subquery(alias, child) => + toSQL(child).map(childSQL => s"($childSQL) AS $alias") + + case Join(left, right, joinType, condition) => + for { + leftSQL <- toSQL(left) + rightSQL <- toSQL(right) + joinTypeSQL = joinType.sql + conditionSQL = condition.map(" ON " + _.sql).getOrElse("") + } yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL" + + case MetastoreRelation(database, table, alias) => + val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("") + Some(s"`$database`.`$table`$aliasSQL") + + case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) + if orders.map(_.child) == partitionExprs => + for { + childSQL <- toSQL(child) + partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") + } yield s"$childSQL CLUSTER BY $partitionExprsSQL" + + case Sort(orders, global, child) => + for { + childSQL <- toSQL(child) + ordersSQL = orders.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ") + orderOrSort = if (global) "ORDER" else "SORT" + } yield s"$childSQL $orderOrSort BY $ordersSQL" + + case RepartitionByExpression(partitionExprs, child, _) => + for { + childSQL <- toSQL(child) + partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") + } yield s"$childSQL DISTRIBUTE BY $partitionExprsSQL" + + case OneRowRelation => + Some("") + + case _ => None + } + + object Canonicalizer extends RuleExecutor[LogicalPlan] { + override protected def batches: Seq[Batch] = Seq( + Batch("Canonicalizer", FixedPoint(100), + // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over + // `Aggregate`s to perform type casting. This rule merges these `Project`s into + // `Aggregate`s. + ProjectCollapsing, + + // Used to handle other auxiliary `Project`s added by analyzer (e.g. + // `ResolveAggregateFunctions` rule) + RecoverScopingInfo + ) + ) + + object RecoverScopingInfo extends Rule[LogicalPlan] { + override def apply(tree: LogicalPlan): LogicalPlan = tree transform { + // This branch handles aggregate functions within HAVING clauses. For example: + // + // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255" + // + // This kind of query results in query plans of the following form because of analysis rule + // `ResolveAggregateFunctions`: + // + // Project ... + // +- Filter ... + // +- Aggregate ... + // +- MetastoreRelation default, src, None + case plan @ Project(_, Filter(_, _: Aggregate)) => + wrapChildWithSubquery(plan) + + case plan @ Project(_, + _: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation | _: Limit + ) => plan + + case plan: Project => + wrapChildWithSubquery(plan) + } + + def wrapChildWithSubquery(project: Project): Project = project match { + case Project(projectList, child) => + val alias = SQLBuilder.newSubqueryName + val childAttributes = child.outputSet + val aliasedProjectList = projectList.map(_.transform { + case a: Attribute if childAttributes.contains(a) => + a.withQualifiers(alias :: Nil) + }.asInstanceOf[NamedExpression]) + + Project(aliasedProjectList, Subquery(alias, child)) + } + } + } +} + +object SQLBuilder { + private val nextSubqueryId = new AtomicLong(0) + + private def newSubqueryName: String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}" +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index b1a6d0ab7df3c..e76c18fa528f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -17,30 +17,26 @@ package org.apache.spark.sql.hive -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.hive.ql.exec._ -import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} -import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions +import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.sequenceOption +import org.apache.spark.sql.catalyst.{InternalRow, analysis} import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client.ClientWrapper import org.apache.spark.sql.types._ @@ -75,19 +71,19 @@ private[hive] class HiveFunctionRegistry( try { if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUDF( - new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) + name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) + HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children) + HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveUDAFFunction( - new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) + name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udtf = HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) + val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children) udtf.elementTypes // Force it to check input data types. udtf } else { @@ -137,7 +133,8 @@ private[hive] class HiveFunctionRegistry( } } -private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveSimpleUDF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { override def deterministic: Boolean = isUDFDeterministic @@ -191,6 +188,8 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } // Adapter from Catalyst ExpressionResult to Hive DeferredObject @@ -205,7 +204,8 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp override def get(): AnyRef = wrap(func(), oi, dataType) } -private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveGenericUDF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { override def nullable: Boolean = true @@ -257,6 +257,8 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -271,6 +273,7 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr * user defined aggregations, which have clean semantics even in a partitioned execution. */ private[hive] case class HiveGenericUDTF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Generator with HiveInspectors with CodegenFallback { @@ -336,6 +339,8 @@ private[hive] case class HiveGenericUDTF( override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -343,6 +348,7 @@ private[hive] case class HiveGenericUDTF( * performance a lot. */ private[hive] case class HiveUDAFFunction( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression], isUDAFBridgeRequired: Boolean = false, @@ -427,5 +433,9 @@ private[hive] case class HiveUDAFFunction( override def supportsPartial: Boolean = false override val dataType: DataType = inspectorToDataType(returnInspector) -} + override def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else " " + s"$name($distinct${children.map(_.sql).mkString(", ")})" + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala new file mode 100644 index 0000000000000..3a6eb57add4e3 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.sql.Timestamp + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{If, Literal} + +class ExpressionSQLBuilderSuite extends SQLBuilderTest { + test("literal") { + checkSQL(Literal("foo"), "\"foo\"") + checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"") + checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)") + checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)") + checkSQL(Literal(4: Int), "4") + checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)") + checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") + checkSQL(Literal(2.5D), "2.5") + checkSQL( + Literal(Timestamp.valueOf("2016-01-01 00:00:00")), + "TIMESTAMP('2016-01-01 00:00:00.0')") + // TODO tests for decimals + } + + test("binary comparisons") { + checkSQL('a.int === 'b.int, "(`a` = `b`)") + checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)") + checkSQL('a.int !== 'b.int, "(NOT (`a` = `b`))") + + checkSQL('a.int < 'b.int, "(`a` < `b`)") + checkSQL('a.int <= 'b.int, "(`a` <= `b`)") + checkSQL('a.int > 'b.int, "(`a` > `b`)") + checkSQL('a.int >= 'b.int, "(`a` >= `b`)") + + checkSQL('a.int in ('b.int, 'c.int), "(`a` IN (`b`, `c`))") + checkSQL('a.int in (1, 2), "(`a` IN (1, 2))") + + checkSQL('a.int.isNull, "(`a` IS NULL)") + checkSQL('a.int.isNotNull, "(`a` IS NOT NULL)") + } + + test("logical operators") { + checkSQL('a.boolean && 'b.boolean, "(`a` AND `b`)") + checkSQL('a.boolean || 'b.boolean, "(`a` OR `b`)") + checkSQL(!'a.boolean, "(NOT `a`)") + checkSQL(If('a.boolean, 'b.int, 'c.int), "(IF(`a`, `b`, `c`))") + } + + test("arithmetic expressions") { + checkSQL('a.int + 'b.int, "(`a` + `b`)") + checkSQL('a.int - 'b.int, "(`a` - `b`)") + checkSQL('a.int * 'b.int, "(`a` * `b`)") + checkSQL('a.int / 'b.int, "(`a` / `b`)") + checkSQL('a.int % 'b.int, "(`a` % `b`)") + + checkSQL(-'a.int, "(-`a`)") + checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala new file mode 100644 index 0000000000000..0e81acf532a03 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.functions._ + +class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + sqlContext.range(10).write.saveAsTable("t0") + + sqlContext + .range(10) + .select('id as 'key, concat(lit("val_"), 'id) as 'value) + .write + .saveAsTable("t1") + + sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") + } + + override protected def afterAll(): Unit = { + sql("DROP TABLE IF EXISTS t0") + sql("DROP TABLE IF EXISTS t1") + sql("DROP TABLE IF EXISTS t2") + } + + private def checkHiveQl(hiveQl: String): Unit = { + val df = sql(hiveQl) + val convertedSQL = new SQLBuilder(df).toSQL + + if (convertedSQL.isEmpty) { + fail( + s"""Cannot convert the following HiveQL query plan back to SQL query string: + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin) + } + + val sqlString = convertedSQL.get + try { + checkAnswer(sql(sqlString), df) + } catch { case cause: Throwable => + fail( + s"""Failed to execute converted SQL string or got wrong answer: + | + |# Converted SQL query string: + |$sqlString + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin, + cause) + } + } + + test("in") { + checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)") + } + + test("aggregate function in having clause") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0") + } + + test("aggregate function in order by clause") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)") + } + + // TODO Fix name collision introduced by ResolveAggregateFunction analysis rule + // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into + // Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query + // execution since these aliases have different expression ID. But this introduces name collision + // when converting resolved plans back to SQL query strings as expression IDs are stripped. + ignore("aggregate function in order by clause with multiple order keys") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)") + } + + test("type widening in union") { + checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0") + } + + test("case") { + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0") + } + + test("case with else") { + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM t0") + } + + test("case with key") { + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM t0") + } + + test("case with key and else") { + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM t0") + } + + test("select distinct without aggregate functions") { + checkHiveQl("SELECT DISTINCT id FROM t0") + } + + test("cluster by") { + checkHiveQl("SELECT id FROM t0 CLUSTER BY id") + } + + test("distribute by") { + checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id") + } + + test("distribute by with sort by") { + checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id SORT BY id") + } + + test("distinct aggregation") { + checkHiveQl("SELECT COUNT(DISTINCT id) FROM t0") + } + + // TODO Enable this + // Query plans transformed by DistinctAggregationRewriter are not recognized yet + ignore("distinct and non-distinct aggregation") { + checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala new file mode 100644 index 0000000000000..cf4a3fdd88806 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.{DataFrame, QueryTest} + +abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { + protected def checkSQL(e: Expression, expectedSQL: String): Unit = { + val actualSQL = e.sql + try { + assert(actualSQL === expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following expression: + | + |${e.prettyName} + | + |$cause + """.stripMargin) + } + } + + protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = { + val maybeSQL = new SQLBuilder(plan, hiveContext).toSQL + + if (maybeSQL.isEmpty) { + fail( + s"""Cannot convert the following logical query plan to SQL: + | + |${plan.treeString} + """.stripMargin) + } + + val actualSQL = maybeSQL.get + + try { + assert(actualSQL === expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following logical query plan: + | + |${plan.treeString} + | + |$cause + """.stripMargin) + } + + checkAnswer(sqlContext.sql(actualSQL), new DataFrame(sqlContext, plan)) + } + + protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { + checkSQL(df.queryExecution.analyzed, expectedSQL) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index d7e8ebc8d312f..57358a07840e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -27,9 +27,10 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.{ExplainCommand, SetCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand +import org.apache.spark.sql.execution.{ExplainCommand, SetCommand} import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable, SQLBuilder} /** * Allows the creations of tests that execute the same query against both hive @@ -130,6 +131,28 @@ abstract class HiveComparisonTest new java.math.BigInteger(1, digest.digest).toString(16) } + /** Used for testing [[SQLBuilder]] */ + private var numConvertibleQueries: Int = 0 + private var numTotalQueries: Int = 0 + + override protected def afterAll(): Unit = { + logInfo({ + val percentage = if (numTotalQueries > 0) { + numConvertibleQueries.toDouble / numTotalQueries * 100 + } else { + 0D + } + + s"""SQLBuiler statistics: + |- Total query number: $numTotalQueries + |- Number of convertible queries: $numConvertibleQueries + |- Percentage of convertible queries: $percentage% + """.stripMargin + }) + + super.afterAll() + } + protected def prepareAnswer( hiveQuery: TestHive.type#QueryExecution, answer: Seq[String]): Seq[String] = { @@ -372,8 +395,49 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = new TestHive.QueryExecution(queryString) - try { (query, prepareAnswer(query, query.stringResult())) } catch { + var query: TestHive.QueryExecution = null + try { + query = { + val originalQuery = new TestHive.QueryExecution(queryString) + val containsCommands = originalQuery.analyzed.collectFirst { + case _: Command => () + case _: LogicalInsertIntoHiveTable => () + }.nonEmpty + + if (containsCommands) { + originalQuery + } else { + numTotalQueries += 1 + new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql => + numConvertibleQueries += 1 + logInfo( + s""" + |### Running SQL generation round-trip test {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + | + |Generated SQL: + |$sql + |}}} + """.stripMargin.trim) + new TestHive.QueryExecution(sql) + }.getOrElse { + logInfo( + s""" + |### Cannot convert the following logical plan back to SQL {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + |}}} + """.stripMargin.trim) + originalQuery + } + } + } + + (query, prepareAnswer(query, query.stringResult())) + } catch { case e: Throwable => val errorMessage = s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 98e22c2e2c1b0..5e9e77c4fb925 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -60,6 +60,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) sql("DROP TEMPORARY FUNCTION udtf_count2") + super.afterAll() } test("SPARK-4908: concurrent hive native commands") {