Skip to content

Commit fc3c168

Browse files
author
Andrew Or
committed
Move things into new ParserUtils object
1 parent b7d4147 commit fc3c168

File tree

4 files changed

+177
-143
lines changed

4 files changed

+177
-143
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala

Lines changed: 1 addition & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
2727
import org.apache.spark.sql.catalyst.plans._
2828
import org.apache.spark.sql.catalyst.plans.logical._
29-
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
3029
import org.apache.spark.sql.types._
3130
import org.apache.spark.unsafe.types.CalendarInterval
3231
import org.apache.spark.util.random.RandomSampler
@@ -36,12 +35,7 @@ import org.apache.spark.util.random.RandomSampler
3635
* This class translates SQL to Catalyst [[LogicalPlan]]s or [[Expression]]s.
3736
*/
3837
private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserInterface {
39-
object Token {
40-
def unapply(node: ASTNode): Some[(String, List[ASTNode])] = {
41-
CurrentOrigin.setPosition(node.line, node.positionInLine)
42-
node.pattern
43-
}
44-
}
38+
import ParserUtils._
4539

4640
/**
4741
* The safeParse method allows a user to focus on the parsing/AST transformation logic. This
@@ -82,102 +76,6 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends
8276
def parseTableIdentifier(sql: String): TableIdentifier =
8377
safeParse(sql, ParseDriver.parseTableName(sql, conf))(extractTableIdent)
8478

85-
def parseDdl(sql: String): Seq[Attribute] = {
86-
safeParse(sql, ParseDriver.parseExpression(sql, conf)) { ast =>
87-
val Token("TOK_CREATETABLE", children) = ast
88-
children
89-
.find(_.text == "TOK_TABCOLLIST")
90-
.getOrElse(sys.error("No columnList!"))
91-
.flatMap(_.children.map(nodeToAttribute))
92-
}
93-
}
94-
95-
protected def getClauses(
96-
clauseNames: Seq[String],
97-
nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = {
98-
var remainingNodes = nodeList
99-
val clauses = clauseNames.map { clauseName =>
100-
val (matches, nonMatches) = remainingNodes.partition(_.text.toUpperCase == clauseName)
101-
remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil)
102-
matches.headOption
103-
}
104-
105-
if (remainingNodes.nonEmpty) {
106-
sys.error(
107-
s"""Unhandled clauses: ${remainingNodes.map(_.treeString).mkString("\n")}.
108-
|You are likely trying to use an unsupported Hive feature."""".stripMargin)
109-
}
110-
clauses
111-
}
112-
113-
protected def getClause(clauseName: String, nodeList: Seq[ASTNode]): ASTNode =
114-
getClauseOption(clauseName, nodeList).getOrElse(sys.error(
115-
s"Expected clause $clauseName missing from ${nodeList.map(_.treeString).mkString("\n")}"))
116-
117-
protected def getClauseOption(clauseName: String, nodeList: Seq[ASTNode]): Option[ASTNode] = {
118-
nodeList.filter { case ast: ASTNode => ast.text == clauseName } match {
119-
case Seq(oneMatch) => Some(oneMatch)
120-
case Seq() => None
121-
case _ => sys.error(s"Found multiple instances of clause $clauseName")
122-
}
123-
}
124-
125-
protected def nodeToAttribute(node: ASTNode): Attribute = node match {
126-
case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) =>
127-
AttributeReference(colName, nodeToDataType(dataType), nullable = true)()
128-
case _ =>
129-
noParseRule("Attribute", node)
130-
}
131-
132-
protected def nodeToDataType(node: ASTNode): DataType = node match {
133-
case Token("TOK_DECIMAL", precision :: scale :: Nil) =>
134-
DecimalType(precision.text.toInt, scale.text.toInt)
135-
case Token("TOK_DECIMAL", precision :: Nil) =>
136-
DecimalType(precision.text.toInt, 0)
137-
case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT
138-
case Token("TOK_BIGINT", Nil) => LongType
139-
case Token("TOK_INT", Nil) => IntegerType
140-
case Token("TOK_TINYINT", Nil) => ByteType
141-
case Token("TOK_SMALLINT", Nil) => ShortType
142-
case Token("TOK_BOOLEAN", Nil) => BooleanType
143-
case Token("TOK_STRING", Nil) => StringType
144-
case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType
145-
case Token("TOK_CHAR", Token(_, Nil) :: Nil) => StringType
146-
case Token("TOK_FLOAT", Nil) => FloatType
147-
case Token("TOK_DOUBLE", Nil) => DoubleType
148-
case Token("TOK_DATE", Nil) => DateType
149-
case Token("TOK_TIMESTAMP", Nil) => TimestampType
150-
case Token("TOK_BINARY", Nil) => BinaryType
151-
case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType))
152-
case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) =>
153-
StructType(fields.map(nodeToStructField))
154-
case Token("TOK_MAP", keyType :: valueType :: Nil) =>
155-
MapType(nodeToDataType(keyType), nodeToDataType(valueType))
156-
case _ =>
157-
noParseRule("DataType", node)
158-
}
159-
160-
protected def nodeToStructField(node: ASTNode): StructField = node match {
161-
case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) =>
162-
StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true)
163-
case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: comment :: Nil) =>
164-
val meta = new MetadataBuilder().putString("comment", unquoteString(comment.text)).build()
165-
StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true, meta)
166-
case _ =>
167-
noParseRule("StructField", node)
168-
}
169-
170-
protected def extractTableIdent(tableNameParts: ASTNode): TableIdentifier = {
171-
tableNameParts.children.map {
172-
case Token(part, Nil) => cleanIdentifier(part)
173-
} match {
174-
case Seq(tableOnly) => TableIdentifier(tableOnly)
175-
case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName))
176-
case other => sys.error("Hive only supports tables names like 'tableName' " +
177-
s"or 'databaseName.tableName', found '$other'")
178-
}
179-
}
180-
18179
/**
18280
* SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2))
18381
* is equivalent to
@@ -625,42 +523,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
625523
noParseRule("Select", node)
626524
}
627525

628-
protected val escapedIdentifier = "`(.+)`".r
629-
protected val doubleQuotedString = "\"([^\"]+)\"".r
630-
protected val singleQuotedString = "'([^']+)'".r
631-
632-
protected def unquoteString(str: String) = str match {
633-
case singleQuotedString(s) => s
634-
case doubleQuotedString(s) => s
635-
case other => other
636-
}
637-
638-
/** Strips backticks from ident if present */
639-
protected def cleanIdentifier(ident: String): String = ident match {
640-
case escapedIdentifier(i) => i
641-
case plainIdent => plainIdent
642-
}
643-
644-
/* Case insensitive matches */
645-
val COUNT = "(?i)COUNT".r
646-
val SUM = "(?i)SUM".r
647-
val AND = "(?i)AND".r
648-
val OR = "(?i)OR".r
649-
val NOT = "(?i)NOT".r
650-
val TRUE = "(?i)TRUE".r
651-
val FALSE = "(?i)FALSE".r
652-
val LIKE = "(?i)LIKE".r
653-
val RLIKE = "(?i)RLIKE".r
654-
val REGEXP = "(?i)REGEXP".r
655-
val IN = "(?i)IN".r
656-
val DIV = "(?i)DIV".r
657-
val BETWEEN = "(?i)BETWEEN".r
658-
val WHEN = "(?i)WHEN".r
659-
val CASE = "(?i)CASE".r
660-
661-
val INTEGRAL = "[+-]?\\d+".r
662-
val DECIMAL = "[+-]?((\\d+(\\.\\d*)?)|(\\.\\d+))".r
663-
664526
protected def nodeToExpr(node: ASTNode): Expression = node match {
665527
/* Attribute References */
666528
case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) =>
@@ -1007,6 +869,4 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
1007869

1008870
protected def nodeToGenerator(node: ASTNode): Generator = noParseRule("Generator", node)
1009871

1010-
protected def noParseRule(msg: String, node: ASTNode): Nothing = throw new NotImplementedError(
1011-
s"[$msg]: No parse rules for ASTNode type: ${node.tokenType}, tree:\n${node.treeString}")
1012872
}
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.parser
19+
20+
import org.apache.spark.sql.catalyst.TableIdentifier
21+
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
22+
import org.apache.spark.sql.types._
23+
24+
25+
/**
26+
* A collection of utility methods and patterns for parsing query texts.
27+
*/
28+
// TODO: merge with ParseUtils
29+
object ParserUtils {
30+
31+
object Token {
32+
def unapply(node: ASTNode): Some[(String, List[ASTNode])] = {
33+
CurrentOrigin.setPosition(node.line, node.positionInLine)
34+
node.pattern
35+
}
36+
}
37+
38+
private val escapedIdentifier = "`(.+)`".r
39+
private val doubleQuotedString = "\"([^\"]+)\"".r
40+
private val singleQuotedString = "'([^']+)'".r
41+
42+
// Token patterns
43+
val COUNT = "(?i)COUNT".r
44+
val SUM = "(?i)SUM".r
45+
val AND = "(?i)AND".r
46+
val OR = "(?i)OR".r
47+
val NOT = "(?i)NOT".r
48+
val TRUE = "(?i)TRUE".r
49+
val FALSE = "(?i)FALSE".r
50+
val LIKE = "(?i)LIKE".r
51+
val RLIKE = "(?i)RLIKE".r
52+
val REGEXP = "(?i)REGEXP".r
53+
val IN = "(?i)IN".r
54+
val DIV = "(?i)DIV".r
55+
val BETWEEN = "(?i)BETWEEN".r
56+
val WHEN = "(?i)WHEN".r
57+
val CASE = "(?i)CASE".r
58+
val INTEGRAL = "[+-]?\\d+".r
59+
val DECIMAL = "[+-]?((\\d+(\\.\\d*)?)|(\\.\\d+))".r
60+
61+
/**
62+
* Strip quotes, if any, from the string.
63+
*/
64+
def unquoteString(str: String): String = {
65+
str match {
66+
case singleQuotedString(s) => s
67+
case doubleQuotedString(s) => s
68+
case other => other
69+
}
70+
}
71+
72+
/**
73+
* Strip backticks, if any, from the string.
74+
*/
75+
def cleanIdentifier(ident: String): String = {
76+
ident match {
77+
case escapedIdentifier(i) => i
78+
case plainIdent => plainIdent
79+
}
80+
}
81+
82+
def getClauses(
83+
clauseNames: Seq[String],
84+
nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = {
85+
var remainingNodes = nodeList
86+
val clauses = clauseNames.map { clauseName =>
87+
val (matches, nonMatches) = remainingNodes.partition(_.text.toUpperCase == clauseName)
88+
remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil)
89+
matches.headOption
90+
}
91+
92+
if (remainingNodes.nonEmpty) {
93+
sys.error(
94+
s"""Unhandled clauses: ${remainingNodes.map(_.treeString).mkString("\n")}.
95+
|You are likely trying to use an unsupported Hive feature."""".stripMargin)
96+
}
97+
clauses
98+
}
99+
100+
def getClause(clauseName: String, nodeList: Seq[ASTNode]): ASTNode = {
101+
getClauseOption(clauseName, nodeList).getOrElse(sys.error(
102+
s"Expected clause $clauseName missing from ${nodeList.map(_.treeString).mkString("\n")}"))
103+
}
104+
105+
def getClauseOption(clauseName: String, nodeList: Seq[ASTNode]): Option[ASTNode] = {
106+
nodeList.filter { case ast: ASTNode => ast.text == clauseName } match {
107+
case Seq(oneMatch) => Some(oneMatch)
108+
case Seq() => None
109+
case _ => sys.error(s"Found multiple instances of clause $clauseName")
110+
}
111+
}
112+
113+
def extractTableIdent(tableNameParts: ASTNode): TableIdentifier = {
114+
tableNameParts.children.map {
115+
case Token(part, Nil) => cleanIdentifier(part)
116+
} match {
117+
case Seq(tableOnly) => TableIdentifier(tableOnly)
118+
case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName))
119+
case other => sys.error("Hive only supports tables names like 'tableName' " +
120+
s"or 'databaseName.tableName', found '$other'")
121+
}
122+
}
123+
124+
def nodeToDataType(node: ASTNode): DataType = node match {
125+
case Token("TOK_DECIMAL", precision :: scale :: Nil) =>
126+
DecimalType(precision.text.toInt, scale.text.toInt)
127+
case Token("TOK_DECIMAL", precision :: Nil) =>
128+
DecimalType(precision.text.toInt, 0)
129+
case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT
130+
case Token("TOK_BIGINT", Nil) => LongType
131+
case Token("TOK_INT", Nil) => IntegerType
132+
case Token("TOK_TINYINT", Nil) => ByteType
133+
case Token("TOK_SMALLINT", Nil) => ShortType
134+
case Token("TOK_BOOLEAN", Nil) => BooleanType
135+
case Token("TOK_STRING", Nil) => StringType
136+
case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType
137+
case Token("TOK_CHAR", Token(_, Nil) :: Nil) => StringType
138+
case Token("TOK_FLOAT", Nil) => FloatType
139+
case Token("TOK_DOUBLE", Nil) => DoubleType
140+
case Token("TOK_DATE", Nil) => DateType
141+
case Token("TOK_TIMESTAMP", Nil) => TimestampType
142+
case Token("TOK_BINARY", Nil) => BinaryType
143+
case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType))
144+
case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) =>
145+
StructType(fields.map(nodeToStructField))
146+
case Token("TOK_MAP", keyType :: valueType :: Nil) =>
147+
MapType(nodeToDataType(keyType), nodeToDataType(valueType))
148+
case _ =>
149+
noParseRule("DataType", node)
150+
}
151+
152+
def nodeToStructField(node: ASTNode): StructField = node match {
153+
case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) =>
154+
StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true)
155+
case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: comment :: Nil) =>
156+
val meta = new MetadataBuilder().putString("comment", unquoteString(comment.text)).build()
157+
StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true, meta)
158+
case _ =>
159+
noParseRule("StructField", node)
160+
}
161+
162+
/**
163+
* Throw an exception because we cannot parse the given node.
164+
*/
165+
def noParseRule(msg: String, node: ASTNode): Nothing = {
166+
throw new NotImplementedError(
167+
s"[$msg]: No parse rules for ASTNode type: ${node.tokenType}, tree:\n${node.treeString}")
168+
}
169+
170+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ package org.apache.spark.sql.execution
1919
import org.apache.spark.sql.{AnalysisException, SaveMode}
2020
import org.apache.spark.sql.catalyst.TableIdentifier
2121
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
22-
import org.apache.spark.sql.catalyst.parser.{ASTNode, CatalystQl, ParserConf, SimpleParserConf}
22+
import org.apache.spark.sql.catalyst.parser._
2323
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
2424
import org.apache.spark.sql.execution.command._
2525
import org.apache.spark.sql.execution.datasources._
2626
import org.apache.spark.sql.types.StructType
2727

2828
private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) {
29+
import ParserUtils._
30+
2931
/** Check if a command should not be explained. */
3032
protected def isNoExplainCommand(command: String): Boolean = "TOK_DESCTABLE" == command
3133

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier
3535
import org.apache.spark.sql.catalyst.catalog._
3636
import org.apache.spark.sql.catalyst.expressions._
3737
import org.apache.spark.sql.catalyst.parser._
38-
import org.apache.spark.sql.catalyst.parser.ParseUtils._
3938
import org.apache.spark.sql.catalyst.plans._
4039
import org.apache.spark.sql.catalyst.plans.logical._
4140
import org.apache.spark.sql.execution.SparkQl
@@ -81,6 +80,9 @@ private[hive] case class CreateViewAsSelect(
8180

8281
/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */
8382
private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging {
83+
import ParseUtils._
84+
import ParserUtils._
85+
8486
protected val nativeCommands = Seq(
8587
"TOK_ALTERDATABASE_OWNER",
8688
"TOK_ALTERDATABASE_PROPERTIES",

0 commit comments

Comments
 (0)