Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ import org.apache.spark.sql.AnalysisException
* All public methods should be synchronized for thread-safety.
*/
class InMemoryCatalog extends Catalog {
import Catalog._

private class TableDesc(var table: Table) {
val partitions = new mutable.HashMap[String, TablePartition]
val partitions = new mutable.HashMap[PartitionSpec, TablePartition]
}

private class DatabaseDesc(var db: Database) {
Expand All @@ -46,40 +47,53 @@ class InMemoryCatalog extends Catalog {
}

private def existsFunction(db: String, funcName: String): Boolean = {
assertDbExists(db)
catalog(db).functions.contains(funcName)
}

private def existsTable(db: String, table: String): Boolean = {
assertDbExists(db)
catalog(db).tables.contains(table)
}

private def existsPartition(db: String, table: String, spec: PartitionSpec): Boolean = {
assertTableExists(db, table)
catalog(db).tables(table).partitions.contains(spec)
}

private def assertDbExists(db: String): Unit = {
if (!catalog.contains(db)) {
throw new AnalysisException(s"Database $db does not exist")
}
}

private def assertFunctionExists(db: String, funcName: String): Unit = {
assertDbExists(db)
if (!existsFunction(db, funcName)) {
throw new AnalysisException(s"Function $funcName does not exists in $db database")
throw new AnalysisException(s"Function $funcName does not exist in $db database")
}
}

private def assertTableExists(db: String, table: String): Unit = {
assertDbExists(db)
if (!existsTable(db, table)) {
throw new AnalysisException(s"Table $table does not exists in $db database")
throw new AnalysisException(s"Table $table does not exist in $db database")
}
}

private def assertPartitionExists(db: String, table: String, spec: PartitionSpec): Unit = {
if (!existsPartition(db, table, spec)) {
throw new AnalysisException(s"Partition does not exist in database $db table $table: $spec")
}
}

// --------------------------------------------------------------------------
// Databases
// --------------------------------------------------------------------------

override def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit = synchronized {
override def createDatabase(
dbDefinition: Database,
ignoreIfExists: Boolean): Unit = synchronized {
if (catalog.contains(dbDefinition.name)) {
if (!ifNotExists) {
if (!ignoreIfExists) {
throw new AnalysisException(s"Database ${dbDefinition.name} already exists.")
}
} else {
Expand All @@ -88,9 +102,9 @@ class InMemoryCatalog extends Catalog {
}

override def dropDatabase(
db: String,
ignoreIfNotExists: Boolean,
cascade: Boolean): Unit = synchronized {
db: String,
ignoreIfNotExists: Boolean,
cascade: Boolean): Unit = synchronized {
if (catalog.contains(db)) {
if (!cascade) {
// If cascade is false, make sure the database is empty.
Expand Down Expand Up @@ -133,20 +147,24 @@ class InMemoryCatalog extends Catalog {
// Tables
// --------------------------------------------------------------------------

override def createTable(db: String, tableDefinition: Table, ifNotExists: Boolean)
: Unit = synchronized {
override def createTable(
db: String,
tableDefinition: Table,
ignoreIfExists: Boolean): Unit = synchronized {
assertDbExists(db)
if (existsTable(db, tableDefinition.name)) {
if (!ifNotExists) {
if (!ignoreIfExists) {
throw new AnalysisException(s"Table ${tableDefinition.name} already exists in $db database")
}
} else {
catalog(db).tables.put(tableDefinition.name, new TableDesc(tableDefinition))
}
}

override def dropTable(db: String, table: String, ignoreIfNotExists: Boolean)
: Unit = synchronized {
override def dropTable(
db: String,
table: String,
ignoreIfNotExists: Boolean): Unit = synchronized {
assertDbExists(db)
if (existsTable(db, table)) {
catalog(db).tables.remove(table)
Expand Down Expand Up @@ -190,26 +208,80 @@ class InMemoryCatalog extends Catalog {
// Partitions
// --------------------------------------------------------------------------

override def alterPartition(db: String, table: String, part: TablePartition)
: Unit = synchronized {
throw new UnsupportedOperationException
override def createPartitions(
db: String,
table: String,
parts: Seq[TablePartition],
ignoreIfExists: Boolean): Unit = synchronized {
assertTableExists(db, table)
val existingParts = catalog(db).tables(table).partitions
if (!ignoreIfExists) {
val dupSpecs = parts.collect { case p if existingParts.contains(p.spec) => p.spec }
if (dupSpecs.nonEmpty) {
val dupSpecsStr = dupSpecs.mkString("\n===\n")
throw new AnalysisException(
s"The following partitions already exist in database $db table $table:\n$dupSpecsStr")
}
}
parts.foreach { p => existingParts.put(p.spec, p) }
}

override def dropPartitions(
db: String,
table: String,
partSpecs: Seq[PartitionSpec],
ignoreIfNotExists: Boolean): Unit = synchronized {
assertTableExists(db, table)
val existingParts = catalog(db).tables(table).partitions
if (!ignoreIfNotExists) {
val missingSpecs = partSpecs.collect { case s if !existingParts.contains(s) => s }
if (missingSpecs.nonEmpty) {
val missingSpecsStr = missingSpecs.mkString("\n===\n")
throw new AnalysisException(
s"The following partitions do not exist in database $db table $table:\n$missingSpecsStr")
}
}
partSpecs.foreach(existingParts.remove)
}

override def alterPartitions(db: String, table: String, parts: Seq[TablePartition])
: Unit = synchronized {
throw new UnsupportedOperationException
override def alterPartition(
db: String,
table: String,
spec: Map[String, String],
newPart: TablePartition): Unit = synchronized {
assertPartitionExists(db, table, spec)
val existingParts = catalog(db).tables(table).partitions
if (spec != newPart.spec) {
// Also a change in specs; remove the old one and add the new one back
existingParts.remove(spec)
}
existingParts.put(newPart.spec, newPart)
}

override def getPartition(
db: String,
table: String,
spec: Map[String, String]): TablePartition = synchronized {
assertPartitionExists(db, table, spec)
catalog(db).tables(table).partitions(spec)
}

override def listPartitions(db: String, table: String): Seq[TablePartition] = synchronized {
assertTableExists(db, table)
catalog(db).tables(table).partitions.values.toSeq
}

// --------------------------------------------------------------------------
// Functions
// --------------------------------------------------------------------------

override def createFunction(
db: String, func: Function, ifNotExists: Boolean): Unit = synchronized {
db: String,
func: Function,
ignoreIfExists: Boolean): Unit = synchronized {
assertDbExists(db)

if (existsFunction(db, func.name)) {
if (!ifNotExists) {
if (!ignoreIfExists) {
throw new AnalysisException(s"Function $func already exists in $db database")
}
} else {
Expand All @@ -222,14 +294,16 @@ class InMemoryCatalog extends Catalog {
catalog(db).functions.remove(funcName)
}

override def alterFunction(db: String, funcName: String, funcDefinition: Function)
: Unit = synchronized {
override def alterFunction(
db: String,
funcName: String,
funcDefinition: Function): Unit = synchronized {
assertFunctionExists(db, funcName)
if (funcName != funcDefinition.name) {
// Also a rename; remove the old one and add the new one back
catalog(db).functions.remove(funcName)
}
catalog(db).functions.put(funcName, funcDefinition)
catalog(db).functions.put(funcDefinition.name, funcDefinition)
}

override def getFunction(db: String, funcName: String): Function = synchronized {
Expand All @@ -239,7 +313,6 @@ class InMemoryCatalog extends Catalog {

override def listFunctions(db: String, pattern: String): Seq[String] = synchronized {
assertDbExists(db)
val regex = pattern.replaceAll("\\*", ".*").r
filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,15 @@ import org.apache.spark.sql.AnalysisException
* Implementations should throw [[AnalysisException]] when table or database don't exist.
*/
abstract class Catalog {
import Catalog._

// --------------------------------------------------------------------------
// Databases
// --------------------------------------------------------------------------

def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit
def createDatabase(dbDefinition: Database, ignoreIfExists: Boolean): Unit

def dropDatabase(
db: String,
ignoreIfNotExists: Boolean,
cascade: Boolean): Unit
def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit

def alterDatabase(db: String, dbDefinition: Database): Unit

Expand Down Expand Up @@ -71,11 +69,28 @@ abstract class Catalog {
// Partitions
// --------------------------------------------------------------------------

// TODO: need more functions for partitioning.
def createPartitions(
db: String,
table: String,
parts: Seq[TablePartition],
ignoreIfExists: Boolean): Unit

def alterPartition(db: String, table: String, part: TablePartition): Unit
def dropPartitions(
db: String,
table: String,
parts: Seq[PartitionSpec],
ignoreIfNotExists: Boolean): Unit

def alterPartitions(db: String, table: String, parts: Seq[TablePartition]): Unit
def alterPartition(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some documentation saying this also allows changing the partition spec.

and let's update alterDatabase / alterTable to say we do not allow renaming there.

db: String,
table: String,
spec: PartitionSpec,
newPart: TablePartition): Unit

def getPartition(db: String, table: String, spec: PartitionSpec): TablePartition

// TODO: support listing by pattern
def listPartitions(db: String, table: String): Seq[TablePartition]

// --------------------------------------------------------------------------
// Functions
Expand Down Expand Up @@ -132,11 +147,11 @@ case class Column(
/**
* A partition (Hive style) defined in the catalog.
*
* @param values values for the partition columns
* @param spec partition spec values indexed by column name
* @param storage storage format of the partition
*/
case class TablePartition(
values: Seq[String],
spec: Catalog.PartitionSpec,
storage: StorageFormat
)

Expand Down Expand Up @@ -176,3 +191,8 @@ case class Database(
locationUri: String,
properties: Map[String, String]
)


object Catalog {
type PartitionSpec = Map[String, String]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to document this is mapping from column names to values.

}
Loading