Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ selectItem
:
(tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns)
|
namedExpression
;

namedExpression
@init { gParent.pushMsg("select named expression", state); }
@after { gParent.popMsg(state); }
:
( expression
((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))?
) -> ^(TOK_SELEXPR expression identifier*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.random.RandomSampler

private[sql] object CatalystQl {
val parser = new CatalystQl
def parseExpression(sql: String): Expression = parser.parseExpression(sql)
def parseTableIdentifier(sql: String): TableIdentifier = parser.parseTableIdentifier(sql)
}

/**
* This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]].
*/
Expand All @@ -41,43 +47,53 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) {
}
}


/**
* Returns the AST for the given SQL string.
* The safeParse method allows a user to focus on the parsing/AST transformation logic. This
* method will take care of possible errors during the parsing process.
*/
protected def getAst(sql: String): ASTNode = ParseDriver.parse(sql, conf)

/** Creates LogicalPlan for a given HiveQL string. */
def createPlan(sql: String): LogicalPlan = {
protected def safeParse[T](sql: String, ast: ASTNode)(toResult: ASTNode => T): T = {
try {
createPlan(sql, ParseDriver.parse(sql, conf))
toResult(ast)
} catch {
case e: MatchError => throw e
case e: AnalysisException => throw e
case e: Exception =>
throw new AnalysisException(e.getMessage)
case e: NotImplementedError =>
throw new AnalysisException(
s"""
|Unsupported language features in query: $sql
|${getAst(sql).treeString}
s"""Unsupported language features in query
|== SQL ==
|$sql
|== AST ==
|${ast.treeString}
|== Error ==
|$e
|== Stacktrace ==
|${e.getStackTrace.head}
""".stripMargin)
}
}

protected def createPlan(sql: String, tree: ASTNode): LogicalPlan = nodeToPlan(tree)

def parseDdl(ddl: String): Seq[Attribute] = {
val tree = getAst(ddl)
assert(tree.text == "TOK_CREATETABLE", "Only CREATE TABLE supported.")
val tableOps = tree.children
val colList = tableOps
.find(_.text == "TOK_TABCOLLIST")
.getOrElse(sys.error("No columnList!"))

colList.children.map(nodeToAttribute)
/** Creates LogicalPlan for a given SQL string. */
def parsePlan(sql: String): LogicalPlan =
safeParse(sql, ParseDriver.parsePlan(sql, conf))(nodeToPlan)

/** Creates Expression for a given SQL string. */
def parseExpression(sql: String): Expression =
safeParse(sql, ParseDriver.parseExpression(sql, conf))(selExprNodeToExpr(_).get)

/** Creates TableIdentifier for a given SQL string. */
def parseTableIdentifier(sql: String): TableIdentifier =
safeParse(sql, ParseDriver.parseTableName(sql, conf))(extractTableIdent)

def parseDdl(sql: String): Seq[Attribute] = {
safeParse(sql, ParseDriver.parseExpression(sql, conf)) { ast =>
val Token("TOK_CREATETABLE", children) = ast
children
.find(_.text == "TOK_TABCOLLIST")
.getOrElse(sys.error("No columnList!"))
.flatMap(_.children.map(nodeToAttribute))
}
}

protected def getClauses(
Expand Down Expand Up @@ -187,7 +203,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val keyMap = keyASTs.zipWithIndex.toMap

val bitmasks: Seq[Int] = setASTs.map {
case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0
case Token("TOK_GROUPING_SETS_EXPRESSION", columns) =>
columns.foldLeft(0)((bitmap, col) => {
val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,25 @@ import org.apache.spark.sql.AnalysisException
* This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver
*/
object ParseDriver extends Logging {
def parse(command: String, conf: ParserConf): ASTNode = {
/** Create an LogicalPlan ASTNode from a SQL command. */
def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add some function doc for these 3 functions, and also cover what the differences are?

parser.statement().getTree
}

/** Create an Expression ASTNode from a SQL command. */
def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
parser.namedExpression().getTree
}

/** Create an TableIdentifier ASTNode from a SQL command. */
def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
parser.tableName().getTree
}

private def parse(
command: String,
conf: ParserConf)(
toTree: SparkSqlParser => CommonTree): ASTNode = {
logInfo(s"Parsing command: $command")

// Setup error collection.
Expand All @@ -44,7 +62,7 @@ object ParseDriver extends Logging {
parser.configure(conf, reporter)

try {
val result = parser.statement()
val result = toTree(parser)

// Check errors.
reporter.checkForErrors()
Expand All @@ -57,7 +75,7 @@ object ParseDriver extends Logging {
if (tree.token != null || tree.getChildCount == 0) tree
else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree])
}
val tree = nonNullToken(result.getTree)
val tree = nonNullToken(result)

// Make sure all boundaries are set.
tree.setUnknownTokenBoundaries()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,157 @@

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.unsafe.types.CalendarInterval

class CatalystQlSuite extends PlanTest {
val parser = new CatalystQl()

test("test case insensitive") {
val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation)
assert(result === parser.parsePlan("seLect 1"))
assert(result === parser.parsePlan("select 1"))
assert(result === parser.parsePlan("SELECT 1"))
}

test("test NOT operator with comparison operations") {
val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE")
val expected = Project(
UnresolvedAlias(
Not(
GreaterThan(Literal(true), Literal(true)))
) :: Nil,
OneRowRelation)
comparePlans(parsed, expected)
}

test("support hive interval literal") {
def checkInterval(sql: String, result: CalendarInterval): Unit = {
val parsed = parser.parsePlan(sql)
val expected = Project(
UnresolvedAlias(
Literal(result)
) :: Nil,
OneRowRelation)
comparePlans(parsed, expected)
}

def checkYearMonth(lit: String): Unit = {
checkInterval(
s"SELECT INTERVAL '$lit' YEAR TO MONTH",
CalendarInterval.fromYearMonthString(lit))
}

def checkDayTime(lit: String): Unit = {
checkInterval(
s"SELECT INTERVAL '$lit' DAY TO SECOND",
CalendarInterval.fromDayTimeString(lit))
}

def checkSingleUnit(lit: String, unit: String): Unit = {
checkInterval(
s"SELECT INTERVAL '$lit' $unit",
CalendarInterval.fromSingleUnitString(unit, lit))
}

checkYearMonth("123-10")
checkYearMonth("496-0")
checkYearMonth("-2-3")
checkYearMonth("-123-0")

checkDayTime("99 11:22:33.123456789")
checkDayTime("-99 11:22:33.123456789")
checkDayTime("10 9:8:7.123456789")
checkDayTime("1 0:0:0")
checkDayTime("-1 0:0:0")
checkDayTime("1 0:0:1")

for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) {
checkSingleUnit("7", unit)
checkSingleUnit("-7", unit)
checkSingleUnit("0", unit)
}

checkSingleUnit("13.123456789", "second")
checkSingleUnit("-13.123456789", "second")
}

test("support scientific notation") {
def assertRight(input: String, output: Double): Unit = {
val parsed = parser.parsePlan("SELECT " + input)
val expected = Project(
UnresolvedAlias(
Literal(output)
) :: Nil,
OneRowRelation)
comparePlans(parsed, expected)
}

assertRight("9.0e1", 90)
assertRight("0.9e+2", 90)
assertRight("900e-1", 90)
assertRight("900.0E-1", 90)
assertRight("9.e+1", 90)

intercept[AnalysisException](parser.parsePlan("SELECT .e3"))
}

test("parse expressions") {
compareExpressions(
parser.parseExpression("prinln('hello', 'world')"),
UnresolvedFunction(
"prinln", Literal("hello") :: Literal("world") :: Nil, false))

compareExpressions(
parser.parseExpression("1 + r.r As q"),
Alias(Add(Literal(1), UnresolvedAttribute("r.r")), "q")())

compareExpressions(
parser.parseExpression("1 - f('o', o(bar))"),
Subtract(Literal(1),
UnresolvedFunction("f",
Literal("o") ::
UnresolvedFunction("o", UnresolvedAttribute("bar") :: Nil, false) ::
Nil, false)))
}

test("table identifier") {
assert(TableIdentifier("q") === parser.parseTableIdentifier("q"))
assert(TableIdentifier("q", Some("d")) === parser.parseTableIdentifier("d.q"))
intercept[AnalysisException](parser.parseTableIdentifier(""))
// TODO parser swallows third identifier.
// intercept[AnalysisException](parser.parseTableIdentifier("d.q.g"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we going to support this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should support this. Are there use cases for this? I'll create a fix, that'll throw an AnalysisException when we encounter this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, throw exception seems reasonable to me

}

test("parse union/except/intersect") {
parser.createPlan("select * from t1 union all select * from t2")
parser.createPlan("select * from t1 union distinct select * from t2")
parser.createPlan("select * from t1 union select * from t2")
parser.createPlan("select * from t1 except select * from t2")
parser.createPlan("select * from t1 intersect select * from t2")
parser.createPlan("(select * from t1) union all (select * from t2)")
parser.createPlan("(select * from t1) union distinct (select * from t2)")
parser.createPlan("(select * from t1) union (select * from t2)")
parser.createPlan("select * from ((select * from t1) union (select * from t2)) t")
parser.parsePlan("select * from t1 union all select * from t2")
parser.parsePlan("select * from t1 union distinct select * from t2")
parser.parsePlan("select * from t1 union select * from t2")
parser.parsePlan("select * from t1 except select * from t2")
parser.parsePlan("select * from t1 intersect select * from t2")
parser.parsePlan("(select * from t1) union all (select * from t2)")
parser.parsePlan("(select * from t1) union distinct (select * from t2)")
parser.parsePlan("(select * from t1) union (select * from t2)")
parser.parsePlan("select * from ((select * from t1) union (select * from t2)) t")
}

test("window function: better support of parentheses") {
parser.createPlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " +
parser.parsePlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " +
"order by 2) from windowData")
parser.createPlan("select sum(product + 1) over (partition by (1 + (product / 2)) " +
parser.parsePlan("select sum(product + 1) over (partition by (1 + (product / 2)) " +
"order by 2) from windowData")
parser.createPlan("select sum(product + 1) over (partition by ((product / 2) + 1) " +
parser.parsePlan("select sum(product + 1) over (partition by ((product / 2) + 1) " +
"order by 2) from windowData")

parser.createPlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " +
parser.parsePlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " +
"from windowData")
parser.createPlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " +
parser.parsePlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " +
"from windowData")
parser.createPlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " +
parser.parsePlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " +
"from windowData")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser {

protected lazy val hiveQl: Parser[LogicalPlan] =
restInput ^^ {
case statement => HiveQl.createPlan(statement.trim)
case statement => HiveQl.parsePlan(statement.trim)
}

protected lazy val dfs: Parser[LogicalPlan] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
alias match {
// because hive use things like `_c0` to build the expanded text
// currently we cannot support view from "create view v1(c1) as ..."
case None => Subquery(table.name, HiveQl.createPlan(viewText))
case Some(aliasText) => Subquery(aliasText, HiveQl.createPlan(viewText))
case None => Subquery(table.name, HiveQl.parsePlan(viewText))
case Some(aliasText) => Subquery(aliasText, HiveQl.parsePlan(viewText))
}
} else {
MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive)
Expand Down
19 changes: 10 additions & 9 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,16 @@ private[hive] object HiveQl extends SparkQl with Logging {
CreateViewAsSelect(tableDesc, nodeToPlan(query), allowExist, replace, sql)
}

protected override def createPlan(
sql: String,
node: ASTNode): LogicalPlan = {
if (nativeCommands.contains(node.text)) {
HiveNativeCommand(sql)
} else {
nodeToPlan(node) match {
case NativePlaceholder => HiveNativeCommand(sql)
case plan => plan
/** Creates LogicalPlan for a given SQL string. */
override def parsePlan(sql: String): LogicalPlan = {
safeParse(sql, ParseDriver.parsePlan(sql, conf)) { ast =>
if (nativeCommands.contains(ast.text)) {
HiveNativeCommand(sql)
} else {
nodeToPlan(ast) match {
case NativePlaceholder => HiveNativeCommand(sql)
case plan => plan
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd
* @param token a unique token in the string that should be indicated by the exception
*/
def positionTest(name: String, query: String, token: String): Unit = {
def ast = ParseDriver.parse(query, hiveContext.conf)
def parseTree =
Try(quietly(ast.treeString)).getOrElse("<failed to parse>")
def ast = ParseDriver.parsePlan(query, hiveContext.conf)
def parseTree = Try(quietly(ast.treeString)).getOrElse("<failed to parse>")

test(name) {
val error = intercept[AnalysisException] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, M

class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
private def extractTableDesc(sql: String): (HiveTable, Boolean) = {
HiveQl.createPlan(sql).collect {
HiveQl.parsePlan(sql).collect {
case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting)
}.head
}
Expand Down