Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ class SqlParser extends AbstractSparkSQLParser {
joinedRelation | relationFactor

protected lazy val relationFactor: Parser[LogicalPlan] =
( ident ~ (opt(AS) ~> opt(ident)) ^^ {
case tableName ~ alias => UnresolvedRelation(None, tableName, alias)
( rep1sep(ident, ".") ~ (opt(AS) ~> opt(ident)) ^^ {
case tableIdent ~ alias => UnresolvedRelation(tableIdent, alias)
}
| ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) }
| ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) }
)

protected lazy val joinedRelation: Parser[LogicalPlan] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,11 @@ class Analyzer(catalog: Catalog,
*/
object ResolveRelations extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case i @ InsertIntoTable(UnresolvedRelation(databaseName, name, alias), _, _, _) =>
case i @ InsertIntoTable(UnresolvedRelation(tableIdentifier, alias), _, _, _) =>
i.copy(
table = EliminateAnalysisOperators(catalog.lookupRelation(databaseName, name, alias)))
case UnresolvedRelation(databaseName, name, alias) =>
catalog.lookupRelation(databaseName, name, alias)
table = EliminateAnalysisOperators(catalog.lookupRelation(tableIdentifier, alias)))
case UnresolvedRelation(tableIdentifier, alias) =>
catalog.lookupRelation(tableIdentifier, alias)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,77 +28,74 @@ trait Catalog {

def caseSensitive: Boolean

def tableExists(db: Option[String], tableName: String): Boolean
def tableExists(tableIdentifier: Seq[String]): Boolean

def lookupRelation(
databaseName: Option[String],
tableName: String,
alias: Option[String] = None): LogicalPlan
tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan

def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit
def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit

def unregisterTable(databaseName: Option[String], tableName: String): Unit
def unregisterTable(tableIdentifier: Seq[String]): Unit

def unregisterAllTables(): Unit

protected def processDatabaseAndTableName(
databaseName: Option[String],
tableName: String): (Option[String], String) = {
protected def processTableIdentifier(tableIdentifier: Seq[String]): Seq[String] = {
if (!caseSensitive) {
(databaseName.map(_.toLowerCase), tableName.toLowerCase)
tableIdentifier.map(_.toLowerCase)
} else {
(databaseName, tableName)
tableIdentifier
}
}

protected def processDatabaseAndTableName(
databaseName: String,
tableName: String): (String, String) = {
if (!caseSensitive) {
(databaseName.toLowerCase, tableName.toLowerCase)
protected def getDbTableName(tableIdent: Seq[String]): String = {
val size = tableIdent.size
if (size <= 2) {
tableIdent.mkString(".")
} else {
(databaseName, tableName)
tableIdent.slice(size - 2, size).mkString(".")
}
}

protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = {
(tableIdent.lift(tableIdent.size - 2), tableIdent.last)
}
}

class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
val tables = new mutable.HashMap[String, LogicalPlan]()

override def registerTable(
databaseName: Option[String],
tableName: String,
tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
tables += ((tblName, plan))
val tableIdent = processTableIdentifier(tableIdentifier)
tables += ((getDbTableName(tableIdent), plan))
}

override def unregisterTable(
databaseName: Option[String],
tableName: String) = {
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
tables -= tblName
override def unregisterTable(tableIdentifier: Seq[String]) = {
val tableIdent = processTableIdentifier(tableIdentifier)
tables -= getDbTableName(tableIdent)
}

override def unregisterAllTables() = {
tables.clear()
}

override def tableExists(db: Option[String], tableName: String): Boolean = {
val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
tables.get(tblName) match {
override def tableExists(tableIdentifier: Seq[String]): Boolean = {
val tableIdent = processTableIdentifier(tableIdentifier)
tables.get(getDbTableName(tableIdent)) match {
case Some(_) => true
case None => false
}
}

override def lookupRelation(
databaseName: Option[String],
tableName: String,
tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan = {
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
val table = tables.getOrElse(tblName, sys.error(s"Table Not Found: $tableName"))
val tableWithQualifiers = Subquery(tblName, table)
val tableIdent = processTableIdentifier(tableIdentifier)
val tableFullName = getDbTableName(tableIdent)
val table = tables.getOrElse(tableFullName, sys.error(s"Table Not Found: $tableFullName"))
val tableWithQualifiers = Subquery(tableIdent.last, table)

// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
// properly qualified with this alias.
Expand All @@ -117,41 +114,39 @@ trait OverrideCatalog extends Catalog {
// TODO: This doesn't work when the database changes...
val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]()

abstract override def tableExists(db: Option[String], tableName: String): Boolean = {
val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
overrides.get((dbName, tblName)) match {
abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = {
val tableIdent = processTableIdentifier(tableIdentifier)
overrides.get(getDBTable(tableIdent)) match {
case Some(_) => true
case None => super.tableExists(db, tableName)
case None => super.tableExists(tableIdentifier)
}
}

abstract override def lookupRelation(
databaseName: Option[String],
tableName: String,
tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan = {
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
val overriddenTable = overrides.get((dbName, tblName))
val tableWithQualifers = overriddenTable.map(r => Subquery(tblName, r))
val tableIdent = processTableIdentifier(tableIdentifier)
val overriddenTable = overrides.get(getDBTable(tableIdent))
val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r))

// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
// properly qualified with this alias.
val withAlias =
tableWithQualifers.map(r => alias.map(a => Subquery(a, r)).getOrElse(r))

withAlias.getOrElse(super.lookupRelation(dbName, tblName, alias))
withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias))
}

override def registerTable(
databaseName: Option[String],
tableName: String,
tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
overrides.put((dbName, tblName), plan)
val tableIdent = processTableIdentifier(tableIdentifier)
overrides.put(getDBTable(tableIdent), plan)
}

override def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
overrides.remove((dbName, tblName))
override def unregisterTable(tableIdentifier: Seq[String]): Unit = {
val tableIdent = processTableIdentifier(tableIdentifier)
overrides.remove(getDBTable(tableIdent))
}

override def unregisterAllTables(): Unit = {
Expand All @@ -167,22 +162,21 @@ object EmptyCatalog extends Catalog {

val caseSensitive: Boolean = true

def tableExists(db: Option[String], tableName: String): Boolean = {
def tableExists(tableIdentifier: Seq[String]): Boolean = {
throw new UnsupportedOperationException
}

def lookupRelation(
databaseName: Option[String],
tableName: String,
tableIdentifier: Seq[String],
alias: Option[String] = None) = {
throw new UnsupportedOperationException
}

def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = {
def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
throw new UnsupportedOperationException
}

def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
def unregisterTable(tableIdentifier: Seq[String]): Unit = {
throw new UnsupportedOperationException
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str
* Holds the name of a relation that has yet to be looked up in a [[Catalog]].
*/
case class UnresolvedRelation(
databaseName: Option[String],
tableName: String,
tableIdentifier: Seq[String],
alias: Option[String] = None) extends LeafNode {
override def output = Nil
override lazy val resolved = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ package object dsl {

def insertInto(tableName: String, overwrite: Boolean = false) =
InsertIntoTable(
analysis.UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite)
analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite)

def analyze = analysis.SimpleAnalyzer(logicalPlan)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("e", ShortType)())

before {
caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation)
caseInsensitiveCatalog.registerTable(None, "TaBlE", testRelation)
caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
}

test("union project *") {
Expand All @@ -64,45 +64,45 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(
caseSensitiveAnalyze(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))

val e = intercept[TreeNodeException[_]] {
caseSensitiveAnalyze(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(None, "TaBlE", Some("TbL"))))
UnresolvedRelation(Seq("TaBlE"), Some("TbL"))))
}
assert(e.getMessage().toLowerCase.contains("unresolved"))

assert(
caseInsensitiveAnalyze(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))

assert(
caseInsensitiveAnalyze(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
}

test("resolve relations") {
val e = intercept[RuntimeException] {
caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None))
caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None))
}
assert(e.getMessage == "Table Not Found: tAbLe")

assert(
caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
testRelation)

assert(
caseInsensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) ===
caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) ===
testRelation)

assert(
caseInsensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
testRelation)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
val f: Expression = UnresolvedAttribute("f")

before {
catalog.registerTable(None, "table", relation)
catalog.registerTable(Seq("table"), relation)
}

private def checkType(expression: Expression, expectedType: DataType): Unit = {
Expand Down
6 changes: 3 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
catalog.registerTable(None, tableName, rdd.queryExecution.logical)
catalog.registerTable(Seq(tableName), rdd.queryExecution.logical)
}

/**
Expand All @@ -289,7 +289,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
def dropTempTable(tableName: String): Unit = {
tryUncacheQuery(table(tableName))
catalog.unregisterTable(None, tableName)
catalog.unregisterTable(Seq(tableName))
}

/**
Expand All @@ -308,7 +308,7 @@ class SQLContext(@transient val sparkContext: SparkContext)

/** Returns the specified table as a SchemaRDD */
def table(tableName: String): SchemaRDD =
new SchemaRDD(this, catalog.lookupRelation(None, tableName))
new SchemaRDD(this, catalog.lookupRelation(Seq(tableName)))

/**
* :: DeveloperApi ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ private[sql] trait SchemaRDDLike {
*/
@Experimental
def insertInto(tableName: String, overwrite: Boolean): Unit =
sqlContext.executePlan(
InsertIntoTable(UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite)).toRdd
sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
Map.empty, logicalPlan, overwrite)).toRdd

/**
* :: Experimental ::
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
upperCaseData.where('N <= 4).registerTempTable("left")
upperCaseData.where('N >= 3).registerTempTable("right")

val left = UnresolvedRelation(None, "left", None)
val right = UnresolvedRelation(None, "right", None)
val left = UnresolvedRelation(Seq("left"), None)
val right = UnresolvedRelation(Seq("right"), None)

checkAnswer(
left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
* in the Hive metastore.
*/
def analyze(tableName: String) {
val relation = EliminateAnalysisOperators(catalog.lookupRelation(None, tableName))
val relation = EliminateAnalysisOperators(catalog.lookupRelation(Seq(tableName)))

relation match {
case relation: MetastoreRelation =>
Expand Down
Loading