Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.util.StringKeyHashMap


/** A catalog for looking up user defined functions, used by an [[Analyzer]]. */
/**
* A catalog for looking up user defined functions, used by an [[Analyzer]].
*
* Note: The implementation should be thread-safe to allow concurrent access.
*/
trait FunctionRegistry {

final def registerFunction(name: String, builder: FunctionBuilder): Unit = {
Expand Down Expand Up @@ -62,7 +66,7 @@ trait FunctionRegistry {

class SimpleFunctionRegistry extends FunctionRegistry {

private[sql] val functionBuilders =
protected val functionBuilders =
StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false)

override def registerFunction(
Expand Down Expand Up @@ -97,7 +101,7 @@ class SimpleFunctionRegistry extends FunctionRegistry {
functionBuilders.remove(name).isDefined
}

override def clear(): Unit = {
override def clear(): Unit = synchronized {
functionBuilders.clear()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ class InMemoryCatalog extends ExternalCatalog {
catalog(db).functions(funcName)
}

override def functionExists(db: String, funcName: String): Boolean = {
override def functionExists(db: String, funcName: String): Boolean = synchronized {
requireDbExists(db)
catalog(db).functions.contains(funcName)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.catalog

import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable

import org.apache.hadoop.conf.Configuration
Expand All @@ -37,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.StringUtils
* proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary
* tables and functions of the Spark Session that it belongs to.
*
* This class is not thread-safe.
* This class must be thread-safe.
*/
class SessionCatalog(
externalCatalog: ExternalCatalog,
Expand Down Expand Up @@ -66,12 +68,14 @@ class SessionCatalog(
}

/** List of temporary tables, mapping from table name to their logical plan. */
@GuardedBy("this")
protected val tempTables = new mutable.HashMap[String, LogicalPlan]

// Note: we track current database here because certain operations do not explicitly
// specify the database (e.g. DROP TABLE my_table). In these cases we must first
// check whether the temporary table or function exists, then, if not, operate on
// the corresponding item in the current database.
@GuardedBy("this")
protected var currentDb = {
val defaultName = "default"
val defaultDbDefinition =
Expand Down Expand Up @@ -137,13 +141,13 @@ class SessionCatalog(
externalCatalog.listDatabases(pattern)
}

def getCurrentDatabase: String = currentDb
def getCurrentDatabase: String = synchronized { currentDb }

def setCurrentDatabase(db: String): Unit = {
if (!databaseExists(db)) {
throw new AnalysisException(s"Database '$db' does not exist.")
}
currentDb = db
synchronized { currentDb = db }
}

def getDefaultDBPath(db: String): String = {
Expand All @@ -169,7 +173,7 @@ class SessionCatalog(
* If no such database is specified, create it in the current database.
*/
def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
val db = tableDefinition.identifier.database.getOrElse(currentDb)
val db = tableDefinition.identifier.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableDefinition.identifier.table)
val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
externalCatalog.createTable(db, newTableDefinition, ignoreIfExists)
Expand All @@ -185,7 +189,7 @@ class SessionCatalog(
* this becomes a no-op.
*/
def alterTable(tableDefinition: CatalogTable): Unit = {
val db = tableDefinition.identifier.database.getOrElse(currentDb)
val db = tableDefinition.identifier.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableDefinition.identifier.table)
val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
externalCatalog.alterTable(db, newTableDefinition)
Expand All @@ -197,7 +201,7 @@ class SessionCatalog(
* If the specified table is not found in the database then an [[AnalysisException]] is thrown.
*/
def getTableMetadata(name: TableIdentifier): CatalogTable = {
val db = name.database.getOrElse(currentDb)
val db = name.database.getOrElse(getCurrentDatabase)
val table = formatTableName(name.table)
externalCatalog.getTable(db, table)
}
Expand All @@ -208,7 +212,7 @@ class SessionCatalog(
* If the specified table is not found in the database then return None if it doesn't exist.
*/
def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = {
val db = name.database.getOrElse(currentDb)
val db = name.database.getOrElse(getCurrentDatabase)
val table = formatTableName(name.table)
externalCatalog.getTableOption(db, table)
}
Expand All @@ -223,7 +227,7 @@ class SessionCatalog(
loadPath: String,
isOverwrite: Boolean,
holdDDLTime: Boolean): Unit = {
val db = name.database.getOrElse(currentDb)
val db = name.database.getOrElse(getCurrentDatabase)
val table = formatTableName(name.table)
externalCatalog.loadTable(db, table, loadPath, isOverwrite, holdDDLTime)
}
Expand All @@ -241,14 +245,14 @@ class SessionCatalog(
holdDDLTime: Boolean,
inheritTableSpecs: Boolean,
isSkewedStoreAsSubdir: Boolean): Unit = {
val db = name.database.getOrElse(currentDb)
val db = name.database.getOrElse(getCurrentDatabase)
val table = formatTableName(name.table)
externalCatalog.loadPartition(db, table, loadPath, partition, isOverwrite, holdDDLTime,
inheritTableSpecs, isSkewedStoreAsSubdir)
}

def defaultTablePath(tableIdent: TableIdentifier): String = {
val dbName = tableIdent.database.getOrElse(currentDb)
val dbName = tableIdent.database.getOrElse(getCurrentDatabase)
val dbLocation = getDatabaseMetadata(dbName).locationUri

new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString
Expand All @@ -264,7 +268,7 @@ class SessionCatalog(
def createTempTable(
name: String,
tableDefinition: LogicalPlan,
overrideIfExists: Boolean): Unit = {
overrideIfExists: Boolean): Unit = synchronized {
val table = formatTableName(name)
if (tempTables.contains(table) && !overrideIfExists) {
throw new AnalysisException(s"Temporary table '$name' already exists.")
Expand All @@ -281,7 +285,7 @@ class SessionCatalog(
*
* This assumes the database specified in `oldName` matches the one specified in `newName`.
*/
def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = {
def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = synchronized {
val db = oldName.database.getOrElse(currentDb)
val newDb = newName.database.getOrElse(currentDb)
if (db != newDb) {
Expand All @@ -306,7 +310,7 @@ class SessionCatalog(
* If no database is specified, this will first attempt to drop a temporary table with
* the same name, then, if that does not exist, drop the table from the current database.
*/
def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = {
def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = synchronized {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
if (name.database.isDefined || !tempTables.contains(table)) {
Expand All @@ -330,19 +334,21 @@ class SessionCatalog(
* the same name, then, if that does not exist, return the table from the current database.
*/
def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
val relation =
if (name.database.isDefined || !tempTables.contains(table)) {
val metadata = externalCatalog.getTable(db, table)
SimpleCatalogRelation(db, metadata, alias)
} else {
tempTables(table)
}
val qualifiedTable = SubqueryAlias(table, relation)
// If an alias was specified by the lookup, wrap the plan in a subquery so that
// attributes are properly qualified with this alias.
alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)
synchronized {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
val relation =
if (name.database.isDefined || !tempTables.contains(table)) {
val metadata = externalCatalog.getTable(db, table)
SimpleCatalogRelation(db, metadata, alias)
} else {
tempTables(table)
}
val qualifiedTable = SubqueryAlias(table, relation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: look like this call and alias.map() can be outside the synchronized block.

// If an alias was specified by the lookup, wrap the plan in a subquery so that
// attributes are properly qualified with this alias.
alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)
}
}

/**
Expand All @@ -353,7 +359,7 @@ class SessionCatalog(
* table with the same name, we will return false if the specified database does not
* contain the table.
*/
def tableExists(name: TableIdentifier): Boolean = {
def tableExists(name: TableIdentifier): Boolean = synchronized {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
if (name.database.isDefined || !tempTables.contains(table)) {
Expand All @@ -369,7 +375,7 @@ class SessionCatalog(
* Note: The temporary table cache is checked only when database is not
* explicitly specified.
*/
def isTemporaryTable(name: TableIdentifier): Boolean = {
def isTemporaryTable(name: TableIdentifier): Boolean = synchronized {
name.database.isEmpty && tempTables.contains(formatTableName(name.table))
}

Expand All @@ -384,9 +390,11 @@ class SessionCatalog(
def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
val dbTables =
externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) }
val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
.map { t => TableIdentifier(t) }
dbTables ++ _tempTables
synchronized {
val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
.map { t => TableIdentifier(t) }
dbTables ++ _tempTables
}
}

// TODO: It's strange that we have both refresh and invalidate here.
Expand All @@ -405,15 +413,15 @@ class SessionCatalog(
* Drop all existing temporary tables.
* For testing only.
*/
def clearTempTables(): Unit = {
def clearTempTables(): Unit = synchronized {
tempTables.clear()
}

/**
* Return a temporary table exactly as it was stored.
* For testing only.
*/
private[catalog] def getTempTable(name: String): Option[LogicalPlan] = {
private[catalog] def getTempTable(name: String): Option[LogicalPlan] = synchronized {
tempTables.get(name)
}

Expand All @@ -437,7 +445,7 @@ class SessionCatalog(
tableName: TableIdentifier,
parts: Seq[CatalogTablePartition],
ignoreIfExists: Boolean): Unit = {
val db = tableName.database.getOrElse(currentDb)
val db = tableName.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableName.table)
externalCatalog.createPartitions(db, table, parts, ignoreIfExists)
}
Expand All @@ -450,7 +458,7 @@ class SessionCatalog(
tableName: TableIdentifier,
parts: Seq[TablePartitionSpec],
ignoreIfNotExists: Boolean): Unit = {
val db = tableName.database.getOrElse(currentDb)
val db = tableName.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableName.table)
externalCatalog.dropPartitions(db, table, parts, ignoreIfNotExists)
}
Expand All @@ -465,7 +473,7 @@ class SessionCatalog(
tableName: TableIdentifier,
specs: Seq[TablePartitionSpec],
newSpecs: Seq[TablePartitionSpec]): Unit = {
val db = tableName.database.getOrElse(currentDb)
val db = tableName.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableName.table)
externalCatalog.renamePartitions(db, table, specs, newSpecs)
}
Expand All @@ -480,7 +488,7 @@ class SessionCatalog(
* this becomes a no-op.
*/
def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = {
val db = tableName.database.getOrElse(currentDb)
val db = tableName.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableName.table)
externalCatalog.alterPartitions(db, table, parts)
}
Expand All @@ -490,7 +498,7 @@ class SessionCatalog(
* If no database is specified, assume the table is in the current database.
*/
def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = {
val db = tableName.database.getOrElse(currentDb)
val db = tableName.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableName.table)
externalCatalog.getPartition(db, table, spec)
}
Expand All @@ -505,7 +513,7 @@ class SessionCatalog(
def listPartitions(
tableName: TableIdentifier,
partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = {
val db = tableName.database.getOrElse(currentDb)
val db = tableName.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableName.table)
externalCatalog.listPartitions(db, table, partialSpec)
}
Expand All @@ -528,7 +536,7 @@ class SessionCatalog(
* If no such database is specified, create it in the current database.
*/
def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
val db = funcDefinition.identifier.database.getOrElse(currentDb)
val db = funcDefinition.identifier.database.getOrElse(getCurrentDatabase)
val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))
val newFuncDefinition = funcDefinition.copy(identifier = identifier)
if (!functionExists(identifier)) {
Expand All @@ -543,7 +551,7 @@ class SessionCatalog(
* If no database is specified, assume the function is in the current database.
*/
def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = {
val db = name.database.getOrElse(currentDb)
val db = name.database.getOrElse(getCurrentDatabase)
val identifier = name.copy(database = Some(db))
if (functionExists(identifier)) {
// TODO: registry should just take in FunctionIdentifier for type safety
Expand All @@ -567,15 +575,15 @@ class SessionCatalog(
* If no database is specified, this will return the function in the current database.
*/
def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = {
val db = name.database.getOrElse(currentDb)
val db = name.database.getOrElse(getCurrentDatabase)
externalCatalog.getFunction(db, name.funcName)
}

/**
* Check if the specified function exists.
*/
def functionExists(name: FunctionIdentifier): Boolean = {
val db = name.database.getOrElse(currentDb)
val db = name.database.getOrElse(getCurrentDatabase)
functionRegistry.functionExists(name.unquotedString) ||
externalCatalog.functionExists(db, name.funcName)
}
Expand Down Expand Up @@ -640,7 +648,7 @@ class SessionCatalog(
/**
* Look up the [[ExpressionInfo]] associated with the specified function, assuming it exists.
*/
private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = {
private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized {
// TODO: just make function registry take in FunctionIdentifier instead of duplicating this
val qualifiedName = name.copy(database = name.database.orElse(Some(currentDb)))
functionRegistry.lookupFunction(name.funcName)
Expand Down Expand Up @@ -669,7 +677,9 @@ class SessionCatalog(
* based on the function class and put the builder into the FunctionRegistry.
* The name of this function in the FunctionRegistry will be `databaseName.functionName`.
*/
def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
def lookupFunction(
name: FunctionIdentifier,
children: Seq[Expression]): Expression = synchronized {
// Note: the implementation of this function is a little bit convoluted.
// We probably shouldn't use a single FunctionRegistry to register all three kinds of functions
// (built-in, temp, and external).
Expand Down Expand Up @@ -737,7 +747,7 @@ class SessionCatalog(
*
* This is mainly used for tests.
*/
private[sql] def reset(): Unit = {
private[sql] def reset(): Unit = synchronized {
val default = "default"
listDatabases().filter(_ != default).foreach { db =>
dropDatabase(db, ignoreIfNotExists = false, cascade = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ class ExperimentalMethods private[sql]() {
* @since 1.3.0
*/
@Experimental
var extraStrategies: Seq[Strategy] = Nil
@volatile var extraStrategies: Seq[Strategy] = Nil

@Experimental
var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil
@volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil

}
Loading