diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
index 4afce3090f739..ffb859662c8ea 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
@@ -1,9 +1,9 @@
/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
+ (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
@@ -580,7 +580,7 @@ import java.util.HashMap;
return header;
}
-
+
@Override
public String getErrorMessage(RecognitionException e, String[] tokenNames) {
String msg = null;
@@ -617,7 +617,7 @@ import java.util.HashMap;
}
return msg;
}
-
+
public void pushMsg(String msg, RecognizerSharedState state) {
// ANTLR generated code does not wrap the @init code wit this backtracking check,
// even if the matching @after has it. If we have parser rules with that are doing
@@ -637,7 +637,7 @@ import java.util.HashMap;
// counter to generate unique union aliases
private int aliasCounter;
private String generateUnionAlias() {
- return "_u" + (++aliasCounter);
+ return "u_" + (++aliasCounter);
}
private char [] excludedCharForColumnName = {'.', ':'};
private boolean containExcludedCharForCreateTableColumnName(String input) {
@@ -1233,7 +1233,7 @@ alterTblPartitionStatementSuffixSkewedLocation
: KW_SET KW_SKEWED KW_LOCATION skewedLocations
-> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations)
;
-
+
skewedLocations
@init { pushMsg("skewed locations", state); }
@after { popMsg(state); }
@@ -1262,7 +1262,7 @@ alterStatementSuffixLocation
-> ^(TOK_ALTERTABLE_LOCATION $newLoc)
;
-
+
alterStatementSuffixSkewedby
@init {pushMsg("alter skewed by statement", state);}
@after{popMsg(state);}
@@ -1334,10 +1334,10 @@ tabTypeExpr
(identifier (DOT^
(
(KW_ELEM_TYPE) => KW_ELEM_TYPE
- |
+ |
(KW_KEY_TYPE) => KW_KEY_TYPE
- |
- (KW_VALUE_TYPE) => KW_VALUE_TYPE
+ |
+ (KW_VALUE_TYPE) => KW_VALUE_TYPE
| identifier
))*
)?
@@ -1374,7 +1374,7 @@ descStatement
analyzeStatement
@init { pushMsg("analyze statement", state); }
@after { popMsg(state); }
- : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN)
+ : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN)
| (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))?
-> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?)
;
@@ -1387,7 +1387,7 @@ showStatement
| KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)?
-> ^(TOK_SHOWCOLUMNS tableName $db_name?)
| KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?)
- | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?)
+ | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?)
| KW_SHOW KW_CREATE (
(KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name)
|
@@ -1396,7 +1396,7 @@ showStatement
| KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec?
-> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?)
| KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?)
- | KW_SHOW KW_LOCKS
+ | KW_SHOW KW_LOCKS
(
(KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?)
|
@@ -1509,7 +1509,7 @@ showCurrentRole
setRole
@init {pushMsg("set role", state);}
@after {popMsg(state);}
- : KW_SET KW_ROLE
+ : KW_SET KW_ROLE
(
(KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text])
|
@@ -1964,7 +1964,7 @@ columnNameOrderList
skewedValueElement
@init { pushMsg("skewed value element", state); }
@after { popMsg(state); }
- :
+ :
skewedColumnValues
| skewedColumnValuePairList
;
@@ -1978,8 +1978,8 @@ skewedColumnValuePairList
skewedColumnValuePair
@init { pushMsg("column value pair", state); }
@after { popMsg(state); }
- :
- LPAREN colValues=skewedColumnValues RPAREN
+ :
+ LPAREN colValues=skewedColumnValues RPAREN
-> ^(TOK_TABCOLVALUES $colValues)
;
@@ -1999,11 +1999,11 @@ skewedColumnValue
skewedValueLocationElement
@init { pushMsg("skewed value location element", state); }
@after { popMsg(state); }
- :
+ :
skewedColumnValue
| skewedColumnValuePair
;
-
+
columnNameOrder
@init { pushMsg("column name order", state); }
@after { popMsg(state); }
@@ -2116,7 +2116,7 @@ unionType
@after { popMsg(state); }
: KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList)
;
-
+
setOperator
@init { pushMsg("set operator", state); }
@after { popMsg(state); }
@@ -2168,7 +2168,7 @@ fromStatement[boolean topLevel]
{adaptor.create(Identifier, generateUnionAlias())}
)
)
- ^(TOK_INSERT
+ ^(TOK_INSERT
^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE))
^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))
)
@@ -2391,8 +2391,8 @@ setColumnsClause
KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* )
;
-/*
- UPDATE
+/*
+ UPDATE
SET col1 = val1, col2 = val2... WHERE ...
*/
updateStatement
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e362b55d80cd1..8a33af8207350 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -86,8 +86,7 @@ class Analyzer(
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
- PullOutNondeterministic,
- ComputeCurrentTime),
+ PullOutNondeterministic),
Batch("UDF", Once,
HandleNullInputsForUDF),
Batch("Cleanup", fixedPoint,
@@ -1229,23 +1228,6 @@ object CleanupAliases extends Rule[LogicalPlan] {
}
}
-/**
- * Computes the current date and time to make sure we return the same result in a single query.
- */
-object ComputeCurrentTime extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = {
- val dateExpr = CurrentDate()
- val timeExpr = CurrentTimestamp()
- val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType)
- val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType)
-
- plan transformAllExpressions {
- case CurrentDate() => currentDate
- case CurrentTimestamp() => currentTime
- }
- }
-}
-
/**
* Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index e8b2fcf819bf6..a8f89ce6de457 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -110,7 +110,9 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog {
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
// properly qualified with this alias.
- alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
+ alias
+ .map(a => Subquery(a, tableWithQualifiers))
+ .getOrElse(tableWithQualifiers)
}
override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d82d3edae4e38..6f199cfc5d8cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -931,6 +931,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
$evPrim = $result.copy();
"""
}
+
+ override def sql: String = dataType match {
+ // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this
+ // type of casting can only be introduced by the analyzer, and can be omitted when converting
+ // back to SQL query string.
+ case _: ArrayType | _: MapType | _: StructType => child.sql
+ case _ => s"CAST(${child.sql} AS ${dataType.sql})"
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 6a9c12127d367..d6219514b752b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -18,9 +18,10 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.types._
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -223,6 +224,15 @@ abstract class Expression extends TreeNode[Expression] {
protected def toCommentSafeString: String = this.toString
.replace("*/", "\\*\\/")
.replace("\\u", "\\\\u")
+
+ /**
+ * Returns SQL representation of this expression. For expressions that don't have a SQL
+ * representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`.
+ */
+ @throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation")
+ def sql: String = throw new UnsupportedOperationException(
+ s"Cannot map expression $this to its SQL representation"
+ )
}
@@ -356,6 +366,8 @@ abstract class UnaryExpression extends Expression {
"""
}
}
+
+ override def sql: String = s"($prettyName(${child.sql}))"
}
@@ -456,6 +468,8 @@ abstract class BinaryExpression extends Expression {
"""
}
}
+
+ override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
}
@@ -492,6 +506,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
TypeCheckResult.TypeCheckSuccess
}
}
+
+ override def sql: String = s"(${left.sql} $symbol ${right.sql})"
}
@@ -593,4 +609,9 @@ abstract class TernaryExpression extends Expression {
"""
}
}
+
+ override def sql: String = {
+ val childrenSQL = children.map(_.sql).mkString(", ")
+ s"$prettyName($childrenSQL)"
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
index f33833c3918df..827dce8af100e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
@@ -49,4 +49,5 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
"org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();"
}
+ override def sql: String = prettyName
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index d0b78e15d99d1..94f8801dec369 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -78,4 +78,8 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with
$countTerm++;
"""
}
+
+ override def prettyName: String = "monotonically_increasing_id"
+
+ override def sql: String = s"$prettyName()"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 3add722da7816..1cb1b9da3049b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -24,9 +24,17 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator
-abstract sealed class SortDirection
-case object Ascending extends SortDirection
-case object Descending extends SortDirection
+abstract sealed class SortDirection {
+ def sql: String
+}
+
+case object Ascending extends SortDirection {
+ override def sql: String = "ASC"
+}
+
+case object Descending extends SortDirection {
+ override def sql: String = "DESC"
+}
/**
* An expression that can be used to sort a tuple. This class extends expression primarily so that
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index b47f32d1768b9..ddd99c51ab0c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.types._
/** The mode of an [[AggregateFunction]]. */
@@ -93,11 +94,13 @@ private[sql] case class AggregateExpression(
override def prettyString: String = aggregateFunction.prettyString
- override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)"
+ override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)"
+
+ override def sql: String = aggregateFunction.sql(isDistinct)
}
/**
- * AggregateFunction2 is the superclass of two aggregation function interfaces:
+ * AggregateFunction is the superclass of two aggregation function interfaces:
*
* - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of
* initialize(), update(), and merge() functions that operate on Row-based aggregation buffers.
@@ -163,6 +166,11 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu
def toAggregateExpression(isDistinct: Boolean): AggregateExpression = {
AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct)
}
+
+ def sql(isDistinct: Boolean): String = {
+ val distinct = if (isDistinct) "DISTINCT " else " "
+ s"$prettyName($distinct${children.map(_.sql).mkString(", ")})"
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 61a17fd7db0fe..7bd851c059d0e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -54,6 +54,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
numeric.negate(input)
}
}
+
+ override def sql: String = s"(-${child.sql})"
}
case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
@@ -67,6 +69,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
defineCodeGen(ctx, ev, c => c)
protected override def nullSafeEval(input: Any): Any = input
+
+ override def sql: String = s"(+${child.sql})"
}
/**
@@ -91,6 +95,8 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes
}
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
+
+ override def sql: String = s"$prettyName(${child.sql})"
}
abstract class BinaryArithmetic extends BinaryOperator {
@@ -513,4 +519,6 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
val r = a % n
if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r
}
+
+ override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 9c73239f67ff2..5bd97cc7467ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -130,6 +130,8 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
}
})
}
+
+ override def sql: String = child.sql + s".`${childSchema(ordinal).name}`"
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index f79c8676fb58c..19da849d2bec9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils}
import org.apache.spark.sql.types._
@@ -74,6 +74,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
+
+ override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))"
}
trait CaseWhenLike extends Expression {
@@ -110,7 +112,7 @@ trait CaseWhenLike extends Expression {
override def nullable: Boolean = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
- thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
+ thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true)
}
}
@@ -206,6 +208,23 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
case Seq(elseValue) => s" ELSE $elseValue"
}.mkString
}
+
+ override def sql: String = {
+ val branchesSQL = branches.map(_.sql)
+ val (cases, maybeElse) = if (branches.length % 2 == 0) {
+ (branchesSQL, None)
+ } else {
+ (branchesSQL.init, Some(branchesSQL.last))
+ }
+
+ val head = s"CASE "
+ val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END"
+ val body = cases.grouped(2).map {
+ case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr"
+ }.mkString(" ")
+
+ head + body + tail
+ }
}
// scalastyle:off
@@ -310,6 +329,24 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
case Seq(elseValue) => s" ELSE $elseValue"
}.mkString
}
+
+ override def sql: String = {
+ val keySQL = key.sql
+ val branchesSQL = branches.map(_.sql)
+ val (cases, maybeElse) = if (branches.length % 2 == 0) {
+ (branchesSQL, None)
+ } else {
+ (branchesSQL.init, Some(branchesSQL.last))
+ }
+
+ val head = s"CASE $keySQL "
+ val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END"
+ val body = cases.grouped(2).map {
+ case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr"
+ }.mkString(" ")
+
+ head + body + tail
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 3d65946a1bc65..17f1df06f2fad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -44,6 +44,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback {
override def eval(input: InternalRow): Any = {
DateTimeUtils.millisToDays(System.currentTimeMillis())
}
+
+ override def prettyName: String = "current_date"
}
/**
@@ -61,6 +63,8 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
override def eval(input: InternalRow): Any = {
System.currentTimeMillis() * 1000L
}
+
+ override def prettyName: String = "current_timestamp"
}
/**
@@ -85,6 +89,8 @@ case class DateAdd(startDate: Expression, days: Expression)
s"""${ev.value} = $sd + $d;"""
})
}
+
+ override def prettyName: String = "date_add"
}
/**
@@ -108,6 +114,8 @@ case class DateSub(startDate: Expression, days: Expression)
s"""${ev.value} = $sd - $d;"""
})
}
+
+ override def prettyName: String = "date_sub"
}
case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
@@ -309,6 +317,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix
def this(time: Expression) = {
this(time, Literal("yyyy-MM-dd HH:mm:ss"))
}
+
+ override def prettyName: String = "to_unix_timestamp"
}
/**
@@ -332,6 +342,8 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTi
def this() = {
this(CurrentTimestamp())
}
+
+ override def prettyName: String = "unix_timestamp"
}
abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
@@ -437,6 +449,8 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
"""
}
}
+
+ override def prettyName: String = "unix_time"
}
/**
@@ -451,6 +465,8 @@ case class FromUnixTime(sec: Expression, format: Expression)
override def left: Expression = sec
override def right: Expression = format
+ override def prettyName: String = "from_unixtime"
+
def this(unix: Expression) = {
this(unix, Literal("yyyy-MM-dd HH:mm:ss"))
}
@@ -733,6 +749,8 @@ case class AddMonths(startDate: Expression, numMonths: Expression)
s"""$dtu.dateAddMonths($sd, $m)"""
})
}
+
+ override def prettyName: String = "add_months"
}
/**
@@ -758,6 +776,8 @@ case class MonthsBetween(date1: Expression, date2: Expression)
s"""$dtu.monthsBetween($l, $r)"""
})
}
+
+ override def prettyName: String = "months_between"
}
/**
@@ -823,6 +843,8 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, d => d)
}
+
+ override def prettyName: String = "to_date"
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index c54bcdd774021..5f8b544edb511 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -73,6 +73,7 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
override def prettyName: String = "promote_precision"
+ override def sql: String = child.sql
}
/**
@@ -107,4 +108,6 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary
}
override def toString: String = s"CheckOverflow($child, $dataType)"
+
+ override def sql: String = child.sql
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 672cc9c45e0af..0eb915fdc1691 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -21,9 +21,9 @@ import java.sql.{Date, Timestamp}
import org.json4s.JsonAST._
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types._
@@ -214,6 +214,41 @@ case class Literal protected (value: Any, dataType: DataType)
}
}
}
+
+ override def sql: String = (value, dataType) match {
+ case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null =>
+ "NULL"
+
+ case _ if value == null =>
+ s"CAST(NULL AS ${dataType.sql})"
+
+ case (v: UTF8String, StringType) =>
+ // Escapes all backslashes and double quotes.
+ "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\""
+
+ case (v: Byte, ByteType) =>
+ s"CAST($v AS ${ByteType.simpleString.toUpperCase})"
+
+ case (v: Short, ShortType) =>
+ s"CAST($v AS ${ShortType.simpleString.toUpperCase})"
+
+ case (v: Long, LongType) =>
+ s"CAST($v AS ${LongType.simpleString.toUpperCase})"
+
+ case (v: Float, FloatType) =>
+ s"CAST($v AS ${FloatType.simpleString.toUpperCase})"
+
+ case (v: Decimal, DecimalType.Fixed(precision, scale)) =>
+ s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))"
+
+ case (v: Int, DateType) =>
+ s"DATE '${DateTimeUtils.toJavaDate(v)}'"
+
+ case (v: Long, TimestampType) =>
+ s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')"
+
+ case _ => value.toString
+ }
}
// TODO: Specialize
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 002f5929cc26b..66d8631a846ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -70,6 +70,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)")
}
+
+ override def sql: String = s"$name(${child.sql})"
}
abstract class UnaryLogExpression(f: Double => Double, name: String)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 6697d463614d5..0f229db2b08a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -220,4 +220,8 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
final int ${ev.value} = ${unsafeRow.value}.hashCode($seed);
"""
}
+
+ override def prettyName: String = "hash"
+
+ override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index eefd9c7482553..eee708cb02f9d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -164,6 +164,12 @@ case class Alias(child: Expression, name: String)(
explicitMetadata == a.explicitMetadata
case _ => false
}
+
+ override def sql: String = {
+ val qualifiersString =
+ if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
+ s"${child.sql} AS $qualifiersString`$name`"
+ }
}
/**
@@ -271,6 +277,12 @@ case class AttributeReference(
// Since the expression id is not in the first constructor it is missing from the default
// tree string.
override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}"
+
+ override def sql: String = {
+ val qualifiersString =
+ if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
+ s"$qualifiersString`$name`"
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index df4747d4e6f7a..89aec2b20fd0c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -83,6 +83,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
"""
}.mkString("\n")
}
+
+ override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
}
@@ -193,6 +195,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
ev.value = eval.isNull
eval.code
}
+
+ override def sql: String = s"(${child.sql} IS NULL)"
}
@@ -212,6 +216,8 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
ev.value = s"(!(${eval.isNull}))"
eval.code
}
+
+ override def sql: String = s"(${child.sql} IS NOT NULL)"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 304b438c84ba4..bca12a8d21023 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -101,6 +101,8 @@ case class Not(child: Expression)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"!($c)")
}
+
+ override def sql: String = s"(NOT ${child.sql})"
}
@@ -176,6 +178,13 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate
}
"""
}
+
+ override def sql: String = {
+ val childrenSQL = children.map(_.sql)
+ val valueSQL = childrenSQL.head
+ val listSQL = childrenSQL.tail.mkString(", ")
+ s"($valueSQL IN ($listSQL))"
+ }
}
/**
@@ -226,6 +235,12 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
"""
}
+
+ override def sql: String = {
+ val valueSQL = child.sql
+ val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ")
+ s"($valueSQL IN ($listSQL))"
+ }
}
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
@@ -274,6 +289,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
}
"""
}
+
+ override def sql: String = s"(${left.sql} AND ${right.sql})"
}
@@ -323,6 +340,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
}
"""
}
+
+ override def sql: String = s"(${left.sql} OR ${right.sql})"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 8bde8cb9fe876..8de47e9ddc28d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -49,6 +49,9 @@ abstract class RDG extends LeafExpression with Nondeterministic {
override def nullable: Boolean = false
override def dataType: DataType = DoubleType
+
+ // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed.
+ override def sql: String = s"$prettyName($seed)"
}
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index adef6050c3565..db266639b8560 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -59,6 +59,8 @@ trait StringRegexExpression extends ImplicitCastInputTypes {
matches(regex, input1.asInstanceOf[UTF8String].toString)
}
}
+
+ override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 50c8b9d59847e..931f752b4dc1a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -61,6 +62,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
}
"""
}
+
+ override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
}
@@ -153,6 +156,8 @@ case class ConcatWs(children: Seq[Expression])
"""
}
}
+
+ override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
}
trait String2StringExpression extends ImplicitCastInputTypes {
@@ -292,24 +297,24 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
val termDict = ctx.freshName("dict")
val classNameDict = classOf[JMap[Character, Character]].getCanonicalName
- ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;")
- ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;")
- ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;")
+ ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;")
+ ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;")
+ ctx.addMutableState(classNameDict, termDict, s"$termDict = null;")
nullSafeCodeGen(ctx, ev, (src, matching, replace) => {
val check = if (matchingExpr.foldable && replaceExpr.foldable) {
- s"${termDict} == null"
+ s"$termDict == null"
} else {
- s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})"
+ s"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)"
}
s"""if ($check) {
// Not all of them is literal or matching or replace value changed
- ${termLastMatching} = ${matching}.clone();
- ${termLastReplace} = ${replace}.clone();
- ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate
- .buildDict(${termLastMatching}, ${termLastReplace});
+ $termLastMatching = $matching.clone();
+ $termLastReplace = $replace.clone();
+ $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate
+ .buildDict($termLastMatching, $termLastReplace);
}
- ${ev.value} = ${src}.translate(${termDict});
+ ${ev.value} = $src.translate($termDict);
"""
})
}
@@ -340,6 +345,8 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi
}
override def dataType: DataType = IntegerType
+
+ override def prettyName: String = "find_in_set"
}
/**
@@ -832,7 +839,6 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn
org.apache.commons.codec.binary.Base64.encodeBase64($child));
"""})
}
-
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0b1c74293bb8b..f8121a733a8d2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -37,6 +37,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
// SubQueries are only needed for analysis and can be removed before execution.
Batch("Remove SubQueries", FixedPoint(100),
EliminateSubQueries) ::
+ Batch("Compute Current Time", Once,
+ ComputeCurrentTime) ::
Batch("Aggregate", FixedPoint(100),
ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) ::
@@ -333,6 +335,39 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
)
Project(cleanedProjection, child)
}
+
+ // TODO Eliminate duplicate code
+ // This clause is identical to the one above except that the inner operator is an `Aggregate`
+ // rather than a `Project`.
+ case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) =>
+ // Create a map of Aliases to their values from the child projection.
+ // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
+ val aliasMap = AttributeMap(projectList2.collect {
+ case a: Alias => (a.toAttribute, a)
+ })
+
+ // We only collapse these two Projects if their overlapped expressions are all
+ // deterministic.
+ val hasNondeterministic = projectList1.exists(_.collect {
+ case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
+ }.exists(!_.deterministic))
+
+ if (hasNondeterministic) {
+ p
+ } else {
+ // Substitute any attributes that are produced by the child projection, so that we safely
+ // eliminate it.
+ // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
+ // TODO: Fix TransformBase to avoid the cast below.
+ val substitutedProjection = projectList1.map(_.transform {
+ case a: Attribute => aliasMap.getOrElse(a, a)
+ }).asInstanceOf[Seq[NamedExpression]]
+ // collapse 2 projects may introduce unnecessary Aliases, trim them here.
+ val cleanedProjection = substitutedProjection.map(p =>
+ CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
+ )
+ agg.copy(aggregateExpressions = cleanedProjection)
+ }
}
}
@@ -976,3 +1011,20 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
a.copy(groupingExpressions = newGrouping)
}
}
+
+/**
+ * Computes the current date and time to make sure we return the same result in a single query.
+ */
+object ComputeCurrentTime extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val dateExpr = CurrentDate()
+ val timeExpr = CurrentTimestamp()
+ val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType)
+ val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType)
+
+ plan transformAllExpressions {
+ case CurrentDate() => currentDate
+ case CurrentTimestamp() => currentTime
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 77dec7ca6e2b5..a5f6764aef7ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -37,14 +37,26 @@ object JoinType {
}
}
-sealed abstract class JoinType
+sealed abstract class JoinType {
+ def sql: String
+}
-case object Inner extends JoinType
+case object Inner extends JoinType {
+ override def sql: String = "INNER"
+}
-case object LeftOuter extends JoinType
+case object LeftOuter extends JoinType {
+ override def sql: String = "LEFT OUTER"
+}
-case object RightOuter extends JoinType
+case object RightOuter extends JoinType {
+ override def sql: String = "RIGHT OUTER"
+}
-case object FullOuter extends JoinType
+case object FullOuter extends JoinType {
+ override def sql: String = "FULL OUTER"
+}
-case object LeftSemi extends JoinType
+case object LeftSemi extends JoinType {
+ override def sql: String = "LEFT SEMI"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 79759b5a37b34..64957db6b4013 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -423,6 +423,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
}
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
+
override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index 62ea731ab5f38..9ebacb4680dc2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -37,7 +37,7 @@ object RuleExecutor {
val maxSize = map.keys.map(_.toString.length).max
map.toSeq.sortBy(_._2).reverseMap { case (k, v) =>
s"${k.padTo(maxSize, " ").mkString} $v"
- }.mkString("\n")
+ }.mkString("\n", "\n", "")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index 71293475ca0f9..7a0d0de6328a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -130,6 +130,20 @@ package object util {
ret
}
+ /**
+ * Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`.
+ */
+ def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match {
+ case xs if xs.isEmpty =>
+ Option(Seq.empty[T])
+
+ case xs =>
+ for {
+ head <- xs.head
+ tail <- sequenceOption(xs.tail)
+ } yield head +: tail
+ }
+
/* FIX ME
implicit class debugLogging(a: Any) {
def debugLogging() {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 6533622492d41..520e344361625 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -77,6 +77,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
override def simpleString: String = s"array<${elementType.simpleString}>"
+ override def sql: String = s"ARRAY<${elementType.sql}>"
+
override private[spark] def asNullable: ArrayType =
ArrayType(elementType.asNullable, containsNull = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 136a97e066df7..92cf8d4c46bda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -65,6 +65,8 @@ abstract class DataType extends AbstractDataType {
/** Readable string representation for the type with truncation */
private[sql] def simpleString(maxNumberFields: Int): String = simpleString
+ def sql: String = simpleString.toUpperCase
+
/**
* Check if `this` and `other` are the same data type when ignoring nullability
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index 00461e529ca0a..5474954af70e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -62,6 +62,8 @@ case class MapType(
override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"
+ override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>"
+
override private[spark] def asNullable: MapType =
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 34382bf124eb0..9b5c86a8984be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -279,6 +279,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
s"struct<${fieldTypes.mkString(",")}>"
}
+ override def sql: String = {
+ val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}")
+ s"STRUCT<${fieldTypes.mkString(", ")}>"
+ }
+
private[sql] override def simpleString(maxNumberFields: Int): String = {
val builder = new StringBuilder
val fieldTypes = fields.take(maxNumberFields).map {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 4305903616bd9..d7a2c23be8a9a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -84,6 +84,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
override private[sql] def acceptsType(dataType: DataType) =
this.getClass == dataType.getClass
+
+ override def sql: String = sqlType.sql
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index fa823e3021835..cf84855885a37 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
class AnalysisSuite extends AnalysisTest {
@@ -238,43 +237,6 @@ class AnalysisSuite extends AnalysisTest {
checkAnalysis(plan, expected)
}
- test("analyzer should replace current_timestamp with literals") {
- val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
- LocalRelation())
-
- val min = System.currentTimeMillis() * 1000
- val plan = in.analyze.asInstanceOf[Project]
- val max = (System.currentTimeMillis() + 1) * 1000
-
- val lits = new scala.collection.mutable.ArrayBuffer[Long]
- plan.transformAllExpressions { case e: Literal =>
- lits += e.value.asInstanceOf[Long]
- e
- }
- assert(lits.size == 2)
- assert(lits(0) >= min && lits(0) <= max)
- assert(lits(1) >= min && lits(1) <= max)
- assert(lits(0) == lits(1))
- }
-
- test("analyzer should replace current_date with literals") {
- val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())
-
- val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
- val plan = in.analyze.asInstanceOf[Project]
- val max = DateTimeUtils.millisToDays(System.currentTimeMillis())
-
- val lits = new scala.collection.mutable.ArrayBuffer[Int]
- plan.transformAllExpressions { case e: Literal =>
- lits += e.value.asInstanceOf[Int]
- e
- }
- assert(lits.size == 2)
- assert(lits(0) >= min && lits(0) <= max)
- assert(lits(1) >= min && lits(1) <= max)
- assert(lits(0) == lits(1))
- }
-
test("SPARK-12102: Ignore nullablity when comparing two sides of case") {
val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false)))
val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val"))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
new file mode 100644
index 0000000000000..10ed4e46ddd1c
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+
+class ComputeCurrentTimeSuite extends PlanTest {
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
+ }
+
+ test("analyzer should replace current_timestamp with literals") {
+ val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
+ LocalRelation())
+
+ val min = System.currentTimeMillis() * 1000
+ val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
+ val max = (System.currentTimeMillis() + 1) * 1000
+
+ val lits = new scala.collection.mutable.ArrayBuffer[Long]
+ plan.transformAllExpressions { case e: Literal =>
+ lits += e.value.asInstanceOf[Long]
+ e
+ }
+ assert(lits.size == 2)
+ assert(lits(0) >= min && lits(0) <= max)
+ assert(lits(1) >= min && lits(1) <= max)
+ assert(lits(0) == lits(1))
+ }
+
+ test("analyzer should replace current_date with literals") {
+ val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())
+
+ val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
+ val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
+ val max = DateTimeUtils.millisToDays(System.currentTimeMillis())
+
+ val lits = new scala.collection.mutable.ArrayBuffer[Int]
+ plan.transformAllExpressions { case e: Literal =>
+ lits += e.value.asInstanceOf[Int]
+ e
+ }
+ assert(lits.size == 2)
+ assert(lits(0) >= min && lits(0) <= max)
+ assert(lits(1) >= min && lits(1) <= max)
+ assert(lits(0) == lits(1))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index b998636909a7d..f9f3bd55aa578 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -75,8 +75,7 @@ class FilterPushdownSuite extends PlanTest {
val correctAnswer =
testRelation
.select('a)
- .groupBy('a)('a)
- .select('a).analyze
+ .groupBy('a)('a).analyze
comparePlans(optimized, correctAnswer)
}
@@ -91,8 +90,7 @@ class FilterPushdownSuite extends PlanTest {
val correctAnswer =
testRelation
.select('a)
- .groupBy('a)('a as 'c)
- .select('c).analyze
+ .groupBy('a)('a as 'c).analyze
comparePlans(optimized, correctAnswer)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index 4b375de05e9e3..ca8d010090401 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.execution.datasources.parquet
import java.net.URI
-import java.util.{List => JList}
import java.util.logging.{Logger => JLogger}
+import java.util.{List => JList}
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -32,24 +32,24 @@ import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.task.JobContextImpl
-import org.apache.parquet.{Log => ApacheParquetLog}
import org.apache.parquet.filter2.predicate.FilterApi
import org.apache.parquet.hadoop._
import org.apache.parquet.hadoop.metadata.CompressionCodecName
import org.apache.parquet.hadoop.util.ContextUtil
import org.apache.parquet.schema.MessageType
+import org.apache.parquet.{Log => ApacheParquetLog}
import org.slf4j.bridge.SLF4JBridgeHandler
-import org.apache.spark.{Logging, Partition => SparkPartition, SparkException}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD}
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser
+import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier}
+import org.apache.spark.sql.execution.datasources.{PartitionSpec, _}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.{Logging, Partition => SparkPartition, SparkException}
private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister {
@@ -147,6 +147,12 @@ private[sql] class ParquetRelation(
.get(ParquetRelation.METASTORE_SCHEMA)
.map(DataType.fromJson(_).asInstanceOf[StructType])
+ // If this relation is converted from a Hive metastore table, this method returns the name of the
+ // original Hive metastore table.
+ private[sql] def metastoreTableName: Option[TableIdentifier] = {
+ parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map(SqlParser.parseTableIdentifier)
+ }
+
private lazy val metadataCache: MetadataCache = {
val meta = new MetadataCache
meta.refresh()
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index bd1a52e5f3303..afd2f611580fc 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -41,9 +41,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
private val originalColumnBatchSize = TestHive.conf.columnBatchSize
private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning
- def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
+ def testCases: Seq[(String, File)] = {
+ hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
+ }
override def beforeAll() {
+ super.beforeAll()
TestHive.cacheTables = true
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
@@ -68,10 +71,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// For debugging dump some statistics about how much time was spent in various optimizer rules.
logWarning(RuleExecutor.dumpTimeSpent())
+ super.afterAll()
}
/** A list of tests deemed out of scope currently and thus completely disregarded. */
- override def blackList = Seq(
+ override def blackList: Seq[String] = Seq(
// These tests use hooks that are not on the classpath and thus break all subsequent execution.
"hook_order",
"hook_context_cs",
@@ -106,7 +110,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"alter_merge",
"alter_concatenate_indexed_table",
"protectmode2",
- //"describe_table",
+ // "describe_table",
"describe_comment_nonascii",
"create_merge_compressed",
@@ -323,7 +327,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
* The set of tests that are believed to be working in catalyst. Tests not on whiteList or
* blacklist are implicitly marked as ignored.
*/
- override def whiteList = Seq(
+ override def whiteList: Seq[String] = Seq(
"add_part_exist",
"add_part_multiple",
"add_partition_no_whitelist",
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
index 98bbdf0653c2a..bad3ca6da231f 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
@@ -104,6 +104,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte
TimeZone.setDefault(originalTimeZone)
Locale.setDefault(originalLocale)
TestHive.reset()
+ super.afterAll()
}
/////////////////////////////////////////////////////////////////////////////
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index bf3fe12d5c5d2..5b13dbe47370e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -668,7 +668,8 @@ private[hive] object HiveQl extends SparkQl with Logging {
Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse(
sys.error(s"Couldn't find function $functionName"))
val functionClassName = functionInfo.getFunctionClass.getName
- HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr))
+ HiveGenericUDTF(
+ functionName, new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr))
case other => super.nodeToGenerator(node)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
new file mode 100644
index 0000000000000..1c910051faccf
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder}
+import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+/**
+ * A builder class used to convert a resolved logical plan into a SQL query string. Note that this
+ * all resolved logical plan are convertible. They either don't have corresponding SQL
+ * representations (e.g. logical plans that operate on local Scala collections), or are simply not
+ * supported by this builder (yet).
+ */
+class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging {
+ def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext)
+
+ def toSQL: Option[String] = {
+ val canonicalizedPlan = Canonicalizer.execute(logicalPlan)
+ val maybeSQL = try {
+ toSQL(canonicalizedPlan)
+ } catch { case cause: UnsupportedOperationException =>
+ logInfo(s"Failed to build SQL query string because: ${cause.getMessage}")
+ None
+ }
+
+ if (maybeSQL.isDefined) {
+ logDebug(
+ s"""Built SQL query string successfully from given logical plan:
+ |
+ |# Original logical plan:
+ |${logicalPlan.treeString}
+ |# Canonicalized logical plan:
+ |${canonicalizedPlan.treeString}
+ |# Built SQL query string:
+ |${maybeSQL.get}
+ """.stripMargin)
+ } else {
+ logDebug(
+ s"""Failed to build SQL query string from given logical plan:
+ |
+ |# Original logical plan:
+ |${logicalPlan.treeString}
+ |# Canonicalized logical plan:
+ |${canonicalizedPlan.treeString}
+ """.stripMargin)
+ }
+
+ maybeSQL
+ }
+
+ private def projectToSQL(
+ projectList: Seq[NamedExpression],
+ child: LogicalPlan,
+ isDistinct: Boolean): Option[String] = {
+ for {
+ childSQL <- toSQL(child)
+ listSQL = projectList.map(_.sql).mkString(", ")
+ maybeFrom = child match {
+ case OneRowRelation => " "
+ case _ => " FROM "
+ }
+ distinct = if (isDistinct) " DISTINCT " else " "
+ } yield s"SELECT$distinct$listSQL$maybeFrom$childSQL"
+ }
+
+ private def aggregateToSQL(
+ groupingExprs: Seq[Expression],
+ aggExprs: Seq[Expression],
+ child: LogicalPlan): Option[String] = {
+ val aggSQL = aggExprs.map(_.sql).mkString(", ")
+ val groupingSQL = groupingExprs.map(_.sql).mkString(", ")
+ val maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY "
+ val maybeFrom = child match {
+ case OneRowRelation => " "
+ case _ => " FROM "
+ }
+
+ toSQL(child).map { childSQL =>
+ s"SELECT $aggSQL$maybeFrom$childSQL$maybeGroupBy$groupingSQL"
+ }
+ }
+
+ private def toSQL(node: LogicalPlan): Option[String] = node match {
+ case Distinct(Project(list, child)) =>
+ projectToSQL(list, child, isDistinct = true)
+
+ case Project(list, child) =>
+ projectToSQL(list, child, isDistinct = false)
+
+ case Aggregate(groupingExprs, aggExprs, child) =>
+ aggregateToSQL(groupingExprs, aggExprs, child)
+
+ case Limit(limit, child) =>
+ for {
+ childSQL <- toSQL(child)
+ limitSQL = limit.sql
+ } yield s"$childSQL LIMIT $limitSQL"
+
+ case Filter(condition, child) =>
+ for {
+ childSQL <- toSQL(child)
+ whereOrHaving = child match {
+ case _: Aggregate => "HAVING"
+ case _ => "WHERE"
+ }
+ conditionSQL = condition.sql
+ } yield s"$childSQL $whereOrHaving $conditionSQL"
+
+ case Union(left, right) =>
+ for {
+ leftSQL <- toSQL(left)
+ rightSQL <- toSQL(right)
+ } yield s"$leftSQL UNION ALL $rightSQL"
+
+ // ParquetRelation converted from Hive metastore table
+ case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) =>
+ // There seems to be a bug related to `ParquetConversions` analysis rule. The problem is
+ // that, the metastore database name and table name are not always propagated to converted
+ // `ParquetRelation` instances via data source options. Here we use subquery alias as a
+ // workaround.
+ Some(s"`$alias`")
+
+ case Subquery(alias, child) =>
+ toSQL(child).map(childSQL => s"($childSQL) AS $alias")
+
+ case Join(left, right, joinType, condition) =>
+ for {
+ leftSQL <- toSQL(left)
+ rightSQL <- toSQL(right)
+ joinTypeSQL = joinType.sql
+ conditionSQL = condition.map(" ON " + _.sql).getOrElse("")
+ } yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL"
+
+ case MetastoreRelation(database, table, alias) =>
+ val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("")
+ Some(s"`$database`.`$table`$aliasSQL")
+
+ case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
+ if orders.map(_.child) == partitionExprs =>
+ for {
+ childSQL <- toSQL(child)
+ partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ")
+ } yield s"$childSQL CLUSTER BY $partitionExprsSQL"
+
+ case Sort(orders, global, child) =>
+ for {
+ childSQL <- toSQL(child)
+ ordersSQL = orders.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ")
+ orderOrSort = if (global) "ORDER" else "SORT"
+ } yield s"$childSQL $orderOrSort BY $ordersSQL"
+
+ case RepartitionByExpression(partitionExprs, child, _) =>
+ for {
+ childSQL <- toSQL(child)
+ partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ")
+ } yield s"$childSQL DISTRIBUTE BY $partitionExprsSQL"
+
+ case OneRowRelation =>
+ Some("")
+
+ case _ => None
+ }
+
+ object Canonicalizer extends RuleExecutor[LogicalPlan] {
+ override protected def batches: Seq[Batch] = Seq(
+ Batch("Canonicalizer", FixedPoint(100),
+ // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over
+ // `Aggregate`s to perform type casting. This rule merges these `Project`s into
+ // `Aggregate`s.
+ ProjectCollapsing,
+
+ // Used to handle other auxiliary `Project`s added by analyzer (e.g.
+ // `ResolveAggregateFunctions` rule)
+ RecoverScopingInfo
+ )
+ )
+
+ object RecoverScopingInfo extends Rule[LogicalPlan] {
+ override def apply(tree: LogicalPlan): LogicalPlan = tree transform {
+ // This branch handles aggregate functions within HAVING clauses. For example:
+ //
+ // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255"
+ //
+ // This kind of query results in query plans of the following form because of analysis rule
+ // `ResolveAggregateFunctions`:
+ //
+ // Project ...
+ // +- Filter ...
+ // +- Aggregate ...
+ // +- MetastoreRelation default, src, None
+ case plan @ Project(_, Filter(_, _: Aggregate)) =>
+ wrapChildWithSubquery(plan)
+
+ case plan @ Project(_,
+ _: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation | _: Limit
+ ) => plan
+
+ case plan: Project =>
+ wrapChildWithSubquery(plan)
+ }
+
+ def wrapChildWithSubquery(project: Project): Project = project match {
+ case Project(projectList, child) =>
+ val alias = SQLBuilder.newSubqueryName
+ val childAttributes = child.outputSet
+ val aliasedProjectList = projectList.map(_.transform {
+ case a: Attribute if childAttributes.contains(a) =>
+ a.withQualifiers(alias :: Nil)
+ }.asInstanceOf[NamedExpression])
+
+ Project(aliasedProjectList, Subquery(alias, child))
+ }
+ }
+ }
+}
+
+object SQLBuilder {
+ private val nextSubqueryId = new AtomicLong(0)
+
+ private def newSubqueryName: String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}"
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index b1a6d0ab7df3c..e76c18fa528f3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -17,30 +17,26 @@
package org.apache.spark.sql.hive
-import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.util.Try
import org.apache.hadoop.hive.ql.exec._
-import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
-import org.apache.hadoop.hive.ql.udf.generic._
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
-import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory}
+import org.apache.hadoop.hive.ql.udf.generic._
+import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
+import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory}
import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.util.sequenceOption
+import org.apache.spark.sql.catalyst.{InternalRow, analysis}
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.hive.client.ClientWrapper
import org.apache.spark.sql.types._
@@ -75,19 +71,19 @@ private[hive] class HiveFunctionRegistry(
try {
if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUDF(
- new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children)
+ name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children)
} else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children)
+ HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children)
+ HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children)
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children)
+ HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveUDAFFunction(
- new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
+ name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
- val udtf = HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children)
+ val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children)
udtf.elementTypes // Force it to check input data types.
udtf
} else {
@@ -137,7 +133,8 @@ private[hive] class HiveFunctionRegistry(
}
}
-private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+private[hive] case class HiveSimpleUDF(
+ name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with CodegenFallback with Logging {
override def deterministic: Boolean = isUDFDeterministic
@@ -191,6 +188,8 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre
override def toString: String = {
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
+
+ override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
}
// Adapter from Catalyst ExpressionResult to Hive DeferredObject
@@ -205,7 +204,8 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp
override def get(): AnyRef = wrap(func(), oi, dataType)
}
-private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+private[hive] case class HiveGenericUDF(
+ name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with CodegenFallback with Logging {
override def nullable: Boolean = true
@@ -257,6 +257,8 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr
override def toString: String = {
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
+
+ override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
}
/**
@@ -271,6 +273,7 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr
* user defined aggregations, which have clean semantics even in a partitioned execution.
*/
private[hive] case class HiveGenericUDTF(
+ name: String,
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression])
extends Generator with HiveInspectors with CodegenFallback {
@@ -336,6 +339,8 @@ private[hive] case class HiveGenericUDTF(
override def toString: String = {
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
+
+ override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
}
/**
@@ -343,6 +348,7 @@ private[hive] case class HiveGenericUDTF(
* performance a lot.
*/
private[hive] case class HiveUDAFFunction(
+ name: String,
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression],
isUDAFBridgeRequired: Boolean = false,
@@ -427,5 +433,9 @@ private[hive] case class HiveUDAFFunction(
override def supportsPartial: Boolean = false
override val dataType: DataType = inspectorToDataType(returnInspector)
-}
+ override def sql(isDistinct: Boolean): String = {
+ val distinct = if (isDistinct) "DISTINCT " else " "
+ s"$name($distinct${children.map(_.sql).mkString(", ")})"
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala
new file mode 100644
index 0000000000000..3a6eb57add4e3
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.sql.Timestamp
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{If, Literal}
+
+class ExpressionSQLBuilderSuite extends SQLBuilderTest {
+ test("literal") {
+ checkSQL(Literal("foo"), "\"foo\"")
+ checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"")
+ checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)")
+ checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)")
+ checkSQL(Literal(4: Int), "4")
+ checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)")
+ checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)")
+ checkSQL(Literal(2.5D), "2.5")
+ checkSQL(
+ Literal(Timestamp.valueOf("2016-01-01 00:00:00")),
+ "TIMESTAMP('2016-01-01 00:00:00.0')")
+ // TODO tests for decimals
+ }
+
+ test("binary comparisons") {
+ checkSQL('a.int === 'b.int, "(`a` = `b`)")
+ checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)")
+ checkSQL('a.int !== 'b.int, "(NOT (`a` = `b`))")
+
+ checkSQL('a.int < 'b.int, "(`a` < `b`)")
+ checkSQL('a.int <= 'b.int, "(`a` <= `b`)")
+ checkSQL('a.int > 'b.int, "(`a` > `b`)")
+ checkSQL('a.int >= 'b.int, "(`a` >= `b`)")
+
+ checkSQL('a.int in ('b.int, 'c.int), "(`a` IN (`b`, `c`))")
+ checkSQL('a.int in (1, 2), "(`a` IN (1, 2))")
+
+ checkSQL('a.int.isNull, "(`a` IS NULL)")
+ checkSQL('a.int.isNotNull, "(`a` IS NOT NULL)")
+ }
+
+ test("logical operators") {
+ checkSQL('a.boolean && 'b.boolean, "(`a` AND `b`)")
+ checkSQL('a.boolean || 'b.boolean, "(`a` OR `b`)")
+ checkSQL(!'a.boolean, "(NOT `a`)")
+ checkSQL(If('a.boolean, 'b.int, 'c.int), "(IF(`a`, `b`, `c`))")
+ }
+
+ test("arithmetic expressions") {
+ checkSQL('a.int + 'b.int, "(`a` + `b`)")
+ checkSQL('a.int - 'b.int, "(`a` - `b`)")
+ checkSQL('a.int * 'b.int, "(`a` * `b`)")
+ checkSQL('a.int / 'b.int, "(`a` / `b`)")
+ checkSQL('a.int % 'b.int, "(`a` % `b`)")
+
+ checkSQL(-'a.int, "(-`a`)")
+ checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))")
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
new file mode 100644
index 0000000000000..0e81acf532a03
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.functions._
+
+class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
+ import testImplicits._
+
+ protected override def beforeAll(): Unit = {
+ sqlContext.range(10).write.saveAsTable("t0")
+
+ sqlContext
+ .range(10)
+ .select('id as 'key, concat(lit("val_"), 'id) as 'value)
+ .write
+ .saveAsTable("t1")
+
+ sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2")
+ }
+
+ override protected def afterAll(): Unit = {
+ sql("DROP TABLE IF EXISTS t0")
+ sql("DROP TABLE IF EXISTS t1")
+ sql("DROP TABLE IF EXISTS t2")
+ }
+
+ private def checkHiveQl(hiveQl: String): Unit = {
+ val df = sql(hiveQl)
+ val convertedSQL = new SQLBuilder(df).toSQL
+
+ if (convertedSQL.isEmpty) {
+ fail(
+ s"""Cannot convert the following HiveQL query plan back to SQL query string:
+ |
+ |# Original HiveQL query string:
+ |$hiveQl
+ |
+ |# Resolved query plan:
+ |${df.queryExecution.analyzed.treeString}
+ """.stripMargin)
+ }
+
+ val sqlString = convertedSQL.get
+ try {
+ checkAnswer(sql(sqlString), df)
+ } catch { case cause: Throwable =>
+ fail(
+ s"""Failed to execute converted SQL string or got wrong answer:
+ |
+ |# Converted SQL query string:
+ |$sqlString
+ |
+ |# Original HiveQL query string:
+ |$hiveQl
+ |
+ |# Resolved query plan:
+ |${df.queryExecution.analyzed.treeString}
+ """.stripMargin,
+ cause)
+ }
+ }
+
+ test("in") {
+ checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)")
+ }
+
+ test("aggregate function in having clause") {
+ checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0")
+ }
+
+ test("aggregate function in order by clause") {
+ checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)")
+ }
+
+ // TODO Fix name collision introduced by ResolveAggregateFunction analysis rule
+ // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into
+ // Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query
+ // execution since these aliases have different expression ID. But this introduces name collision
+ // when converting resolved plans back to SQL query strings as expression IDs are stripped.
+ ignore("aggregate function in order by clause with multiple order keys") {
+ checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)")
+ }
+
+ test("type widening in union") {
+ checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0")
+ }
+
+ test("case") {
+ checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0")
+ }
+
+ test("case with else") {
+ checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM t0")
+ }
+
+ test("case with key") {
+ checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM t0")
+ }
+
+ test("case with key and else") {
+ checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM t0")
+ }
+
+ test("select distinct without aggregate functions") {
+ checkHiveQl("SELECT DISTINCT id FROM t0")
+ }
+
+ test("cluster by") {
+ checkHiveQl("SELECT id FROM t0 CLUSTER BY id")
+ }
+
+ test("distribute by") {
+ checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id")
+ }
+
+ test("distribute by with sort by") {
+ checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id SORT BY id")
+ }
+
+ test("distinct aggregation") {
+ checkHiveQl("SELECT COUNT(DISTINCT id) FROM t0")
+ }
+
+ // TODO Enable this
+ // Query plans transformed by DistinctAggregationRewriter are not recognized yet
+ ignore("distinct and non-distinct aggregation") {
+ checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a")
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala
new file mode 100644
index 0000000000000..cf4a3fdd88806
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.{DataFrame, QueryTest}
+
+abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
+ protected def checkSQL(e: Expression, expectedSQL: String): Unit = {
+ val actualSQL = e.sql
+ try {
+ assert(actualSQL === expectedSQL)
+ } catch {
+ case cause: Throwable =>
+ fail(
+ s"""Wrong SQL generated for the following expression:
+ |
+ |${e.prettyName}
+ |
+ |$cause
+ """.stripMargin)
+ }
+ }
+
+ protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = {
+ val maybeSQL = new SQLBuilder(plan, hiveContext).toSQL
+
+ if (maybeSQL.isEmpty) {
+ fail(
+ s"""Cannot convert the following logical query plan to SQL:
+ |
+ |${plan.treeString}
+ """.stripMargin)
+ }
+
+ val actualSQL = maybeSQL.get
+
+ try {
+ assert(actualSQL === expectedSQL)
+ } catch {
+ case cause: Throwable =>
+ fail(
+ s"""Wrong SQL generated for the following logical query plan:
+ |
+ |${plan.treeString}
+ |
+ |$cause
+ """.stripMargin)
+ }
+
+ checkAnswer(sqlContext.sql(actualSQL), new DataFrame(sqlContext, plan))
+ }
+
+ protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = {
+ checkSQL(df.queryExecution.analyzed, expectedSQL)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index d7e8ebc8d312f..57358a07840e2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -27,9 +27,10 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.execution.{ExplainCommand, SetCommand}
import org.apache.spark.sql.execution.datasources.DescribeCommand
+import org.apache.spark.sql.execution.{ExplainCommand, SetCommand}
import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable, SQLBuilder}
/**
* Allows the creations of tests that execute the same query against both hive
@@ -130,6 +131,28 @@ abstract class HiveComparisonTest
new java.math.BigInteger(1, digest.digest).toString(16)
}
+ /** Used for testing [[SQLBuilder]] */
+ private var numConvertibleQueries: Int = 0
+ private var numTotalQueries: Int = 0
+
+ override protected def afterAll(): Unit = {
+ logInfo({
+ val percentage = if (numTotalQueries > 0) {
+ numConvertibleQueries.toDouble / numTotalQueries * 100
+ } else {
+ 0D
+ }
+
+ s"""SQLBuiler statistics:
+ |- Total query number: $numTotalQueries
+ |- Number of convertible queries: $numConvertibleQueries
+ |- Percentage of convertible queries: $percentage%
+ """.stripMargin
+ })
+
+ super.afterAll()
+ }
+
protected def prepareAnswer(
hiveQuery: TestHive.type#QueryExecution,
answer: Seq[String]): Seq[String] = {
@@ -372,8 +395,49 @@ abstract class HiveComparisonTest
// Run w/ catalyst
val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
- val query = new TestHive.QueryExecution(queryString)
- try { (query, prepareAnswer(query, query.stringResult())) } catch {
+ var query: TestHive.QueryExecution = null
+ try {
+ query = {
+ val originalQuery = new TestHive.QueryExecution(queryString)
+ val containsCommands = originalQuery.analyzed.collectFirst {
+ case _: Command => ()
+ case _: LogicalInsertIntoHiveTable => ()
+ }.nonEmpty
+
+ if (containsCommands) {
+ originalQuery
+ } else {
+ numTotalQueries += 1
+ new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql =>
+ numConvertibleQueries += 1
+ logInfo(
+ s"""
+ |### Running SQL generation round-trip test {{{
+ |${originalQuery.analyzed.treeString}
+ |Original SQL:
+ |$queryString
+ |
+ |Generated SQL:
+ |$sql
+ |}}}
+ """.stripMargin.trim)
+ new TestHive.QueryExecution(sql)
+ }.getOrElse {
+ logInfo(
+ s"""
+ |### Cannot convert the following logical plan back to SQL {{{
+ |${originalQuery.analyzed.treeString}
+ |Original SQL:
+ |$queryString
+ |}}}
+ """.stripMargin.trim)
+ originalQuery
+ }
+ }
+ }
+
+ (query, prepareAnswer(query, query.stringResult()))
+ } catch {
case e: Throwable =>
val errorMessage =
s"""
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 98e22c2e2c1b0..5e9e77c4fb925 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -60,6 +60,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TimeZone.setDefault(originalTimeZone)
Locale.setDefault(originalLocale)
sql("DROP TEMPORARY FUNCTION udtf_count2")
+ super.afterAll()
}
test("SPARK-4908: concurrent hive native commands") {