Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 103 additions & 121 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package org.apache.spark.sql.hive

import java.util.concurrent.atomic.AtomicLong

import scala.util.control.NonFatal

import org.apache.spark.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
Expand All @@ -37,27 +39,22 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging {
def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext)

def toSQL: Option[String] = {
def toSQL: String = {
val canonicalizedPlan = Canonicalizer.execute(logicalPlan)
val maybeSQL = try {
toSQL(canonicalizedPlan)
} catch { case cause: UnsupportedOperationException =>
logInfo(s"Failed to build SQL query string because: ${cause.getMessage}")
None
}

if (maybeSQL.isDefined) {
try {
val generatedSQL = toSQL(canonicalizedPlan)
logDebug(
s"""Built SQL query string successfully from given logical plan:
|
|# Original logical plan:
|${logicalPlan.treeString}
|# Canonicalized logical plan:
|${canonicalizedPlan.treeString}
|# Built SQL query string:
|${maybeSQL.get}
|# Generated SQL:
|$generatedSQL
""".stripMargin)
} else {
generatedSQL
} catch { case NonFatal(e) =>
logDebug(
s"""Failed to build SQL query string from given logical plan:
|
Expand All @@ -66,128 +63,113 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
|# Canonicalized logical plan:
|${canonicalizedPlan.treeString}
""".stripMargin)
throw e
}

maybeSQL
}

private def projectToSQL(
projectList: Seq[NamedExpression],
child: LogicalPlan,
isDistinct: Boolean): Option[String] = {
for {
childSQL <- toSQL(child)
listSQL = projectList.map(_.sql).mkString(", ")
maybeFrom = child match {
case OneRowRelation => " "
case _ => " FROM "
}
distinct = if (isDistinct) " DISTINCT " else " "
} yield s"SELECT$distinct$listSQL$maybeFrom$childSQL"
}
private def toSQL(node: LogicalPlan): String = node match {
case Distinct(p: Project) =>
projectToSQL(p, isDistinct = true)

private def aggregateToSQL(
groupingExprs: Seq[Expression],
aggExprs: Seq[Expression],
child: LogicalPlan): Option[String] = {
val aggSQL = aggExprs.map(_.sql).mkString(", ")
val groupingSQL = groupingExprs.map(_.sql).mkString(", ")
val maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY "
val maybeFrom = child match {
case OneRowRelation => " "
case _ => " FROM "
}
case p: Project =>
projectToSQL(p, isDistinct = false)

toSQL(child).map { childSQL =>
s"SELECT $aggSQL$maybeFrom$childSQL$maybeGroupBy$groupingSQL"
}
}
case p: Aggregate =>
aggregateToSQL(p)

private def toSQL(node: LogicalPlan): Option[String] = node match {
case Distinct(Project(list, child)) =>
projectToSQL(list, child, isDistinct = true)

case Project(list, child) =>
projectToSQL(list, child, isDistinct = false)

case Aggregate(groupingExprs, aggExprs, child) =>
aggregateToSQL(groupingExprs, aggExprs, child)

case Limit(limit, child) =>
for {
childSQL <- toSQL(child)
limitSQL = limit.sql
} yield s"$childSQL LIMIT $limitSQL"

case Filter(condition, child) =>
for {
childSQL <- toSQL(child)
whereOrHaving = child match {
case _: Aggregate => "HAVING"
case _ => "WHERE"
}
conditionSQL = condition.sql
} yield s"$childSQL $whereOrHaving $conditionSQL"

case Union(children) if children.length > 1 =>
val childrenSql = children.map(toSQL(_))
if (childrenSql.exists(_.isEmpty)) {
None
} else {
Some(childrenSql.map(_.get).mkString(" UNION ALL "))
case p: Limit =>
s"${toSQL(p.child)} LIMIT ${p.limitExpr.sql}"

case p: Filter =>
val whereOrHaving = p.child match {
case _: Aggregate => "HAVING"
case _ => "WHERE"
}
build(toSQL(p.child), whereOrHaving, p.condition.sql)

case p: Union if p.children.length > 1 =>
val childrenSql = p.children.map(toSQL(_))
childrenSql.mkString(" UNION ALL ")

case p: Subquery =>
p.child match {
// Persisted data source relation
case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) =>
s"`$database`.`$table`"
// Parentheses is not used for persisted data source relations
// e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1
case Subquery(_, _: LogicalRelation | _: MetastoreRelation) =>
build(toSQL(p.child), "AS", p.alias)
case _ =>
build("(" + toSQL(p.child) + ")", "AS", p.alias)
}

// Persisted data source relation
case Subquery(alias, LogicalRelation(_, _, Some(TableIdentifier(table, Some(database))))) =>
Some(s"`$database`.`$table`")

case Subquery(alias, child) =>
toSQL(child).map( childSQL =>
child match {
// Parentheses is not used for persisted data source relations
// e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1
case Subquery(_, _: LogicalRelation | _: MetastoreRelation) =>
s"$childSQL AS $alias"
case _ =>
s"($childSQL) AS $alias"
})

case Join(left, right, joinType, condition) =>
for {
leftSQL <- toSQL(left)
rightSQL <- toSQL(right)
joinTypeSQL = joinType.sql
conditionSQL = condition.map(" ON " + _.sql).getOrElse("")
} yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL"

case MetastoreRelation(database, table, alias) =>
val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("")
Some(s"`$database`.`$table`$aliasSQL")
case p: Join =>
build(
toSQL(p.left),
p.joinType.sql,
"JOIN",
toSQL(p.right),
p.condition.map(" ON " + _.sql).getOrElse(""))

case p: MetastoreRelation =>
build(
s"`${p.databaseName}`.`${p.tableName}`",
p.alias.map(a => s" AS `$a`").getOrElse("")
)

case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
if orders.map(_.child) == partitionExprs =>
for {
childSQL <- toSQL(child)
partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ")
} yield s"$childSQL CLUSTER BY $partitionExprsSQL"

case Sort(orders, global, child) =>
for {
childSQL <- toSQL(child)
ordersSQL = orders.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ")
orderOrSort = if (global) "ORDER" else "SORT"
} yield s"$childSQL $orderOrSort BY $ordersSQL"

case RepartitionByExpression(partitionExprs, child, _) =>
for {
childSQL <- toSQL(child)
partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ")
} yield s"$childSQL DISTRIBUTE BY $partitionExprsSQL"
build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", "))

case p: Sort =>
build(
toSQL(p.child),
if (p.global) "ORDER BY" else "SORT BY",
p.order.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ")
)

case p: RepartitionByExpression =>
build(
toSQL(p.child),
"DISTRIBUTE BY",
p.partitionExpressions.map(_.sql).mkString(", ")
)

case OneRowRelation =>
Some("")
""

case _ => None
case _ =>
throw new UnsupportedOperationException(s"unsupported plan $node")
}

/**
* Turns a bunch of string segments into a single string and separate each segment by a space.
* The segments are trimmed so only a single space appears in the separation.
* For example, `build("a", " b ", " c")` becomes "a b c".
*/
private def build(segments: String*): String = segments.map(_.trim).mkString(" ")

private def projectToSQL(plan: Project, isDistinct: Boolean): String = {
build(
"SELECT",
if (isDistinct) "DISTINCT" else "",
plan.projectList.map(_.sql).mkString(", "),
if (plan.child == OneRowRelation) "" else "FROM",
toSQL(plan.child)
)
}

private def aggregateToSQL(plan: Aggregate): String = {
val groupingSQL = plan.groupingExpressions.map(_.sql).mkString(", ")
build(
"SELECT",
plan.aggregateExpressions.map(_.sql).mkString(", "),
if (plan.child == OneRowRelation) "" else "FROM",
toSQL(plan.child),
if (groupingSQL.isEmpty) "" else "GROUP BY",
groupingSQL
)
}

object Canonicalizer extends RuleExecutor[LogicalPlan] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.hive.execution

import scala.util.control.NonFatal

import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Alias
Expand Down Expand Up @@ -72,7 +74,9 @@ private[hive] case class CreateViewAsSelect(

private def prepareTable(sqlContext: SQLContext): HiveTable = {
val expandedText = if (sqlContext.conf.canonicalView) {
rebuildViewQueryString(sqlContext).getOrElse(wrapViewTextWithSelect)
try rebuildViewQueryString(sqlContext) catch {
case NonFatal(e) => wrapViewTextWithSelect
}
} else {
wrapViewTextWithSelect
}
Expand Down Expand Up @@ -112,7 +116,7 @@ private[hive] case class CreateViewAsSelect(
s"SELECT $viewOutput FROM ($viewText) $viewName"
}

private def rebuildViewQueryString(sqlContext: SQLContext): Option[String] = {
private def rebuildViewQueryString(sqlContext: SQLContext): String = {
val logicalPlan = if (tableDesc.schema.isEmpty) {
child
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest {
checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)")
checkSQL(Literal(2.5D), "2.5")
checkSQL(
Literal(Timestamp.valueOf("2016-01-01 00:00:00")),
"TIMESTAMP('2016-01-01 00:00:00.0')")
Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')")
// TODO tests for decimals
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.hive

import scala.util.control.NonFatal

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SQLTestUtils

Expand Down Expand Up @@ -46,29 +48,28 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {

private def checkHiveQl(hiveQl: String): Unit = {
val df = sql(hiveQl)
val convertedSQL = new SQLBuilder(df).toSQL

if (convertedSQL.isEmpty) {
fail(
s"""Cannot convert the following HiveQL query plan back to SQL query string:
|
|# Original HiveQL query string:
|$hiveQl
|
|# Resolved query plan:
|${df.queryExecution.analyzed.treeString}
""".stripMargin)
val convertedSQL = try new SQLBuilder(df).toSQL catch {
case NonFatal(e) =>
fail(
s"""Cannot convert the following HiveQL query plan back to SQL query string:
|
|# Original HiveQL query string:
|$hiveQl
|
|# Resolved query plan:
|${df.queryExecution.analyzed.treeString}
""".stripMargin)
}

val sqlString = convertedSQL.get
try {
checkAnswer(sql(sqlString), df)
checkAnswer(sql(convertedSQL), df)
} catch { case cause: Throwable =>
fail(
s"""Failed to execute converted SQL string or got wrong answer:
|
|# Converted SQL query string:
|$sqlString
|$convertedSQL
|
|# Original HiveQL query string:
|$hiveQl
Expand Down
Loading