diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c75025e79af4a..3095e640f1d79 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -272,8 +272,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { } protected[sql] override def parseSql(sql: String): LogicalPlan = { - var state = SessionState.get() - if (state == null) { + if (SessionState.get() == null) { SessionState.setCurrentSessionState(tlSession.get().asInstanceOf[SQLSession].sessionState) } super.parseSql(substitutor.substitute(hiveconf, sql)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ea1521a48c8a7..5089e8c208b34 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -588,12 +588,39 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p case p: LogicalPlan if p.resolved => p + + case p @ CreateTableAsSelect(table, child, allowExisting) if table.tableType == VirtualView => + val childSchema = child.output.map { attr => + HiveColumn(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), null) + } + + val newNames = table.schema.map(_.name) + + val schema = if (table.schema.nonEmpty) { + assert(newNames.length == childSchema.length) + assert(newNames.map(_.toLowerCase).distinct.length == newNames.length) + childSchema.zip(table.schema).map { + case (f1, f2) => HiveColumn(f1.name, f1.hiveType, f2.comment) + } + } else childSchema + + val (dbName, tblName) = processDatabaseAndTableName( + table.specifiedDatabase.getOrElse(client.currentDatabase), table.name) + + execution.CreateViewAsSelect( + table.copy( + specifiedDatabase = Some(dbName), + name = tblName, + schema = schema), + newNames, + allowExisting) + case p @ CreateTableAsSelect(table, child, allowExisting) => val schema = if (table.schema.nonEmpty) { table.schema } else { child.output.map { - attr => new HiveColumn( + attr => HiveColumn( attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), null) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 256440a9a2e97..278ba47db9083 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -110,7 +110,6 @@ private[hive] object HiveQl extends Logging { "TOK_CREATEFUNCTION", "TOK_CREATEINDEX", "TOK_CREATEROLE", - "TOK_CREATEVIEW", "TOK_DESCDATABASE", "TOK_DESCFUNCTION", @@ -248,7 +247,7 @@ private[hive] object HiveQl extends Logging { /** * Returns the AST for the given SQL string. */ - def getAst(sql: String): ASTNode = { + def getAst(sql: String): (Context, ASTNode) = { /* * Context has to be passed in hive0.13.1. * Otherwise, there will be Null pointer exception, @@ -256,8 +255,7 @@ private[hive] object HiveQl extends Logging { */ val hContext = new Context(SessionState.get().getConf()) val node = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, hContext)) - hContext.clear() - node + hContext -> node } /** @@ -280,15 +278,18 @@ private[hive] object HiveQl extends Logging { /** Creates LogicalPlan for a given HiveQL string. */ def createPlan(sql: String): LogicalPlan = { try { - val tree = getAst(sql) - if (nativeCommands contains tree.getText) { + val (ctx, tree) = getAst(sql) + val result = if (nativeCommands contains tree.getText) { HiveNativeCommand(sql) } else { + implicit val _ctx = ctx nodeToPlan(tree) match { case NativePlaceholder => HiveNativeCommand(sql) case other => other } } + ctx.clear() + result } catch { case pe: org.apache.hadoop.hive.ql.parse.ParseException => pe.getMessage match { @@ -304,7 +305,7 @@ private[hive] object HiveQl extends Logging { throw new AnalysisException( s""" |Unsupported language features in query: $sql - |${dumpTree(getAst(sql))} + |${dumpTree(getAst(sql)._2)} |$e |${e.getStackTrace.head} """.stripMargin) @@ -342,7 +343,8 @@ private[hive] object HiveQl extends Logging { } } - protected def getClauses(clauseNames: Seq[String], nodeList: Seq[ASTNode]): Seq[Option[Node]] = { + protected def getClauses( + clauseNames: Seq[String], nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { var remainingNodes = nodeList val clauses = clauseNames.map { clauseName => val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName) @@ -423,15 +425,12 @@ private[hive] object HiveQl extends Logging { } protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { - val (db, tableName) = - tableNameParts.getChildren.asScala.map { - case Token(part, Nil) => cleanIdentifier(part) - } match { - case Seq(tableOnly) => (None, tableOnly) - case Seq(databaseName, table) => (Some(databaseName), table) - } - - (db, tableName) + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { + case Seq(tableOnly) => (None, tableOnly) + case Seq(databaseName, table) => (Some(databaseName), table) + } } protected def extractTableIdent(tableNameParts: Node): Seq[String] = { @@ -489,7 +488,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } } - protected def nodeToPlan(node: Node): LogicalPlan = node match { + protected def nodeToPlan(node: Node)(implicit ctx: Context): LogicalPlan = node match { // Special drop table that also uncaches. case Token("TOK_DROPTABLE", Token("TOK_TABNAME", tableNameParts) :: @@ -563,6 +562,77 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } } + case Token("TOK_CREATEVIEW", children) + if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => + + val Seq( + Some(viewNameParts), + Some(query), + maybeComment, + allowExisting, + maybeProperties, + maybeColumns, + maybePartCols + ) = getClauses( + Seq( + "TOK_TABNAME", + "TOK_QUERY", + "TOK_TABLECOMMENT", + "TOK_IFNOTEXISTS", + "TOK_TABLEPROPERTIES", + "TOK_TABCOLNAME", + "TOK_VIEWPARTCOLS"), + children) + + if (maybePartCols.isDefined) { + val sql = ctx.getTokenRewriteStream + .toString(query.parent.getTokenStartIndex, query.parent.getTokenStopIndex) + println(sql) + HiveNativeCommand(sql) + } else { + val (db, viewName) = extractDbNameTableName(viewNameParts) + + val originalText = ctx.getTokenRewriteStream + .toString(query.getTokenStartIndex, query.getTokenStopIndex) + + val schema = maybeColumns.map { cols => + BaseSemanticAnalyzer.getColumns(cols, true).asScala.map { field => + HiveColumn(field.getName, field.getType, field.getComment) + } + }.getOrElse(Seq.empty[HiveColumn]) + + val properties = scala.collection.mutable.Map.empty[String, String] + + maybeProperties.foreach { + case Token("TOK_TABLEPROPERTIES", list :: Nil) => + properties ++= getProperties(list) + } + + maybeComment.foreach { + case Token("TOK_TABLECOMMENT", child :: Nil) => + val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) + if (comment ne null) { + properties += ("comment" -> comment) + } + } + + val tableDesc = HiveTable( + specifiedDatabase = db, + name = viewName, + schema = schema, + partitionColumns = Seq.empty[HiveColumn], + properties = properties.toMap, + serdeProperties = Map[String, String](), + tableType = VirtualView, + location = None, + inputFormat = None, + outputFormat = None, + serde = None, + viewText = Some(originalText)) + + CreateTableAsSelect(tableDesc, nodeToPlan(query), allowExisting.isDefined) + } + case Token("TOK_CREATETABLE", children) if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -1102,7 +1172,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val allJoinTokens = "(TOK_.*JOIN)".r val laterViewToken = "TOK_LATERAL_VIEW(.*)".r - def nodeToRelation(node: Node): LogicalPlan = node match { + def nodeToRelation(node: Node)(implicit ctx: Context): LogicalPlan = node match { case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => Subquery(cleanIdentifier(alias), nodeToPlan(query)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 3811c152a7ae6..dd9dbce76612a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -23,9 +23,7 @@ import java.util.{Map => JMap} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression -private[hive] case class HiveDatabase( - name: String, - location: String) +private[hive] case class HiveDatabase(name: String, location: String) private[hive] abstract class TableType { val name: String } private[hive] case object ExternalTable extends TableType { override val name = "EXTERNAL_TABLE" } @@ -126,6 +124,9 @@ private[hive] trait ClientInterface { /** Returns the metadata for the specified table or None if it doens't exist. */ def getTableOption(dbName: String, tableName: String): Option[HiveTable] + /** Creates a view with the given metadata. */ + def createView(view: HiveTable): Unit + /** Creates a table with the given metadata. */ def createTable(table: HiveTable): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 4d1e3ed9198e6..79a72aa7d86b6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -354,6 +354,18 @@ private[hive] class ClientWrapper( qlTable } + override def createView(view: HiveTable): Unit = withHiveState { + val tbl = new metadata.Table(view.database, view.name) + tbl.setTableType(HTableType.VIRTUAL_VIEW) + tbl.setSerializationLib(null) + tbl.clearSerDeInfo() + tbl.setViewOriginalText(view.viewText.get) + tbl.setViewExpandedText(view.viewText.get) + tbl.setFields(view.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) + view.properties.foreach { case (k, v) => tbl.setProperty(k, v) } + client.createTable(tbl) + } + override def createTable(table: HiveTable): Unit = withHiveState { val qlTable = toQlTable(table) client.createTable(qlTable) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala new file mode 100644 index 0000000000000..535fa2d13879a --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -0,0 +1,50 @@ +package org.apache.spark.sql.hive.execution + +import org.apache.hadoop.hive.ql.metadata.HiveUtils +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.hive.client.HiveTable + +private[hive] +case class CreateViewAsSelect( + tableDesc: HiveTable, + newNames: Seq[String], + allowExisting: Boolean) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val hiveContext = sqlContext.asInstanceOf[HiveContext] + val database = tableDesc.database + val viewName = tableDesc.name + + if (hiveContext.catalog.tableExists(Seq(database, viewName))) { + if (allowExisting) { + // view already exists, will do nothing, to keep consistent with Hive + } else { + throw new AnalysisException(s"$database.$viewName already exists.") + } + } else { + val tbl = if (newNames.nonEmpty) { + val sb = new StringBuilder + sb.append("SELECT ") + for (i <- 0 until newNames.length) { + if (i > 0) { + sb.append(", ") + } + sb.append(HiveUtils.unparseIdentifier(tableDesc.schema(i).name)) + sb.append(" AS ") + sb.append(HiveUtils.unparseIdentifier(newNames(i))) + } + sb.append(" FROM (") + sb.append(tableDesc.viewText.get) + sb.append(") ") + sb.append(HiveUtils.unparseIdentifier(tableDesc.name)) + tableDesc.copy(viewText = Some(sb.toString)) + } else tableDesc + + hiveContext.catalog.client.createView(tbl) + } + + Seq.empty[Row] + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index cf737836939f9..fb142378902eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -118,7 +118,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd */ def positionTest(name: String, query: String, token: String): Unit = { def parseTree = - Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("") + Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)._2))).getOrElse("") test(name) { val error = intercept[AnalysisException] { @@ -142,7 +142,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd val actualStart = error.startPosition.getOrElse { fail( s"start not returned for error on token $token\n" + - HiveQl.dumpTree(HiveQl.getAst(query)) + HiveQl.dumpTree(HiveQl.getAst(query)._2) ) } assert(expectedStart === actualStart, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8c3f9ac202637..a321efcd58447 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1248,4 +1248,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin), Row("b", 6.0) :: Row("a", 7.0) :: Nil) } } + + test("SPARK-10337: correctly handle hive views") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + } }